Skip to content
Open
Show file tree
Hide file tree
Changes from 25 commits
Commits
Show all changes
35 commits
Select commit Hold shift + click to select a range
d3b84cb
varlen maba
oleksost Aug 13, 2025
79a4565
requirement
oleksost Aug 13, 2025
1657a1b
docker
oleksost Aug 13, 2025
2b171eb
test varlen mamba
oleksost Aug 15, 2025
115c1ec
wip
oleksost Aug 19, 2025
37d3be8
cleanup
oleksost Aug 19, 2025
1b20268
Merge branch 'mamba_varlen' into tp_mamba2
oleksost Aug 19, 2025
35c6f20
wip
oleksost Aug 20, 2025
17f86fd
wip
oleksost Aug 20, 2025
adb0666
wip
oleksost Aug 20, 2025
bc25e74
mamba2 nemotron h tp
oleksost Aug 21, 2025
7c5fb0a
modeling
oleksost Aug 22, 2025
9cef978
convertion + MIL init
oleksost Aug 25, 2025
662e9ef
convertion
oleksost Aug 25, 2025
f78055c
undo requirement varlen for m2 testing
oleksost Aug 25, 2025
eb8a54e
varlen
oleksost Aug 25, 2025
33281d5
wip
oleksost Aug 25, 2025
2a5d0f9
rms norm
oleksost Aug 26, 2025
7a047b4
clean up
oleksost Aug 26, 2025
7a09387
TP RMS norm
oleksost Sep 16, 2025
03a7ac2
TP RMS norm
oleksost Sep 16, 2025
bd85e85
Merge branch 'hybrid_dev' into tp_mamba2
oleksost Sep 16, 2025
a3cb3e0
nvm
oleksost Sep 16, 2025
826f2f0
nvm
oleksost Sep 17, 2025
c9c412e
nvm
oleksost Sep 17, 2025
33e9597
wip
oleksost Sep 18, 2025
7f3bfe9
modelling mamba2
oleksost Sep 22, 2025
bad4c3b
wip
oleksost Sep 22, 2025
fd617c8
mamba2 with rms norm not per head
oleksost Sep 22, 2025
799ec67
per head norm
oleksost Sep 22, 2025
85afd22
per head norm
oleksost Sep 22, 2025
157ce73
multihead norm
oleksost Sep 22, 2025
dfb75ae
norm per layer
oleksost Sep 23, 2025
4003f37
nvm
oleksost Sep 23, 2025
70a04e3
clean
oleksost Sep 23, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 10 additions & 0 deletions fast_llm/layers/common/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,16 @@ class NormalizationImplementation(str, enum.Enum):
triton = "triton"


class TPRMSNormImplementation(str, enum.Enum):
"""
An enum for the available implementations of rms norm.
"""

fused_redtensor = "fused_redtensor"
autograd_redstats = "autograd_redstats"
torch_comp_redstats = "torch_comp_redstats"


@config_class(registry=True)
class NormalizationConfig(BaseModelConfig):
pass
Expand Down
121 changes: 120 additions & 1 deletion fast_llm/layers/common/normalization.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,12 @@
import torch
import torch.distributed as dist
from torch.distributed import ProcessGroup, ReduceOp, all_reduce # noqa

from fast_llm.engine.config_utils.run import log_main_rank
from fast_llm.engine.config_utils.tensor_space import TensorDim
from fast_llm.functional.config import TritonConfig
from fast_llm.functional.triton.normalization import triton_normalization_autograd
from fast_llm.layers.common.config import NormalizationImplementation
from fast_llm.layers.common.config import NormalizationImplementation, TPRMSNormImplementation
from fast_llm.tensor import ParameterMeta, accumulate_gradient, init_ones_, init_zeros_
from fast_llm.utils import Assert

Expand Down Expand Up @@ -243,6 +245,7 @@ def __init__(
):
super().__init__()
assert not hidden_dim.is_parallel

