Skip to content

Commit 663f897

Browse files
Added the latest change for subfunction
1 parent 6bc5256 commit 663f897

File tree

9 files changed

+23084
-3804
lines changed

9 files changed

+23084
-3804
lines changed

QEfficient/__init__.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
from QEfficient.utils import custom_format_warning
1313
from QEfficient.utils.logging_utils import logger
1414

15+
from QEfficient.utils.patches import apply_torch_patches, is_patched
1516
# For faster downloads via hf_transfer
1617
# This code is put above import statements as this needs to be executed before
1718
# hf_transfer is imported (will happen on line 15 via leading imports)
@@ -21,6 +22,9 @@
2122
# custom warning for the better logging experience
2223
warnings.formatwarning = custom_format_warning
2324

25+
# Apply patches
26+
# TODO: Find a better way to do this, this is temp. fix.
27+
apply_torch_patches()
2428

2529
def check_qaic_sdk():
2630
"""Check if QAIC SDK is installed"""
@@ -69,6 +73,8 @@ def check_qaic_sdk():
6973
"QEFFAutoModelForImageTextToText",
7074
"QEFFAutoModelForSpeechSeq2Seq",
7175
"QEFFCommonLoader",
76+
"apply_torch_patches",
77+
"is_patched",
7278
]
7379

7480
else:

QEfficient/base/modeling_qeff.py

