Skip to content

Commit b5f181e

Browse files
juhi10071998claude
andcommitted
refactor: relocate tied-weight helpers per Edward's review
Move three private helpers out of unified_export_hf.py into the more specific util modules suggested in the round-2 review: - sync_tied_input_amax → modelopt/torch/export/quant_utils.py Quantizer-amax merge helper, fits the file's "Utils for quantization including scaling factors adjustments" charter (sibling to adjust_attn_amax_values, get_*_scaling_factor, etc.). Uses the file's existing `from warnings import warn` rather than a new warnings import. - _collect_canonical_tied_patterns → modelopt/torch/export/model_utils.py - _reorder_canonical_first → modelopt/torch/export/model_utils.py Both are model-introspection helpers that walk the model graph and read HF's _tied_weights_keys declaration, which fits the file's "model type detection and classification" charter (sibling to get_model_type, is_multimodal_model). Added `import re` to support the regex compilation in _collect_canonical_tied_patterns. unified_export_hf.py now imports the three helpers from their new homes and uses them unchanged. No behavioral change. tests/unit/torch/export/test_unified_export_hf.py updates its imports to pull the three helpers from the new locations. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com> Signed-off-by: Juhi Mittal <juhim@nvidia.com>
1 parent afd0831 commit b5f181e

4 files changed

Lines changed: 164 additions & 161 deletions

File tree

modelopt/torch/export/model_utils.py

Lines changed: 83 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,8 @@
1414
# limitations under the License.
1515
"""Utility functions for model type detection and classification."""
1616

17+
import re
18+
1719
import torch.nn as nn
1820

