Skip to content

Commit 4f5a294

Browse files
authored
feat(archon): add moe_router_dtype config for FP32 router gate GEMM (#1009)
Add configurable FP32 precision for MoE router gate GEMM to improve numerical stability with large expert counts, using a Megatron-Core-style custom torch.autograd.Function. Key changes: - Add moe_router_dtype field to ArchonEngineConfig (default "fp32") - Add router_dtype field to MoEArgs dataclass - Implement RouterGatingLinearFunction with FP32 forward/backward - Thread config from ArchonEngineConfig through to TokenChoiceTopKRouter - None means no override (use model dtype), "fp32" runs gate GEMM in float32 - Consolidate test_moe_args.py and test_router_fp32.py into test_moe_common.py
1 parent eb494de commit 4f5a294

10 files changed

Lines changed: 478 additions & 120 deletions

File tree

areal/api/cli_args.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -542,6 +542,17 @@ class ArchonEngineConfig:
542542
},
543543
)
544544

545+
# MoE
546+
moe_router_dtype: str | None = field(
547+
default="fp32",
548+
metadata={
549+
"help": "Data type for MoE router gate GEMM computation. "
550+
"'fp32' runs gate linear in float32 for numerical stability. "
551+
"None uses model dtype (no override).",
552+
"choices": ["fp32", None],
553+
},
554+
)
555+
545556
def __post_init__(self):
546557
if self.pp_layers_per_stage is not None and self.pp_layers_per_stage < 1:
547558
raise ValueError(
@@ -563,6 +574,12 @@ def __post_init__(self):
563574
f"reshard_after_forward_policy must be one of {valid_reshard_policies}, "
564575
f"got '{self.reshard_after_forward_policy}'"
565576
)
577+
valid_router_dtypes = ("fp32", None)
578+
if self.moe_router_dtype not in valid_router_dtypes:
579+
raise ValueError(
580+
f"moe_router_dtype must be one of {valid_router_dtypes}, "
581+
f"got '{self.moe_router_dtype}'"
582+
)
566583

567584

568585
# These configurations are used by Megatron Bridge to build Megatron models.

areal/experimental/engine/archon_engine.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -976,10 +976,15 @@ def _create_model_structure(self) -> nn.Module:
976976
)
977977
attn_type = "varlen"
978978

979+
# Map moe_router_dtype string config to torch.dtype; None means no override
980+
router_dtype = (
981+
torch.float32 if self.config.archon.moe_router_dtype == "fp32" else None
982+
)
979983
model_args = self.spec.model_args_class.from_hf_config(
980984
self.model_config,
981985
is_critic=self.config.is_critic,
982986
attn_type=attn_type,
987+
router_dtype=router_dtype,
983988
)
984989
return self.spec.model_class(model_args)
985990

areal/experimental/models/archon/moe/args.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,8 @@
55
from dataclasses import dataclass
66
from typing import TYPE_CHECKING, Literal
77

8+
import torch
9+
810
if TYPE_CHECKING:
911
from transformers import PretrainedConfig
1012

@@ -25,6 +27,8 @@ class MoEArgs:
2527
route_norm: Whether to normalize routing scores.
2628
route_scale: Scale factor for routing scores.
2729
score_before_experts: Whether to apply scores before or after expert computation.
30+
router_dtype: Data type for router gate GEMM computation.
31+
If None, the model's default dtype is used (no override).
2832
2933
num_expert_groups: Number of expert groups for node-limited routing.
3034
If None, standard top-k routing is used.
@@ -51,6 +55,7 @@ class MoEArgs:
5155
route_norm: bool = False
5256
route_scale: float = 1.0
5357
score_before_experts: bool = False
58+
router_dtype: torch.dtype | None = None
5459

5560
# Node-limited routing (optional)
5661
num_expert_groups: int | None = None

areal/experimental/models/archon/moe/moe.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -76,6 +76,7 @@ def __init__(self, moe_args: MoEArgs, dim: int, hidden_dim: int):
7676
score_func=moe_args.score_func,
7777
route_norm=moe_args.route_norm,
7878
route_scale=moe_args.route_scale,
79+
router_dtype=moe_args.router_dtype,
7980
num_expert_groups=moe_args.num_expert_groups,
8081
num_limited_groups=moe_args.num_limited_groups,
8182
_debug_force_load_balance=moe_args._debug_force_load_balance,

