|
88 | 88 | QUANTIZATION_W4A8_NVFP4_FP8, |
89 | 89 | QUANTIZATION_W4A16_NVFP4, |
90 | 90 | ) |
91 | | -from .model_utils import get_language_model_from_vl, is_multimodal_model |
| 91 | +from .model_utils import _reorder_canonical_first, get_language_model_from_vl, is_multimodal_model |
92 | 92 | from .moe_utils import _export_fused_experts |
93 | 93 | from .plugins import SpeculativeDecodingExporter, has_spec_opt |
94 | 94 | from .quant_utils import ( |
|
104 | 104 | maybe_transpose_expert_weight_dimensions, |
105 | 105 | postprocess_state_dict, |
106 | 106 | preprocess_linear_fusion, |
| 107 | + sync_tied_input_amax, |
107 | 108 | to_quantized_weight, |
108 | 109 | ) |
109 | 110 |
|
@@ -754,163 +755,6 @@ def _export_quantized_weight( |
754 | 755 | torch.cuda.empty_cache() |
755 | 756 |
|
756 | 757 |
|
757 | | -def _collect_canonical_tied_patterns( |
758 | | - model: nn.Module, |
759 | | -) -> tuple[list[re.Pattern], list[str]]: |
760 | | - """Walk the model and collect canonical-side tied-weight matchers. |
761 | | -
|
762 | | - Patterns are submodule-prefixed regexes from each module's |
763 | | - ``_tied_weights_keys`` dict-style declaration (the prefix matters |
764 | | - for nested models where the dict lives on an inner submodule). |
765 | | - Side substrings are dot-separated tokens that appear only on the |
766 | | - canonical side of those declarations — needed because modelopt's |
767 | | - per-expert unpacking creates post-export keys (e.g. |
768 | | - ``…experts.Y.gate_proj.input_scale``) that HF's regexes never knew |
769 | | - about. List-style (legacy) declarations are skipped. |
770 | | - """ |
771 | | - patterns: list[re.Pattern] = [] |
772 | | - alias_token_set: set[str] = set() |
773 | | - canonical_token_set: set[str] = set() |
774 | | - |
775 | | - def _tokens(s: str) -> set[str]: |
776 | | - """Identifiers in a regex string, with regex specials as separators.""" |
777 | | - return {tok for tok in re.split(r"[^A-Za-z0-9_]+", s) if tok} |
778 | | - |
779 | | - for name, submodule in model.named_modules(): |
780 | | - tied = getattr(submodule, "_tied_weights_keys", None) |
781 | | - if not isinstance(tied, dict) or not tied: |
782 | | - continue |
783 | | - prefix = f"{name}." if name else "" |
784 | | - for alias_pat, canonical_pat in tied.items(): |
785 | | - patterns.append(re.compile(prefix + canonical_pat)) |
786 | | - alias_token_set.update(_tokens(prefix + alias_pat)) |
787 | | - canonical_token_set.update(_tokens(prefix + canonical_pat)) |
788 | | - |
789 | | - # Tokens unique to the canonical side become substring matchers. |
790 | | - side_substrings = sorted(canonical_token_set - alias_token_set) |
791 | | - return patterns, side_substrings |
792 | | - |
793 | | - |
794 | | -def _reorder_canonical_first(state_dict: dict, model: nn.Module) -> dict: |
795 | | - r"""Reorder ``state_dict`` so canonical-side tied keys iterate first. |
796 | | -
|
797 | | - Lets the downstream first-wins data_ptr dedup keep canonical names. |
798 | | - Uses both regex patterns and substring matchers from |
799 | | - :func:`_collect_canonical_tied_patterns`. Gated on the model class |
800 | | - name to scope the reorder to DiffusionGemma; other tied |
801 | | - encoder-decoder models that ship dict-style ``_tied_weights_keys`` |
802 | | - can be added to the allowlist here. Mirrors the ``model_type`` |
803 | | - dispatch used for the Whisper and Nemotron-VL branches elsewhere |
804 | | - in this file. |
805 | | - """ |
806 | | - model_type = type(model).__name__.lower() |
807 | | - if "diffusiongemma" not in model_type and "diffusion_gemma" not in model_type: |
808 | | - return state_dict |
809 | | - |
810 | | - canonical_patterns, side_substrings = _collect_canonical_tied_patterns(model) |
811 | | - if not canonical_patterns and not side_substrings: |
812 | | - return state_dict |
813 | | - |
814 | | - def _has_side_substring(key: str) -> bool: |
815 | | - # Require the token to appear as a proper dot-separated path |
816 | | - # component, not just as a substring of an unrelated identifier. |
817 | | - for tok in side_substrings: |
818 | | - if ( |
819 | | - f".{tok}." in key |
820 | | - or key.startswith(f"{tok}.") |
821 | | - or key.endswith(f".{tok}") |
822 | | - or key == tok |
823 | | - ): |
824 | | - return True |
825 | | - return False |
826 | | - |
827 | | - head: dict = {} |
828 | | - tail: dict = {} |
829 | | - for k, v in state_dict.items(): |
830 | | - if any(p.search(k) for p in canonical_patterns) or _has_side_substring(k): |
831 | | - head[k] = v |
832 | | - else: |
833 | | - tail[k] = v |
834 | | - head.update(tail) |
835 | | - return head |
836 | | - |
837 | | - |
838 | | -def sync_tied_input_amax(model: nn.Module) -> int: |
839 | | - """Max-merge input_quantizer amaxes across modules sharing a weight ``data_ptr``. |
840 | | -
|
841 | | - Mutates ``model`` in place: overwrites the ``.amax`` buffer on every |
842 | | - affected ``input_quantizer`` with the per-group maximum. Intended to |
843 | | - run as part of an export pipeline that already replaces weights with |
844 | | - packed bytes downstream — i.e. the model is not expected to be reused |
845 | | - after this helper runs. |
846 | | -
|
847 | | - Closes the loop on ``input_scale`` for HF-tied modules whose forward |
848 | | - paths see different activation distributions (encoder vs decoder in |
849 | | - YOCO-style models). Must run BEFORE per-module export so the merged |
850 | | - amax flows into ``input_scale`` derivation. Handles both dense |
851 | | - Linears (keyed by ``weight.data_ptr()``) and fused MoE (keyed by |
852 | | - ``(gate_up_proj, down_proj)`` data_ptr tuple). Returns the number of |
853 | | - tied groups merged. |
854 | | - """ |
855 | | - from collections import defaultdict |
856 | | - |
857 | | - by_dp: dict = defaultdict(list) |
858 | | - for _, m in model.named_modules(): |
859 | | - # Fused MoE: 3-D source tensors with shared input quantizers |
860 | | - if ( |
861 | | - hasattr(m, "gate_up_proj_input_quantizer") |
862 | | - and hasattr(m, "gate_up_proj") |
863 | | - and hasattr(m, "down_proj") |
864 | | - and m.gate_up_proj.dim() == 3 |
865 | | - ): |
866 | | - key = ("moe", m.gate_up_proj.data_ptr(), m.down_proj.data_ptr()) |
867 | | - by_dp[key].append(m) |
868 | | - # Dense quantized Linear with an input_quantizer |
869 | | - elif ( |
870 | | - hasattr(m, "input_quantizer") |
871 | | - and hasattr(m, "weight") |
872 | | - and isinstance(m.weight, torch.nn.Parameter) |
873 | | - ): |
874 | | - by_dp[("dense", m.weight.data_ptr())].append(m) |
875 | | - |
876 | | - def _merge(quantizers: list) -> bool: |
877 | | - """Max-merge amaxes across the quantizer list. Returns True on merge.""" |
878 | | - valid = [ |
879 | | - q |
880 | | - for q in quantizers |
881 | | - if q is not None |
882 | | - and getattr(q, "is_enabled", False) |
883 | | - and getattr(q, "_amax", None) is not None |
884 | | - and not q._amax.is_meta |
885 | | - ] |
886 | | - if len(valid) < 2: |
887 | | - return False |
888 | | - # Require scalar (per-tensor) amax — matches preprocess_linear_fusion. |
889 | | - if any(q._amax.numel() != 1 for q in valid): |
890 | | - warnings.warn( |
891 | | - "sync_tied_input_amax: non-scalar input_quantizer amax encountered " |
892 | | - "in a tied group; skipping. Only per-tensor input quantizers are " |
893 | | - "supported for tied-modules merging." |
894 | | - ) |
895 | | - return False |
896 | | - merged = torch.max(torch.stack([q.amax for q in valid])) |
897 | | - for q in valid: |
898 | | - q.amax = merged.clone() |
899 | | - return True |
900 | | - |
901 | | - synced = 0 |
902 | | - for key, modules in by_dp.items(): |
903 | | - if len(modules) < 2: |
904 | | - continue |
905 | | - if key[0] == "moe": |
906 | | - for q_name in ("gate_up_proj_input_quantizer", "down_proj_input_quantizer"): |
907 | | - if _merge([getattr(m, q_name, None) for m in modules]): |
908 | | - synced += 1 |
909 | | - elif _merge([m.input_quantizer for m in modules]): |
910 | | - synced += 1 |
911 | | - return synced |
912 | | - |
913 | | - |
914 | 758 | def _process_quantized_modules( |
915 | 759 | model: nn.Module, |
916 | 760 | dtype: torch.dtype, |
|
0 commit comments