self._eps = eps
self._zero_centered = zero_centered
if implementation == NormalizationImplementation.auto:
Expand Down Expand Up @@ -288,3 +291,119 @@ def _forward_fused(self, input_: torch.Tensor) -> torch.Tensor:

def _forward_torch(self, input_: torch.Tensor) -> torch.Tensor:
return torch.rms_norm(input_.to(self.weight.dtype), self.normalized_shape, self.weight, self._eps)


class InputParallelGatedRMSNorm(torch.nn.Module):
def __init__(
self,
hidden_dim: TensorDim,
*,
eps=1e-5,
implementation: NormalizationImplementation = TPRMSNormImplementation.autograd_redstats,
weight_init_method=None,
zero_centered: bool = False,
lr_scale: float | None = None,
norm_before_gate: bool = True,
):
super().__init__()
self.group = hidden_dim.parallel_group
self.n_groups = hidden_dim.parallel_dim.size
self._norm_before_gate = norm_before_gate
self.hidden_dim_global = hidden_dim.parallel_dim.size * hidden_dim.size

self._eps = eps
self._zero_centered = zero_centered

if weight_init_method is None:
weight_init_method = init_zeros_ if self._zero_centered else init_ones_
if implementation == TPRMSNormImplementation.fused_redtensor:
raise NotImplementedError("Fused red tensor implementation is not implemented yet.")
self._forward = self._forward_fused_red_tensor
elif implementation == TPRMSNormImplementation.autograd_redstats:
self._forward = self._forward_distributed
elif implementation == TPRMSNormImplementation.torch_comp_redstats:
self._forward = self._forward_tc_distributed
else:
raise NotImplementedError(implementation)

self.weight = ParameterMeta.from_dims( # local weights
(hidden_dim,),
init_method=weight_init_method,
weight_decay=False,
auto_grad_accumulation=True,
lr_scale=lr_scale,
)
self.normalized_shape = self.weight.shape

def forward(self, input_: torch.Tensor) -> torch.Tensor:
return self._forward(input_)

def _forward_fused_red_tensor(self, input_: torch.Tensor) -> torch.Tensor:
return _TPRMSNormFnRedTensor.apply(input_, self.normalized_shape, self.weight, self._eps)

def _forward_distributed(self, input_: torch.Tensor) -> torch.Tensor:
return _TPRMSNormFn.apply(input_, self.weight, self._eps, self.group, self.hidden_dim_global)

def _forward_tc_distributed(self, input_: torch.Tensor) -> torch.Tensor:
return rms_norm_distributed(input_, self.weight, self._eps, self.group, self.hidden_dim_global)


class _TPRMSNormFn(torch.autograd.Function):
@staticmethod
def forward(ctx, x, w, eps: float, group, H_global: int):
x_dtype = x.dtype
x = x.float().contiguous()
w = w.float()

# global stats
ss_local = x.square().sum(dim=-1, keepdim=True) # [S,B,1]
# ss_global = ss_local.sum(dim=0)
ss_global = torch.distributed.nn.functional.all_reduce(ss_local, op=dist.ReduceOp.SUM, group=group)
inv_rms = torch.rsqrt(ss_global / float(H_global) + eps) # [S,B,1]

y = (x * inv_rms) * w # [S,B,H_local

# Save minimal stuff for backward
ctx.save_for_backward(x, w, inv_rms)
ctx.group = group
ctx.H = H_global
return y.to(dtype=x_dtype)

@staticmethod
def backward(ctx, gy):
x, w, inv_rms = ctx.saved_tensors
group, H = ctx.group, ctx.H

gy = gy.float()
x = x.float()
w = w.float()
inv_rms = inv_rms.float()

gy_pre = gy

# RMSNorm backward for y_pre = (x_g * inv_rms) * w
# Note: when gate-before, x_g = x * gate
x_eff = x
gw = gy_pre * w # [B,S,H_local]
local_dot = (gw * x_eff).sum(dim=-1, keepdim=True) # [B,S,1]
global_dot = torch.distributed.nn.functional.all_reduce(local_dot, op=dist.ReduceOp.SUM, group=group)
inv_rms3 = inv_rms * inv_rms * inv_rms
gx = inv_rms * gw - (inv_rms3 / float(H)) * x_eff * global_dot