areal/experimental/models/archon/moe/router.py

Lines changed: 67 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,13 +2,75 @@
22

33
from __future__ import annotations
44

5-
from typing import Literal
5+
from typing import Any, Literal, cast
66

77
import torch
88
import torch.nn.functional as F
99
from torch import nn
1010

1111

12+
class RouterGatingLinearFunction(torch.autograd.Function):
13+
"""Custom autograd function for MoE router gate GEMM in higher precision.
14+
15+
Performs the gate linear layer (input @ weight.T) in the specified dtype
16+
while saving tensors in the original dtype for memory efficiency.
17+
18+
This is adapted from Megatron-Core's RouterGatingLinearFunction.
19+
"""
20+
21+
@staticmethod
22+
@torch.amp.custom_fwd(device_type="cuda")
23+
def forward(
24+
ctx: torch.autograd.function.FunctionCtx,
25+
input: torch.Tensor,
26+
weight: torch.Tensor,
27+
router_dtype: torch.dtype,
28+
) -> torch.Tensor:
29+
"""Forward pass: compute input @ weight.T in router_dtype.
30+
31+
Saves input and weight in their original dtype (BF16) for memory efficiency.
32+
"""
33+
ctx.save_for_backward(input, weight)
34+
cast(Any, ctx).router_dtype = router_dtype
35+
return torch.mm(input.to(router_dtype), weight.to(router_dtype).t())
36+
37+
@staticmethod
38+
@torch.amp.custom_bwd(device_type="cuda")
39+
def backward(
40+
ctx: torch.autograd.function.FunctionCtx,
41+
*grad_outputs: torch.Tensor,
42+
) -> tuple[torch.Tensor, torch.Tensor, None]:
43+
"""Backward pass: compute gradients in router_dtype, return in original dtype."""
44+
grad_output = grad_outputs[0]
45+
input, weight = cast(Any, ctx).saved_tensors
46+
router_dtype = cast(Any, ctx).router_dtype
47+
grad_output_fp = grad_output.to(router_dtype)
48+
grad_input = grad_output_fp.mm(weight.to(router_dtype)).to(input.dtype)
49+
grad_weight = grad_output_fp.t().mm(input.to(router_dtype)).to(weight.dtype)
50+
return grad_input, grad_weight, None
51+
52+
53+
def router_gating_linear(
54+
input: torch.Tensor, weight: torch.Tensor, router_dtype: torch.dtype | None
55+
) -> torch.Tensor:
56+
"""Apply router gate linear with optional dtype casting for numerical stability.
57+
58+
Args:
59+
input: Input tensor (num_tokens, dim).
60+
weight: Gate weight tensor (num_experts, dim).
61+
router_dtype: Dtype to use for GEMM. If None, uses standard F.linear.
62+
63+
Returns:
64+
Output tensor (num_tokens, num_experts).
65+
"""
66+
if router_dtype is not None:
67+
return cast(
68+
torch.Tensor,
69+
RouterGatingLinearFunction.apply(input, weight, router_dtype),
70+
)
71+
return F.linear(input, weight)
72+
73+
1274
class TokenChoiceTopKRouter(nn.Module):
1375
"""Token-choice routing with top-k expert selection.
1476
@@ -23,6 +85,7 @@ class TokenChoiceTopKRouter(nn.Module):
2385
score_func: Scoring function, either "softmax" or "sigmoid".
2486
route_norm: Whether to normalize routing scores after top-k selection.
2587
route_scale: Scale factor applied to routing scores.
88+
router_dtype: Data type for gate GEMM. If None, uses model dtype (no override).
2689
num_expert_groups: Number of expert groups for node-limited routing.
2790
If None, standard top-k routing is used.
2891
num_limited_groups: Number of groups to select in node-limited routing.
@@ -41,6 +104,7 @@ def __init__(
41104
score_func: Literal["softmax", "sigmoid"] = "sigmoid",
42105
route_norm: bool = False,
43106
route_scale: float = 1.0,
107+
router_dtype: torch.dtype | None = None,
44108
num_expert_groups: int | None = None,
45109
num_limited_groups: int | None = None,
46110
_debug_force_load_balance: bool = False,
@@ -52,6 +116,7 @@ def __init__(
52116
self.score_func = score_func
53117
self.route_norm = route_norm
54118
self.route_scale = route_scale
119+
self.router_dtype = router_dtype
55120
self.num_expert_groups = num_expert_groups
56121
self.num_limited_groups = num_limited_groups
57122
self._debug_force_load_balance = _debug_force_load_balance
@@ -147,7 +212,7 @@ def forward(
147212
- num_tokens_per_expert: Token count per expert, shape (num_experts,).
148213
"""
149214
# Compute gate scores: (num_tokens, num_experts)
150-
scores = self.gate(x)
215+
scores = router_gating_linear(x, self.gate.weight, self.router_dtype)
151216

