Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[do not land] lora experiment #9863

Draft
wants to merge 1 commit into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions examples/models/llama/attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -185,6 +185,9 @@ def __init__(self, args: ModelArgs, layer_id: int, rope: Rope):
self.q_norm_fn = RMSNorm(q_norm_dim, eps=args.norm_eps)
self.k_norm_fn = RMSNorm(k_norm_dim, eps=args.norm_eps)

# self.wq = nn.Linear(
# self.dim, self.n_heads * self.head_dim, bias=self.attention_qkv_bias
# )
self.wq = nn.Linear(
self.dim, self.n_heads * self.head_dim, bias=self.attention_qkv_bias
)
Expand Down
54 changes: 50 additions & 4 deletions examples/models/llama/export_llama_lib.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,7 +101,7 @@
"phi_4_mini",
"smollm2",
]
TORCHTUNE_DEFINED_MODELS = ["llama3_2_vision"]
TORCHTUNE_DEFINED_MODELS = ["llama3_2_vision", "llama3_2_lora"]
HUGGING_FACE_REPO_IDS = {
"qwen2_5": "Qwen/Qwen2.5-1.5B",
"phi_4_mini": "microsoft/Phi-4-mini-instruct",
Expand Down Expand Up @@ -209,6 +209,12 @@ def build_args_parser() -> argparse.ArgumentParser:
help="checkpoint directory. Use with a sharded checkpoint, not for the standard llama2 model. Note, checkpoint_dir takes precedence over checkpoint if both are set.",
)

parser.add_argument(
"--adapter",
default=None,
help="Adapter path",
)

parser.add_argument(
"--use_qnn_sha",
action="store_true",
Expand Down Expand Up @@ -585,17 +591,20 @@ def _prepare_for_llama_export(args) -> LLMEdgeManager:
checkpoint_dir = (
canonical_path(args.checkpoint_dir) if args.checkpoint_dir else None
)
adapter_path = canonical_path(args.adapter) if args.adapter else None
params_path = canonical_path(args.params) if args.params else None
output_dir_path = canonical_path(args.output_dir, dir=True)
weight_type = WeightType.FAIRSEQ2 if args.fairseq2 else WeightType.LLAMA

# Convert dtype override string arg to actual type.
dtype_override = DType[args.dtype_override]

# breakpoint() # 1, OK.
edge_manager = _load_llama_model(
args.model,
checkpoint=checkpoint_path,
checkpoint_dir=checkpoint_dir,
adapter=adapter_path,
params_path=params_path,
use_kv_cache=args.use_kv_cache,
use_sdpa_with_kv_cache=args.use_sdpa_with_kv_cache,
Expand All @@ -616,10 +625,16 @@ def _prepare_for_llama_export(args) -> LLMEdgeManager:
dtype_override=dtype_override,
args=args,
)

# At this point, the model is loaded in the default fp32.

# Checkpoint dtype should be lower or equal precision to the dtype override.
eg = torch.tensor([[2, 3, 4]], dtype=torch.int64)
ip = torch.tensor([[0, 1, 2]], dtype=torch.long)

em1 = edge_manager.model.forward(eg, input_pos=ip)
eager = torch.load("/data/users/lfq/executorch/eager_res.pt")
torch.allclose(eager, em1)
# breakpoint() # 4, OK.
checkpoint_dtype = edge_manager.model.checkpoint_dtype
if not (
checkpoint_dtype == dtype_override.to_torch_dtype()
Expand All @@ -637,6 +652,10 @@ def _prepare_for_llama_export(args) -> LLMEdgeManager:
)

edge_manager.model = edge_manager.model.to(dtype=dtype_override.to_torch_dtype())
# edge_manager.model = edge_manager.model.to(dtype=torch.float32)
em2 = edge_manager.model.forward(eg, input_pos=ip)
torch.allclose(em2, eager)
# breakpoint() # 5, not OK, gets converted to bf16. OK if dtype is consistent.

# We want to quantize (in the source transforms) the weights of the model
# in the checkpoint dtype.
Expand All @@ -649,7 +668,9 @@ def _prepare_for_llama_export(args) -> LLMEdgeManager:
args=args,
)
)

# torch.allclose here as well.
em3 = edge_manager.model.forward(eg, input_pos=ip)
torch.allclose(em3, eager)
return edge_manager


Expand Down Expand Up @@ -777,6 +798,9 @@ def _to_edge_and_lower_llama( # noqa: C901
builder_exported_to_edge = builder_exported.pt2e_quantize(
quantizers
).export_to_edge()
breakpoint()
# ^to_edge_res.pt
# allclose 1e-1 compared to pre-auto.

# to_backend
partitioners = []
Expand Down Expand Up @@ -911,7 +935,16 @@ def _export_llama(args) -> LLMEdgeManager: # noqa: C901

# export_to_edge
builder_exported = _prepare_for_llama_export(args).export()
eg = torch.tensor([[2, 3, 4]], dtype=torch.int64)
ip = torch.tensor([[0, 1, 2]], dtype=torch.long)
b_e = builder_exported.model.forward(eg, input_pos=ip)
eager = torch.load("/data/users/lfq/executorch/eager_res.pt")
torch.allclose(b_e, eager)
# breakpoint()

builder_exported.run_canonical_optimizations()
b_e2 = builder_exported.model.forward(eg, input_pos=ip)
torch.allclose(b_e2, eager)
modelname = builder_exported.modelname

if args.export_only:
Expand All @@ -932,6 +965,9 @@ def _export_llama(args) -> LLMEdgeManager: # noqa: C901
args,
)
else:
# breakpoint()
b_e3 = builder_exported.model.forward(eg, input_pos=ip)
torch.allclose(b_e3, eager)
builder = _to_edge_and_lower_llama(
builder_exported,
modelname,
Expand All @@ -941,6 +977,7 @@ def _export_llama(args) -> LLMEdgeManager: # noqa: C901
quant_dtype,
args,
)
breakpoint()