gw = (gy_pre * (x_eff * inv_rms)).sum(dim=(0, 1)) # [H_local]

return gx, gw, None, None, None, None


@torch.compile
def rms_norm_distributed(x, w, eps: float, group: dist.ProcessGroup, H_global: int):
# Shapes: x [B,S,H_local], w [H_local], z None or [B,S,1]/[B,S,H_local]
x_dtype = x.dtype
x = x.float()
w = w.float()
# Tokenwise global stats
ss_local = x.square().sum(dim=-1, keepdim=True) # [B,S,1]
ss_global = torch.distributed.nn.functional.all_reduce(ss_local, op=dist.ReduceOp.SUM, group=group)
inv_rms = torch.rsqrt(ss_global / float(H_global) + eps)
y = (x * inv_rms) * w
return y.to(dtype=x_dtype)
72 changes: 69 additions & 3 deletions fast_llm/layers/ssm/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@ class SSMDimNames:
head_dim = "ssm_head_dim"
head_groups = "ssm_head_groups"
group_heads = "ssm_group_heads"
conv1d_dim = "ssm_conv1d_dim"

# Mamba 2
x_proj_dim_2 = "x_proj_dim_2" # d_xb
Expand All @@ -48,7 +49,10 @@ class SSMDimNames:
# Composite dimensions
composite_heads = "ssm_composite_heads"
composite_heads_and_head_dim = "ssm_composite_heads_and_head_dim"
composite_heads_and_head_dim_nontp = "ssm_composite_heads_and_head_dim_nontp"
composite_heads_and_state_dim = "ssm_composite_heads_and_state_dim"
composite_head_groups_and_state = "ssm_composite_head_groups_and_state"
composite_head_groups_and_head = "ssm_composite_head_groups_and_head"

# Concatenated dimensions
concatenated_convolution = "ssm_concatenated_convolution"
Expand All @@ -65,6 +69,7 @@ class SSMBlockType(enum.StrEnum):
mamba2_discrete = "m2d"
mamba2 = "m2"
transformer = "t"
nemotron_h_mamba2 = "nm2"

def get_mixer_class(self):
if self == SSMBlockType.mamba:
Expand All @@ -79,6 +84,10 @@ def get_mixer_class(self):
from fast_llm.layers.ssm.discrete_mamba2 import DiscreteMamba2

return DiscreteMamba2
elif self == SSMBlockType.nemotron_h_mamba2:
from fast_llm.layers.ssm.mamba2 import NemotronHMamba2

return NemotronHMamba2
else:
raise NotImplementedError(self)

Expand Down Expand Up @@ -226,6 +235,21 @@ class SSMConfig(LLMBlockConfig):
valid=check_field(Assert.gt, 0),
)

# Nemotron H Mamba2 (the real mamba2 actually)
# here instead of setting d_inner, we set head dim. and number of heads
# Note: we do not implement n_groups for Mamba2, because, sicne we do MiL init, we do not want to share B and C parameters accross heads.
# Instead, we mimic the GQA behaviour (x -> v, B -> k, C -> q), where x and B are shared accross heads. So this is the same as having n_groups = n_heads?
# n_groups: int = Field(
# default=8,
# desc="Number of groups for Mamba2. Allows sharing B and C parameters accross heads.",
# hint=FieldHint.architecture,
# )
head_dim: int = Field(
default=64,
desc="Head dimension for Nemotron H",
hint=FieldHint.architecture,
)

