Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions fast_llm/engine/checkpoint/huggingface.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,14 +120,14 @@ def _export_config(cls, config: FastLLMModelConfig) -> dict[str, typing.Any]:
cls.base_model_converter_class.export_config(config.base_model),
{
"model_type": cls.get_huggingface_model_type(),
"architecture": cls.architecture,
"architectures": [cls.architecture],
},
)

@classmethod
def _import_config(cls, config: dict[str, typing.Any]) -> FastLLMModelConfig:
Assert.eq(config["model_type"], cls.get_huggingface_model_type())
Assert.eq(config["architecture"], cls.architecture)
Assert.eq(config["architectures"], [cls.architecture])
return cls._model_class.from_dict({"base_model": cls.base_model_converter_class.import_config(config)})

def _create_weight_converters(self) -> list[WeightConverter]:
Expand Down
54 changes: 51 additions & 3 deletions fast_llm/functional/linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,8 +65,24 @@ def update_linear_gradients(
)
else:
accumulate_gradient(weight, torch.mm(lhs, rhs))

# Bias gradients
if bias is not None and bias.requires_grad:
accumulate_gradient(bias, grad_output.sum(dim=0))
if sparse_map is not None and bias.ndim == 2:
# For sparse maps with 2D bias: bias has shape (num_experts, out_features_per_expert)
# This is the case for manually created MoE biases (e.g., layer_2 in MoE)
# Need to sum gradients per expert
grad_bias = torch.zeros_like(bias)
for expert_idx in range(sparse_map.num_experts):
expert_begin = 0 if expert_idx == 0 else sparse_map.expert_ends[expert_idx - 1].item()
expert_pad_begin = sparse_map.expert_pad_begins[expert_idx].item()
# Sum gradients only from unpadded rows
if expert_begin < expert_pad_begin:
grad_bias[expert_idx].copy_(grad_output[expert_begin:expert_pad_begin].sum(dim=0))
accumulate_gradient(bias, grad_bias)
else:
# For 1D bias (including sparse maps where bias already has experts in flattened dim)
accumulate_gradient(bias, grad_output.sum(dim=0))