Lines changed: 1 addition & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -253,22 +253,6 @@ def _export(
253253
decoder_layer_classes = get_decoder_layer_classes_for_export(self.model)
254254
export_kwargs = {} if export_kwargs is None else export_kwargs
255255

256-
# def sanitize_decoder_layer_for_onnx(module):
257-
# """Remove or simplify attributes that ONNX export cannot handle."""
258-
# unsafe_attrs = ["config", "experts", "router", "cache", "past_key_values", "sliding_window"]
259-
# for attr in unsafe_attrs:
260-
# if hasattr(module, attr):
261-
# try:
262-
# setattr(module, attr, None)
263-
# except Exception:
264-
# pass
265-
266-
# # Sanitize *only* the decoder layers
267-
# for m in self.model.modules():
268-
# if m.__class__ in decoder_layer_classes:
269-
# sanitize_decoder_layer_for_onnx(m)
270-
271-
# import pdb; pdb.set_trace()
272256
torch.onnx.export(
273257
self.model,
274258
(example_inputs,),
@@ -285,9 +269,8 @@ def _export(
285269
logger.info("PyTorch export successful")
286270

287271
_ = self._offload_model_weights(offload_pt_weights)
288-
289-
rename_function_outputs(tmp_onnx_path, output_names)
290272
model = onnx.load(tmp_onnx_path, load_external_data=False)
273+
model,transformed = rename_function_outputs(model)
291274
transform_kwargs = {
292275
"onnx_base_dir": str(tmp_onnx_dir),
293276
"temp_onnx_path": tmp_onnx_path,

QEfficient/base/onnx_transforms.py

Lines changed: 15 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -223,50 +223,22 @@ def _is_custom_op_used(cls, model: ModelProto, op_name: str, used_op_types: set)
223223

224224
return False
225225

226-
227-
def rename_function_outputs(onnx_path, expected_output_names):
228-
model = onnx.load(onnx_path, load_external_data=False)
226+
def rename_function_outputs(model):
229227
graph = model.graph
230-
for i, output in enumerate(graph.output):
231-
output.name = expected_output_names[i]
232-
228+
op_type_to_func_map = {func.name:func for func in model.functions}
233229
decoder_layer_patterns = ["DecoderLayer", "Block", "Layer"]
234-
layer_index = 0
235-
output_rename_map = {}
236-
230+
transformed = False
231+
model_graph_outputs = [val.name for val in model.graph.output]
237232
for node in graph.node:
238233
if any(pattern in node.name or pattern in node.op_type for pattern in decoder_layer_patterns):
239-
if "layers.0" in node.name:
240-
if len(node.output) >= 4:
241-
print(f"Renaming outputs of node (layers.0): {node.name}")
242-
new_output_0 = f"past_key.{layer_index}_RetainedState"
243-
new_output_1 = f"past_value.{layer_index}_RetainedState"
244-
output_rename_map[node.output[2]] = new_output_0
245-
output_rename_map[node.output[3]] = new_output_1
246-
node.output[2] = new_output_0
247-
node.output[3] = new_output_1
248-
layer_index += 1
249-
else:
250-
print(f"Warning: Node {node.name} has fewer than 4 outputs.")
251-
elif len(node.output) >= 2:
252-
print(f"Renaming outputs of node: {node.name}")
253-
new_output_0 = f"past_key.{layer_index}_RetainedState"
254-
new_output_1 = f"past_value.{layer_index}_RetainedState"
255-
output_rename_map[node.output[0]] = new_output_0
256-
output_rename_map[node.output[1]] = new_output_1
257-
node.output[0] = new_output_0
258-
node.output[1] = new_output_1
259-
layer_index += 1
260-
else:
261-
print(f"Warning: Node {node.name} has fewer than 2 outputs.")
262-
263-
for node in graph.node:
264-
for i, input_name in enumerate(node.input):
265-
if input_name in output_rename_map:
266-
import pdb
267-
268-
pdb.set_trace()
269-
print(f"Replacing input {input_name} in node {node.name} with {output_rename_map[input_name]}")
270-
node.input[i] = output_rename_map[input_name]
271-
272-
onnx.save(model, onnx_path)
234+
func = op_type_to_func_map[node.op_type]
235+
for i, out_name in enumerate(func.output):
236+
if "_InternalRetainedState" in out_name:
237+
transformed = True
238+
tmp = node.output[i]
239+
new_name = func.output[i].replace("Internal", "")
240+
print(f"renaming {node.output[i]} to {new_name}")
241+
node.output[i] = new_name
242+
model.graph.output[model_graph_outputs.index(tmp)].name = new_name
243+
244+
return model, transformed

QEfficient/transformers/models/gpt_oss/modeling_gpt_oss.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -171,7 +171,7 @@ def forward(self, hidden_states):
171171
up = torch.bmm(expert_in, up_proj) + up_proj_bias.unsqueeze(1)
172172

173173
# Apply activation with clamping
174-
gate = gate.clamp(min=None, max=self.experts.limit)
174+
gate = gate.clamp(min=-self.experts.limit, max=self.experts.limit)
175175
up = up.clamp(min=-self.experts.limit, max=self.experts.limit)
176176

177177
# GLU activation

QEfficient/utils/patches.py

Lines changed: 120 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,120 @@
1+
# -----------------------------------------------------------------------------
2+
#
3+
# Copyright (c) Qualcomm Technologies, Inc. and/or its subsidiaries.
4+
# SPDX-License-Identifier: BSD-3-Clause
5+
#
6+
# -----------------------------------------------------------------------------
7+
8+
"""Monkey patches for torch.onnx.utils to fix ONNX export issues."""
9+
10+
from typing import Collection, Set, Type, Union
11+
12+
import torch
13+
import torch.onnx.utils as onnx_utils
14+
from torch import _C
15+
16+
17+
def _setup_trace_module_map_patched(
18+
model: Union[torch.nn.Module, torch.jit.ScriptModule],
19+
export_modules_as_functions: Union[bool, Collection[Type[torch.nn.Module]]],
20+
) -> Set[str]:
21+
"""Patched version of _setup_trace_module_map that fixes onnx_attrs type mismatch."""
22+
23+
def __register_attribute_hook():
24+
attr_name = "_onnx_attrs"
25+
26+
def _track_module_attributes_forward_pre_hook(module, input):
27+
setattr(module, attr_name, _get_module_attributes(module))
28+
29+
def _track_module_attributes_forward_hook(module, input, output):
30+
tracing_state = _C._get_tracing_state()
31+
if not tracing_state:
32+
return
33+
graph = tracing_state.graph()
34+
onnx_attrs = {}
35+
if hasattr(module, attr_name):
36+
onnx_attrs = getattr(module, attr_name)
37+
delattr(module, attr_name)
38+
# FIX: use empty dict to avoid type mismatch with _jit_pass_onnx_track_scope_attributes
39+
# Observed in transformers v4.55 and above
40+
onnx_attrs = {}
41+
_C._jit_pass_onnx_track_scope_attributes(graph, onnx_attrs)
42+
43+
for m in model.modules():
44+
m.register_forward_hook(_track_module_attributes_forward_hook)
45+
m.register_forward_pre_hook(_track_module_attributes_forward_pre_hook)
46+
47+
def _unqualified_variable_name(qualified_name: str) -> str:
48+
"""
49+
Parse qualified variable name and return the unqualified version.
50+
Pure numeric atoms are considered inadequate, so this function will look past them,
51+
and start from the first non-numeric atom.
52+
"""
53+
name_atoms = qualified_name.split(".")
54+
for i, atom in reversed(list(enumerate(name_atoms))):
55+
if not atom.isnumeric():
56+
return ".".join(name_atoms[i:])
57+
return qualified_name
58+
59+
trace_module_map = {
60+
_m: torch._C._jit_onnx_create_full_scope_name(torch.typename(type(_m)), _unqualified_variable_name(_n))
61+
for _n, _m in model.named_modules()
62+
}
63+
torch.jit._trace._trace_module_map = trace_module_map
64+
65+
if isinstance(export_modules_as_functions, bool) and export_modules_as_functions:
66+
module_typenames = {torch.typename(type(module)) for module in trace_module_map}
67+
elif isinstance(export_modules_as_functions, set) and export_modules_as_functions:
68+
69+
def _find_typename(v):
70+
if isinstance(v, type):
71+
return torch.typename(v)
72+
else:
73+
raise RuntimeError(
74+
"Only type of the `nn.Module` should be "
75+
"passed in the set for argument `export_modules_as_functions`. "
76+
f"Got `{type(v).__name__}`."
77+
)
78+
79+
module_typenames = {_find_typename(v) for v in export_modules_as_functions}
80+
else:
81+
module_typenames = set()
82+
83+
if module_typenames:
84+
__register_attribute_hook()
85+
86+
return module_typenames
87+
88+
89+
def _get_module_attributes(module):
90+
"""Helper function to get module attributes safely."""
91+
import typing
92+
93+
annotations = typing.get_type_hints(type(module))
94+
base_m_annotations = typing.get_type_hints(torch.nn.Module)
95+
[annotations.pop(k, None) for k in base_m_annotations]
96+
97+
attrs = {}
98+
for k in annotations:
99+
try:
100+
attrs[k] = getattr(module, k)
101+
except AttributeError:
102+
_C._jit_onnx_log(f"Skipping module attribute '{k}'")
103+
continue
104+
return attrs
105+
106+
107+
def apply_torch_patches():
108+
"""Apply all necessary torch patches for ONNX export."""
109+
# Monkey patch the function
110+
onnx_utils._setup_trace_module_map = _setup_trace_module_map_patched
111+
112+
if hasattr(onnx_utils, "_get_module_attributes"):
113+
onnx_utils._get_module_attributes = _get_module_attributes
114+
115+
print("Applied torch ONNX export patches for export_modules_as_functions compatibility")
116+
117+
118+
def is_patched():
119+
"""Check if patches have been applied."""
120+
return onnx_utils._setup_trace_module_map == _setup_trace_module_map_patched

output.txt

Lines changed: 22935 additions & 3736 deletions
Large diffs are not rendered by default.

pyproject.toml

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -39,11 +39,11 @@ dependencies = [
3939
"fire",
4040
"py7zr",
4141
"torchmetrics==1.7.0",
42-
"torch==2.7.0; platform_machine=='aarch64'",
42+
"torch==2.4.1; platform_machine=='aarch64'",
4343
# Specifying torch cpu package URL per python version, update the list once pytorch releases whl for python>3.11
4444
"torch@https://download.pytorch.org/whl/cpu/torch-2.4.1%2Bcpu-cp38-cp38-linux_x86_64.whl ; python_version=='3.8' and platform_machine=='x86_64'",
45-
"torch@https://download.pytorch.org/whl/cpu/torch-2.7.0%2Bcpu-cp39-cp39-manylinux_2_28_x86_64.whl ; python_version=='3.9' and platform_machine=='x86_64'",
46-
"torch@https://download.pytorch.org/whl/cpu/torch-2.7.0%2Bcpu-cp310-cp310-manylinux_2_28_x86_64.whl ; python_version=='3.10' and platform_machine=='x86_64'",
45+
"torch@https://download.pytorch.org/whl/cpu/torch-2.4.1%2Bcpu-cp39-cp39-linux_x86_64.whl ; python_version=='3.9' and platform_machine=='x86_64'",
46+
"torch@https://download.pytorch.org/whl/cpu/torch-2.4.1%2Bcpu-cp310-cp310-linux_x86_64.whl ; python_version=='3.10' and platform_machine=='x86_64'",
4747
]
4848

4949
[project.optional-dependencies]

run.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -84,7 +84,7 @@
8484
# ort_tokens = api_runner.run_kv_model_on_ort(onnx_model_path, is_tlm=False)
8585

8686
qpc_path = qeff_model.compile(
87-
prefill_seq_len=Constants.PROMPT_LEN,
87+
prefill_seq_len=1,
8888
ctx_len=Constants.CTX_LEN,
8989
num_cores=16,
9090
mxfp6_matmul=False,
@@ -98,7 +98,6 @@
9898
streamer = TextStreamer(tokenizer)
9999
exec_info = qeff_model.generate(
100100
tokenizer,
101-
streamer=streamer,
102101
prompts=Constants.INPUT_STR,
103102
device_ids=[0, 1, 2, 3],
104103
)

test.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
from QEfficient import QEFFAutoModelForCausalLM
44
import torch
55
# Initialize the model using from_pretrained similar to transformers.AutoModelForCausalLM
6-
model_name = "Qwen/Qwen3-30B-A3B-Instruct-2507"
6+
model_name = "meta-llama/Llama-3.2-1B"
77
# model_name="GPT2"
88
# model_name="Qwen/Qwen2-1.5B-Instruct"
99
import time
@@ -14,6 +14,7 @@
1414
# print("torch.compile run for model.model")
1515
# print("time ",t2-t1)
1616
# print("done")
17+
# import pdb; pdb.set_trace()
1718
inputs="Help me with this"
1819
tokenizer = AutoTokenizer.from_pretrained(model_name)
1920
# tokens=tokenizer([input], return_tensors="pt")

0 commit comments

Comments
 (0)