55#
66# ----------------------------------------------------------------------------
77
8- # import hashlib
8+ import copy
99import inspect
1010import json
1111import logging
@@ -52,10 +52,9 @@ def __init__(self, model: torch.nn.Module, **kwargs) -> None:
5252
5353 # Store Model parameters to Calculate Hash for caching
5454 self .model_params = {}
55- self .model_params . update (kwargs )
55+ self .model_params = copy . deepcopy (kwargs )
5656 self .model_params ["config" ] = self .model .config .to_diff_dict ()
5757 self .model_params ["_transform_names" ] = self ._transform_names ()
58- self .compile_params = {}
5958
6059 if hasattr (self .model .config , "architectures" ):
6160 self .model_architecture = self .model .config .architectures [0 ]
@@ -142,13 +141,15 @@ def _export(
142141 :onnx_transform_kwargs (dict): Additional arguments to be passed to `Transform.apply` for this class.
143142 :export_dir (str): Specify the export directory. The export_dir will be suffixed with a hash corresponding to current model.
144143 """
145- self .model_params ["output_names" ] = output_names
146- self .model_params ["dynamic_axes" ] = dynamic_axes
144+ export_params = {}
145+ export_params ["output_names" ] = output_names
146+ export_params ["dynamic_axes" ] = dynamic_axes
147+
148+ self .model_params ["export_params" ] = export_params
149+
150+ self .model_params .update (export_kwargs ) if export_kwargs is not None else None
151+ self .model_params .update (onnx_transform_kwargs ) if export_kwargs is not None else None
147152
148- if export_kwargs is not None :
149- self .model_params .update (export_kwargs )
150- if onnx_transform_kwargs is not None :
151- self .model_params .update (onnx_transform_kwargs )
152153 export_dir = Path (export_dir or (QEFF_HOME / self .model_architecture / self .model_name ))
153154
154155 export_hash = hash_dict_params (self .model_params )
@@ -163,17 +164,6 @@ def _export(
163164 tmp_onnx_path = tmp_onnx_dir / f"{ self .model_name } .onnx"
164165 tmp_onnx_dir .mkdir (parents = True , exist_ok = True )
165166
166- model_params_json = export_dir / "model_params.json"
167- with open (model_params_json , "w" ) as fp :
168- json .dump (
169- {
170- "model_params" : [
171- {k : make_serializable (self .model_params [k ]) for k in sorted (self .model_params .keys ())}
172- ]
173- },
174- fp ,
175- indent = 4 ,
176- )
177167 # Create input_names from example_inputs
178168
179169 input_names = []
@@ -231,6 +221,20 @@ def _export(
231221 onnx .save (model , onnx_path )
232222 logger .info ("Transformed onnx saved" )
233223
224+ # Dumping model paramters in a JSON file after successful ONNX export
225+ model_params_json = export_dir / "model_params.json"
226+ with open (model_params_json , "w" ) as fp :
227+ json .dump (
228+ {
229+ "model_params" : {
230+ k : make_serializable (self .model_params [k ]) for k in sorted (self .model_params .keys ())
231+ }
232+ },
233+ fp ,
234+ indent = 4 ,
235+ )
236+ logger .info ("Parameters used for export hash dumped in a JSON file successfully" )
237+
234238 except Exception as e :
235239 logger .error (f"ONNX export (or) ONNXTransforms failed: { e } " )
236240
@@ -277,6 +281,8 @@ def _compile(
277281 if onnx_path is None and self .onnx_path is None :
278282 self .export ()
279283
284+ self .compile_params = {}
285+
280286 onnx_path = Path (onnx_path or self .onnx_path )
281287 compile_dir = Path (compile_dir or onnx_path .parent )
282288 qpc_path = compile_dir / "qpc"
@@ -339,18 +345,6 @@ def _compile(
339345 # Probably compilation failure last time, delete directory to start over
340346 shutil .rmtree (qpc_path )
341347
342- compile_params_json = compile_dir / "compile_params.json"
343- with open (compile_params_json , "w" ) as fp :
344- json .dump (
345- {
346- "compile_params" : [
347- {k : make_serializable (self .compile_params [k ]) for k in sorted (self .compile_params .keys ())}
348- ]
349- },
350- fp ,
351- indent = 4 ,
352- )
353-
354348 # Write specializations.json file
355349 if specializations is not None :
356350 specializations_json = compile_dir / "specializations.json"
@@ -394,6 +388,19 @@ def _compile(
394388 logger .info (f"Running compiler: { ' ' .join (command )} " )
395389 try :
396390 subprocess .run (command , capture_output = True , check = True )
391+
392+ # Dumping compile paramters in a JSON file after successful ONNX export
393+ compile_params_json = compile_dir / "compile_params.json"
394+ with open (compile_params_json , "w" ) as fp :
395+ json .dump (
396+ {
397+ "compile_params" : {
398+ k : make_serializable (self .compile_params [k ]) for k in sorted (self .compile_params .keys ())
399+ }
400+ },
401+ fp ,
402+ indent = 4 ,
403+ )
397404 except subprocess .CalledProcessError as e :
398405 raise RuntimeError (
399406 "\n " .join (
0 commit comments