Skip to content

Commit 808220e

Browse files
committed
torchtune lora experiment
1 parent 644b7dd commit 808220e

File tree

12 files changed

+645
-21
lines changed

12 files changed

+645
-21
lines changed

examples/models/llama/attention.py

+3
Original file line numberDiff line numberDiff line change
@@ -185,6 +185,9 @@ def __init__(self, args: ModelArgs, layer_id: int, rope: Rope):
185185
self.q_norm_fn = RMSNorm(q_norm_dim, eps=args.norm_eps)
186186
self.k_norm_fn = RMSNorm(k_norm_dim, eps=args.norm_eps)
187187

188+
# self.wq = nn.Linear(
189+
# self.dim, self.n_heads * self.head_dim, bias=self.attention_qkv_bias
190+
# )
188191
self.wq = nn.Linear(
189192
self.dim, self.n_heads * self.head_dim, bias=self.attention_qkv_bias
190193
)

examples/models/llama/export_llama_lib.py

+50-4
Original file line numberDiff line numberDiff line change
@@ -101,7 +101,7 @@
101101
"phi_4_mini",
102102
"smollm2",
103103
]
104-
TORCHTUNE_DEFINED_MODELS = ["llama3_2_vision"]
104+
TORCHTUNE_DEFINED_MODELS = ["llama3_2_vision", "llama3_2_lora"]
105105
HUGGING_FACE_REPO_IDS = {
106106
"qwen2_5": "Qwen/Qwen2.5-1.5B",
107107
"phi_4_mini": "microsoft/Phi-4-mini-instruct",
@@ -209,6 +209,12 @@ def build_args_parser() -> argparse.ArgumentParser:
209209
help="checkpoint directory. Use with a sharded checkpoint, not for the standard llama2 model. Note, checkpoint_dir takes precedence over checkpoint if both are set.",
210210
)
211211

212+
parser.add_argument(
213+
"--adapter",
214+
default=None,
215+
help="Adapter path",
216+
)
217+
212218
parser.add_argument(
213219
"--use_qnn_sha",
214220
action="store_true",
@@ -585,17 +591,20 @@ def _prepare_for_llama_export(args) -> LLMEdgeManager:
585591
checkpoint_dir = (
586592
canonical_path(args.checkpoint_dir) if args.checkpoint_dir else None
587593
)
594+
adapter_path = canonical_path(args.adapter) if args.adapter else None
588595
params_path = canonical_path(args.params) if args.params else None
589596
output_dir_path = canonical_path(args.output_dir, dir=True)
590597
weight_type = WeightType.FAIRSEQ2 if args.fairseq2 else WeightType.LLAMA
591598

592599
# Convert dtype override string arg to actual type.
593600
dtype_override = DType[args.dtype_override]
594601

602+
# breakpoint() # 1, OK.
595603
edge_manager = _load_llama_model(
596604
args.model,
597605
checkpoint=checkpoint_path,
598606
checkpoint_dir=checkpoint_dir,
607+
adapter=adapter_path,
599608
params_path=params_path,
600609
use_kv_cache=args.use_kv_cache,
601610
use_sdpa_with_kv_cache=args.use_sdpa_with_kv_cache,
@@ -616,10 +625,16 @@ def _prepare_for_llama_export(args) -> LLMEdgeManager:
616625
dtype_override=dtype_override,
617626
args=args,
618627
)
619-
620628
# At this point, the model is loaded in the default fp32.
621629

622630
# Checkpoint dtype should be lower or equal precision to the dtype override.
631+
eg = torch.tensor([[2, 3, 4]], dtype=torch.int64)
632+
ip = torch.tensor([[0, 1, 2]], dtype=torch.long)
633+
634+
em1 = edge_manager.model.forward(eg, input_pos=ip)
635+
eager = torch.load("/data/users/lfq/executorch/eager_res.pt")
636+
torch.allclose(eager, em1)
637+
# breakpoint() # 4, OK.
623638
checkpoint_dtype = edge_manager.model.checkpoint_dtype
624639
if not (
625640
checkpoint_dtype == dtype_override.to_torch_dtype()
@@ -637,6 +652,10 @@ def _prepare_for_llama_export(args) -> LLMEdgeManager:
637652
)
638653

639654
edge_manager.model = edge_manager.model.to(dtype=dtype_override.to_torch_dtype())
655+
# edge_manager.model = edge_manager.model.to(dtype=torch.float32)
656+
em2 = edge_manager.model.forward(eg, input_pos=ip)
657+
torch.allclose(em2, eager)
658+
# breakpoint() # 5, not OK, gets converted to bf16. OK if dtype is consistent.
640659

641660
# We want to quantize (in the source transforms) the weights of the model
642661
# in the checkpoint dtype.
@@ -649,7 +668,9 @@ def _prepare_for_llama_export(args) -> LLMEdgeManager:
649668
args=args,
650669
)
651670
)
652-
671+
# torch.allclose here as well.
672+
em3 = edge_manager.model.forward(eg, input_pos=ip)
673+
torch.allclose(em3, eager)
653674
return edge_manager
654675

