Skip to content

Commit 92f320c

Browse files
pushed all changes for incoperating subfunction in CausalLM
1 parent f4ff803 commit 92f320c

File tree

9 files changed

+515
-32
lines changed

9 files changed

+515
-32
lines changed

QEfficient/base/modeling_qeff.py

Lines changed: 33 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
import gc
99
import inspect
1010
import logging
11+
import re
1112
import shutil
1213
import subprocess
1314
import warnings
@@ -18,10 +19,14 @@
1819
import onnx
1920
import torch
2021

21-
from QEfficient.base.onnx_transforms import OnnxTransform
22+
from QEfficient.base.onnx_transforms import CustomOpTransform, OnnxTransform
2223
from QEfficient.base.pytorch_transforms import PytorchTransform
2324
from QEfficient.compile.qnn_compiler import compile as qnn_compile
25+
from QEfficient.customop.ctx_scatter_gather import CtxGather, CtxGatherFunc, CtxScatter, CtxScatterFunc
26+
from QEfficient.customop.rms_norm import CustomRMSNorm, CustomRMSNormFunc
2427
from QEfficient.generation.cloud_infer import QAICInferenceSession
28+
from QEfficient.transformers.cache_utils import InvalidIndexProvider
29+
from QEfficient.transformers.models.pytorch_transforms import get_decoder_layer_classes_for_export
2530
from QEfficient.utils import (
2631
constants,
2732
create_json,
@@ -32,6 +37,7 @@
3237
hash_dict_params,
3338
load_json,
3439
)
40+
from QEfficient.utils.patches import apply_torch_patches, undo_torch_patches
3541

3642
logger = logging.getLogger(__name__)
3743

@@ -53,7 +59,7 @@ class QEFFBaseModel(ABC):
5359
def _transform_names(cls) -> List[str]:
5460
return [x.__name__ for x in cls._pytorch_transforms + cls._onnx_transforms]
5561

56-
def __init__(self, model: torch.nn.Module, **kwargs) -> None:
62+
def __init__(self, model: torch.nn.Module, use_subfunctions: bool = False, **kwargs) -> None:
5763
super().__init__()
5864
self.model = model
5965
self.hash_params = create_model_params(self, **kwargs)
@@ -64,6 +70,7 @@ def __init__(self, model: torch.nn.Module, **kwargs) -> None:
6470
(arch := getattr(self.model.config, "architectures", None)) and len(arch) > 0 and arch[0]
6571
) or None
6672

73+
self.use_subfunctions = use_subfunctions
6774
# Flag for checking if weights are offloaded
6875
self._is_weights_offloaded: bool = False
6976

