Skip to content

Commit 1c7e59f

Browse files
committed
torchtune lora experiment
1 parent 644b7dd commit 1c7e59f

File tree

11 files changed

+558
-9
lines changed

11 files changed

+558
-9
lines changed

examples/models/llama/export_llama_lib.py

+50-4
Original file line numberDiff line numberDiff line change
@@ -101,7 +101,7 @@
101101
"phi_4_mini",
102102
"smollm2",
103103
]
104-
TORCHTUNE_DEFINED_MODELS = ["llama3_2_vision"]
104+
TORCHTUNE_DEFINED_MODELS = ["llama3_2_vision", "llama3_2_lora"]
105105
HUGGING_FACE_REPO_IDS = {
106106
"qwen2_5": "Qwen/Qwen2.5-1.5B",
107107
"phi_4_mini": "microsoft/Phi-4-mini-instruct",
@@ -209,6 +209,12 @@ def build_args_parser() -> argparse.ArgumentParser:
209209
help="checkpoint directory. Use with a sharded checkpoint, not for the standard llama2 model. Note, checkpoint_dir takes precedence over checkpoint if both are set.",
210210
)
211211

212+
parser.add_argument(
213+
"--adapter",
214+
default=None,
215+
help="Adapter path",
216+
)
217+
212218
parser.add_argument(
213219
"--use_qnn_sha",
214220
action="store_true",
@@ -585,17 +591,20 @@ def _prepare_for_llama_export(args) -> LLMEdgeManager:
585591
checkpoint_dir = (
586592
canonical_path(args.checkpoint_dir) if args.checkpoint_dir else None
587593
)
594+
adapter_path = canonical_path(args.adapter) if args.adapter else None
588595
params_path = canonical_path(args.params) if args.params else None
589596
output_dir_path = canonical_path(args.output_dir, dir=True)
590597
weight_type = WeightType.FAIRSEQ2 if args.fairseq2 else WeightType.LLAMA
591598

592599
# Convert dtype override string arg to actual type.
593600
dtype_override = DType[args.dtype_override]
594601

602+
# breakpoint() # 1, OK.
595603
edge_manager = _load_llama_model(
596604
args.model,
597605
checkpoint=checkpoint_path,
598606
checkpoint_dir=checkpoint_dir,
607+
adapter=adapter_path,
599608
params_path=params_path,
600609
use_kv_cache=args.use_kv_cache,
601610
use_sdpa_with_kv_cache=args.use_sdpa_with_kv_cache,
@@ -616,10 +625,16 @@ def _prepare_for_llama_export(args) -> LLMEdgeManager:
616625
dtype_override=dtype_override,
617626
args=args,
618627
)
619-
620628
# At this point, the model is loaded in the default fp32.
621629

622630
# Checkpoint dtype should be lower or equal precision to the dtype override.
631+
eg = torch.tensor([[2, 3, 4]], dtype=torch.int64)
632+
ip = torch.tensor([[0, 1, 2]], dtype=torch.long)
633+
634+
em1 = edge_manager.model.forward(eg, input_pos=ip)
635+
eager = torch.load("/data/users/lfq/executorch/eager_res.pt")
636+
torch.allclose(eager, em1)
637+
# breakpoint() # 4, OK.
623638
checkpoint_dtype = edge_manager.model.checkpoint_dtype
624639
if not (
625640
checkpoint_dtype == dtype_override.to_torch_dtype()
@@ -637,6 +652,10 @@ def _prepare_for_llama_export(args) -> LLMEdgeManager:
637652
)
638653

639654
edge_manager.model = edge_manager.model.to(dtype=dtype_override.to_torch_dtype())
655+
# edge_manager.model = edge_manager.model.to(dtype=torch.float32)
656+
em2 = edge_manager.model.forward(eg, input_pos=ip)
657+
torch.allclose(em2, eager)
658+
# breakpoint() # 5, not OK, gets converted to bf16. OK if dtype is consistent.
640659

641660
# We want to quantize (in the source transforms) the weights of the model
642661
# in the checkpoint dtype.
@@ -649,7 +668,9 @@ def _prepare_for_llama_export(args) -> LLMEdgeManager:
649668
args=args,
650669
)
651670
)
652-
671+
# torch.allclose here as well.
672+
em3 = edge_manager.model.forward(eg, input_pos=ip)
673+
torch.allclose(em3, eager)
653674
return edge_manager
654675