655676

@@ -777,6 +798,9 @@ def _to_edge_and_lower_llama( # noqa: C901
777798
builder_exported_to_edge = builder_exported.pt2e_quantize(
778799
quantizers
779800
).export_to_edge()
801+
breakpoint()
802+
# ^to_edge_res.pt
803+
# allclose 1e-1 compared to pre-auto.
780804

781805
# to_backend
782806
partitioners = []
@@ -911,7 +935,16 @@ def _export_llama(args) -> LLMEdgeManager: # noqa: C901
911935

912936
# export_to_edge
913937
builder_exported = _prepare_for_llama_export(args).export()
938+
eg = torch.tensor([[2, 3, 4]], dtype=torch.int64)
939+
ip = torch.tensor([[0, 1, 2]], dtype=torch.long)
940+
b_e = builder_exported.model.forward(eg, input_pos=ip)
941+
eager = torch.load("/data/users/lfq/executorch/eager_res.pt")
942+
torch.allclose(b_e, eager)
943+
# breakpoint()
944+
914945
builder_exported.run_canonical_optimizations()
946+
b_e2 = builder_exported.model.forward(eg, input_pos=ip)
947+
torch.allclose(b_e2, eager)
915948
modelname = builder_exported.modelname
916949

917950
if args.export_only:
@@ -932,6 +965,9 @@ def _export_llama(args) -> LLMEdgeManager: # noqa: C901
932965
args,
933966
)
934967
else:
968+
# breakpoint()
969+
b_e3 = builder_exported.model.forward(eg, input_pos=ip)
970+
torch.allclose(b_e3, eager)
935971
builder = _to_edge_and_lower_llama(
936972
builder_exported,
937973
modelname,
@@ -941,6 +977,7 @@ def _export_llama(args) -> LLMEdgeManager: # noqa: C901
941977
quant_dtype,
942978
args,
943979
)
980+
breakpoint()
944981

