@@ -313,7 +313,7 @@ def export_to_edge(
313313 core_aten_ep , edge_constant_methods , edge_compile_config , verbose = verbose
314314 )
315315
316- def export_for_et (model , device , output_path ) -> str :
316+ def export_for_et (model , device , output_path , edge_constant_methods ) -> str :
317317
318318 input = (
319319 torch .tensor ([[1 ]], dtype = torch .long , device = device ),
@@ -344,12 +344,15 @@ def export_for_et(model, device, output_path) -> str:
344344 with torch .nn .attention .sdpa_kernel (
345345 [torch .nn .attention .SDPBackend .MATH ]
346346 ), torch .no_grad ():
347- m = export_for_training (model , input , dynamic_shapes = dynamic_shapes ).module ()
347+ m = export_for_training (
348+ model , input , dynamic_shapes = dynamic_shapes
349+ ).module ()
348350
349351 edge_manager = export_to_edge (
350352 m ,
351353 input ,
352354 dynamic_shapes = dynamic_shapes ,
355+ edge_constant_methods = edge_constant_methods ,
353356 edge_compile_config = edge_config ,
354357 )
355358 edge_manager = edge_manager .to_backend (XnnpackDynamicallyQuantizedPartitioner ())
@@ -365,6 +368,7 @@ def export_for_et(model, device, output_path) -> str:
365368 )
366369
367370 print ("The methods are: " , export_program .methods )
371+ print ("The config methods are: " , export_program .config_methods )
368372 with open (output_path , "wb" ) as f :
369373 export_program .write_to_file (f )
370374
@@ -407,7 +411,9 @@ def main(args):
407411 f"Warning! ExecuTorch export target is controlled by export recipe, not device setting. Ignoring device={ builder_args .device } setting."
408412 )
409413 builder_args .device = "cpu"
410- elif (output_pte_path or output_dso_path or output_aoti_package_path ) and "mps" in builder_args .device :
414+ elif (
415+ output_pte_path or output_dso_path or output_aoti_package_path
416+ ) and "mps" in builder_args .device :
411417 print ("Warning! Device MPS not supported for export. Exporting for device CPU." )
412418 builder_args .device = "cpu"
413419
@@ -473,13 +479,26 @@ def main(args):
473479 support_tensor_subclass = False ,
474480 )
475481 _unset_gguf_kwargs (builder_args )
476-
482+
483+ if tokenizer_args is None :
484+ tokenizer_type = "0"
485+ elif tokenizer_args .is_sentencepiece :
486+ tokenizer_type = "2" # Corresponding to llama2
487+ else :
488+ tokenizer_type = "3" # Corresponding to llama3
489+
477490 with torch .no_grad ():
478491 if output_pte_path :
479492 output_pte_path = str (os .path .abspath (output_pte_path ))
480493 if executorch_export_available :
481494 print (f"Exporting model using ExecuTorch to { output_pte_path } " )
482- export_for_et (model_to_pte , builder_args .device , args .output_pte_path )
495+ print (f"Tokenizer type is { tokenizer_type } " )
496+ export_for_et (
497+ model_to_pte ,
498+ builder_args .device ,
499+ args .output_pte_path ,
500+ {"tokenizer_type" : int (tokenizer_type )},
501+ )
483502 else :
484503 print (
485504 "Export with executorch requested but ExecuTorch could not be loaded"
@@ -503,13 +522,6 @@ def main(args):
503522 if output_aoti_package_path :
504523 output_aoti_package_path = str (os .path .abspath (output_aoti_package_path ))
505524
506- if tokenizer_args is None :
507- tokenizer_type = "0"
508- elif tokenizer_args .is_sentencepiece :
509- tokenizer_type = "2" # Corresponding to llama2
510- else :
511- tokenizer_type = "3" # Corresponding to llama3
512-
513525 metadata = {"tokenizer_type" : tokenizer_type }
514526 print (
515527 "Exporting model using AOT Inductor to " f"{ output_aoti_package_path } ."
0 commit comments