@@ -179,6 +186,7 @@ def _export(
179186
onnx_transform_kwargs: Optional[Dict[str, any]] = None,
180187
export_dir: Optional[str] = None,
181188
offload_pt_weights: bool = True,
189+
use_subfunctions: bool = False,
182190
) -> str:
183191
"""
184192
Export the PyTorch model to ONNX and apply ONNX transforms
@@ -243,7 +251,21 @@ def _export(
243251
input_names.append(param)
244252

245253
try:
254+
# Initialize the registry with your custom ops
246255
export_kwargs = {} if export_kwargs is None else export_kwargs
256+
CustomOpTransform.register_custom_op("CustomRMSNormFunc", CustomRMSNormFunc, CustomRMSNorm)
257+
CustomOpTransform.register_custom_op("CtxScatterFunc", CtxScatterFunc, CtxScatter)
258+
CustomOpTransform.register_custom_op("CtxGatherFunc", CtxGatherFunc, CtxGather)
259+
if use_subfunctions:
260+
warnings.warn(
261+
"The subfunction feature is experimental. Please note that using compile consecutively with and without subfunction may produce inconsistent results."
262+
)
263+
apply_torch_patches()
264+
InvalidIndexProvider.SUBFUNC_ENABLED = True
265+
output_names = [re.sub("_RetainedState", "_InternalRetainedState", s) for s in output_names]
266+
export_kwargs["export_modules_as_functions"] = get_decoder_layer_classes_for_export(self.model)
267+
self._onnx_transforms.append(CustomOpTransform)
268+
247269
torch.onnx.export(
248270
self.model,
249271
(example_inputs,),
@@ -252,15 +274,16 @@ def _export(
252274
output_names=output_names,
253275
dynamic_axes=dynamic_axes,
254276
opset_version=constants.ONNX_EXPORT_OPSET,
277+
do_constant_folding=True,
255278
**export_kwargs,
256279
)
257280
logger.info("PyTorch export successful")
258-
259281
_ = self._offload_model_weights(offload_pt_weights)
260282

261283
model = onnx.load(tmp_onnx_path, load_external_data=False)
262284
transform_kwargs = {
263285
"onnx_base_dir": str(tmp_onnx_dir),
286+
"temp_onnx_path": tmp_onnx_path,
264287
"model_name": self.model_name,
265288
}
266289
if onnx_transform_kwargs is not None:
@@ -284,6 +307,10 @@ def _export(
284307
finally:
285308
shutil.rmtree(tmp_onnx_dir, ignore_errors=True)
286309

310+
if use_subfunctions:
311+
undo_torch_patches()
312+
InvalidIndexProvider.SUBFUNC_ENABLED = False
313+
287314
self.onnx_path = onnx_path
288315
return onnx_path
289316

@@ -300,6 +327,7 @@ def _compile(
300327
num_speculative_tokens: Optional[int] = None,
301328
enable_qnn: Optional[bool] = False,
302329
qnn_config: Optional[str] = None,
330+
use_subfunctions: bool = False,
303331
**compiler_options,
304332
) -> str:
305333
"""
@@ -325,9 +353,9 @@ def _compile(
325353
326354
For QNN Compilation path, when enable_qnn is set to True, any parameter passed in compiler_options will be ignored.
327355
"""
328-
if onnx_path is None and self.onnx_path is None:
329-
self.export()
330356

357+
if onnx_path is None and self.onnx_path is None:
358+
self.export(use_subfunctions=use_subfunctions)
331359
onnx_path = Path(onnx_path or self.onnx_path)
332360
compile_dir = Path(compile_dir or onnx_path.parent)
333361
qpc_path = compile_dir / "qpc"

QEfficient/base/onnx_transforms.py

Lines changed: 165 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,9 +5,12 @@
55
#
66
# ----------------------------------------------------------------------------
77

8-
from typing import Optional, Tuple
8+
from typing import Any, Dict, List, Optional, Tuple
99

1010
import numpy as np
11+
import onnx
12+
import onnxslim
13+
import torch
1114
from onnx import ModelProto, external_data_helper, numpy_helper
1215

1316

@@ -99,3 +102,164 @@ def apply(
99102
current_file_size = tsize
100103
external_data_helper.set_external_data(tensor, f"{model_name}_{file_num}.onnx.data")
101104
return model, transformed
105+
106+
107+
class OnnxSlimTransform(OnnxTransform):
108+
"""
109+
Applies onnx-slim transformations on the given ONNX graph.
110+
"""
111+
112+
@classmethod
113+
def apply(
114+
cls,
115+
model: ModelProto,
116+
*,
117+
onnx_base_dir: Optional[str] = None,
118+
**kwargs,
119+
) -> Tuple[ModelProto, bool]:
120+
"""
121+
:param enable_onnx_slim_transform: If True, applies onnx-slim transformations.
122+
:param temp_onnx_path: Path to save the slimmed ONNX model.
123+
"""
124+
transformed = False
125+
onnx_slim_transform = True # kwargs.get("enable_onnx_slim_transform", False)
126+
temp_onnx_path = kwargs.get("temp_onnx_path", None)
127+
if not temp_onnx_path:
128+
err_str = "temp_onnx_path is required for onnx-slim transform."
129+
raise RuntimeError(err_str)
130+
if onnx_slim_transform:
131+
transformed = True
132+
slimmed_model = onnxslim.slim(model)
133+
onnx.save(slimmed_model, temp_onnx_path)
134+
return slimmed_model, transformed
135+
return model, transformed
136+
137+
138+
class CustomOpTransform(OnnxTransform):
139+
"""
140+
Transform to register custom operations and add their function protos to the ONNX model.
141+
"""
142+
143+
# Registry of custom operations
144+
_custom_ops: Dict[str, Tuple[Any, Any]] = {} # op_name -> (func_class, onnxscript_func)
145+
146+
@classmethod
147+
def register_custom_op(cls, op_name: str, func_class: Any, onnxscript_func: Any):
148+
"""Register a custom operation."""
149+
cls._custom_ops[op_name] = (func_class, onnxscript_func)
150+
151+
@classmethod
152+
def apply(cls, model: ModelProto, *, opset_version: int = 17, **kwargs) -> Tuple[ModelProto, bool]:
153+
"""
154+
Apply custom op registration and add function protos to the model.
155+
156+
:param model: The ONNX model to transform
157+
:param opset_version: ONNX opset version for symbolic registration
158+
:returns: Transformed model and success flag
159+
"""
160+
transformed = False
161+
162+
# Register all custom op symbolic functions with torch.onnx
163+
for op_name, (func_class, _) in cls._custom_ops.items():
164+
if hasattr(func_class, "symbolic"):
165+
torch.onnx.register_custom_op_symbolic(f"::{op_name}", func_class.symbolic, opset_version)
166+
167+
# Add function protos for custom ops that are used in the model
168+
used_protos = cls._get_function_protos_for_model(model)
169+
170+
for proto in used_protos:
171+
# Check if proto already exists to avoid duplicates
172+
proto_name = proto.name
173+
if not any(func.name == proto_name for func in model.functions):
174+
model.functions.append(proto)
175+
transformed = True
176+
177+
return model, transformed
178+
179+
@classmethod
180+
def _get_function_protos_for_model(cls, model: ModelProto) -> List[Any]:
181+
"""Get function protos for custom ops that are actually used in the model."""
182+
used_protos = []
183+
184+
# Get all node op_types in the model
185+
used_op_types = set()
186+
for node in model.graph.node:
187+
used_op_types.add(node.op_type)
188+
189+
# Also check function calls
190+
for func in model.functions:
191+
for node in func.node:
192+
used_op_types.add(node.op_type)
193+
194+
# Check which custom ops are actually used
195+
for op_name, (func_class, onnxscript_func) in cls._custom_ops.items():
196+
# Check if the custom op is referenced in the model
197+
if cls._is_custom_op_used(model, op_name, used_op_types):
198+
proto = onnxscript_func.to_function_proto()
199+
used_protos.append(proto)
200+
201+
return used_protos
202+
203+
@classmethod
204+
def _is_custom_op_used(cls, model: ModelProto, op_name: str, used_op_types: set) -> bool:
205+
"""Check if a custom op is used in the model."""
206+
# Check if the op_name appears in node op_types
207+
if op_name in used_op_types:
208+
return True
209+
210+
# Check for domain-specific ops (e.g., "com.qti.aisw.onnx::CustomRMSNorm")
211+
custom_op_pattern = f"com.qti.aisw.onnx::{op_name.replace('Func', '')}"
212+
if custom_op_pattern in used_op_types:
213+
return True
214+
215+
# Heuristic checks based on op type
216+
if "RMSNorm" in op_name:
217+
# Check if any RMSNorm-related ops are present
218+
return any("RMSNorm" in op_type for op_type in used_op_types)
219+
220+
if "Ctx" in op_name:
221+
# Check if Gather/Scatter operations are present (indicating KV cache usage)
222+
return any(op_type in ["Gather", "GatherND", "Scatter", "ScatterND"] for op_type in used_op_types)
223+
224+
return False
225+
226+
227+
class RenameFunctionOutputsTransform(OnnxTransform):
228+
"""
229+
Renames function outputs in decoder layers by removing 'Internal' from '_InternalRetainedState' patterns.
230+
"""
231+
232+
@classmethod
233+
def apply(cls, model: ModelProto, **kwargs) -> Tuple[ModelProto, bool]:
234+
"""
235+
Rename function outputs in decoder layer nodes.
236+
237+
:param model: The ONNX model to transform
238+
:returns: Transformed model and boolean indicating whether transform was applied
239+
"""
240+
graph = model.graph
241+
op_type_to_func_map = {func.name: func for func in model.functions}
242+
decoder_layer_patterns = ["DecoderLayer", "Block", "Layer"]
243+
transformed = False
244+
model_graph_outputs = [val.name for val in model.graph.output]
245+
layer_index = 0
246+
for node in graph.node:
247+
if any(pattern in node.name or pattern in node.op_type for pattern in decoder_layer_patterns):
248+
func = op_type_to_func_map.get(node.op_type)
249+
if func is None:
250+
continue
251+
252+
for i, out_name in enumerate(func.output):
253+
if "_InternalRetainedState" in out_name:
254+
transformed = True
255+
tmp = node.output[i]
256+
if "key" in out_name:
257+
new_name = f"past_key.{layer_index}_RetainedState"
258+
elif "value" in out_name:
259+
new_name = f"past_value.{layer_index}_RetainedState"
260+
node.output[i] = new_name
261+
# Update graph output name if it exists
262+
if tmp in model_graph_outputs:
263+
model.graph.output[model_graph_outputs.index(tmp)].name = new_name
264+
layer_index = layer_index + 1
265+
return model, transformed

QEfficient/transformers/cache_utils.py

Lines changed: 30 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,33 @@
2424
)
2525

2626

27+
class InvalidIndexProvider:
28+
SUBFUNC_ENABLED = False
29+
30+
@classmethod
31+
def enable_subfunc(cls):
32+
cls.SUBFUNC_ENABLED = True
33+
34+
@classmethod
35+
def _get_invalid_idx_value(cls):
36+
"""
37+
Get the appropriate invalid index value for CtxGather operations.
38+
39+
For ONNX export with functions, we use 0 to avoid INT32_MAX constants
40+
that cause issues when functions are inlined at runtime.
41+
42+
Returns:
43+
int: Invalid index value (0 for ONNX functions, INT32_MAX otherwise)
44+
"""
45+
if torch.onnx.is_in_onnx_export():
46+
if cls.SUBFUNC_ENABLED:
47+
return 0
48+
else:
49+
return torch.iinfo(torch.int32).max
50+
else:
51+
return 0
52+
53+
2754
class QEffDynamicLayer(DynamicLayer):
2855
def read_only(self, cache_kwargs):
2956
"""
@@ -46,10 +73,7 @@ def read_only(self, cache_kwargs):
4673
gather_limit = position_ids.max(1, keepdim=True).values.unsqueeze(1)
4774
invalid_mask = ctx_indices > gather_limit
4875

49-
if torch.onnx.is_in_onnx_export():
50-
invalid_idx_value = torch.iinfo(torch.int32).max
51-
else:
52-
invalid_idx_value = 0
76+
invalid_idx_value = InvalidIndexProvider._get_invalid_idx_value()
5377

5478
ctx_indices = torch.where(invalid_mask, invalid_idx_value, ctx_indices)
5579

@@ -143,10 +167,7 @@ def update(
143167
gather_limit = position_ids.max(1, keepdim=True).values.unsqueeze(1)
144168
invalid_mask = ctx_indices > gather_limit
145169

146-
if torch.onnx.is_in_onnx_export():
147-
invalid_idx_value = torch.iinfo(torch.int32).max
148-
else:
149-
invalid_idx_value = 0
170+
invalid_idx_value = InvalidIndexProvider._get_invalid_idx_value()
150171

151172
ctx_indices = torch.where(invalid_mask, invalid_idx_value, ctx_indices)
152173
if batch_index is not None:
@@ -419,10 +440,7 @@ def update(
419440
ctx_indices = torch.arange(ctx_len)[None, None, ...]
420441
gather_limit = kv_position_ids.max(1, keepdim=True).values.unsqueeze(1)
421442
invalid_mask = ctx_indices > gather_limit
422-
if torch.onnx.is_in_onnx_export():
423-
invalid_idx_value = torch.iinfo(torch.int32).max
424-
else:
425-
invalid_idx_value = 0
443+
invalid_idx_value = InvalidIndexProvider._get_invalid_idx_value()
426444
ctx_indices = torch.where(invalid_mask, invalid_idx_value, ctx_indices)
427445

428446
all_indices = torch.arange(layer_ctx_len) + kv_position_ids.max() + 1

0 commit comments

Comments
 (0)