945982
if args.profile_memory:
946983
generate_memory_trace(builder.export_program, "memory_profile.json")
@@ -1004,6 +1041,7 @@ def _load_llama_model(
10041041
*,
10051042
checkpoint: Optional[str] = None,
10061043
checkpoint_dir: Optional[str] = None,
1044+
adapter: Optional[str] = None,
10071045
params_path: Optional[str] = None,
10081046
use_kv_cache: bool = False,
10091047
use_sdpa_with_kv_cache: bool = False,
@@ -1038,6 +1076,9 @@ def _load_llama_model(
10381076
if modelname == "llama3_2_vision":
10391077
module_name = "llama3_2_vision"
10401078
model_class_name = "Llama3_2Decoder"
1079+
if modelname == "llama3_2_lora":
1080+
module_name = "llama3_2_lora"
1081+
model_class_name = "Llama3_2_Lora"
10411082
else:
10421083
raise ValueError(f"{modelname} is not a valid Llama model.")
10431084
else:
@@ -1051,6 +1092,7 @@ def _load_llama_model(
10511092
model_class_name,
10521093
checkpoint=checkpoint,
10531094
checkpoint_dir=checkpoint_dir,
1095+
adapter=adapter,
10541096
params=params_path,
10551097
use_kv_cache=use_kv_cache,
10561098
use_sdpa_with_kv_cache=use_sdpa_with_kv_cache,
@@ -1066,6 +1108,7 @@ def _load_llama_model(
10661108
)
10671109
)
10681110

1111+
# breakpoint() # 3. OK.
10691112
return LLMEdgeManager(
10701113
model=model,
10711114
modelname=modelname,
@@ -1093,7 +1136,7 @@ def _load_llama_model(
10931136
model.max_seq_len,
10941137
# pyre-fixme[6]: For 6th argument expected `ModelArgs` but got
10951138
# `Union[Tensor, Module]`.
1096-
model.max_context_len,
1139+
max_context_len,
10971140
# pyre-fixme[6]: For 7th argument expected `int` but got `Union[Tensor,
10981141
# Module]`.
10991142
model.n_layers,
@@ -1244,6 +1287,9 @@ def _get_source_transforms( # noqa
12441287
if args.vulkan:
12451288
transforms.append(replace_with_vulkan_rotary_emb)
12461289

1290+
# transforms.append(
1291+
# replace_rope_with_inference_rope()
1292+
# )
12471293
return transforms
12481294

12491295

examples/models/llama/source_transformation/sdpa.py

+71-2
Original file line numberDiff line numberDiff line change
@@ -9,12 +9,16 @@
99
# Example script for exporting Llama2 to flatbuffer
1010

1111
import math
12-
from typing import Tuple
12+
from typing import Optional, Tuple
1313

1414
import torch
1515

1616
from executorch.examples.models.llama.attention import KVCache, SDPA
1717

18+
# from executorch.extension.llm.modules.attention import SDPA as TTSDPA
19+
20+
from torchtune.modules.attention_utils import _MaskType
21+
1822

1923
class SDPACustom(torch.nn.Module):
2024
def __init__(
@@ -49,7 +53,7 @@ def forward(
4953
q,
5054
k,
5155
v,
52-
input_pos[0].item(),
56+
input_pos[2].item(),
5357
None, # Attention mask
5458
0, # dropout probability. Ignored by the code
5559
True, # is_causal
@@ -60,11 +64,19 @@ def forward(
6064
def _replace_sdpa_with_custom_op(module: torch.nn.Module):
6165
for name, child in module.named_children():
6266
if isinstance(child, SDPA):
67+
breakpoint()
6368
setattr(
6469
module,
6570
name,
6671
SDPACustom(child.dim),
6772
)
73+
# elif isinstance(child, TTSDPA):
74+
# # breakpoint()
75+
# setattr(
76+
# module,
77+
# name,
78+
# SDPAConverter(child.num_heads * child.head_dim),
79+
# )
6880
else:
6981
_replace_sdpa_with_custom_op(child)
7082

@@ -76,6 +88,63 @@ def replace_sdpa_with_custom_op(module: torch.nn.Module) -> torch.nn.Module:
7688
return module
7789

7890

91+
# Convert from torchtune SDPA to SDPACustom.
92+
class SDPAConverter(torch.nn.Module):
93+
def __init__(
94+
self,
95+
dim: int,
96+
):
97+
super().__init__()
98+
self.dim = dim
99+
self.SDPA = SDPACustom(dim)
100+
101+
def forward(
102+
self,
103+
q: torch.Tensor, # [b, s, n_h, h_d]
104+
k: torch.Tensor, # [b, s, n_kv, h_d]
105+
v: torch.Tensor, # [b, s, n_kv, h_d]
106+
bsz: int,
107+
seq_len: int,
108+
mask: Optional[_MaskType] = None,
109+
):
110+
# input_pos = 0
111+
# Mask isn't used in SDPA?
112+
113+
# Make sure mask isn't None
114+
# take the first row of the mask, number of 0s/Trues. Index of the first non-zero.
115+
# assert mask is not None
116+
if mask is not None:
117+
attention_mask = mask.reshape(-1, max_seq_len)
118+
first_row = attention_mask[0, :]
119+
start_pos = torch.argmin(first_row).item() - 1
120+
else:
121+
start_pos = 0
122+
123+
##
124+
q = q.transpose(1, 2) # (bs, seqlen, n_local_heads, head_dim)
125+
k = k.transpose(1, 2)
126+
v = v.transpose(1, 2)
127+
128+
# Custom op only supports float32 currently. Converting to/from float32 is
129+
# faster than not having the op.
130+
input_dtype = q.dtype
131+
q = q.to(dtype=torch.float)
132+
k = k.to(dtype=torch.float)
133+
v = v.to(dtype=torch.float)
134+
135+
output = torch.ops.llama.custom_sdpa(
136+
q,
137+
k,
138+
v,
139+
start_pos,
140+
mask, # Attention mask
141+
0, # dropout probability. Ignored by the code
142+
True, # is_causal
143+
)
144+
return output.view(bsz, seq_len, self.dim).to(dtype=input_dtype)
145+
# return self.SDPA(start_pos, q, k, v, bsz, seq_len, mask)
146+
147+
79148
class SDPASimple(torch.nn.Module):
80149
def __init__(
81150
self,
+11
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,11 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
from .model import Llama3_2_Lora
8+
9+
__all__ = [
10+
"Llama3_2_Lora",
11+
]

0 commit comments

Comments
 (0)