def _validate(self) -> None:
with self._set_implicit_default():
if self.activation_type is None:
Expand All @@ -243,6 +267,10 @@ def setup_tensor_space(self, tensor_space: TensorSpace, block_type: SSMBlockType
elif block_type == SSMBlockType.mamba2:
num_heads = div(self.d_inner, self.state_size)
num_head_groups = div(self.d_xb, self.state_size)
elif block_type == SSMBlockType.nemotron_h_mamba2:
# head dim and state size are not the same
num_heads = div(self.d_inner, self.head_dim)
num_head_groups = div(self.d_xb, self.head_dim)
elif block_type == SSMBlockType.mamba2_discrete:
# TODO: Use different variables?
num_heads = self.n_v_heads
Expand All @@ -253,6 +281,8 @@ def setup_tensor_space(self, tensor_space: TensorSpace, block_type: SSMBlockType
tensor_space.add_tensor_dim(state := TensorDim(SSMDimNames.state, self.state_size))
if block_type == SSMBlockType.mamba2_discrete:
tensor_space.add_tensor_dim(head_dim := TensorDim(SSMDimNames.head_dim, div(self.d_inner, num_heads)))
elif block_type == SSMBlockType.nemotron_h_mamba2:
tensor_space.add_tensor_dim(head_dim := TensorDim(SSMDimNames.head_dim, self.head_dim))
else:
head_dim = state

Expand All @@ -261,14 +291,16 @@ def setup_tensor_space(self, tensor_space: TensorSpace, block_type: SSMBlockType
tensor_space.add_tensor_dim(
heads := CompositeTensorDim(SSMDimNames.composite_heads, (head_groups, group_heads))
)
# full d_inner or intermediate_size (e.g. for z gate, also the d_inner size for C in mamba2)
tensor_space.add_tensor_dim(
heads_and_head_dim := CompositeTensorDim(
SSMDimNames.composite_heads_and_head_dim, (head_groups, group_heads, head_dim)
)
)
# d_xb
tensor_space.add_tensor_dim(
head_groups_and_state := CompositeTensorDim(
SSMDimNames.composite_head_groups_and_state, (head_groups, state)
head_groups_and_head := CompositeTensorDim(
SSMDimNames.composite_head_groups_and_head, (head_groups, head_dim)
)
)
tensor_space.add_tensor_dim(TensorDim(SSMDimNames.convolution_kernel, self.conv_kernel_dimension))
Expand All @@ -292,7 +324,41 @@ def setup_tensor_space(self, tensor_space: TensorSpace, block_type: SSMBlockType
tensor_space.add_tensor_dim(
ConcatenatedTensorDim(
SSMDimNames.concatenated_inner_projection,
(heads_and_head_dim, head_groups_and_state, head_groups_and_state, heads_and_head_dim),
(heads_and_head_dim, head_groups_and_head, head_groups_and_head, heads_and_head_dim),
)
)
elif block_type == SSMBlockType.nemotron_h_mamba2:
# for the norm
tensor_space.add_tensor_dim(
TensorDim(
SSMDimNames.composite_heads_and_head_dim_nontp, num_head_groups * group_heads.size * head_dim.size
)
)
# state and head dim are not the same
# C: for each head, size of state
tensor_space.add_tensor_dim(
heads_and_state_dim := CompositeTensorDim(
SSMDimNames.composite_heads_and_state_dim, (head_groups, group_heads, state)
)
)
# B: for each head group, size of state
tensor_space.add_tensor_dim(
head_groups_and_state := CompositeTensorDim(
SSMDimNames.composite_head_groups_and_state, (head_groups, state)
)
)
# here we apply depthwise conv. layer to xBC, so the dim. is x (d_xb) x B (d_bb) x C
tensor_space.add_tensor_dim(
conv1d_dim := ConcatenatedTensorDim(
SSMDimNames.conv1d_dim, (heads_and_state_dim, head_groups_and_head, head_groups_and_state)
)
)

# inner projection dimention: also includes z (gate), which has size d_inner (heads_and_head_dim)
tensor_space.add_tensor_dim(
ConcatenatedTensorDim(
SSMDimNames.concatenated_inner_projection,
(conv1d_dim, heads_and_head_dim),
)
)
elif block_type == SSMBlockType.mamba2_discrete:
Expand Down
Loading