655676

@@ -777,6 +798,9 @@ def _to_edge_and_lower_llama( # noqa: C901
777798
builder_exported_to_edge = builder_exported.pt2e_quantize(
778799
quantizers
779800
).export_to_edge()
801+
breakpoint()
802+
# ^to_edge_res.pt
803+
# allclose 1e-1 compared to pre-auto.
780804

781805
# to_backend
782806
partitioners = []
@@ -911,7 +935,16 @@ def _export_llama(args) -> LLMEdgeManager: # noqa: C901
911935

912936
# export_to_edge
913937
builder_exported = _prepare_for_llama_export(args).export()
938+
eg = torch.tensor([[2, 3, 4]], dtype=torch.int64)
939+
ip = torch.tensor([[0, 1, 2]], dtype=torch.long)
940+
b_e = builder_exported.model.forward(eg, input_pos=ip)
941+
eager = torch.load("/data/users/lfq/executorch/eager_res.pt")
942+
torch.allclose(b_e, eager)
943+
# breakpoint()
944+
914945
builder_exported.run_canonical_optimizations()
946+
b_e2 = builder_exported.model.forward(eg, input_pos=ip)
947+
torch.allclose(b_e2, eager)
915948
modelname = builder_exported.modelname
916949

917950
if args.export_only:
@@ -932,6 +965,9 @@ def _export_llama(args) -> LLMEdgeManager: # noqa: C901
932965
args,
933966
)
934967
else:
968+
# breakpoint()
969+
b_e3 = builder_exported.model.forward(eg, input_pos=ip)
970+
torch.allclose(b_e3, eager)
935971
builder = _to_edge_and_lower_llama(
936972
builder_exported,
937973
modelname,
@@ -941,6 +977,7 @@ def _export_llama(args) -> LLMEdgeManager: # noqa: C901
941977
quant_dtype,
942978
args,
943979
)
980+
breakpoint()
944981

