55#
66# ----------------------------------------------------------------------------
77
8- from typing import Any , Dict , List , Optional , Tuple
8+ from typing import Any , Dict , List , Optional , Set , Tuple
99
1010import numpy as np
1111import torch
@@ -107,11 +107,11 @@ class CustomOpTransform(OnnxTransform):
107107 Transform to register custom operations and add their function protos to the ONNX model.
108108 """
109109
110- # Registry of custom operations
111- _custom_ops : Dict [str , Tuple [Any , Any ]] = {} # op_name -> (func_class, onnxscript_func)
110+ # Registry of custom operations: op_name -> (func_class, onnxscript_func)
111+ _custom_ops : Dict [str , Tuple [Any , Any ]] = {}
112112
113113 @classmethod
114- def register_custom_op (cls , op_name : str , func_class : Any , onnxscript_func : Any ):
114+ def register_custom_op (cls , op_name : str , func_class : Any , onnxscript_func : Any ) -> None :
115115 """Register a custom operation."""
116116 cls ._custom_ops [op_name ] = (func_class , onnxscript_func )
117117
@@ -120,9 +120,9 @@ def apply(cls, model: ModelProto, *, opset_version: int = 17, **kwargs) -> Tuple
120120 """
121121 Apply custom op registration and add function protos to the model.
122122
123- :param model: The ONNX model to transform
124- :param opset_version: ONNX opset version for symbolic registration
125- :returns: Transformed model and success flag
123+ :param model: The ONNX model to transform.
124+ :param opset_version: ONNX opset version for symbolic registration.
125+ :returns: ( Transformed model, success flag).
126126 """
127127 transformed = False
128128
@@ -131,62 +131,70 @@ def apply(cls, model: ModelProto, *, opset_version: int = 17, **kwargs) -> Tuple
131131 if hasattr (func_class , "symbolic" ):
132132 torch .onnx .register_custom_op_symbolic (f"::{ op_name } " , func_class .symbolic , opset_version )
133133
134- # Add function protos for custom ops that are used in the model
135- used_protos = cls ._get_function_protos_for_model (model )
134+ # Gather function names and all nodes (graph + function nodes)
135+ func_names : Set [str ] = {func .name for func in model .functions }
136+ all_nodes = list (model .graph .node )
137+ for func in model .functions :
138+ all_nodes .extend (func .node )
139+
140+ # Collect used op types
141+ used_op_types : Set [str ] = {node .op_type for node in all_nodes }
142+
143+ # Precompute heuristic flags
144+ has_rmsnorm = any ("RMSNorm" in op_type for op_type in used_op_types )
145+ has_ctx_ops = any (op_type in ["Gather" , "GatherND" , "Scatter" , "ScatterND" ] for op_type in used_op_types )
146+
147+ # Get function protos for custom ops used in the model
148+ used_protos = cls ._get_function_protos_for_model (used_op_types , has_rmsnorm , has_ctx_ops )
136149
150+ # Append new function protos if not already present
137151 for proto in used_protos :
138- # Check if proto already exists to avoid duplicates
139- proto_name = proto .name
140- if not any (func .name == proto_name for func in model .functions ):
152+ if proto .name not in func_names :
141153 model .functions .append (proto )
142154 transformed = True
143155
144156 return model , transformed
145157
146158 @classmethod
147- def _get_function_protos_for_model (cls , model : ModelProto ) -> List [Any ]:
148- """Get function protos for custom ops that are actually used in the model."""
149- used_protos = []
150-
151- # Get all node op_types in the model
152- used_op_types = set ()
153- for node in model .graph .node :
154- used_op_types .add (node .op_type )
155-
156- # Also check function calls
157- for func in model .functions :
158- for node in func .node :
159- used_op_types .add (node .op_type )
160-
161- # Check which custom ops are actually used
162- for op_name , (func_class , onnxscript_func ) in cls ._custom_ops .items ():
163- # Check if the custom op is referenced in the model
164- if cls ._is_custom_op_used (model , op_name , used_op_types ):
165- proto = onnxscript_func .to_function_proto ()
166- used_protos .append (proto )
159+ def _get_function_protos_for_model (cls , used_op_types : Set [str ], has_rmsnorm : bool , has_ctx_ops : bool ) -> List [Any ]:
160+ """
161+ Get function protos for custom ops that are actually used in the model.
167162
163+ :param used_op_types: Set of op types used in the model.
164+ :param has_rmsnorm: Flag indicating if RMSNorm-related ops are present.
165+ :param has_ctx_ops: Flag indicating if context-related ops are present.
166+ :returns: List of ONNX function protos.
167+ """
168+ used_protos : List [Any ] = []
169+ for op_name , (_ , onnxscript_func ) in cls ._custom_ops .items ():
170+ if cls ._is_custom_op_used (op_name , used_op_types , has_rmsnorm , has_ctx_ops ):
171+ used_protos .append (onnxscript_func .to_function_proto ())
168172 return used_protos
169173
170174 @classmethod
171- def _is_custom_op_used (cls , model : ModelProto , op_name : str , used_op_types : set ) -> bool :
172- """Check if a custom op is used in the model."""
173- # Check if the op_name appears in node op_types
175+ def _is_custom_op_used (cls , op_name : str , used_op_types : Set [str ], has_rmsnorm : bool , has_ctx_ops : bool ) -> bool :
176+ """
177+ Check if a custom op is used in the model.
178+
179+ :param op_name: Name of the custom op.
180+ :param used_op_types: Set of op types used in the model.
181+ :param has_rmsnorm: Precomputed RMSNorm presence flag.
182+ :param has_ctx_ops: Precomputed context ops presence flag.
183+ :returns: True if the custom op is used, False otherwise.
184+ """
174185 if op_name in used_op_types :
175186 return True
176187
177- # Check for domain-specific ops (e.g., "com.qti.aisw.onnx::CustomRMSNorm")
178- custom_op_pattern = f"com.qti.aisw.onnx::{ op_name .replace ('Func' , '' )} "
179- if custom_op_pattern in used_op_types :
188+ # Check for domain-specific ops
189+ if f"com.qti.aisw.onnx::{ op_name .replace ('Func' , '' )} " in used_op_types :
180190 return True
181191
182- # Heuristic checks based on op type
183- if "RMSNorm" in op_name :
184- # Check if any RMSNorm-related ops are present
185- return any ("RMSNorm" in op_type for op_type in used_op_types )
192+ # Heuristic checks
193+ if "RMSNorm" in op_name and has_rmsnorm :
194+ return True
186195
187- if "Ctx" in op_name :
188- # Check if Gather/Scatter operations are present (indicating KV cache usage)
189- return any (op_type in ["Gather" , "GatherND" , "Scatter" , "ScatterND" ] for op_type in used_op_types )
196+ if "Ctx" in op_name and has_ctx_ops :
197+ return True
190198
191199 return False
192200
@@ -208,7 +216,10 @@ def apply(cls, model: ModelProto, **kwargs) -> Tuple[ModelProto, bool]:
208216 op_type_to_func_map = {func .name : func for func in model .functions }
209217 decoder_layer_patterns = ["DecoderLayer" , "Block" , "Layer" ]
210218 transformed = False
211- model_graph_outputs = [val .name for val in model .graph .output ]
219+
220+ # Create a dict mapping output name to its index for quick lookup
221+ model_graph_outputs_map = {val .name : idx for idx , val in enumerate (model .graph .output )}
222+
212223 layer_index = 0
213224 for node in graph .node :
214225 if any (pattern in node .name or pattern in node .op_type for pattern in decoder_layer_patterns ):
@@ -219,14 +230,18 @@ def apply(cls, model: ModelProto, **kwargs) -> Tuple[ModelProto, bool]:
219230 for i , out_name in enumerate (func .output ):
220231 if "_InternalRetainedState" in out_name :
221232 transformed = True
222- tmp = node .output [i ]
233+ original_output_name = node .output [i ]
234+
235+ # Generate new name based on key/value
223236 if "key" in out_name :
224237 new_name = f"past_key.{ layer_index } _RetainedState"
225238 elif "value" in out_name :
226239 new_name = f"past_value.{ layer_index } _RetainedState"
227240 node .output [i ] = new_name
241+
228242 # Update graph output name if it exists
229- if tmp in model_graph_outputs :
230- model .graph .output [model_graph_outputs .index (tmp )].name = new_name
231- layer_index = layer_index + 1
243+ if original_output_name in model_graph_outputs_map :
244+ idx = model_graph_outputs_map [original_output_name ]
245+ model .graph .output [idx ].name = new_name
246+ layer_index += 1
232247 return model , transformed
0 commit comments