Skip to content

Commit 4ebc3bf

Browse files
jenchen13claude
andcommitted
[OMNIML-5003] Restrict non-gated detection to single up_proj (review)
Address review feedback: - _fused_experts_wrapper_class now claims _QuantNonGatedFusedExperts only for a 3-D up_proj with no gate_proj and no gate_up_proj. A split-gated container (separate 3-D gate_proj/up_proj/down_proj, three F.linear calls per expert) falls through to None/unsupported instead of being mis-wrapped, since the two-call toggle and up_proj-storage index recovery assume exactly two calls. - Add test_split_gated_layout_not_claimed_as_nongated and test_get_quant_config_resolves_nongated_experts (down_proj anchors format / has_quantizers detection, so the produced quant config is correct). Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com> Signed-off-by: Jennifer Chen <jennifchen@nvidia.com>
1 parent 7f22c90 commit 4ebc3bf

2 files changed

Lines changed: 61 additions & 4 deletions

File tree

modelopt/torch/quantization/plugins/huggingface.py

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1502,8 +1502,8 @@ def _fused_experts_wrapper_class(module):
15021502
``MixtralExperts``, ``Qwen2MoeExperts``, ``Qwen3MoeExperts``,
15031503
``Qwen3_5MoeExperts``, ``DeepseekV3NaiveMoe``, ``JambaExperts``,
15041504
``OlmoeExperts``, etc.
1505-
* non-gated (``_QuantNonGatedFusedExperts``): a 3-D ``up_proj`` and no
1506-
``gate_up_proj``. Matches NemotronH ``NemotronHExperts``.
1505+
* non-gated (``_QuantNonGatedFusedExperts``): a 3-D ``up_proj`` with no
1506+
``gate_proj`` and no ``gate_up_proj``. Matches NemotronH ``NemotronHExperts``.
15071507
15081508
Returns ``None`` for non-standard layouts (DBRX, GptOss, GraniteMoE,
15091509
Llama4TextExperts) which have their own explicit registrations.
@@ -1518,7 +1518,14 @@ def _fused_experts_wrapper_class(module):
15181518
return _QuantFusedExperts
15191519
up = getattr(module, "up_proj", None)
15201520
if isinstance(up, (nn.Parameter, Tensor)) and up.dim() == 3:
1521-
return _QuantNonGatedFusedExperts
1521+
# Strictly single up_proj/down_proj only. A split-gated container with a
1522+
# separate gate projection (3-D gate_proj or gate_up_proj) makes three
1523+
# F.linear calls per expert, which would break _QuantNonGatedFusedExperts'
1524+
# two-call toggle and its up_proj-storage expert-index recovery. Such a
1525+
# layout is unsupported here (falls through to None) rather than silently
1526+
# mis-quantizing the wrong projection.
1527+
if getattr(module, "gate_proj", None) is None and gate_up is None:
1528+
return _QuantNonGatedFusedExperts
15221529
return None
15231530

15241531

tests/unit/torch/quantization/plugins/test_fused_experts.py

Lines changed: 51 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@
2424

2525
import modelopt.torch.quantization as mtq
2626
from modelopt.torch.export.moe_utils import _export_fused_experts
27-
from modelopt.torch.export.quant_utils import get_quant_config
27+
from modelopt.torch.export.quant_utils import get_quant_config, get_quantization_format
2828
from modelopt.torch.quantization.conversion import _normalize_fused_experts_quantizer_name
2929
from modelopt.torch.quantization.model_calib import local_hessian_calibrate
3030
from modelopt.torch.quantization.nn import QuantModuleRegistry, TensorQuantizer
@@ -1191,3 +1191,53 @@ def test_enumeration_yields_up_and_down_proj(self):
11911191
assert set(weight_attr_names(converted)) == {"up_proj", "down_proj"}
11921192
finally:
11931193
self._cleanup_registry(expert_type)
1194+
1195+
def test_split_gated_layout_not_claimed_as_nongated(self):
1196+
"""A fused container with a separate 3-D gate_proj (split-gated: three
1197+
F.linear calls per expert) must NOT be claimed by the non-gated wrapper,
1198+
whose two-call toggle and up_proj-storage index recovery assume exactly
1199+
two projections. It is left unsupported (None) rather than mis-quantized."""
1200+
1201+
class _SplitGatedExperts(nn.Module):
1202+
def __init__(self):
1203+
super().__init__()
1204+
self.num_experts = NUM_EXPERTS
1205+
self.gate_proj = nn.Parameter(
1206+
torch.randn(NUM_EXPERTS, INTERMEDIATE_DIM, HIDDEN_DIM) * 0.02
1207+
)
1208+
self.up_proj = nn.Parameter(
1209+
torch.randn(NUM_EXPERTS, INTERMEDIATE_DIM, HIDDEN_DIM) * 0.02
1210+
)
1211+
self.down_proj = nn.Parameter(
1212+
torch.randn(NUM_EXPERTS, HIDDEN_DIM, INTERMEDIATE_DIM) * 0.02
1213+
)
1214+
self.act_fn = nn.SiLU()
1215+
1216+
module = _SplitGatedExperts()
1217+
assert _fused_experts_wrapper_class(module) is None
1218+
assert _is_fused_experts_module(module) is False
1219+
1220+
def test_get_quant_config_resolves_nongated_experts(self):
1221+
"""get_quant_config must detect the non-gated experts as quantized. The
1222+
up_proj weight name does not resolve to its quantizers (they live on the
1223+
gate_up_proj sentinel list), but down_proj anchors both has_quantizers
1224+
(down_proj_input_quantizer) and format detection (down_proj_weight_quantizers),
1225+
so the produced config is correct."""
1226+
model = _TinyNonGatedMoEModel()
1227+
expert_type = type(model.moe.experts)
1228+
self._cleanup_registry(expert_type)
1229+
1230+
def forward_loop(m):
1231+
torch.manual_seed(0)
1232+
for _ in range(2):
1233+
m(torch.randn(1, 4, HIDDEN_DIM))
1234+
1235+
try:
1236+
mtq.quantize(model, self._nongated_fp8_cfg(), forward_loop=forward_loop)
1237+
# Format resolves (via down_proj) instead of QUANTIZATION_NONE (None).
1238+
assert get_quantization_format(model.moe.experts) is not None
1239+
# The non-gated experts are reflected in the produced quant config.
1240+
quant = get_quant_config(model)["quantization"]
1241+
assert quant.get("quant_algo") is not None
1242+
finally:
1243+
self._cleanup_registry(expert_type)

0 commit comments

Comments
 (0)