Skip to content

Commit 6daa209

Browse files
Addressed all the comments
Signed-off-by: abhishek-singh591 <[email protected]>
1 parent 219230a commit 6daa209

File tree

6 files changed

+115
-63
lines changed

6 files changed

+115
-63
lines changed

QEfficient/base/modeling_qeff.py

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@
1919
import onnx
2020
import torch
2121

22-
from QEfficient.base.onnx_transforms import CustomOpTransform, OnnxTransform
22+
from QEfficient.base.onnx_transforms import CustomOpTransform, OnnxTransform, RenameFunctionOutputsTransform
2323
from QEfficient.base.pytorch_transforms import PytorchTransform
2424
from QEfficient.compile.qnn_compiler import compile as qnn_compile
2525
from QEfficient.customop.ctx_scatter_gather import CtxGather, CtxGatherFunc, CtxScatter, CtxScatterFunc
@@ -37,7 +37,7 @@
3737
hash_dict_params,
3838
load_json,
3939
)
40-
from QEfficient.utils.patches import apply_torch_patches, undo_torch_patches
40+
from QEfficient.utils.torch_patches import apply_torch_patches, undo_torch_patches
4141

4242
logger = logging.getLogger(__name__)
4343

@@ -59,7 +59,7 @@ class QEFFBaseModel(ABC):
5959
def _transform_names(cls) -> List[str]:
6060
return [x.__name__ for x in cls._pytorch_transforms + cls._onnx_transforms]
6161

62-
def __init__(self, model: torch.nn.Module, use_onnx_subfunctions: bool = False, **kwargs) -> None:
62+
def __init__(self, model: torch.nn.Module, **kwargs) -> None:
6363
super().__init__()
6464
self.model = model
6565
self.hash_params = create_model_params(self, **kwargs)
@@ -70,7 +70,6 @@ def __init__(self, model: torch.nn.Module, use_onnx_subfunctions: bool = False,
7070
(arch := getattr(self.model.config, "architectures", None)) and len(arch) > 0 and arch[0]
7171
) or None
7272

73-
self.use_onnx_subfunctions = use_onnx_subfunctions
7473
# Flag for checking if weights are offloaded
7574
self._is_weights_offloaded: bool = False
7675