945982
if args.profile_memory:
946983
generate_memory_trace(builder.export_program, "memory_profile.json")
@@ -1004,6 +1041,7 @@ def _load_llama_model(
10041041
*,
10051042
checkpoint: Optional[str] = None,
10061043
checkpoint_dir: Optional[str] = None,
1044+
adapter: Optional[str] = None,
10071045
params_path: Optional[str] = None,
10081046
use_kv_cache: bool = False,
10091047
use_sdpa_with_kv_cache: bool = False,
@@ -1038,6 +1076,9 @@ def _load_llama_model(
10381076
if modelname == "llama3_2_vision":
10391077
module_name = "llama3_2_vision"
10401078
model_class_name = "Llama3_2Decoder"
1079+
if modelname == "llama3_2_lora":
1080+
module_name = "llama3_2_lora"
1081+
model_class_name = "Llama3_2_Lora"
10411082
else:
10421083
raise ValueError(f"{modelname} is not a valid Llama model.")
10431084
else:
@@ -1051,6 +1092,7 @@ def _load_llama_model(
10511092
model_class_name,
10521093
checkpoint=checkpoint,
10531094
checkpoint_dir=checkpoint_dir,
1095+
adapter=adapter,
10541096
params=params_path,
10551097
use_kv_cache=use_kv_cache,
10561098
use_sdpa_with_kv_cache=use_sdpa_with_kv_cache,
@@ -1066,6 +1108,7 @@ def _load_llama_model(
10661108
)
10671109
)
10681110

1111+
# breakpoint() # 3. OK.
10691112
return LLMEdgeManager(
10701113
model=model,
10711114
modelname=modelname,
@@ -1093,7 +1136,7 @@ def _load_llama_model(
10931136
model.max_seq_len,
10941137
# pyre-fixme[6]: For 6th argument expected `ModelArgs` but got
10951138
# `Union[Tensor, Module]`.
1096-
model.max_context_len,
1139+
max_context_len,
10971140
# pyre-fixme[6]: For 7th argument expected `int` but got `Union[Tensor,
10981141
# Module]`.
10991142
model.n_layers,
@@ -1244,6 +1287,9 @@ def _get_source_transforms( # noqa
12441287
if args.vulkan:
12451288
transforms.append(replace_with_vulkan_rotary_emb)
12461289

1290+
# transforms.append(
1291+
# replace_rope_with_inference_rope()
1292+
# )
12471293
return transforms
12481294

12491295

examples/models/llama/source_transformation/sdpa.py

+10
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,8 @@
1515

1616
from executorch.examples.models.llama.attention import KVCache, SDPA
1717

18+
from executorch.extension.llm.modules.attention import SDPA as TTSDPA
19+
1820

1921
class SDPACustom(torch.nn.Module):
2022
def __init__(
@@ -60,11 +62,19 @@ def forward(
6062
def _replace_sdpa_with_custom_op(module: torch.nn.Module):
6163
for name, child in module.named_children():
6264
if isinstance(child, SDPA):
65+
breakpoint()
6366
setattr(
6467
module,
6568
name,
6669
SDPACustom(child.dim),
6770
)
71+
elif isinstance(child, TTSDPA):
72+
breakpoint()
73+
setattr(
74+
module,
75+
name,
76+
SDPACustom(child.num_heads * child.head_dim),
77+
)
6878
else:
6979
_replace_sdpa_with_custom_op(child)
7080

+11
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,11 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
from .model import Llama3_2_Lora
8+
9+
__all__ = [
10+
"Llama3_2_Lora",
11+
]
+155
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,155 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
# pyre-unsafe
8+
9+
import json
10+
import os
11+
from typing import Any, Dict
12+
13+
import torch
14+
15+
from executorch.examples.models.checkpoint import get_checkpoint_dtype
16+
from executorch.examples.models.llama.model_args import ModelArgs
17+
from executorch.examples.models.llama.rope import Rope, RotaryEmbedding
18+
from executorch.examples.models.model_base import EagerModelBase
19+
from executorch.extension.llm.modules.attention import (
20+
replace_mha_with_inference_mha,
21+
replace_rope_with_inference_rope,
22+
)
23+
24+
from torchtune.models import convert_weights
25+
26+
from torchtune.models.llama3_1._position_embeddings import Llama3ScaledRoPE
27+
28+
from torchtune.models.llama3_2._component_builders import lora_llama3_2
29+
30+
31+
class Llama3_2_Lora(EagerModelBase):
32+
def __init__(self, **kwargs):
33+
# Set member vars from kwargs.
34+
self.max_seq_len = kwargs.get(
35+
"max_seq_len", 8192
36+
) # Trained to be a lot larger, but this value is kept small because of static kv cache at the moment.
37+
# self.encoder_max_seq_len = kwargs.get(
38+
# "encoder_max_seq_len", int(4 * (448 / 14) ** 2 + 1)
39+
# ) # Same as above.
40+
self.generate_full_logits = kwargs.get("generate_full_logits", False)
41+
self.enable_dynamic_shape = kwargs.get("enable_dynamic_shape", True)
42+
self.output_prune_map_path = kwargs.get("output_prune_map_path", None)
43+
self.use_kv_cache = kwargs.get("use_kv_cache", False)
44+
self.verbose = kwargs.get("verbose", False)
45+
self.args = kwargs.get("args", None)
46+
self.dtype = kwargs.get("dtype", torch.float16)
47+
self.use_checkpoint = False
48+
self.max_context_len = kwargs.get("max_context_len", 8192)
49+
50+
# Single checkpoint file.
51+
checkpoint_path = kwargs.get("checkpoint")
52+
53+
if os.path.isfile(checkpoint_path):
54+
self.use_checkpoint = True
55+
56+
params_path = kwargs.get("params")
57+
adapter_path = kwargs.get("adapter")
58+
59+
# self.input_pos = torch.arange(self.max_seq_len, dtype=torch.int64)
60+
# Load checkpoint and params.
61+
device = "cpu"
62+
if self.use_checkpoint:
63+
checkpoint = torch.load(
64+
checkpoint_path, map_location=device, weights_only=False, mmap=True
65+
)
66+
checkpoint = convert_weights.meta_to_tune(checkpoint)
67+
self.dtype = get_checkpoint_dtype(checkpoint)
68+
69+
adapter = torch.load(
70+
adapter_path, map_location="cpu", mmap=True, weights_only=False
71+
)
72+
73+
checkpoint.update(adapter)
74+
75+
with open(params_path, "r") as f:
76+
params = json.loads(f.read())
77+
78+
# Load model.
79+
# Cannot use "with torch.device("meta"):" because it causes some exceptions during export,
80+
# i.e. the model isn't fully initialized or something.
81+
self.model_ = lora_llama3_2(
82+
lora_attn_modules=[
83+
"q_proj",
84+
],
85+
apply_lora_to_mlp=False,
86+
apply_lora_to_output=False,
87+
# llama3_2 args
88+
vocab_size=params["vocab_size"],
89+
num_layers=params["n_layers"],
90+
num_heads=params["n_heads"],
91+
num_kv_heads=params["n_kv_heads"],
92+
embed_dim=params["dim"],
93+
max_seq_len=self.max_seq_len, # 131072
94+
# intermediate_dim=params["intermediate_dim"], # 8192, calc is 4096
95+
# LoRA args. TODO take in the adapter config.
96+
lora_rank=8,
97+
lora_alpha=16,
98+
)
99+
self.model_.requires_grad_(False)
100+
for param_name, param_val in params.items():
101+
setattr(self.model_, param_name, param_val)
102+
103+
setattr(self.model_, "enable_dynamic_shape", self.enable_dynamic_shape)
104+
# Source transformation for MultiHeadAttention
105+
self.model_ = replace_mha_with_inference_mha(self.model_)
106+
107+
model_args: ModelArgs = ModelArgs(
108+
max_seq_len=self.max_seq_len,
109+
max_context_len=self.max_context_len,
110+
use_kv_cache=self.use_kv_cache,
111+
generate_full_logits=self.generate_full_logits,
112+
enable_dynamic_shape=self.enable_dynamic_shape,
113+
**params,
114+
)
115+
# Source transformation for RoPE
116+
# self.model_ = replace_rope_with_inference_rope(self.model_, model_args)
117+
118+
setattr(self.model_, "checkpoint_dtype", self.dtype)
119+
if self.use_checkpoint:
120+
# Load checkpoint.
121+
missing, unexpected = self.model_.load_state_dict(
122+
checkpoint,
123+
strict=False,
124+
assign=True,
125+
)
126+
if kwargs.get("verbose", False):
127+
print("============= missing keys ================")
128+
print(missing)
129+
print("============= /missing ================")
130+
print("============= unexpected keys ================")
131+
print(unexpected)
132+
print("============= /unexpected ================")
133+
134+
self.model_.to(self.dtype)
135+
# breakpoint() # 2, OK.
136+
137+
def get_eager_model(self) -> torch.nn.Module:
138+
return self.model_
139+
140+
def get_example_inputs(self):
141+
return (torch.tensor([[2, 3, 4]], dtype=torch.int64),)
142+
# return (
143+
# torch.tensor([[2, 3, 4]], dtype=torch.long),
144+
# {"input_pos": torch.tensor([0], dtype=torch.long)},
145+
# )
146+
# return (torch.ones(1, self.n_tokens, dtype=torch.int64),)
147+
148+
# eg=torch.tensor([[2, 3, 4]], dtype=torch.int64)
149+
# ip=torch.tensor([[0, 1, 2]], dtype=torch.long)
150+
def get_example_kwarg_inputs(self):
151+
return {"input_pos": torch.tensor([[0, 1, 2]], dtype=torch.long)}
152+
153+
def get_dynamic_shapes(self):
154+
dim = torch.export.Dim("token_dim", min=1, max=self.max_seq_len - 1)
155+
return ({1: dim}, {1: dim})

0 commit comments

Comments
 (0)