def linear_forward(
Expand Down Expand Up @@ -115,14 +131,30 @@ def output_parallel_linear_forward(

# Matmul
if TritonConfig.TRITON_LINEAR or sparse_map is not None:
assert bias is None
if sparse_map is not None:
assert not transposed_weight
output = output_sparse_matmul(
input1.flatten(0, -2),
maybe_transpose(weight, not transposed_weight),
sparse_map,
).unflatten(0, input_.shape[:-1])
# Add bias if present (for sparse maps, bias has expert dimension)
if bias is not None:
if sparse_map is not None:
# bias shape: (num_experts, out_features_per_expert)
# We need to add the correct expert's bias to each row
# sparse_map tells us which expert each row belongs to
output_flat = output.flatten(0, -2)
for expert_idx in range(sparse_map.num_experts):
expert_begin = 0 if expert_idx == 0 else sparse_map.expert_ends[expert_idx - 1].item()
expert_pad_begin = sparse_map.expert_pad_begins[expert_idx].item()
# Add bias only to unpadded rows
if expert_begin < expert_pad_begin:
output_flat[expert_begin:expert_pad_begin] += bias[expert_idx]
output = output_flat.unflatten(0, input_.shape[:-1])
else:
# Regular bias for non-sparse case
output = output + bias
else:
output = torch.nn.functional.linear(input1, maybe_transpose(weight, transposed_weight), bias)

Expand Down Expand Up @@ -179,12 +211,28 @@ def input_parallel_linear_forward(
) -> tuple[torch.Tensor, tuple[typing.Any, ...]]:
# Matmul
if TritonConfig.TRITON_LINEAR or sparse_map is not None:
assert bias is None
if sparse_map is not None:
assert transposed_weight
output = input_inner_sparse_matmul(
input_.flatten(0, -2), maybe_transpose(weight, not transposed_weight), sparse_map
).unflatten(0, input_.shape[:-1])
# Add bias if present (for sparse maps, bias has expert dimension)
if bias is not None:
if sparse_map is not None:
# bias shape: (num_experts, out_features_per_expert)
# We need to add the correct expert's bias to each row
# sparse_map tells us which expert each row belongs to
output_flat = output.flatten(0, -2)
for expert_idx in range(sparse_map.num_experts):
expert_begin = 0 if expert_idx == 0 else sparse_map.expert_ends[expert_idx - 1].item()
expert_pad_begin = sparse_map.expert_pad_begins[expert_idx].item()
# Add bias only to unpadded rows
if expert_begin < expert_pad_begin:
output_flat[expert_begin:expert_pad_begin] += bias[expert_idx]
output = output_flat.unflatten(0, input_.shape[:-1])
else:
# Regular bias for non-sparse case
output = output + bias
else:
output = torch.nn.functional.linear(input_, maybe_transpose(weight, transposed_weight), bias)

Expand Down
26 changes: 23 additions & 3 deletions fast_llm/functional/triton/mlp.py
Original file line number Diff line number Diff line change
Expand Up @@ -459,6 +459,8 @@ def mlp_autograd_looped(
sequence_parallel: bool,
training: bool = True,
recompute_level: MLPRecomputeLevel = MLPRecomputeLevel.none,
bias_1: torch.Tensor | None = None,
bias_2: torch.Tensor | None = None,
) -> torch.Tensor:
# TODO: Needed?
scores = scores.to(hidden_states.dtype)
Expand All @@ -468,17 +470,30 @@ def mlp_autograd_looped(
hidden_states, weight_1_chunked = chunk_weight(hidden_states, weight_1, num_experts)
hidden_states, weight_2_t_chunked = chunk_weight(hidden_states, weight_2, num_experts)

for expert_idx, (weight_1_chunk, weight_2_t_chunk) in enumerate(zip(weight_1_chunked, weight_2_t_chunked)):
# Chunk biases if present
if bias_1 is not None:
_, bias_1_chunked = chunk_weight(hidden_states, bias_1, num_experts)
else:
bias_1_chunked = [None] * num_experts

if bias_2 is not None:
_, bias_2_chunked = chunk_weight(hidden_states, bias_2, num_experts)
else:
bias_2_chunked = [None] * num_experts

for expert_idx, (weight_1_chunk, weight_2_t_chunk, bias_1_chunk, bias_2_chunk) in enumerate(
zip(weight_1_chunked, weight_2_t_chunked, bias_1_chunked, bias_2_chunked)
):
row, column = torch.where(expert_mask[expert_idx])
if column.size(0) > 0:
output[column] += (
mlp_autograd(
hidden_states[column],
None,
weight_1_chunk,
None,
bias_1_chunk,
weight_2_t_chunk,
None,
bias_2_chunk,
gated,
activation_type,
group,
Expand All @@ -490,6 +505,11 @@ def mlp_autograd_looped(
* scores[column, row, None]
)

# Finalize gradient tracking in reverse order
if bias_2 is not None:
output = chunk_weight_post(output, bias_2, bias_2_chunked)
if bias_1 is not None:
output = chunk_weight_post(output, bias_1, bias_1_chunked)
output = chunk_weight_post(output, weight_2, weight_2_t_chunked)
output = chunk_weight_post(output, weight_1, weight_1_chunked)

Expand Down
97 changes: 96 additions & 1 deletion fast_llm/functional/triton/sparse_copy.py
Original file line number Diff line number Diff line change
Expand Up @@ -307,7 +307,9 @@ def get_sparse_map(
num_rows_unpadded = num_rows_dense * num_experts_per_token
max_rows = (num_rows_unpadded + num_experts * pad_to_multiple) // pad_to_multiple * pad_to_multiple
dtype = torch.int16 if max_rows < 32768 else torch.int32
if (use_triton is None and TritonConfig.TRITON_ENABLED) or use_triton:
# TEMPORARY: Disable Triton kernel due to bug on Triton 3.3+/ARM64
# TODO: Fix sparse_map_kernel to work correctly on newer Triton versions
if False and ((use_triton is None and TritonConfig.TRITON_ENABLED) or use_triton):
expert_ends, expert_pad_begins = top_experts.new_empty((2 * num_experts,), dtype=dtype).chunk(2)
sparse_rows = expert_ends.new_empty(num_rows_dense, num_experts_per_token)
sparse_map_kernel[(triton.cdiv(num_rows_dense, block_size),)](
Expand Down Expand Up @@ -335,3 +337,96 @@ def get_sparse_map(
num_experts=num_experts,
num_experts_per_token=num_experts_per_token,
)


@triton_jit()
def add_sparse_bias_kernel(
input_ptr,
bias_ptr,
output_ptr,
expert_ends_ptr,
num_columns: tl_constexpr,
num_experts: tl_constexpr,
block_size: tl_constexpr,
):
"""Add expert-specific bias to sparse tensor."""
sparse_row = tl.program_id(0)
offsets = tl.arange(0, block_size) + block_size * tl.program_id(1)
mask = None if num_columns % block_size == 0 else offsets < num_columns

# Find which expert this sparse row belongs to
# The sparse rows are organized such that rows for expert i are in range [expert_begins[i], expert_ends[i])
expert_idx = 0
for i in range(num_experts):
expert_end = tl.load(expert_ends_ptr + i)
if sparse_row < expert_end:
expert_idx = i
break

# Load input and bias
input_val = tl.load(input_ptr + sparse_row * num_columns + offsets, mask=mask)
bias_val = tl.load(bias_ptr + expert_idx * num_columns + offsets, mask=mask)

# Add bias and store
output_val = input_val + bias_val
tl.store(output_ptr + sparse_row * num_columns + offsets, output_val, mask=mask)


def add_sparse_bias(
input_: torch.Tensor, # shape: (num_sparse_rows, out_features_per_expert)
bias: torch.Tensor, # shape: (num_experts, out_features_per_expert)
sparse_map: SparseMap,
) -> torch.Tensor:
"""Add expert-specific biases to sparse tensor based on expert assignment."""
num_sparse_rows, hidden_size = input_.shape
num_experts, bias_hidden_size = bias.shape
assert hidden_size == bias_hidden_size, f"Hidden size mismatch: {hidden_size} vs {bias_hidden_size}"
assert num_experts == sparse_map.num_experts

# Use PyTorch implementation for now (can optimize with Triton later if needed)
output = input_.clone()

# For each expert, add its bias to the rows it processed
for expert_idx in range(num_experts):
expert_begin = 0 if expert_idx == 0 else sparse_map.expert_ends[expert_idx - 1].item()
expert_end = sparse_map.expert_ends[expert_idx].item()
expert_pad_begin = sparse_map.expert_pad_begins[expert_idx].item()

# Add bias only to unpadded rows
if expert_begin < expert_pad_begin:
output[expert_begin:expert_pad_begin] += bias[expert_idx]

return output


def add_sparse_bias_forward(
input_: torch.Tensor, bias: torch.Tensor, sparse_map: SparseMap
) -> tuple[torch.Tensor, tuple[torch.Tensor, torch.Tensor, SparseMap]]:
return add_sparse_bias(input_, bias, sparse_map), (input_, bias, sparse_map)


def add_sparse_bias_backward(
grad_output: torch.Tensor, context: tuple[torch.Tensor, torch.Tensor, SparseMap]
) -> tuple[torch.Tensor, torch.Tensor]:
input_, bias, sparse_map = context

# Gradient w.r.t. input is just grad_output (bias is added elementwise)
grad_input = grad_output

# Gradient w.r.t. bias: sum gradients for each expert's rows
grad_bias = torch.zeros_like(bias)
num_experts = sparse_map.num_experts

for expert_idx in range(num_experts):
expert_begin = 0 if expert_idx == 0 else sparse_map.expert_ends[expert_idx - 1].item()
expert_end = sparse_map.expert_ends[expert_idx].item()
expert_pad_begin = sparse_map.expert_pad_begins[expert_idx].item()

# Sum gradients only from unpadded rows
if expert_begin < expert_pad_begin:
grad_bias[expert_idx] = grad_output[expert_begin:expert_pad_begin].sum(dim=0)

return grad_input, grad_bias


add_sparse_bias_autograd = wrap_forward_backward(add_sparse_bias_forward, add_sparse_bias_backward)
16 changes: 16 additions & 0 deletions fast_llm/layers/attention/attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -143,6 +143,22 @@ def __init__(
# Rotary embeddings.
self._rotary = self._config.rotary.get_layer(head_size_dim)

# Attention sinks for streaming attention (optional)
# Sinks are learnable embeddings, one per head
# TODO: Implement sinks usage in forward pass
sinks_dim = TensorDim("sinks", self._config.heads)
sinks = self._config.sinks.get_parameter(
(sinks_dim,),
default_initialization=init_normal_(std=self._hidden_size**-0.5),
lr_scale=self._lr_scale,
default_enabled=False,
peft=None,
)
if sinks is not None:
# Mark as not requiring gradients since sinks are not yet used in forward pass
sinks.allow_no_grad = True
self.sinks = sinks

# Output.
self.dense = self._config.dense_layer.get_layer(
dense_dim,
Expand Down
5 changes: 5 additions & 0 deletions fast_llm/layers/attention/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

from fast_llm.config import Field, FieldHint, check_field, config_class, skip_valid_if_none
from fast_llm.engine.config_utils.data_type import DataType
from fast_llm.engine.config_utils.parameter import OptionalParameterConfig
from fast_llm.engine.distributed.config import DistributedConfig
from fast_llm.functional.config import TritonConfig
from fast_llm.layers.attention.rotary.config import RotaryConfig
Expand Down Expand Up @@ -99,6 +100,10 @@ class AttentionConfig(MixerConfig):
hint=FieldHint.feature,
valid=skip_valid_if_none(check_field(Assert.geq, 0)),
)
sinks: OptionalParameterConfig = Field(
desc="Configuration for attention sinks parameter. Sinks are learnable embeddings (one per head) prepended to keys/values for streaming attention.",
hint=FieldHint.architecture,
)
softmax_scale_power: float = Field(
default=0.5,
desc="The scaling power to apply to head_size in the attention calculation. "
Expand Down
Loading
Loading