if args.profile_memory:
generate_memory_trace(builder.export_program, "memory_profile.json")
Expand Down Expand Up @@ -1004,6 +1041,7 @@ def _load_llama_model(
*,
checkpoint: Optional[str] = None,
checkpoint_dir: Optional[str] = None,
adapter: Optional[str] = None,
params_path: Optional[str] = None,
use_kv_cache: bool = False,
use_sdpa_with_kv_cache: bool = False,
Expand Down Expand Up @@ -1038,6 +1076,9 @@ def _load_llama_model(
if modelname == "llama3_2_vision":
module_name = "llama3_2_vision"
model_class_name = "Llama3_2Decoder"
if modelname == "llama3_2_lora":
module_name = "llama3_2_lora"
model_class_name = "Llama3_2_Lora"
else:
raise ValueError(f"{modelname} is not a valid Llama model.")
else:
Expand All @@ -1051,6 +1092,7 @@ def _load_llama_model(
model_class_name,
checkpoint=checkpoint,
checkpoint_dir=checkpoint_dir,
adapter=adapter,
params=params_path,
use_kv_cache=use_kv_cache,
use_sdpa_with_kv_cache=use_sdpa_with_kv_cache,
Expand All @@ -1066,6 +1108,7 @@ def _load_llama_model(
)
)

# breakpoint() # 3. OK.
return LLMEdgeManager(
model=model,
modelname=modelname,
Expand Down Expand Up @@ -1093,7 +1136,7 @@ def _load_llama_model(
model.max_seq_len,
# pyre-fixme[6]: For 6th argument expected `ModelArgs` but got
# `Union[Tensor, Module]`.
model.max_context_len,
max_context_len,
# pyre-fixme[6]: For 7th argument expected `int` but got `Union[Tensor,
# Module]`.
model.n_layers,
Expand Down Expand Up @@ -1244,6 +1287,9 @@ def _get_source_transforms( # noqa
if args.vulkan:
transforms.append(replace_with_vulkan_rotary_emb)

# transforms.append(
# replace_rope_with_inference_rope()
# )
return transforms


Expand Down
73 changes: 71 additions & 2 deletions examples/models/llama/source_transformation/sdpa.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,12 +9,16 @@
# Example script for exporting Llama2 to flatbuffer

import math
from typing import Tuple
from typing import Optional, Tuple

import torch

from executorch.examples.models.llama.attention import KVCache, SDPA

# from executorch.extension.llm.modules.attention import SDPA as TTSDPA

from torchtune.modules.attention_utils import _MaskType


class SDPACustom(torch.nn.Module):
def __init__(
Expand Down Expand Up @@ -49,7 +53,7 @@ def forward(
q,
k,
v,
input_pos[0].item(),
input_pos.item(),
None, # Attention mask
0, # dropout probability. Ignored by the code
True, # is_causal
Expand All @@ -60,11 +64,19 @@ def forward(
def _replace_sdpa_with_custom_op(module: torch.nn.Module):
for name, child in module.named_children():
if isinstance(child, SDPA):
breakpoint()
setattr(
module,
name,
SDPACustom(child.dim),
)
# elif isinstance(child, TTSDPA):
# # breakpoint()
# setattr(
# module,
# name,
# SDPAConverter(child.num_heads * child.head_dim),
# )
else:
_replace_sdpa_with_custom_op(child)

Expand All @@ -76,6 +88,63 @@ def replace_sdpa_with_custom_op(module: torch.nn.Module) -> torch.nn.Module:
return module


# Convert from torchtune SDPA to SDPACustom.
class SDPAConverter(torch.nn.Module):
def __init__(
self,
dim: int,
):
super().__init__()
self.dim = dim
self.SDPA = SDPACustom(dim)

def forward(
self,
q: torch.Tensor, # [b, s, n_h, h_d]
k: torch.Tensor, # [b, s, n_kv, h_d]
v: torch.Tensor, # [b, s, n_kv, h_d]
bsz: int,
seq_len: int,
mask: Optional[_MaskType] = None,
):
# input_pos = 0
# Mask isn't used in SDPA?

# Make sure mask isn't None
# take the first row of the mask, number of 0s/Trues. Index of the first non-zero.
# assert mask is not None
if mask is not None:
attention_mask = mask.reshape(-1, max_seq_len)
first_row = attention_mask[0, :]
start_pos = torch.argmin(first_row).item() - 1
else:
start_pos = 0

##
q = q.transpose(1, 2) # (bs, seqlen, n_local_heads, head_dim)
k = k.transpose(1, 2)
v = v.transpose(1, 2)

# Custom op only supports float32 currently. Converting to/from float32 is
# faster than not having the op.
input_dtype = q.dtype
q = q.to(dtype=torch.float)
k = k.to(dtype=torch.float)
v = v.to(dtype=torch.float)

output = torch.ops.llama.custom_sdpa(
q,
k,
v,
start_pos,
mask, # Attention mask
0, # dropout probability. Ignored by the code
True, # is_causal
)
return output.view(bsz, seq_len, self.dim).to(dtype=input_dtype)
# return self.SDPA(start_pos, q, k, v, bsz, seq_len, mask)


class SDPASimple(torch.nn.Module):
def __init__(
self,
Expand Down
11 changes: 11 additions & 0 deletions examples/models/llama3_2_lora/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

from .model import Llama3_2_Lora

__all__ = [
"Llama3_2_Lora",
]
Loading
Loading