You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
feat(archon): add moe_router_dtype config for FP32 router gate GEMM (#1009)
Add configurable FP32 precision for MoE router gate GEMM to improve
numerical stability with large expert counts, using a Megatron-Core-style
custom torch.autograd.Function.
Key changes:
- Add moe_router_dtype field to ArchonEngineConfig (default "fp32")
- Add router_dtype field to MoEArgs dataclass
- Implement RouterGatingLinearFunction with FP32 forward/backward
- Thread config from ArchonEngineConfig through to TokenChoiceTopKRouter
- None means no override (use model dtype), "fp32" runs gate GEMM in float32
- Consolidate test_moe_args.py and test_router_fp32.py into test_moe_common.py
Copy file name to clipboardExpand all lines: docs/en/cli_reference.md
+1Lines changed: 1 addition & 0 deletions
Display the source diff
Display the rich diff
Original file line number
Diff line number
Diff line change
@@ -808,6 +808,7 @@ Configuration for Archon Engine training backend.
808
808
|`pp_last_stage_less_layers`| integer |`1`| Number of layers to reduce in the last pipeline stage. Accounts for output layer overhead. |
809
809
|`reshard_after_forward_policy`| string |`"default"`| FSDP reshard policy after forward pass. 'default': reshard when pipeline parallelism is off; keep unsharded when on to avoid repeated all-gather per microbatch. 'always': always reshard after forward (saves memory). 'never': never reshard after forward. **Choices:**`default`, `always`, `never`|
810
810
|`use_deterministic_algorithms`| boolean |`False`| Enable deterministic algorithms for training reproducibility. Sets torch.use_deterministic_algorithms(True, warn_only=True), CUBLAS_WORKSPACE_CONFIG, NCCL_ALGO, and TORCH_COMPILE_DETERMINISTIC. May reduce performance. |
811
+
|`moe_router_dtype`| string \| None |`"fp32"`| Data type for MoE router gate GEMM computation. 'fp32' runs gate linear in float32 for numerical stability. None uses model dtype (no override). **Choices:**`fp32`, `None`|
Copy file name to clipboardExpand all lines: docs/zh/cli_reference.md
+1Lines changed: 1 addition & 0 deletions
Display the source diff
Display the rich diff
Original file line number
Diff line number
Diff line change
@@ -806,6 +806,7 @@ Configuration for Archon Engine training backend.
806
806
|`pp_last_stage_less_layers`| integer |`1`| Number of layers to reduce in the last pipeline stage. Accounts for output layer overhead. |
807
807
|`reshard_after_forward_policy`| string |`"default"`| FSDP reshard policy after forward pass. 'default': reshard when pipeline parallelism is off; keep unsharded when on to avoid repeated all-gather per microbatch. 'always': always reshard after forward (saves memory). 'never': never reshard after forward. **Choices:**`default`, `always`, `never`|
808
808
|`use_deterministic_algorithms`| boolean |`False`| Enable deterministic algorithms for training reproducibility. Sets torch.use_deterministic_algorithms(True, warn_only=True), CUBLAS_WORKSPACE_CONFIG, NCCL_ALGO, and TORCH_COMPILE_DETERMINISTIC. May reduce performance. |
809
+
|`moe_router_dtype`| string \| None |`"fp32"`| Data type for MoE router gate GEMM computation. 'fp32' runs gate linear in float32 for numerical stability. None uses model dtype (no override). **Choices:**`fp32`, `None`|
0 commit comments