Skip to content

Commit 30c334b

Browse files
Add ONNX Sub Functions Export Feature for AutoModelForCausalLM (#621)
# ONNX Functions Export Support ## Overview This PR introduces support for exporting ONNX modules as **functions**, enabling more efficient model compilation and execution on hardware. ## Key Changes - Added a new flag **`use_onnx_subfunctions`** to control ONNX function export behavior. - Integrated ONNX function export capability into the inference pipeline. ## How to Enable ONNX Function Export Set the flag before running inference (either during export or compile): ```bash model.export(tmp_path, use_onnx_subfunctions=True) ``` ## Backward Compatibility This feature is **opt-in** and requires an explicit environment variable. Existing workflows remain unaffected when the flag is disabled. --------- Signed-off-by: abhishek-singh591 <[email protected]> Signed-off-by: Ann Kuruvilla <[email protected]> Co-authored-by: quic-akuruvil <[email protected]>
1 parent b014f72 commit 30c334b

File tree

11 files changed

+488
-34
lines changed

11 files changed

+488
-34
lines changed

QEfficient/base/modeling_qeff.py

Lines changed: 27 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
import gc
99
import inspect
1010
import logging
11+
import re
1112
import shutil
1213
import subprocess
1314
import warnings
@@ -18,10 +19,12 @@
1819
import onnx
1920
import torch
2021

21-
from QEfficient.base.onnx_transforms import OnnxTransform
22+
from QEfficient.base.onnx_transforms import CustomOpTransform, OnnxTransform, RenameFunctionOutputsTransform
2223
from QEfficient.base.pytorch_transforms import PytorchTransform
2324
from QEfficient.compile.qnn_compiler import compile as qnn_compile
2425
from QEfficient.generation.cloud_infer import QAICInferenceSession
26+
from QEfficient.transformers.cache_utils import InvalidIndexProvider
27+
from QEfficient.transformers.models.pytorch_transforms import get_decoder_layer_classes_for_export
2528
from QEfficient.utils import (
2629
constants,
2730
create_json,
@@ -32,6 +35,7 @@
3235
hash_dict_params,
3336
load_json,
3437
)
38+
from QEfficient.utils.torch_patches import apply_torch_patches, undo_torch_patches
3539

3640
logger = logging.getLogger(__name__)
3741

@@ -179,6 +183,7 @@ def _export(
179183
onnx_transform_kwargs: Optional[Dict[str, any]] = None,
180184
export_dir: Optional[str] = None,
181185
offload_pt_weights: bool = True,
186+
use_onnx_subfunctions: bool = False,
182187
) -> str:
183188
"""
184189
Export the PyTorch model to ONNX and apply ONNX transforms
@@ -243,7 +248,19 @@ def _export(
243248
input_names.append(param)
244249

245250
try:
251+
# Initialize the registry with your custom ops
246252
export_kwargs = {} if export_kwargs is None else export_kwargs
253+
if use_onnx_subfunctions:
254+
warnings.warn(
255+
"The subfunction feature is experimental. Please note that using compile consecutively with and without subfunction may produce inconsistent results."
256+
)
257+
apply_torch_patches()
258+
InvalidIndexProvider.SUBFUNC_ENABLED = True
259+
output_names = [re.sub("_RetainedState", "_InternalRetainedState", s) for s in output_names]
260+
export_kwargs["export_modules_as_functions"] = get_decoder_layer_classes_for_export(self.model)
261+
self._onnx_transforms.append(RenameFunctionOutputsTransform)
262+
self._onnx_transforms.append(CustomOpTransform)
263+
247264
torch.onnx.export(
248265
self.model,
249266
(example_inputs,),
@@ -255,7 +272,6 @@ def _export(
255272
**export_kwargs,
256273
)
257274
logger.info("PyTorch export successful")
258-
259275
_ = self._offload_model_weights(offload_pt_weights)
260276

261277
model = onnx.load(tmp_onnx_path, load_external_data=False)
@@ -284,6 +300,12 @@ def _export(
284300
finally:
285301
shutil.rmtree(tmp_onnx_dir, ignore_errors=True)
286302

303+
if use_onnx_subfunctions:
304+
undo_torch_patches()
305+
InvalidIndexProvider.SUBFUNC_ENABLED = False
306+
self._onnx_transforms.remove(CustomOpTransform)
307+
self._onnx_transforms.remove(RenameFunctionOutputsTransform)
308+
287309
self.onnx_path = onnx_path
288310
return onnx_path
289311

@@ -300,6 +322,7 @@ def _compile(
300322
num_speculative_tokens: Optional[int] = None,
301323
enable_qnn: Optional[bool] = False,
302324
qnn_config: Optional[str] = None,
325+
use_onnx_subfunctions: bool = False,
303326
**compiler_options,
304327
) -> str:
305328
"""
@@ -325,8 +348,9 @@ def _compile(
325348
326349
For QNN Compilation path, when enable_qnn is set to True, any parameter passed in compiler_options will be ignored.
327350
"""
351+
328352
if onnx_path is None and self.onnx_path is None:
329-
self.export()
353+
self.export(use_onnx_subfunctions=use_onnx_subfunctions)
330354

331355
onnx_path = Path(onnx_path or self.onnx_path)
332356
compile_dir = Path(compile_dir or onnx_path.parent)

QEfficient/base/onnx_transforms.py

Lines changed: 96 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,11 +5,15 @@
55
#
66
# ----------------------------------------------------------------------------
77

8-
from typing import Optional, Tuple
8+
from typing import Any, Dict, Optional, Tuple
99

1010
import numpy as np
11+
import torch
1112
from onnx import ModelProto, external_data_helper, numpy_helper
1213

14+
from QEfficient.customop.ctx_scatter_gather import CtxGather, CtxGatherFunc, CtxScatter, CtxScatterFunc
15+
from QEfficient.customop.rms_norm import CustomRMSNorm, CustomRMSNormFunc
16+
1317

1418
class OnnxTransform:
1519
"""
@@ -99,3 +103,94 @@ def apply(
99103
current_file_size = tsize
100104
external_data_helper.set_external_data(tensor, f"{model_name}_{file_num}.onnx.data")
101105
return model, transformed
106+
107+
108+
class CustomOpTransform(OnnxTransform):
109+
"""
110+
Transform to register custom operations and add their function protos to the ONNX model.
111+
"""
112+
113+
_custom_ops: Dict[str, Tuple[Any, Any]] = {
114+
"CustomRMSNormFunc": (CustomRMSNormFunc, CustomRMSNorm),
115+
"CtxScatterFunc": (CtxScatterFunc, CtxScatter),
116+
"CtxGatherFunc": (CtxGatherFunc, CtxGather),
117+
}
118+
119+
@classmethod
120+
def register_custom_op(cls, op_name: str, func_class: Any, onnxscript_func: Any) -> None:
121+
"""Register a custom operation."""
122+
cls._custom_ops[op_name] = (func_class, onnxscript_func)
123+
124+
@classmethod
125+
def apply(cls, model: ModelProto, *, opset_version: int = 17, **kwargs) -> Tuple[ModelProto, bool]:
126+
"""
127+
Apply custom op registration and add all function protos to the model.
128+
129+
:param model: The ONNX model to transform.
130+
:param opset_version: ONNX opset version for symbolic registration.
131+
:returns: (Transformed model, success flag).
132+
"""
133+
transformed = False
134+
135+
# Register all custom op symbolic functions with torch.onnx
136+
for op_name, (func_class, _) in cls._custom_ops.items():
137+
if hasattr(func_class, "symbolic"):
138+
torch.onnx.register_custom_op_symbolic(f"::{op_name}", func_class.symbolic, opset_version)
139+
140+
func_names = {func.name for func in model.functions}
141+
142+
for _, onnxscript_func in cls._custom_ops.values():
143+
proto = onnxscript_func.to_function_proto()
144+
if proto.name not in func_names:
145+
model.functions.append(proto)
146+
transformed = True
147+
148+
return model, transformed
149+
150+
151+
class RenameFunctionOutputsTransform(OnnxTransform):
152+
"""
153+
Renames function outputs in decoder layers by removing 'Internal' from '_InternalRetainedState' patterns.
154+
"""
155+
156+
@classmethod
157+
def apply(cls, model: ModelProto, **kwargs) -> Tuple[ModelProto, bool]:
158+
"""
159+
Rename function outputs in decoder layer nodes.
160+
161+
:param model: The ONNX model to transform
162+
:returns: Transformed model and boolean indicating whether transform was applied
163+
"""
164+
graph = model.graph
165+
op_type_to_func_map = {func.name: func for func in model.functions}
166+
decoder_layer_patterns = ["DecoderLayer", "Block", "Layer"]
167+
transformed = False
168+
169+
# Create a dict mapping output name to its index for quick lookup
170+
model_graph_outputs_map = {val.name: idx for idx, val in enumerate(model.graph.output)}
171+
172+
layer_index = 0
173+
for node in graph.node:
174+
if any(pattern in node.name or pattern in node.op_type for pattern in decoder_layer_patterns):
175+
func = op_type_to_func_map.get(node.op_type)
176+
if func is None:
177+
continue
178+
179+
for i, out_name in enumerate(func.output):
180+
if "_InternalRetainedState" in out_name:
181+
transformed = True
182+
original_output_name = node.output[i]
183+
184+
# Generate new name based on key/value
185+
if "key" in out_name:
186+
new_name = f"past_key.{layer_index}_RetainedState"
187+
elif "value" in out_name:
188+
new_name = f"past_value.{layer_index}_RetainedState"
189+
node.output[i] = new_name
190+
191+
# Update graph output name if it exists
192+
if original_output_name in model_graph_outputs_map:
193+
idx = model_graph_outputs_map[original_output_name]
194+
model.graph.output[idx].name = new_name
195+
layer_index += 1
196+
return model, transformed

QEfficient/peft/auto.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -245,7 +245,7 @@ def from_pretrained(cls, pretrained_name_or_path: str, *args, **kwargs):
245245
obj = cls._from_pretrained(pretrained_name_or_path, *args, **kwargs)
246246
return obj
247247

248-
def export(self, export_dir: Optional[str] = None) -> str:
248+
def export(self, export_dir: Optional[str] = None, use_onnx_subfunctions: bool = False) -> str:
249249
"""
250250
Export the model with the active adapter to ONNX format.
251251
@@ -286,6 +286,7 @@ def export(self, export_dir: Optional[str] = None) -> str:
286286
export_kwargs={"do_constant_folding": False}, # To avoid merging adapter weights with base weights
287287
onnx_transform_kwargs={"adapter_name": self.model.active_adapter},
288288
export_dir=export_dir,
289+
use_onnx_subfunctions=use_onnx_subfunctions,
289290
)
290291

291292
def compile(
@@ -300,6 +301,7 @@ def compile(
300301
num_cores: int = 16,
301302
mxfp6_matmul: bool = False,
302303
mxint8_kv_cache: bool = False,
304+
use_onnx_subfunctions: bool = False,
303305
**compiler_options,
304306
) -> str:
305307
"""
@@ -367,6 +369,7 @@ def compile(
367369
mdp_ts_num_devices=num_devices,
368370
aic_num_cores=num_cores,
369371
mxint8_kv_cache=mxint8_kv_cache,
372+
use_onnx_subfunctions=use_onnx_subfunctions,
370373
**compiler_options,
371374
)
372375

QEfficient/peft/lora/auto.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -327,7 +327,7 @@ def _init_adapter_model(self):
327327
# load_weight to model
328328
self._load_adapter_weights_to_model()
329329

330-
def export(self, export_dir: Optional[str] = None) -> str:
330+
def export(self, export_dir: Optional[str] = None, use_onnx_subfunctions: bool = False) -> str:
331331
"""
332332
Export the model with all loaded adapters to ONNX format using ``torch.onnx.export``.
333333
@@ -387,6 +387,7 @@ def export(self, export_dir: Optional[str] = None) -> str:
387387
output_names,
388388
dynamic_axes,
389389
export_dir=export_dir,
390+
use_onnx_subfunctions=use_onnx_subfunctions,
390391
)
391392

392393
def generate(

QEfficient/transformers/cache_utils.py

Lines changed: 30 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,33 @@
2424
)
2525

2626

27+
class InvalidIndexProvider:
28+
SUBFUNC_ENABLED = False
29+
30+
@classmethod
31+
def enable_subfunc(cls):
32+
cls.SUBFUNC_ENABLED = True
33+
34+
@classmethod
35+
def _get_invalid_idx_value(cls):
36+
"""
37+
Get the appropriate invalid index value for CtxGather operations.
38+
39+
For ONNX export with functions, we use 0 to avoid INT32_MAX constants
40+
that cause issues when functions are inlined at runtime.
41+
42+
Returns:
43+
int: Invalid index value (0 for ONNX functions, INT32_MAX otherwise)
44+
"""
45+
if torch.onnx.is_in_onnx_export():
46+
if cls.SUBFUNC_ENABLED:
47+
return 0
48+
else:
49+
return torch.iinfo(torch.int32).max
50+
else:
51+
return 0
52+
53+
2754
class QEffDynamicLayer(DynamicLayer):
2855
def read_only(self, cache_kwargs):
2956
"""
@@ -46,10 +73,7 @@ def read_only(self, cache_kwargs):
4673
gather_limit = position_ids.max(1, keepdim=True).values.unsqueeze(1)
4774
invalid_mask = ctx_indices > gather_limit
4875

49-
if torch.onnx.is_in_onnx_export():
50-
invalid_idx_value = torch.iinfo(torch.int32).max
51-
else:
52-
invalid_idx_value = 0
76+
invalid_idx_value = InvalidIndexProvider._get_invalid_idx_value()
5377

5478
ctx_indices = torch.where(invalid_mask, invalid_idx_value, ctx_indices)
5579

@@ -143,10 +167,7 @@ def update(
143167
gather_limit = position_ids.max(1, keepdim=True).values.unsqueeze(1)
144168
invalid_mask = ctx_indices > gather_limit
145169

146-
if torch.onnx.is_in_onnx_export():
147-
invalid_idx_value = torch.iinfo(torch.int32).max
148-
else:
149-
invalid_idx_value = 0
170+
invalid_idx_value = InvalidIndexProvider._get_invalid_idx_value()
150171

151172
ctx_indices = torch.where(invalid_mask, invalid_idx_value, ctx_indices)
152173
if batch_index is not None:
@@ -419,10 +440,7 @@ def update(
419440
ctx_indices = torch.arange(ctx_len)[None, None, ...]
420441
gather_limit = kv_position_ids.max(1, keepdim=True).values.unsqueeze(1)
421442
invalid_mask = ctx_indices > gather_limit
422-
if torch.onnx.is_in_onnx_export():
423-
invalid_idx_value = torch.iinfo(torch.int32).max
424-
else:
425-
invalid_idx_value = 0
443+
invalid_idx_value = InvalidIndexProvider._get_invalid_idx_value()
426444
ctx_indices = torch.where(invalid_mask, invalid_idx_value, ctx_indices)
427445

428446
all_indices = torch.arange(layer_ctx_len) + kv_position_ids.max() + 1

0 commit comments

Comments
 (0)