@@ -264,8 +263,10 @@ def _export(
264263
InvalidIndexProvider.SUBFUNC_ENABLED = True
265264
output_names = [re.sub("_RetainedState", "_InternalRetainedState", s) for s in output_names]
266265
export_kwargs["export_modules_as_functions"] = get_decoder_layer_classes_for_export(self.model)
266+
self._onnx_transforms.append(RenameFunctionOutputsTransform)
267267
self._onnx_transforms.append(CustomOpTransform)
268268

269+
# import pdb; pdb.set_trace()
269270
torch.onnx.export(
270271
self.model,
271272
(example_inputs,),
@@ -274,7 +275,6 @@ def _export(
274275
output_names=output_names,
275276
dynamic_axes=dynamic_axes,
276277
opset_version=constants.ONNX_EXPORT_OPSET,
277-
do_constant_folding=True,
278278
**export_kwargs,
279279
)
280280
logger.info("PyTorch export successful")
@@ -283,7 +283,6 @@ def _export(
283283
model = onnx.load(tmp_onnx_path, load_external_data=False)
284284
transform_kwargs = {
285285
"onnx_base_dir": str(tmp_onnx_dir),
286-
"temp_onnx_path": tmp_onnx_path,
287286
"model_name": self.model_name,
288287
}
289288
if onnx_transform_kwargs is not None:
@@ -310,6 +309,8 @@ def _export(
310309
if use_onnx_subfunctions:
311310
undo_torch_patches()
312311
InvalidIndexProvider.SUBFUNC_ENABLED = False
312+
self._onnx_transforms.pop()
313+
self._onnx_transforms.pop()
313314

314315
self.onnx_path = onnx_path
315316
return onnx_path
@@ -356,6 +357,7 @@ def _compile(
356357

357358
if onnx_path is None and self.onnx_path is None:
358359
self.export(use_onnx_subfunctions=use_onnx_subfunctions)
360+
359361
onnx_path = Path(onnx_path or self.onnx_path)
360362
compile_dir = Path(compile_dir or onnx_path.parent)
361363
qpc_path = compile_dir / "qpc"

QEfficient/base/onnx_transforms.py

Lines changed: 65 additions & 50 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
#
66
# ----------------------------------------------------------------------------
77

8-
from typing import Any, Dict, List, Optional, Tuple
8+
from typing import Any, Dict, List, Optional, Set, Tuple
99

1010
import numpy as np
1111
import torch
@@ -107,11 +107,11 @@ class CustomOpTransform(OnnxTransform):
107107
Transform to register custom operations and add their function protos to the ONNX model.
108108
"""
109109

110-
# Registry of custom operations
111-
_custom_ops: Dict[str, Tuple[Any, Any]] = {} # op_name -> (func_class, onnxscript_func)
110+
# Registry of custom operations: op_name -> (func_class, onnxscript_func)
111+
_custom_ops: Dict[str, Tuple[Any, Any]] = {}
112112

113113
@classmethod
114-
def register_custom_op(cls, op_name: str, func_class: Any, onnxscript_func: Any):
114+
def register_custom_op(cls, op_name: str, func_class: Any, onnxscript_func: Any) -> None:
115115
"""Register a custom operation."""
116116
cls._custom_ops[op_name] = (func_class, onnxscript_func)
117117

@@ -120,9 +120,9 @@ def apply(cls, model: ModelProto, *, opset_version: int = 17, **kwargs) -> Tuple
120120
"""
121121
Apply custom op registration and add function protos to the model.
122122
123-
:param model: The ONNX model to transform
124-
:param opset_version: ONNX opset version for symbolic registration
125-
:returns: Transformed model and success flag
123+
:param model: The ONNX model to transform.
124+
:param opset_version: ONNX opset version for symbolic registration.
125+
:returns: (Transformed model, success flag).
126126
"""
127127
transformed = False
128128

@@ -131,62 +131,70 @@ def apply(cls, model: ModelProto, *, opset_version: int = 17, **kwargs) -> Tuple
131131
if hasattr(func_class, "symbolic"):
132132
torch.onnx.register_custom_op_symbolic(f"::{op_name}", func_class.symbolic, opset_version)
133133

134-
# Add function protos for custom ops that are used in the model
135-
used_protos = cls._get_function_protos_for_model(model)
134+
# Gather function names and all nodes (graph + function nodes)
135+
func_names: Set[str] = {func.name for func in model.functions}
136+
all_nodes = list(model.graph.node)
137+
for func in model.functions:
138+
all_nodes.extend(func.node)
139+
140+
# Collect used op types
141+
used_op_types: Set[str] = {node.op_type for node in all_nodes}
142+
143+
# Precompute heuristic flags
144+
has_rmsnorm = any("RMSNorm" in op_type for op_type in used_op_types)
145+
has_ctx_ops = any(op_type in ["Gather", "GatherND", "Scatter", "ScatterND"] for op_type in used_op_types)
146+
147+
# Get function protos for custom ops used in the model
148+
used_protos = cls._get_function_protos_for_model(used_op_types, has_rmsnorm, has_ctx_ops)
136149

150+
# Append new function protos if not already present
137151
for proto in used_protos:
138-
# Check if proto already exists to avoid duplicates
139-
proto_name = proto.name
140-
if not any(func.name == proto_name for func in model.functions):
152+
if proto.name not in func_names:
141153
model.functions.append(proto)
142154
transformed = True
143155

144156
return model, transformed
145157

146158
@classmethod
147-
def _get_function_protos_for_model(cls, model: ModelProto) -> List[Any]:
148-
"""Get function protos for custom ops that are actually used in the model."""
149-
used_protos = []
150-
151-
# Get all node op_types in the model
152-
used_op_types = set()
153-
for node in model.graph.node:
154-
used_op_types.add(node.op_type)
155-
156-
# Also check function calls
157-
for func in model.functions:
158-
for node in func.node:
159-
used_op_types.add(node.op_type)
160-
161-
# Check which custom ops are actually used
162-
for op_name, (func_class, onnxscript_func) in cls._custom_ops.items():
163-
# Check if the custom op is referenced in the model
164-
if cls._is_custom_op_used(model, op_name, used_op_types):
165-
proto = onnxscript_func.to_function_proto()
166-
used_protos.append(proto)
159+
def _get_function_protos_for_model(cls, used_op_types: Set[str], has_rmsnorm: bool, has_ctx_ops: bool) -> List[Any]:
160+
"""
161+
Get function protos for custom ops that are actually used in the model.
167162
163+
:param used_op_types: Set of op types used in the model.
164+
:param has_rmsnorm: Flag indicating if RMSNorm-related ops are present.
165+
:param has_ctx_ops: Flag indicating if context-related ops are present.
166+
:returns: List of ONNX function protos.
167+
"""
168+
used_protos: List[Any] = []
169+
for op_name, (_, onnxscript_func) in cls._custom_ops.items():
170+
if cls._is_custom_op_used(op_name, used_op_types, has_rmsnorm, has_ctx_ops):
171+
used_protos.append(onnxscript_func.to_function_proto())
168172
return used_protos
169173

170174
@classmethod
171-
def _is_custom_op_used(cls, model: ModelProto, op_name: str, used_op_types: set) -> bool:
172-
"""Check if a custom op is used in the model."""
173-
# Check if the op_name appears in node op_types
175+
def _is_custom_op_used(cls, op_name: str, used_op_types: Set[str], has_rmsnorm: bool, has_ctx_ops: bool) -> bool:
176+
"""
177+
Check if a custom op is used in the model.
178+
179+
:param op_name: Name of the custom op.
180+
:param used_op_types: Set of op types used in the model.
181+
:param has_rmsnorm: Precomputed RMSNorm presence flag.
182+
:param has_ctx_ops: Precomputed context ops presence flag.
183+
:returns: True if the custom op is used, False otherwise.
184+
"""
174185
if op_name in used_op_types:
175186
return True
176187

177-
# Check for domain-specific ops (e.g., "com.qti.aisw.onnx::CustomRMSNorm")
178-
custom_op_pattern = f"com.qti.aisw.onnx::{op_name.replace('Func', '')}"
179-
if custom_op_pattern in used_op_types:
188+
# Check for domain-specific ops
189+
if f"com.qti.aisw.onnx::{op_name.replace('Func', '')}" in used_op_types:
180190
return True
181191

182-
# Heuristic checks based on op type
183-
if "RMSNorm" in op_name:
184-
# Check if any RMSNorm-related ops are present
185-
return any("RMSNorm" in op_type for op_type in used_op_types)
192+
# Heuristic checks
193+
if "RMSNorm" in op_name and has_rmsnorm:
194+
return True
186195

187-
if "Ctx" in op_name:
188-
# Check if Gather/Scatter operations are present (indicating KV cache usage)
189-
return any(op_type in ["Gather", "GatherND", "Scatter", "ScatterND"] for op_type in used_op_types)
196+
if "Ctx" in op_name and has_ctx_ops:
197+
return True
190198

191199
return False
192200

@@ -208,7 +216,10 @@ def apply(cls, model: ModelProto, **kwargs) -> Tuple[ModelProto, bool]:
208216
op_type_to_func_map = {func.name: func for func in model.functions}
209217
decoder_layer_patterns = ["DecoderLayer", "Block", "Layer"]
210218
transformed = False
211-
model_graph_outputs = [val.name for val in model.graph.output]
219+
220+
# Create a dict mapping output name to its index for quick lookup
221+
model_graph_outputs_map = {val.name: idx for idx, val in enumerate(model.graph.output)}
222+
212223
layer_index = 0
213224
for node in graph.node:
214225
if any(pattern in node.name or pattern in node.op_type for pattern in decoder_layer_patterns):
@@ -219,14 +230,18 @@ def apply(cls, model: ModelProto, **kwargs) -> Tuple[ModelProto, bool]:
219230
for i, out_name in enumerate(func.output):
220231
if "_InternalRetainedState" in out_name:
221232
transformed = True
222-
tmp = node.output[i]
233+
original_output_name = node.output[i]
234+
235+
# Generate new name based on key/value
223236
if "key" in out_name:
224237
new_name = f"past_key.{layer_index}_RetainedState"
225238
elif "value" in out_name:
226239
new_name = f"past_value.{layer_index}_RetainedState"
227240
node.output[i] = new_name
241+
228242
# Update graph output name if it exists
229-
if tmp in model_graph_outputs:
230-
model.graph.output[model_graph_outputs.index(tmp)].name = new_name
231-
layer_index = layer_index + 1
243+
if original_output_name in model_graph_outputs_map:
244+
idx = model_graph_outputs_map[original_output_name]
245+
model.graph.output[idx].name = new_name
246+
layer_index += 1
232247
return model, transformed

0 commit comments

Comments
 (0)