1921
MODEL_NAME_TO_TYPE = {
@@ -160,3 +162,84 @@ def get_language_model_from_vl(model) -> list[nn.Module] | None:
160162

161163
# Pattern 4: No language_model found
162164
return None
165+
166+
167+
def _collect_canonical_tied_patterns(
168+
model: nn.Module,
169+
) -> tuple[list[re.Pattern], list[str]]:
170+
"""Walk the model and collect canonical-side tied-weight matchers.
171+
172+
Patterns are submodule-prefixed regexes from each module's
173+
``_tied_weights_keys`` dict-style declaration (the prefix matters
174+
for nested models where the dict lives on an inner submodule).
175+
Side substrings are dot-separated tokens that appear only on the
176+
canonical side of those declarations — needed because modelopt's
177+
per-expert unpacking creates post-export keys (e.g.
178+
``…experts.Y.gate_proj.input_scale``) that HF's regexes never knew
179+
about. List-style (legacy) declarations are skipped.
180+
"""
181+
patterns: list[re.Pattern] = []
182+
alias_token_set: set[str] = set()
183+
canonical_token_set: set[str] = set()
184+
185+
def _tokens(s: str) -> set[str]:
186+
"""Identifiers in a regex string, with regex specials as separators."""
187+
return {tok for tok in re.split(r"[^A-Za-z0-9_]+", s) if tok}
188+
189+
for name, submodule in model.named_modules():
190+
tied = getattr(submodule, "_tied_weights_keys", None)
191+
if not isinstance(tied, dict) or not tied:
192+
continue
193+
prefix = f"{name}." if name else ""
194+
for alias_pat, canonical_pat in tied.items():
195+
patterns.append(re.compile(prefix + canonical_pat))
196+
alias_token_set.update(_tokens(prefix + alias_pat))
197+
canonical_token_set.update(_tokens(prefix + canonical_pat))
198+
199+
# Tokens unique to the canonical side become substring matchers.
200+
side_substrings = sorted(canonical_token_set - alias_token_set)
201+
return patterns, side_substrings
202+
203+
204+
def _reorder_canonical_first(state_dict: dict, model: nn.Module) -> dict:
205+
r"""Reorder ``state_dict`` so canonical-side tied keys iterate first.
206+
207+
Lets the downstream first-wins data_ptr dedup keep canonical names.
208+
Uses both regex patterns and substring matchers from
209+
:func:`_collect_canonical_tied_patterns`. Gated on the model class
210+
name to scope the reorder to DiffusionGemma; other tied
211+
encoder-decoder models that ship dict-style ``_tied_weights_keys``
212+
can be added to the allowlist here. Mirrors the ``model_type``
213+
dispatch used for the Whisper and Nemotron-VL branches elsewhere
214+
in ``unified_export_hf.py``.
215+
"""
216+
model_type = type(model).__name__.lower()
217+
if "diffusiongemma" not in model_type and "diffusion_gemma" not in model_type:
218+
return state_dict
219+
220+
canonical_patterns, side_substrings = _collect_canonical_tied_patterns(model)
221+
if not canonical_patterns and not side_substrings:
222+
return state_dict
223+
224+
def _has_side_substring(key: str) -> bool:
225+
# Require the token to appear as a proper dot-separated path
226+
# component, not just as a substring of an unrelated identifier.
227+
for tok in side_substrings:
228+
if (
229+
f".{tok}." in key
230+
or key.startswith(f"{tok}.")
231+
or key.endswith(f".{tok}")
232+
or key == tok
233+
):
234+
return True
235+
return False
236+
237+
head: dict = {}
238+
tail: dict = {}
239+
for k, v in state_dict.items():
240+
if any(p.search(k) for p in canonical_patterns) or _has_side_substring(k):
241+
head[k] = v
242+
else:
243+
tail[k] = v
244+
head.update(tail)
245+
return head

modelopt/torch/export/quant_utils.py

Lines changed: 76 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1479,3 +1479,79 @@ def has_quantized_modules(model: nn.Module) -> bool:
14791479
get_quantization_format(sub_module) != QUANTIZATION_NONE
14801480
for _, sub_module in model.named_modules()
14811481
)
1482+
1483+
1484+
def sync_tied_input_amax(model: nn.Module) -> int:
1485+
"""Max-merge input_quantizer amaxes across modules sharing a weight ``data_ptr``.
1486+
1487+
Mutates ``model`` in place: overwrites the ``.amax`` buffer on every
1488+
affected ``input_quantizer`` with the per-group maximum. Intended to
1489+
run as part of an export pipeline that already replaces weights with
1490+
packed bytes downstream — i.e. the model is not expected to be reused
1491+
after this helper runs.
1492+
1493+
Closes the loop on ``input_scale`` for HF-tied modules whose forward
1494+
paths see different activation distributions (encoder vs decoder in
1495+
YOCO-style models). Must run BEFORE per-module export so the merged
1496+
amax flows into ``input_scale`` derivation. Handles both dense
1497+
Linears (keyed by ``weight.data_ptr()``) and fused MoE (keyed by
1498+
``(gate_up_proj, down_proj)`` data_ptr tuple). Returns the number of
1499+
tied groups merged.
1500+
"""
1501+
from collections import defaultdict
1502+
1503+
by_dp: dict = defaultdict(list)
1504+
for _, m in model.named_modules():
1505+
# Fused MoE: 3-D source tensors with shared input quantizers
1506+
if (
1507+
hasattr(m, "gate_up_proj_input_quantizer")
1508+
and hasattr(m, "gate_up_proj")
1509+
and hasattr(m, "down_proj")
1510+
and m.gate_up_proj.dim() == 3
1511+
):
1512+
key = ("moe", m.gate_up_proj.data_ptr(), m.down_proj.data_ptr())
1513+
by_dp[key].append(m)
1514+
# Dense quantized Linear with an input_quantizer
1515+
elif (
1516+
hasattr(m, "input_quantizer")
1517+
and hasattr(m, "weight")
1518+
and isinstance(m.weight, torch.nn.Parameter)
1519+
):
1520+
by_dp[("dense", m.weight.data_ptr())].append(m)
1521+
1522+
def _merge(quantizers: list) -> bool:
1523+
"""Max-merge amaxes across the quantizer list. Returns True on merge."""
1524+
valid = [
1525+
q
1526+
for q in quantizers
1527+
if q is not None
1528+
and getattr(q, "is_enabled", False)
1529+
and getattr(q, "_amax", None) is not None
1530+
and not q._amax.is_meta
1531+
]
1532+
if len(valid) < 2:
1533+
return False
1534+
# Require scalar (per-tensor) amax — matches preprocess_linear_fusion.
1535+
if any(q._amax.numel() != 1 for q in valid):
1536+
warn(
1537+
"sync_tied_input_amax: non-scalar input_quantizer amax encountered "
1538+
"in a tied group; skipping. Only per-tensor input quantizers are "
1539+
"supported for tied-modules merging."
1540+
)
1541+
return False
1542+
merged = torch.max(torch.stack([q.amax for q in valid]))
1543+
for q in valid:
1544+
q.amax = merged.clone()
1545+
return True
1546+
1547+
synced = 0
1548+
for key, modules in by_dp.items():
1549+
if len(modules) < 2:
1550+
continue
1551+
if key[0] == "moe":
1552+
for q_name in ("gate_up_proj_input_quantizer", "down_proj_input_quantizer"):
1553+
if _merge([getattr(m, q_name, None) for m in modules]):
1554+
synced += 1
1555+
elif _merge([m.input_quantizer for m in modules]):
1556+
synced += 1
1557+
return synced

