Skip to content

Commit b540ea1

Browse files
committed
Incorporated changes suggested in comments
Signed-off-by: Dhiraj Kumar Sah <[email protected]>
1 parent 0481612 commit b540ea1

File tree

2 files changed

+39
-35
lines changed

2 files changed

+39
-35
lines changed

QEfficient/base/modeling_qeff.py

Lines changed: 39 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
#
66
# ----------------------------------------------------------------------------
77

8-
# import hashlib
8+
import copy
99
import inspect
1010
import json
1111
import logging
@@ -52,10 +52,9 @@ def __init__(self, model: torch.nn.Module, **kwargs) -> None:
5252

5353
# Store Model parameters to Calculate Hash for caching
5454
self.model_params = {}
55-
self.model_params.update(kwargs)
55+
self.model_params = copy.deepcopy(kwargs)
5656
self.model_params["config"] = self.model.config.to_diff_dict()
5757
self.model_params["_transform_names"] = self._transform_names()
58-
self.compile_params = {}
5958

6059
if hasattr(self.model.config, "architectures"):
6160
self.model_architecture = self.model.config.architectures[0]
@@ -142,13 +141,15 @@ def _export(
142141
:onnx_transform_kwargs (dict): Additional arguments to be passed to `Transform.apply` for this class.
143142
:export_dir (str): Specify the export directory. The export_dir will be suffixed with a hash corresponding to current model.
144143
"""
145-
self.model_params["output_names"] = output_names
146-
self.model_params["dynamic_axes"] = dynamic_axes
144+
export_params = {}
145+
export_params["output_names"] = output_names
146+
export_params["dynamic_axes"] = dynamic_axes
147+
148+
self.model_params["export_params"] = export_params
149+
150+
self.model_params.update(export_kwargs) if export_kwargs is not None else None
151+
self.model_params.update(onnx_transform_kwargs) if export_kwargs is not None else None
147152

148-
if export_kwargs is not None:
149-
self.model_params.update(export_kwargs)
150-
if onnx_transform_kwargs is not None:
151-
self.model_params.update(onnx_transform_kwargs)
152153
export_dir = Path(export_dir or (QEFF_HOME / self.model_architecture / self.model_name))
153154

154155
export_hash = hash_dict_params(self.model_params)
@@ -163,17 +164,6 @@ def _export(
163164
tmp_onnx_path = tmp_onnx_dir / f"{self.model_name}.onnx"
164165
tmp_onnx_dir.mkdir(parents=True, exist_ok=True)
165166

166-
model_params_json = export_dir / "model_params.json"
167-
with open(model_params_json, "w") as fp:
168-
json.dump(
169-
{
170-
"model_params": [
171-
{k: make_serializable(self.model_params[k]) for k in sorted(self.model_params.keys())}
172-
]
173-
},
174-
fp,
175-
indent=4,
176-
)
177167
# Create input_names from example_inputs
178168

179169
input_names = []
@@ -231,6 +221,20 @@ def _export(
231221
onnx.save(model, onnx_path)
232222
logger.info("Transformed onnx saved")
233223

224+
# Dumping model paramters in a JSON file after successful ONNX export
225+
model_params_json = export_dir / "model_params.json"
226+
with open(model_params_json, "w") as fp:
227+
json.dump(
228+
{
229+
"model_params": {
230+
k: make_serializable(self.model_params[k]) for k in sorted(self.model_params.keys())
231+
}
232+
},
233+
fp,
234+
indent=4,
235+
)
236+
logger.info("Parameters used for export hash dumped in a JSON file successfully")
237+
234238
except Exception as e:
235239
logger.error(f"ONNX export (or) ONNXTransforms failed: {e}")
236240

@@ -277,6 +281,8 @@ def _compile(
277281
if onnx_path is None and self.onnx_path is None:
278282
self.export()
279283

284+
self.compile_params = {}
285+
280286
onnx_path = Path(onnx_path or self.onnx_path)
281287
compile_dir = Path(compile_dir or onnx_path.parent)
282288
qpc_path = compile_dir / "qpc"
@@ -339,18 +345,6 @@ def _compile(
339345
# Probably compilation failure last time, delete directory to start over
340346
shutil.rmtree(qpc_path)
341347

342-
compile_params_json = compile_dir / "compile_params.json"
343-
with open(compile_params_json, "w") as fp:
344-
json.dump(
345-
{
346-
"compile_params": [
347-
{k: make_serializable(self.compile_params[k]) for k in sorted(self.compile_params.keys())}
348-
]
349-
},
350-
fp,
351-
indent=4,
352-
)
353-
354348
# Write specializations.json file
355349
if specializations is not None:
356350
specializations_json = compile_dir / "specializations.json"
@@ -394,6 +388,19 @@ def _compile(
394388
logger.info(f"Running compiler: {' '.join(command)}")
395389
try:
396390
subprocess.run(command, capture_output=True, check=True)
391+
392+
# Dumping compile paramters in a JSON file after successful ONNX export
393+
compile_params_json = compile_dir / "compile_params.json"
394+
with open(compile_params_json, "w") as fp:
395+
json.dump(
396+
{
397+
"compile_params": {
398+
k: make_serializable(self.compile_params[k]) for k in sorted(self.compile_params.keys())
399+
}
400+
},
401+
fp,
402+
indent=4,
403+
)
397404
except subprocess.CalledProcessError as e:
398405
raise RuntimeError(
399406
"\n".join(

QEfficient/transformers/models/modeling_auto.py

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -171,8 +171,6 @@ def __init__(self, model: nn.Module, pooling=None, **kwargs):
171171
self.model.base_model.config.use_cache = True
172172
self.model_params["qeff_class"] = self.__class__.__name__
173173

174-
# self.pretrained_model_name_or_path = kwargs.get("pretrained_model_name_or_path", None)
175-
176174
@classmethod
177175
@with_replaced_quantizers
178176
def from_pretrained(cls, pretrained_model_name_or_path, pooling=None, *args, **kwargs):
@@ -913,7 +911,6 @@ def __init__(
913911
self.model.config.vision_config.use_flash_attn = "false"
914912
else:
915913
self.model.config.text_config.use_cache = True
916-
self.pretrained_model_name_or_path = kwargs.get("pretrained_model_name_or_path", None)
917914

918915
@classmethod
919916
def from_pretrained(

0 commit comments

Comments
 (0)