diff --git a/QEfficient/base/modeling_qeff.py b/QEfficient/base/modeling_qeff.py index 6ecbf0fc0..72f5c050e 100644 --- a/QEfficient/base/modeling_qeff.py +++ b/QEfficient/base/modeling_qeff.py @@ -8,6 +8,7 @@ import gc import inspect import logging +import re import shutil import subprocess import warnings @@ -18,10 +19,12 @@ import onnx import torch -from QEfficient.base.onnx_transforms import OnnxTransform +from QEfficient.base.onnx_transforms import CustomOpTransform, OnnxTransform, RenameFunctionOutputsTransform from QEfficient.base.pytorch_transforms import PytorchTransform from QEfficient.compile.qnn_compiler import compile as qnn_compile from QEfficient.generation.cloud_infer import QAICInferenceSession +from QEfficient.transformers.cache_utils import InvalidIndexProvider +from QEfficient.transformers.models.pytorch_transforms import get_decoder_layer_classes_for_export from QEfficient.utils import ( constants, create_json, @@ -32,6 +35,7 @@ hash_dict_params, load_json, ) +from QEfficient.utils.torch_patches import apply_torch_patches, undo_torch_patches logger = logging.getLogger(__name__) @@ -179,6 +183,7 @@ def _export( onnx_transform_kwargs: Optional[Dict[str, any]] = None, export_dir: Optional[str] = None, offload_pt_weights: bool = True, + use_onnx_subfunctions: bool = False, ) -> str: """ Export the PyTorch model to ONNX and apply ONNX transforms @@ -243,7 +248,19 @@ def _export( input_names.append(param) try: + # Initialize the registry with your custom ops export_kwargs = {} if export_kwargs is None else export_kwargs + if use_onnx_subfunctions: + warnings.warn( + "The subfunction feature is experimental. Please note that using compile consecutively with and without subfunction may produce inconsistent results." + ) + apply_torch_patches() + InvalidIndexProvider.SUBFUNC_ENABLED = True + output_names = [re.sub("_RetainedState", "_InternalRetainedState", s) for s in output_names] + export_kwargs["export_modules_as_functions"] = get_decoder_layer_classes_for_export(self.model) + self._onnx_transforms.append(RenameFunctionOutputsTransform) + self._onnx_transforms.append(CustomOpTransform) + torch.onnx.export( self.model, (example_inputs,), @@ -255,7 +272,6 @@ def _export( **export_kwargs, ) logger.info("PyTorch export successful") - _ = self._offload_model_weights(offload_pt_weights) model = onnx.load(tmp_onnx_path, load_external_data=False) @@ -284,6 +300,12 @@ def _export( finally: shutil.rmtree(tmp_onnx_dir, ignore_errors=True) + if use_onnx_subfunctions: + undo_torch_patches() + InvalidIndexProvider.SUBFUNC_ENABLED = False + self._onnx_transforms.remove(CustomOpTransform) + self._onnx_transforms.remove(RenameFunctionOutputsTransform) + self.onnx_path = onnx_path return onnx_path @@ -300,6 +322,7 @@ def _compile( num_speculative_tokens: Optional[int] = None, enable_qnn: Optional[bool] = False, qnn_config: Optional[str] = None, + use_onnx_subfunctions: bool = False, **compiler_options, ) -> str: """ @@ -325,8 +348,9 @@ def _compile( For QNN Compilation path, when enable_qnn is set to True, any parameter passed in compiler_options will be ignored. """ + if onnx_path is None and self.onnx_path is None: - self.export() + self.export(use_onnx_subfunctions=use_onnx_subfunctions) onnx_path = Path(onnx_path or self.onnx_path) compile_dir = Path(compile_dir or onnx_path.parent) diff --git a/QEfficient/base/onnx_transforms.py b/QEfficient/base/onnx_transforms.py index 61b5c00f6..7ebe6bce5 100644 --- a/QEfficient/base/onnx_transforms.py +++ b/QEfficient/base/onnx_transforms.py @@ -5,11 +5,15 @@ # # ---------------------------------------------------------------------------- -from typing import Optional, Tuple +from typing import Any, Dict, Optional, Tuple import numpy as np +import torch from onnx import ModelProto, external_data_helper, numpy_helper +from QEfficient.customop.ctx_scatter_gather import CtxGather, CtxGatherFunc, CtxScatter, CtxScatterFunc +from QEfficient.customop.rms_norm import CustomRMSNorm, CustomRMSNormFunc + class OnnxTransform: """ @@ -99,3 +103,94 @@ def apply( current_file_size = tsize external_data_helper.set_external_data(tensor, f"{model_name}_{file_num}.onnx.data") return model, transformed + + +class CustomOpTransform(OnnxTransform): + """ + Transform to register custom operations and add their function protos to the ONNX model. + """ + + _custom_ops: Dict[str, Tuple[Any, Any]] = { + "CustomRMSNormFunc": (CustomRMSNormFunc, CustomRMSNorm), + "CtxScatterFunc": (CtxScatterFunc, CtxScatter), + "CtxGatherFunc": (CtxGatherFunc, CtxGather), + } + + @classmethod + def register_custom_op(cls, op_name: str, func_class: Any, onnxscript_func: Any) -> None: + """Register a custom operation.""" + cls._custom_ops[op_name] = (func_class, onnxscript_func) + + @classmethod + def apply(cls, model: ModelProto, *, opset_version: int = 17, **kwargs) -> Tuple[ModelProto, bool]: + """ + Apply custom op registration and add all function protos to the model. + + :param model: The ONNX model to transform. + :param opset_version: ONNX opset version for symbolic registration. + :returns: (Transformed model, success flag). + """ + transformed = False + + # Register all custom op symbolic functions with torch.onnx + for op_name, (func_class, _) in cls._custom_ops.items(): + if hasattr(func_class, "symbolic"): + torch.onnx.register_custom_op_symbolic(f"::{op_name}", func_class.symbolic, opset_version) + + func_names = {func.name for func in model.functions} + + for _, onnxscript_func in cls._custom_ops.values(): + proto = onnxscript_func.to_function_proto() + if proto.name not in func_names: + model.functions.append(proto) + transformed = True + + return model, transformed + + +class RenameFunctionOutputsTransform(OnnxTransform): + """ + Renames function outputs in decoder layers by removing 'Internal' from '_InternalRetainedState' patterns. + """ + + @classmethod + def apply(cls, model: ModelProto, **kwargs) -> Tuple[ModelProto, bool]: + """ + Rename function outputs in decoder layer nodes. + + :param model: The ONNX model to transform + :returns: Transformed model and boolean indicating whether transform was applied + """ + graph = model.graph + op_type_to_func_map = {func.name: func for func in model.functions} + decoder_layer_patterns = ["DecoderLayer", "Block", "Layer"] + transformed = False + + # Create a dict mapping output name to its index for quick lookup + model_graph_outputs_map = {val.name: idx for idx, val in enumerate(model.graph.output)} + + layer_index = 0 + for node in graph.node: + if any(pattern in node.name or pattern in node.op_type for pattern in decoder_layer_patterns): + func = op_type_to_func_map.get(node.op_type) + if func is None: + continue + + for i, out_name in enumerate(func.output): + if "_InternalRetainedState" in out_name: + transformed = True + original_output_name = node.output[i] + + # Generate new name based on key/value + if "key" in out_name: + new_name = f"past_key.{layer_index}_RetainedState" + elif "value" in out_name: + new_name = f"past_value.{layer_index}_RetainedState" + node.output[i] = new_name + + # Update graph output name if it exists + if original_output_name in model_graph_outputs_map: + idx = model_graph_outputs_map[original_output_name] + model.graph.output[idx].name = new_name + layer_index += 1 + return model, transformed diff --git a/QEfficient/peft/auto.py b/QEfficient/peft/auto.py index 592c0c1d3..99d64cc2f 100644 --- a/QEfficient/peft/auto.py +++ b/QEfficient/peft/auto.py @@ -245,7 +245,7 @@ def from_pretrained(cls, pretrained_name_or_path: str, *args, **kwargs): obj = cls._from_pretrained(pretrained_name_or_path, *args, **kwargs) return obj - def export(self, export_dir: Optional[str] = None) -> str: + def export(self, export_dir: Optional[str] = None, use_onnx_subfunctions: bool = False) -> str: """ Export the model with the active adapter to ONNX format. @@ -286,6 +286,7 @@ def export(self, export_dir: Optional[str] = None) -> str: export_kwargs={"do_constant_folding": False}, # To avoid merging adapter weights with base weights onnx_transform_kwargs={"adapter_name": self.model.active_adapter}, export_dir=export_dir, + use_onnx_subfunctions=use_onnx_subfunctions, ) def compile( @@ -300,6 +301,7 @@ def compile( num_cores: int = 16, mxfp6_matmul: bool = False, mxint8_kv_cache: bool = False, + use_onnx_subfunctions: bool = False, **compiler_options, ) -> str: """ @@ -367,6 +369,7 @@ def compile( mdp_ts_num_devices=num_devices, aic_num_cores=num_cores, mxint8_kv_cache=mxint8_kv_cache, + use_onnx_subfunctions=use_onnx_subfunctions, **compiler_options, ) diff --git a/QEfficient/peft/lora/auto.py b/QEfficient/peft/lora/auto.py index 8196cd769..64fa3f61c 100644 --- a/QEfficient/peft/lora/auto.py +++ b/QEfficient/peft/lora/auto.py @@ -327,7 +327,7 @@ def _init_adapter_model(self): # load_weight to model self._load_adapter_weights_to_model() - def export(self, export_dir: Optional[str] = None) -> str: + def export(self, export_dir: Optional[str] = None, use_onnx_subfunctions: bool = False) -> str: """ Export the model with all loaded adapters to ONNX format using ``torch.onnx.export``. @@ -387,6 +387,7 @@ def export(self, export_dir: Optional[str] = None) -> str: output_names, dynamic_axes, export_dir=export_dir, + use_onnx_subfunctions=use_onnx_subfunctions, ) def generate( diff --git a/QEfficient/transformers/cache_utils.py b/QEfficient/transformers/cache_utils.py index 5452589f6..292fe0487 100644 --- a/QEfficient/transformers/cache_utils.py +++ b/QEfficient/transformers/cache_utils.py @@ -24,6 +24,33 @@ ) +class InvalidIndexProvider: + SUBFUNC_ENABLED = False + + @classmethod + def enable_subfunc(cls): + cls.SUBFUNC_ENABLED = True + + @classmethod + def _get_invalid_idx_value(cls): + """ + Get the appropriate invalid index value for CtxGather operations. + + For ONNX export with functions, we use 0 to avoid INT32_MAX constants + that cause issues when functions are inlined at runtime. + + Returns: + int: Invalid index value (0 for ONNX functions, INT32_MAX otherwise) + """ + if torch.onnx.is_in_onnx_export(): + if cls.SUBFUNC_ENABLED: + return 0 + else: + return torch.iinfo(torch.int32).max + else: + return 0 + + class QEffDynamicLayer(DynamicLayer): def read_only(self, cache_kwargs): """ @@ -46,10 +73,7 @@ def read_only(self, cache_kwargs): gather_limit = position_ids.max(1, keepdim=True).values.unsqueeze(1) invalid_mask = ctx_indices > gather_limit - if torch.onnx.is_in_onnx_export(): - invalid_idx_value = torch.iinfo(torch.int32).max - else: - invalid_idx_value = 0 + invalid_idx_value = InvalidIndexProvider._get_invalid_idx_value() ctx_indices = torch.where(invalid_mask, invalid_idx_value, ctx_indices) @@ -143,10 +167,7 @@ def update( gather_limit = position_ids.max(1, keepdim=True).values.unsqueeze(1) invalid_mask = ctx_indices > gather_limit - if torch.onnx.is_in_onnx_export(): - invalid_idx_value = torch.iinfo(torch.int32).max - else: - invalid_idx_value = 0 + invalid_idx_value = InvalidIndexProvider._get_invalid_idx_value() ctx_indices = torch.where(invalid_mask, invalid_idx_value, ctx_indices) if batch_index is not None: @@ -419,10 +440,7 @@ def update( ctx_indices = torch.arange(ctx_len)[None, None, ...] gather_limit = kv_position_ids.max(1, keepdim=True).values.unsqueeze(1) invalid_mask = ctx_indices > gather_limit - if torch.onnx.is_in_onnx_export(): - invalid_idx_value = torch.iinfo(torch.int32).max - else: - invalid_idx_value = 0 + invalid_idx_value = InvalidIndexProvider._get_invalid_idx_value() ctx_indices = torch.where(invalid_mask, invalid_idx_value, ctx_indices) all_indices = torch.arange(layer_ctx_len) + kv_position_ids.max() + 1 diff --git a/QEfficient/transformers/models/modeling_auto.py b/QEfficient/transformers/models/modeling_auto.py index 5f1ec51e6..cbff5be91 100644 --- a/QEfficient/transformers/models/modeling_auto.py +++ b/QEfficient/transformers/models/modeling_auto.py @@ -27,7 +27,10 @@ import QEfficient from QEfficient.base.modeling_qeff import QEFFBaseModel -from QEfficient.base.onnx_transforms import FP16ClipTransform, SplitTensorsTransform +from QEfficient.base.onnx_transforms import ( + FP16ClipTransform, + SplitTensorsTransform, +) from QEfficient.base.pytorch_transforms import SplitGateUpWeightsTransform from QEfficient.generation.cloud_infer import QAICInferenceSession from QEfficient.generation.text_generation_inference import ( @@ -315,7 +318,7 @@ def get_model_config(self) -> dict: """ return self.model.config.__dict__ - def export(self, export_dir: Optional[str] = None) -> str: + def export(self, export_dir: Optional[str] = None, use_onnx_subfunctions: bool = False) -> str: """ Export the model to ONNX format using ``torch.onnx.export``. @@ -327,6 +330,8 @@ def export(self, export_dir: Optional[str] = None) -> str: export_dir : str, optional Directory path where the exported ONNX graph will be saved. If not provided, the default export directory is used. + use_onnx_subfunctions: bool, optional + whether to enable ONNX subfunctions during export. Exporting PyTorch model to ONNX with modules as subfunctions helps to reduce export/compile time. Defaults to False Returns ------- @@ -350,6 +355,7 @@ def export(self, export_dir: Optional[str] = None) -> str: output_names, dynamic_axes, export_dir=export_dir, + use_onnx_subfunctions=use_onnx_subfunctions, ) def compile( @@ -362,6 +368,7 @@ def compile( num_devices: int = 1, num_cores: int = 16, # FIXME: Make this mandatory arg mxfp6_matmul: bool = False, + use_onnx_subfunctions: bool = False, **compiler_options, ) -> str: """ @@ -388,6 +395,8 @@ def compile( Number of cores to use for compilation. mxfp6_matmul : bool, optional Use MXFP6 compression for weights. Default is False. + use_onnx_subfunctions: bool, optional + whether to enable ONNX subfunctions during export. Exporting PyTorch model to ONNX with modules as subfunctions helps to reduce export/compile time. Defaults to False **compiler_options : dict Additional compiler options for QAIC or QNN compilers. These are passed directly to the underlying compilation command. @@ -431,6 +440,7 @@ def compile( mxfp6_matmul=mxfp6_matmul, mdp_ts_num_devices=num_devices, aic_num_cores=num_cores, + use_onnx_subfunctions=use_onnx_subfunctions, **compiler_options, ) @@ -595,7 +605,15 @@ def __init__(self, model: nn.modules, **kwargs): self.model = model.get_qeff_vision_encoder() self.hash_params["qeff_auto_class"] = self.__class__.__name__ - def export(self, inputs, output_names, dynamic_axes, export_dir=None, offload_pt_weights=True): + def export( + self, + inputs, + output_names, + dynamic_axes, + export_dir=None, + offload_pt_weights=True, + use_onnx_subfunctions: bool = False, + ): """ Exports the vision encoder component to ONNX format. @@ -611,6 +629,8 @@ def export(self, inputs, output_names, dynamic_axes, export_dir=None, offload_pt Directory path where the exported ONNX graph will be saved. Default is None. offload_pt_weights : bool, optional If True, PyTorch weights will be offloaded after export. Default is True. + use_onnx_subfunctions: bool, optional + whether to enable ONNX subfunctions during export. Exporting PyTorch model to ONNX with modules as subfunctions helps to reduce export/compile time. Defaults to False Returns ------- @@ -618,7 +638,12 @@ def export(self, inputs, output_names, dynamic_axes, export_dir=None, offload_pt Path to the generated ONNX graph file for the vision encoder. """ return self._export( - inputs, output_names, dynamic_axes, export_dir=export_dir, offload_pt_weights=offload_pt_weights + inputs, + output_names, + dynamic_axes, + export_dir=export_dir, + offload_pt_weights=offload_pt_weights, + use_onnx_subfunctions=use_onnx_subfunctions, ) def compile( @@ -631,6 +656,7 @@ def compile( mdp_ts_num_devices, aic_num_cores, custom_io, + use_onnx_subfunctions: bool = False, **compiler_options, ) -> str: """ @@ -654,6 +680,8 @@ def compile( Number of cores to use for compilation. custom_io : Dict[str, str] Custom I/O configurations for the compiler. + use_onnx_subfunctions: bool, optional + whether to enable ONNX subfunctions during export. Exporting PyTorch model to ONNX with modules as subfunctions helps to reduce export/compile time. Defaults to False **compiler_options : Additional compiler options passed to the underlying compilation command. @@ -671,6 +699,7 @@ def compile( mdp_ts_num_devices=mdp_ts_num_devices, aic_num_cores=aic_num_cores, custom_io=custom_io, + use_onnx_subfunctions=use_onnx_subfunctions, **compiler_options, ) @@ -737,7 +766,15 @@ def __init__(self, model, **kwargs): self.model = model.get_qeff_language_decoder() self.hash_params["qeff_auto_class"] = self.__class__.__name__ - def export(self, inputs, output_names, dynamic_axes, export_dir=None, offload_pt_weights=True): + def export( + self, + inputs, + output_names, + dynamic_axes, + export_dir=None, + offload_pt_weights=True, + use_onnx_subfunctions: bool = False, + ): """ Exports the language decoder component to ONNX format. @@ -753,6 +790,8 @@ def export(self, inputs, output_names, dynamic_axes, export_dir=None, offload_pt Directory path where the exported ONNX graph will be saved. Default is None. offload_pt_weights : bool, optional If True, PyTorch weights will be offloaded after export. Default is True. + use_onnx_subfunctions: bool, optional + whether to enable ONNX subfunctions during export. Exporting PyTorch model to ONNX with modules as subfunctions helps to reduce export/compile time. Defaults to False Returns ------- @@ -760,7 +799,12 @@ def export(self, inputs, output_names, dynamic_axes, export_dir=None, offload_pt Path to the generated ONNX graph file for the language decoder. """ return self._export( - inputs, output_names, dynamic_axes, export_dir=export_dir, offload_pt_weights=offload_pt_weights + inputs, + output_names, + dynamic_axes, + export_dir=export_dir, + offload_pt_weights=offload_pt_weights, + use_onnx_subfunctions=use_onnx_subfunctions, ) def compile( @@ -773,6 +817,7 @@ def compile( mdp_ts_num_devices, aic_num_cores, custom_io, + use_onnx_subfunctions: bool = False, **compiler_options, ) -> str: """ @@ -796,6 +841,8 @@ def compile( Number of cores to use for compilation. custom_io : Dict[str, str] Custom I/O configurations for the compiler. + use_onnx_subfunctions: bool, optional + whether to enable ONNX subfunctions during export. Exporting PyTorch model to ONNX with modules as subfunctions helps to reduce export/compile time. Defaults to False **compiler_options : Additional compiler options passed to the underlying compilation command. @@ -813,6 +860,7 @@ def compile( mdp_ts_num_devices=mdp_ts_num_devices, aic_num_cores=aic_num_cores, custom_io=custom_io, + use_onnx_subfunctions=use_onnx_subfunctions, **compiler_options, ) @@ -973,6 +1021,7 @@ def qpc_path(self): def export( self, export_dir: Optional[str] = None, + use_onnx_subfunctions: bool = False, **kwargs, ) -> str: """ @@ -985,6 +1034,8 @@ def export( ---------- export_dir : str, optional Directory path where the exported ONNX graphs will be saved. Default is None. + use_onnx_subfunctions: bool, optional + whether to enable ONNX subfunctions during export. Exporting PyTorch model to ONNX with modules as subfunctions helps to reduce export/compile time. Defaults to False **kwargs : Additional keyword arguments. @@ -1018,9 +1069,15 @@ def export( dynamic_axes["vision"], export_dir=export_dir, offload_pt_weights=False, + use_onnx_subfunctions=use_onnx_subfunctions, ) self.lang_model.export( - inputs["lang"], output_names["lang"], dynamic_axes["lang"], export_dir=export_dir, offload_pt_weights=True + inputs["lang"], + output_names["lang"], + dynamic_axes["lang"], + export_dir=export_dir, + offload_pt_weights=True, + use_onnx_subfunctions=use_onnx_subfunctions, ) return self.onnx_path @@ -1043,6 +1100,7 @@ def compile( mxint8_kv_cache: bool = False, skip_vision: Optional[bool] = False, skip_lang: Optional[bool] = False, + use_onnx_subfunctions: bool = False, **compiler_options, ) -> str: """ @@ -1082,6 +1140,8 @@ def compile( If True, skips compilation of the vision encoder. Default is False. skip_lang : bool, optional If True, skips compilation of the language decoder. Default is False. + use_onnx_subfunctions: bool, optional + whether to enable ONNX subfunctions during export. Exporting PyTorch model to ONNX with modules as subfunctions helps to reduce export/compile time. Defaults to False **compiler_options : dict Additional compiler options for QAIC or QNN compilers. @@ -1154,7 +1214,9 @@ def compile( if (self.vision_model.onnx_path is None and vision_onnx_path is None) or ( self.lang_model.onnx_path is None and lang_onnx_path is None ): - self.export() + self.export( + use_onnx_subfunctions=use_onnx_subfunctions, + ) # TODO this hould be removed once the continous batching is supported for all the models. compiler_options.pop("continuous_batching", None) @@ -1172,6 +1234,7 @@ def compile( aic_num_cores=num_cores, custom_io=custom_io_vision, mxint8_kv_cache=mxint8_kv_cache, + use_onnx_subfunctions=use_onnx_subfunctions, **compiler_options, ) @@ -1200,6 +1263,7 @@ def compile( aic_num_cores=num_cores, custom_io=custom_io_lang, mxint8_kv_cache=mxint8_kv_cache, + use_onnx_subfunctions=use_onnx_subfunctions, **compiler_options, ) return self.qpc_path @@ -1624,6 +1688,7 @@ def from_pretrained( def export( self, export_dir: Optional[str] = None, + use_onnx_subfunctions: bool = False, **kwargs, ) -> str: """ @@ -1644,7 +1709,13 @@ def export( inputs = self.model.get_dummy_inputs(comp_ctx_lengths=self.comp_ctx_lengths_decode) dynamic_axes = self.model.get_onnx_dynamic_axes(comp_ctx_lengths=self.comp_ctx_lengths_decode) output_names = self.model.get_output_names() - return self._export(inputs, output_names, dynamic_axes, export_dir=export_dir) + return self._export( + inputs, + output_names, + dynamic_axes, + export_dir=export_dir, + use_onnx_subfunctions=use_onnx_subfunctions, + ) def compile( self, @@ -1662,6 +1733,7 @@ def compile( mxfp6_matmul: bool = False, mxint8_kv_cache: bool = False, num_speculative_tokens: Optional[int] = None, + use_onnx_subfunctions: bool = False, **compiler_options, ) -> str: """ @@ -1697,6 +1769,8 @@ def compile( Use MXINT8 compression for KV cache. Default is False. num_speculative_tokens : int, optional Not supported for this model; must be None. + use_onnx_subfunctions: bool, optional + whether to enable ONNX subfunctions during export. Exporting PyTorch model to ONNX with modules as subfunctions helps to reduce export/compile time. Defaults to False **compiler_options : dict Additional compiler options for QAIC or QNN compilers. @@ -1769,6 +1843,7 @@ def compile( mdp_ts_num_devices=num_devices, aic_num_cores=num_cores, mxint8_kv_cache=mxint8_kv_cache, + use_onnx_subfunctions=use_onnx_subfunctions, **compiler_options, ) return self.qpc_path @@ -2232,7 +2307,10 @@ class QEFFAutoModelForCausalLM(QEFFBaseModel): SplitGateUpWeightsTransform, KVCacheExternalModuleMapperTransform, ] - _onnx_transforms = [FP16ClipTransform, SplitTensorsTransform] + _onnx_transforms = [ + FP16ClipTransform, + SplitTensorsTransform, + ] def __init__( self, @@ -2423,7 +2501,7 @@ def get_model_config(self) -> dict: """ return self.model.config.__dict__ - def export(self, export_dir: Optional[str] = None) -> str: + def export(self, export_dir: Optional[str] = None, use_onnx_subfunctions: bool = False, **kwargs) -> str: """ Export the model to ONNX format using ``torch.onnx.export``. @@ -2436,7 +2514,8 @@ def export(self, export_dir: Optional[str] = None) -> str: export_dir : str, optional Directory path where the exported ONNX graph will be saved. If not provided, the default export directory is used. - + use_onnx_subfunctions: bool, optional + whether to enable ONNX subfunctions during export. Exporting PyTorch model to ONNX with modules as subfunctions helps to reduce export/compile time. Defaults to False Returns ------- str @@ -2532,6 +2611,8 @@ def export(self, export_dir: Optional[str] = None) -> str: output_names, dynamic_axes, export_dir=export_dir, + use_onnx_subfunctions=use_onnx_subfunctions, + offload_pt_weights=kwargs.get("offload_pt_weights", True), ) def get_sampling_inputs_and_outputs( @@ -2742,6 +2823,7 @@ def compile( mxint8_kv_cache: bool = False, num_speculative_tokens: Optional[int] = None, prefill_only: Optional[bool] = None, + use_onnx_subfunctions: bool = False, **compiler_options, ) -> str: """ @@ -2783,6 +2865,8 @@ def compile( prefill_only : bool, optional If True, compiles only for the prefill stage. If False, compiles only for the decode stage. If None, compiles for both stages. Default is None. + use_onnx_subfunctions: bool, optional + whether to enable ONNX subfunctions during export. Exporting PyTorch model to ONNX with modules as subfunctions helps to reduce export/compile time. Defaults to False **compiler_options : dict Additional compiler options for QAIC or QNN compilers. @@ -2944,6 +3028,7 @@ def compile( num_speculative_tokens=num_speculative_tokens, aic_num_cores=num_cores, mxint8_kv_cache=mxint8_kv_cache, + use_onnx_subfunctions=use_onnx_subfunctions, **compiler_options, ) @@ -3135,7 +3220,7 @@ def get_model_config(self) -> dict: """ return self.model.config.__dict__ - def export(self, export_dir: Optional[str] = None) -> str: + def export(self, export_dir: Optional[str] = None, use_onnx_subfunctions: bool = False) -> str: """ Export the model to ONNX format using ``torch.onnx.export``. @@ -3147,6 +3232,8 @@ def export(self, export_dir: Optional[str] = None) -> str: export_dir : str, optional Directory path where the exported ONNX graph will be saved. If not provided, the default export directory is used. + use_onnx_subfunctions: bool, optional + whether to enable ONNX subfunctions during export. Exporting PyTorch model to ONNX with modules as subfunctions helps to reduce export/compile time. Defaults to False Returns ------- @@ -3156,7 +3243,13 @@ def export(self, export_dir: Optional[str] = None) -> str: inputs = self.model.get_dummy_inputs() dynamic_axes = self.model.get_onnx_dynamic_axes() output_names = self.model.get_output_names() - return self._export(inputs, output_names, dynamic_axes, export_dir=export_dir) + return self._export( + inputs, + output_names, + dynamic_axes, + export_dir=export_dir, + use_onnx_subfunctions=use_onnx_subfunctions, + ) def compile( self, @@ -3174,6 +3267,7 @@ def compile( mxfp6_matmul: bool = False, mxint8_kv_cache: bool = False, num_speculative_tokens: Optional[int] = None, + use_onnx_subfunctions: bool = False, **compiler_options, ) -> str: """ @@ -3215,6 +3309,8 @@ def compile( Not yet supported for this model. num_speculative_tokens : int, optional Not yet supported for this model. + use_onnx_subfunctions: bool, optional + whether to enable ONNX subfunctions during export. Exporting PyTorch model to ONNX with modules as subfunctions helps to reduce export/compile time. Defaults to False **compiler_options : dict Additional compiler options for QAIC. @@ -3282,6 +3378,7 @@ def compile( mdp_ts_num_devices=num_devices, aic_num_cores=num_cores, custom_io=custom_io, + use_onnx_subfunctions=use_onnx_subfunctions, **compiler_options, ) @@ -3499,12 +3596,14 @@ def from_pretrained(cls, pretrained_model_name_or_path, pooling=None, *args, **k def get_model_config(self) -> dict: return self.model.config.__dict__ - def export(self, export_dir: Optional[str] = None) -> str: + def export(self, export_dir: Optional[str] = None, use_onnx_subfunctions: bool = False) -> str: """ Exports the model to ``ONNX`` format using ``torch.onnx.export``. ``Optional`` Args: :export_dir (str, optional): The directory path to store ONNX-graph. + :use_onnx_subfunctions: bool, optional + whether to enable ONNX subfunctions during export. Exporting PyTorch model to ONNX with modules as subfunctions helps to reduce export/compile time. Defaults to False Returns: :str: Path of the generated ``ONNX`` graph. @@ -3525,6 +3624,7 @@ def export(self, export_dir: Optional[str] = None) -> str: output_names, dynamic_axes, export_dir=export_dir, + use_onnx_subfunctions=use_onnx_subfunctions, ) def compile( @@ -3537,6 +3637,7 @@ def compile( num_devices: int = 1, num_cores: int = 16, # FIXME: Make this mandatory arg mxfp6_matmul: bool = False, + use_onnx_subfunctions: bool = False, **compiler_options, ) -> str: """ @@ -3552,6 +3653,7 @@ def compile( :num_devices (int): Number of devices the model needs to be compiled for. Defaults to 1. :num_cores (int): Number of cores used to compile the model. :mxfp6_matmul (bool, optional): Whether to use ``mxfp6`` compression for weights. ``Defaults to False``. + :use_onnx_subfunctions: bool, optional: whether to enable ONNX subfunctions during export. Exporting PyTorch model to ONNX with modules as subfunctions helps to reduce export/compile time. Defaults to False :compiler_options (dict, optional): Additional compiler options. For QAIC Compiler: Extra arguments for qaic-exec can be passed. @@ -3584,6 +3686,7 @@ def compile( mxfp6_matmul=mxfp6_matmul, mdp_ts_num_devices=num_devices, aic_num_cores=num_cores, + use_onnx_subfunctions=use_onnx_subfunctions, **compiler_options, ) diff --git a/QEfficient/transformers/models/pytorch_transforms.py b/QEfficient/transformers/models/pytorch_transforms.py index 773ce178c..62a873b9e 100644 --- a/QEfficient/transformers/models/pytorch_transforms.py +++ b/QEfficient/transformers/models/pytorch_transforms.py @@ -821,3 +821,29 @@ def apply(cls, model: nn.Module, pooling: Union[str, Callable]) -> Tuple[nn.Modu model = PooledModel(model, pooling_method) warnings.warn("Pooling is applied to the model.") return model, transformed + + +def get_decoder_layer_classes_for_export(model: nn.Module) -> set: + """ + Dynamically determine which DecoderLayer classes should be exported as functions + based on the model's architecture using the existing KVCacheTransform mapping. + """ + # Define patterns that identify decoder layer classes + DECODER_LAYER_PATTERNS = ["DecoderLayer", "Block", "Layer"] + + # Get all QEff classes that are decoder layers from the existing mapping + decoder_layer_classes = set() + + for original_class, qeff_class in KVCacheTransform._module_mapping.items(): + # Check if the QEff class name contains decoder layer patterns + qeff_class_name = qeff_class.__name__ + if any(pattern in qeff_class_name for pattern in DECODER_LAYER_PATTERNS): + decoder_layer_classes.add(qeff_class) + + # Filter to only include classes that are actually used in the current model + model_decoder_classes = set() + for module in model.modules(): + if module.__class__ in decoder_layer_classes: + model_decoder_classes.add(module.__class__) + + return model_decoder_classes diff --git a/QEfficient/utils/_utils.py b/QEfficient/utils/_utils.py index d58f54952..1fb0311eb 100644 --- a/QEfficient/utils/_utils.py +++ b/QEfficient/utils/_utils.py @@ -566,6 +566,7 @@ def wrapper(self, *args, **kwargs): dynamic_axes=all_args.get("dynamic_axes"), export_kwargs=all_args.get("export_kwargs", None), onnx_transform_kwargs=all_args.get("onnx_transform_kwargs", None), + use_onnx_subfunctions=all_args.get("use_onnx_subfunctions", False), ) export_dir = export_dir.with_name(export_dir.name + "-" + export_hash) kwargs["export_dir"] = export_dir diff --git a/QEfficient/utils/hash_utils.py b/QEfficient/utils/hash_utils.py index b6b38b8b4..948b72e6a 100644 --- a/QEfficient/utils/hash_utils.py +++ b/QEfficient/utils/hash_utils.py @@ -55,7 +55,8 @@ def create_export_hash(**kwargs): export_params = {} export_params["output_names"] = kwargs.get("output_names") export_params["dynamic_axes"] = kwargs.get("dynamic_axes") - + if kwargs.get("use_onnx_subfunctions"): + export_params["use_onnx_subfunctions"] = True export_hash_params["export_params"] = export_params export_kwargs = kwargs.get("export_kwargs") diff --git a/QEfficient/utils/torch_patches.py b/QEfficient/utils/torch_patches.py new file mode 100644 index 000000000..0b9b37afa --- /dev/null +++ b/QEfficient/utils/torch_patches.py @@ -0,0 +1,115 @@ +# ----------------------------------------------------------------------------- +# +# Copyright (c) Qualcomm Technologies, Inc. and/or its subsidiaries. +# SPDX-License-Identifier: BSD-3-Clause +# +# ----------------------------------------------------------------------------- + +"""Monkey patches for torch.onnx.utils to fix ONNX export issues.""" + +import torch +import torch.onnx.utils as onnx_utils +from torch import _C + +# Store original references before patching +_original_setup_trace_module_map = onnx_utils._setup_trace_module_map +_original_get_module_attributes = getattr(onnx_utils, "_get_module_attributes", None) + + +def _setup_trace_module_map_patched( + model, + export_modules_as_functions, +): + """Patched version of _setup_trace_module_map that fixes onnx_attrs type mismatch.""" + + def __register_attribute_hook(): + attr_name = "_onnx_attrs" + + def _track_module_attributes_forward_pre_hook(module, input): + setattr(module, attr_name, _get_module_attributes(module)) + + def _track_module_attributes_forward_hook(module, input, output): + tracing_state = _C._get_tracing_state() + if not tracing_state: + return + graph = tracing_state.graph() + onnx_attrs = {} + if hasattr(module, attr_name): + onnx_attrs = getattr(module, attr_name) + delattr(module, attr_name) + # FIX: use empty dict to avoid type mismatch + onnx_attrs = {} + _C._jit_pass_onnx_track_scope_attributes(graph, onnx_attrs) + + for m in model.modules(): + m.register_forward_hook(_track_module_attributes_forward_hook) + m.register_forward_pre_hook(_track_module_attributes_forward_pre_hook) + + def _unqualified_variable_name(qualified_name: str) -> str: + name_atoms = qualified_name.split(".") + for i, atom in reversed(list(enumerate(name_atoms))): + if not atom.isnumeric(): + return ".".join(name_atoms[i:]) + return qualified_name + + trace_module_map = { + _m: torch._C._jit_onnx_create_full_scope_name(torch.typename(type(_m)), _unqualified_variable_name(_n)) + for _n, _m in model.named_modules() + } + torch.jit._trace._trace_module_map = trace_module_map + + if isinstance(export_modules_as_functions, bool) and export_modules_as_functions: + module_typenames = {torch.typename(type(module)) for module in trace_module_map} + elif isinstance(export_modules_as_functions, set) and export_modules_as_functions: + + def _find_typename(v): + if isinstance(v, type): + return torch.typename(v) + else: + raise RuntimeError( + "Only type of the `nn.Module` should be passed in the set for argument `export_modules_as_functions`. " + f"Got `{type(v).__name__}`." + ) + + module_typenames = {_find_typename(v) for v in export_modules_as_functions} + else: + module_typenames = set() + + if module_typenames: + __register_attribute_hook() + + return module_typenames + + +def _get_module_attributes(module): + """Helper function to get module attributes safely.""" + import typing + + import torch.nn + + annotations = typing.get_type_hints(type(module)) + base_m_annotations = typing.get_type_hints(torch.nn.Module) + [annotations.pop(k, None) for k in base_m_annotations] + + attrs = {} + for k in annotations: + try: + attrs[k] = getattr(module, k) + except AttributeError: + _C._jit_onnx_log(f"Skipping module attribute '{k}'") + continue + return attrs + + +def apply_torch_patches(): + """Apply monkey patches for ONNX export.""" + onnx_utils._setup_trace_module_map = _setup_trace_module_map_patched + if hasattr(onnx_utils, "_get_module_attributes"): + onnx_utils._get_module_attributes = _get_module_attributes + + +def undo_torch_patches(): + """Undo monkey patches and restore original functions.""" + onnx_utils._setup_trace_module_map = _original_setup_trace_module_map + if _original_get_module_attributes: + onnx_utils._get_module_attributes = _original_get_module_attributes diff --git a/tests/transformers/test_subfunction.py b/tests/transformers/test_subfunction.py new file mode 100644 index 000000000..36cfc0ce5 --- /dev/null +++ b/tests/transformers/test_subfunction.py @@ -0,0 +1,67 @@ +# ----------------------------------------------------------------------------- +# +# Copyright (c) Qualcomm Technologies, Inc. and/or its subsidiaries. +# SPDX-License-Identifier: BSD-3-Clause +# +# ---------------------------------------------------------------------------- + +import pytest +import torch +from transformers import AutoConfig, AutoModelForCausalLM, AutoTokenizer + +from QEfficient.transformers.models.modeling_auto import QEFFAutoModelForCausalLM + +torch.manual_seed(42) + +configs = [ + ("gpt2", 256, 2, 4, 128, 512, 127, {}), +] + +configs = [ + AutoConfig.for_model( + model_name, + max_position_embeddings=max_position_embeddings, + num_hidden_layers=num_hidden_layers, + num_attention_heads=num_attention_heads, + hidden_size=hidden_size, + intermediate_size=intermediate_size, + vocab_size=vocab_size, + **additional_params, + ) + for ( + model_name, + max_position_embeddings, + num_hidden_layers, + num_attention_heads, + hidden_size, + intermediate_size, + vocab_size, + additional_params, + ) in configs +] + +model_kwargs = {"attn_implementation": "eager"} +config_ids = [x.model_type for x in configs] + + +@pytest.mark.parametrize("config", configs, ids=config_ids) +def test_subfunction_vs_nonsubfunction(config, tmp_path): + tokenizer = AutoTokenizer.from_pretrained(config.model_type) + model_0_0 = QEFFAutoModelForCausalLM(AutoModelForCausalLM.from_config(config, **model_kwargs), cb=False) + # model_0_0 = QEFFAutoModelForCausalLM.from_pretrained(config.model_type) + + with_sub_func_onnx = model_0_0.export(tmp_path, use_onnx_subfunctions=True, offload_pt_weights=False) + hash_0_0 = model_0_0.export_hash + + without_sub_func_onnx = model_0_0.export(tmp_path, use_onnx_subfunctions=False) + hash_0_1 = model_0_0.export_hash + + assert hash_0_0 != hash_0_1 + + compile_params = {"prefill_seq_len": 8, "ctx_len": 16} + model_0_0.compile(onnx_path=with_sub_func_onnx, **compile_params) + generation_00 = model_0_0.generate(prompts=["Help me with this"], tokenizer=tokenizer) + + model_0_0.compile(onnx_path=without_sub_func_onnx, **compile_params) + generation_01 = model_0_0.generate(prompts=["Help me with this"], tokenizer=tokenizer) + assert generation_00.generated_texts == generation_01.generated_texts