101
101
"phi_4_mini" ,
102
102
"smollm2" ,
103
103
]
104
- TORCHTUNE_DEFINED_MODELS = ["llama3_2_vision" ]
104
+ TORCHTUNE_DEFINED_MODELS = ["llama3_2_vision" , "llama3_2_lora" ]
105
105
HUGGING_FACE_REPO_IDS = {
106
106
"qwen2_5" : "Qwen/Qwen2.5-1.5B" ,
107
107
"phi_4_mini" : "microsoft/Phi-4-mini-instruct" ,
@@ -209,6 +209,12 @@ def build_args_parser() -> argparse.ArgumentParser:
209
209
help = "checkpoint directory. Use with a sharded checkpoint, not for the standard llama2 model. Note, checkpoint_dir takes precedence over checkpoint if both are set." ,
210
210
)
211
211
212
+ parser .add_argument (
213
+ "--adapter" ,
214
+ default = None ,
215
+ help = "Adapter path" ,
216
+ )
217
+
212
218
parser .add_argument (
213
219
"--use_qnn_sha" ,
214
220
action = "store_true" ,
@@ -585,17 +591,20 @@ def _prepare_for_llama_export(args) -> LLMEdgeManager:
585
591
checkpoint_dir = (
586
592
canonical_path (args .checkpoint_dir ) if args .checkpoint_dir else None
587
593
)
594
+ adapter_path = canonical_path (args .adapter ) if args .adapter else None
588
595
params_path = canonical_path (args .params ) if args .params else None
589
596
output_dir_path = canonical_path (args .output_dir , dir = True )
590
597
weight_type = WeightType .FAIRSEQ2 if args .fairseq2 else WeightType .LLAMA
591
598
592
599
# Convert dtype override string arg to actual type.
593
600
dtype_override = DType [args .dtype_override ]
594
601
602
+ # breakpoint() # 1, OK.
595
603
edge_manager = _load_llama_model (
596
604
args .model ,
597
605
checkpoint = checkpoint_path ,
598
606
checkpoint_dir = checkpoint_dir ,
607
+ adapter = adapter_path ,
599
608
params_path = params_path ,
600
609
use_kv_cache = args .use_kv_cache ,
601
610
use_sdpa_with_kv_cache = args .use_sdpa_with_kv_cache ,
@@ -616,10 +625,16 @@ def _prepare_for_llama_export(args) -> LLMEdgeManager:
616
625
dtype_override = dtype_override ,
617
626
args = args ,
618
627
)
619
-
620
628
# At this point, the model is loaded in the default fp32.
621
629
622
630
# Checkpoint dtype should be lower or equal precision to the dtype override.
631
+ eg = torch .tensor ([[2 , 3 , 4 ]], dtype = torch .int64 )
632
+ ip = torch .tensor ([[0 , 1 , 2 ]], dtype = torch .long )
633
+
634
+ em1 = edge_manager .model .forward (eg , input_pos = ip )
635
+ eager = torch .load ("/data/users/lfq/executorch/eager_res.pt" )
636
+ torch .allclose (eager , em1 )
637
+ # breakpoint() # 4, OK.
623
638
checkpoint_dtype = edge_manager .model .checkpoint_dtype
624
639
if not (
625
640
checkpoint_dtype == dtype_override .to_torch_dtype ()
@@ -637,6 +652,10 @@ def _prepare_for_llama_export(args) -> LLMEdgeManager:
637
652
)
638
653
639
654
edge_manager .model = edge_manager .model .to (dtype = dtype_override .to_torch_dtype ())
655
+ # edge_manager.model = edge_manager.model.to(dtype=torch.float32)
656
+ em2 = edge_manager .model .forward (eg , input_pos = ip )
657
+ torch .allclose (em2 , eager )
658
+ # breakpoint() # 5, not OK, gets converted to bf16. OK if dtype is consistent.
640
659
641
660
# We want to quantize (in the source transforms) the weights of the model
642
661
# in the checkpoint dtype.
@@ -649,7 +668,9 @@ def _prepare_for_llama_export(args) -> LLMEdgeManager:
649
668
args = args ,
650
669
)
651
670
)
652
-
671
+ # torch.allclose here as well.
672
+ em3 = edge_manager .model .forward (eg , input_pos = ip )
673
+ torch .allclose (em3 , eager )
653
674
return edge_manager
654
675
655
676
@@ -777,6 +798,9 @@ def _to_edge_and_lower_llama( # noqa: C901
777
798
builder_exported_to_edge = builder_exported .pt2e_quantize (
778
799
quantizers
779
800
).export_to_edge ()
801
+ breakpoint ()
802
+ # ^to_edge_res.pt
803
+ # allclose 1e-1 compared to pre-auto.
780
804
781
805
# to_backend
782
806
partitioners = []
@@ -911,7 +935,16 @@ def _export_llama(args) -> LLMEdgeManager: # noqa: C901
911
935
912
936
# export_to_edge
913
937
builder_exported = _prepare_for_llama_export (args ).export ()
938
+ eg = torch .tensor ([[2 , 3 , 4 ]], dtype = torch .int64 )
939
+ ip = torch .tensor ([[0 , 1 , 2 ]], dtype = torch .long )
940
+ b_e = builder_exported .model .forward (eg , input_pos = ip )
941
+ eager = torch .load ("/data/users/lfq/executorch/eager_res.pt" )
942
+ torch .allclose (b_e , eager )
943
+ # breakpoint()
944
+
914
945
builder_exported .run_canonical_optimizations ()
946
+ b_e2 = builder_exported .model .forward (eg , input_pos = ip )
947
+ torch .allclose (b_e2 , eager )
915
948
modelname = builder_exported .modelname
916
949
917
950
if args .export_only :
@@ -932,6 +965,9 @@ def _export_llama(args) -> LLMEdgeManager: # noqa: C901
932
965
args ,
933
966
)
934
967
else :
968
+ # breakpoint()
969
+ b_e3 = builder_exported .model .forward (eg , input_pos = ip )
970
+ torch .allclose (b_e3 , eager )
935
971
builder = _to_edge_and_lower_llama (
936
972
builder_exported ,
937
973
modelname ,
@@ -941,6 +977,7 @@ def _export_llama(args) -> LLMEdgeManager: # noqa: C901
941
977
quant_dtype ,
942
978
args ,
943
979
)
980
+ breakpoint ()
944
981
945
982
if args .profile_memory :
946
983
generate_memory_trace (builder .export_program , "memory_profile.json" )
@@ -1004,6 +1041,7 @@ def _load_llama_model(
1004
1041
* ,
1005
1042
checkpoint : Optional [str ] = None ,
1006
1043
checkpoint_dir : Optional [str ] = None ,
1044
+ adapter : Optional [str ] = None ,
1007
1045
params_path : Optional [str ] = None ,
1008
1046
use_kv_cache : bool = False ,
1009
1047
use_sdpa_with_kv_cache : bool = False ,
@@ -1038,6 +1076,9 @@ def _load_llama_model(
1038
1076
if modelname == "llama3_2_vision" :
1039
1077
module_name = "llama3_2_vision"
1040
1078
model_class_name = "Llama3_2Decoder"
1079
+ if modelname == "llama3_2_lora" :
1080
+ module_name = "llama3_2_lora"
1081
+ model_class_name = "Llama3_2_Lora"
1041
1082
else :
1042
1083
raise ValueError (f"{ modelname } is not a valid Llama model." )
1043
1084
else :
@@ -1051,6 +1092,7 @@ def _load_llama_model(
1051
1092
model_class_name ,
1052
1093
checkpoint = checkpoint ,
1053
1094
checkpoint_dir = checkpoint_dir ,
1095
+ adapter = adapter ,
1054
1096
params = params_path ,
1055
1097
use_kv_cache = use_kv_cache ,
1056
1098
use_sdpa_with_kv_cache = use_sdpa_with_kv_cache ,
@@ -1066,6 +1108,7 @@ def _load_llama_model(
1066
1108
)
1067
1109
)
1068
1110
1111
+ # breakpoint() # 3. OK.
1069
1112
return LLMEdgeManager (
1070
1113
model = model ,
1071
1114
modelname = modelname ,
@@ -1093,7 +1136,7 @@ def _load_llama_model(
1093
1136
model .max_seq_len ,
1094
1137
# pyre-fixme[6]: For 6th argument expected `ModelArgs` but got
1095
1138
# `Union[Tensor, Module]`.
1096
- model . max_context_len ,
1139
+ max_context_len ,
1097
1140
# pyre-fixme[6]: For 7th argument expected `int` but got `Union[Tensor,
1098
1141
# Module]`.
1099
1142
model .n_layers ,
@@ -1244,6 +1287,9 @@ def _get_source_transforms( # noqa
1244
1287
if args .vulkan :
1245
1288
transforms .append (replace_with_vulkan_rotary_emb )
1246
1289
1290
+ # transforms.append(
1291
+ # replace_rope_with_inference_rope()
1292
+ # )
1247
1293
return transforms
1248
1294
1249
1295
0 commit comments