modelopt/torch/export/unified_export_hf.py

Lines changed: 2 additions & 158 deletions
Original file line numberDiff line numberDiff line change
@@ -88,7 +88,7 @@
8888
QUANTIZATION_W4A8_NVFP4_FP8,
8989
QUANTIZATION_W4A16_NVFP4,
9090
)
91-
from .model_utils import get_language_model_from_vl, is_multimodal_model
91+
from .model_utils import _reorder_canonical_first, get_language_model_from_vl, is_multimodal_model
9292
from .moe_utils import _export_fused_experts
9393
from .plugins import SpeculativeDecodingExporter, has_spec_opt
9494
from .quant_utils import (
@@ -104,6 +104,7 @@
104104
maybe_transpose_expert_weight_dimensions,
105105
postprocess_state_dict,
106106
preprocess_linear_fusion,
107+
sync_tied_input_amax,
107108
to_quantized_weight,
108109
)
109110

@@ -754,163 +755,6 @@ def _export_quantized_weight(
754755
torch.cuda.empty_cache()
755756

756757

757-
def _collect_canonical_tied_patterns(
758-
model: nn.Module,
759-
) -> tuple[list[re.Pattern], list[str]]:
760-
"""Walk the model and collect canonical-side tied-weight matchers.
761-
762-
Patterns are submodule-prefixed regexes from each module's
763-
``_tied_weights_keys`` dict-style declaration (the prefix matters
764-
for nested models where the dict lives on an inner submodule).
765-
Side substrings are dot-separated tokens that appear only on the
766-
canonical side of those declarations — needed because modelopt's
767-
per-expert unpacking creates post-export keys (e.g.
768-
``…experts.Y.gate_proj.input_scale``) that HF's regexes never knew
769-
about. List-style (legacy) declarations are skipped.
770-
"""
771-
patterns: list[re.Pattern] = []
772-
alias_token_set: set[str] = set()
773-
canonical_token_set: set[str] = set()
774-
775-
def _tokens(s: str) -> set[str]:
776-
"""Identifiers in a regex string, with regex specials as separators."""
777-
return {tok for tok in re.split(r"[^A-Za-z0-9_]+", s) if tok}
778-
779-
for name, submodule in model.named_modules():
780-
tied = getattr(submodule, "_tied_weights_keys", None)
781-
if not isinstance(tied, dict) or not tied:
782-
continue
783-
prefix = f"{name}." if name else ""
784-
for alias_pat, canonical_pat in tied.items():
785-
patterns.append(re.compile(prefix + canonical_pat))
786-
alias_token_set.update(_tokens(prefix + alias_pat))
787-
canonical_token_set.update(_tokens(prefix + canonical_pat))
788-
789-
# Tokens unique to the canonical side become substring matchers.
790-
side_substrings = sorted(canonical_token_set - alias_token_set)
791-
return patterns, side_substrings
792-
793-
794-
def _reorder_canonical_first(state_dict: dict, model: nn.Module) -> dict:
795-
r"""Reorder ``state_dict`` so canonical-side tied keys iterate first.
796-
797-
Lets the downstream first-wins data_ptr dedup keep canonical names.
798-
Uses both regex patterns and substring matchers from
799-
:func:`_collect_canonical_tied_patterns`. Gated on the model class
800-
name to scope the reorder to DiffusionGemma; other tied
801-
encoder-decoder models that ship dict-style ``_tied_weights_keys``
802-
can be added to the allowlist here. Mirrors the ``model_type``
803-
dispatch used for the Whisper and Nemotron-VL branches elsewhere
804-
in this file.
805-
"""
806-
model_type = type(model).__name__.lower()
807-
if "diffusiongemma" not in model_type and "diffusion_gemma" not in model_type:
808-
return state_dict
809-
810-
canonical_patterns, side_substrings = _collect_canonical_tied_patterns(model)
811-
if not canonical_patterns and not side_substrings:
812-
return state_dict
813-
814-
def _has_side_substring(key: str) -> bool:
815-
# Require the token to appear as a proper dot-separated path
816-
# component, not just as a substring of an unrelated identifier.
817-
for tok in side_substrings:
818-
if (
819-
f".{tok}." in key
820-
or key.startswith(f"{tok}.")
821-
or key.endswith(f".{tok}")
822-
or key == tok
823-
):
824-
return True
825-
return False
826-
827-
head: dict = {}
828-
tail: dict = {}
829-
for k, v in state_dict.items():
830-
if any(p.search(k) for p in canonical_patterns) or _has_side_substring(k):
831-
head[k] = v
832-
else:
833-
tail[k] = v
834-
head.update(tail)
835-
return head
836-
837-
838-
def sync_tied_input_amax(model: nn.Module) -> int:
839-
"""Max-merge input_quantizer amaxes across modules sharing a weight ``data_ptr``.
840-
841-
Mutates ``model`` in place: overwrites the ``.amax`` buffer on every
842-
affected ``input_quantizer`` with the per-group maximum. Intended to
843-
run as part of an export pipeline that already replaces weights with
844-
packed bytes downstream — i.e. the model is not expected to be reused
845-
after this helper runs.
846-
847-
Closes the loop on ``input_scale`` for HF-tied modules whose forward
848-
paths see different activation distributions (encoder vs decoder in
849-
YOCO-style models). Must run BEFORE per-module export so the merged
850-
amax flows into ``input_scale`` derivation. Handles both dense
851-
Linears (keyed by ``weight.data_ptr()``) and fused MoE (keyed by
852-
``(gate_up_proj, down_proj)`` data_ptr tuple). Returns the number of
853-
tied groups merged.
854-
"""
855-
from collections import defaultdict
856-
857-
by_dp: dict = defaultdict(list)
858-
for _, m in model.named_modules():
859-
# Fused MoE: 3-D source tensors with shared input quantizers
860-
if (
861-
hasattr(m, "gate_up_proj_input_quantizer")
862-
and hasattr(m, "gate_up_proj")
863-
and hasattr(m, "down_proj")
864-
and m.gate_up_proj.dim() == 3
865-
):
866-
key = ("moe", m.gate_up_proj.data_ptr(), m.down_proj.data_ptr())
867-
by_dp[key].append(m)
868-
# Dense quantized Linear with an input_quantizer
869-
elif (
870-
hasattr(m, "input_quantizer")
871-
and hasattr(m, "weight")
872-
and isinstance(m.weight, torch.nn.Parameter)
873-
):
874-
by_dp[("dense", m.weight.data_ptr())].append(m)
875-
876-
def _merge(quantizers: list) -> bool:
877-
"""Max-merge amaxes across the quantizer list. Returns True on merge."""
878-
valid = [
879-
q
880-
for q in quantizers
881-
if q is not None
882-
and getattr(q, "is_enabled", False)
883-
and getattr(q, "_amax", None) is not None
884-
and not q._amax.is_meta
885-
]
886-
if len(valid) < 2:
887-
return False
888-
# Require scalar (per-tensor) amax — matches preprocess_linear_fusion.
889-
if any(q._amax.numel() != 1 for q in valid):
890-
warnings.warn(
891-
"sync_tied_input_amax: non-scalar input_quantizer amax encountered "
892-
"in a tied group; skipping. Only per-tensor input quantizers are "
893-
"supported for tied-modules merging."
894-
)
895-
return False
896-
merged = torch.max(torch.stack([q.amax for q in valid]))
897-
for q in valid:
898-
q.amax = merged.clone()
899-
return True
900-
901-
synced = 0
902-
for key, modules in by_dp.items():
903-
if len(modules) < 2:
904-
continue
905-
if key[0] == "moe":
906-
for q_name in ("gate_up_proj_input_quantizer", "down_proj_input_quantizer"):
907-
if _merge([getattr(m, q_name, None) for m in modules]):
908-
synced += 1
909-
elif _merge([m.input_quantizer for m in modules]):
910-
synced += 1
911-
return synced
912-
913-
914758
def _process_quantized_modules(
915759
model: nn.Module,
916760
dtype: torch.dtype,

tests/unit/torch/export/test_unified_export_hf.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -24,12 +24,12 @@
2424
)
2525

2626
import modelopt.torch.quantization as mtq
27-
from modelopt.torch.export.unified_export_hf import (
27+
from modelopt.torch.export.model_utils import (
2828
_collect_canonical_tied_patterns,
29-
_export_quantized_weight,
3029
_reorder_canonical_first,
31-
sync_tied_input_amax,
3230
)
31+
from modelopt.torch.export.quant_utils import sync_tied_input_amax
32+
from modelopt.torch.export.unified_export_hf import _export_quantized_weight
3333

3434

3535
def test_collect_canonical_tied_patterns_dict_style():

0 commit comments

Comments
 (0)