152217
# Apply scoring function in float32 to avoid loss explosion
153218
if self.score_func == "sigmoid":

areal/experimental/models/archon/qwen3/model/args.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -70,6 +70,9 @@ def from_hf_config(
7070
# Override with additional fields from HF config
7171
if hasattr(hf_config, "num_shared_experts"):
7272
moe_args.num_shared_experts = hf_config.num_shared_experts
73+
router_dtype = kwargs.get("router_dtype", None)
74+
if router_dtype is not None:
75+
moe_args.router_dtype = router_dtype
7376

7477
# Get decoder_sparse_step (default to 1 = all MoE layers)
7578
decoder_sparse_step = getattr(hf_config, "decoder_sparse_step", 1)

docs/en/cli_reference.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -808,6 +808,7 @@ Configuration for Archon Engine training backend.
808808
| `pp_last_stage_less_layers` | integer | `1` | Number of layers to reduce in the last pipeline stage. Accounts for output layer overhead. |
809809
| `reshard_after_forward_policy` | string | `"default"` | FSDP reshard policy after forward pass. 'default': reshard when pipeline parallelism is off; keep unsharded when on to avoid repeated all-gather per microbatch. 'always': always reshard after forward (saves memory). 'never': never reshard after forward. **Choices:** `default`, `always`, `never` |
810810
| `use_deterministic_algorithms` | boolean | `False` | Enable deterministic algorithms for training reproducibility. Sets torch.use_deterministic_algorithms(True, warn_only=True), CUBLAS_WORKSPACE_CONFIG, NCCL_ALGO, and TORCH_COMPILE_DETERMINISTIC. May reduce performance. |
811+
| `moe_router_dtype` | string \| None | `"fp32"` | Data type for MoE router gate GEMM computation. 'fp32' runs gate linear in float32 for numerical stability. None uses model dtype (no override). **Choices:** `fp32`, `None` |
811812

812813
(section-distributed-data-parallel)=
813814

docs/zh/cli_reference.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -806,6 +806,7 @@ Configuration for Archon Engine training backend.
806806
| `pp_last_stage_less_layers` | integer | `1` | Number of layers to reduce in the last pipeline stage. Accounts for output layer overhead. |
807807
| `reshard_after_forward_policy` | string | `"default"` | FSDP reshard policy after forward pass. 'default': reshard when pipeline parallelism is off; keep unsharded when on to avoid repeated all-gather per microbatch. 'always': always reshard after forward (saves memory). 'never': never reshard after forward. **Choices:** `default`, `always`, `never` |
808808
| `use_deterministic_algorithms` | boolean | `False` | Enable deterministic algorithms for training reproducibility. Sets torch.use_deterministic_algorithms(True, warn_only=True), CUBLAS_WORKSPACE_CONFIG, NCCL_ALGO, and TORCH_COMPILE_DETERMINISTIC. May reduce performance. |
809+
| `moe_router_dtype` | string \| None | `"fp32"` | Data type for MoE router gate GEMM computation. 'fp32' runs gate linear in float32 for numerical stability. None uses model dtype (no override). **Choices:** `fp32`, `None` |
809810

810811
(section-distributed-data-parallel)=
811812

tests/experimental/archon/test_moe_args.py

Lines changed: 0 additions & 118 deletions
This file was deleted.

0 commit comments

Comments
 (0)