Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

export llama with lora #9916

Draft
wants to merge 1 commit into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
61 changes: 59 additions & 2 deletions examples/models/llama/attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -160,6 +160,48 @@ def forward(
return y.transpose(1, 2).contiguous().view(bsz, seqlen, self.dim)


class LoRALinear(nn.Module):
def __init__(
self,
in_dim: int,
out_dim: int,
rank: int,
alpha: float,
dropout: float = 0.0,
use_bias: bool = False,
):
super().__init__()
self.in_dim = in_dim
self.out_dim = out_dim
self.rank = rank
self.alpha = alpha
self.use_bias = use_bias
self.dropout = dropout

# Setup weight and bias
# self.wq = nn.Linear(
# self.dim, self.n_heads * self.head_dim, bias=self.attention_qkv_bias
# )
linear_q = nn.Linear(in_dim, out_dim, bias=use_bias)
weight = linear_q.weight
bias = linear_q.bias if self.use_bias else None
self.register_parameter("weight", nn.Parameter(weight))
self.register_parameter(
"bias", nn.Parameter(bias) if bias is not None else None
)

self.dropout = nn.Dropout(p=dropout) if dropout > 0.0 else nn.Identity()
self.lora_a = nn.Linear(in_features=in_dim, out_features=rank, bias=False)
self.lora_b = nn.Linear(in_features=rank, out_features=out_dim, bias=False)

def forward(self, x: torch.Tensor) -> torch.Tensor:
out = torch.nn.functional.linear(x, self.weight, self.bias)
lora_out = self.lora_a(self.dropout(x))
lora_out = (self.alpha / self.rank) * self.lora_b(lora_out)

return out + lora_out


@register_attention("mha")
class AttentionMHA(Attention):
def __init__(self, args: ModelArgs, layer_id: int, rope: Rope):
Expand All @@ -185,9 +227,19 @@ def __init__(self, args: ModelArgs, layer_id: int, rope: Rope):
self.q_norm_fn = RMSNorm(q_norm_dim, eps=args.norm_eps)
self.k_norm_fn = RMSNorm(k_norm_dim, eps=args.norm_eps)

self.wq = nn.Linear(
self.dim, self.n_heads * self.head_dim, bias=self.attention_qkv_bias
# self.wq = nn.Linear(
# self.dim, self.n_heads * self.head_dim, bias=self.attention_qkv_bias
# )
self.wq = LoRALinear(
in_dim=self.dim,
out_dim=self.n_heads * self.head_dim,
rank=8,
alpha=16.0,
dropout=0.0,
use_bias=self.attention_qkv_bias,
)

# breakpoint()
self.wk = nn.Linear(
self.dim, self.n_kv_heads * self.head_dim, bias=self.attention_qkv_bias
)
Expand Down Expand Up @@ -238,6 +290,10 @@ def forward(

# QKV
q, k, v = self.wq(x), self.wk(x), self.wv(x)

# q_per_kv = self.num_heads // self.num_kv_heads
# q = q.view(b, s_x, self.num_kv_heads * q_per_kv, self.head_dim)

# We need view_copy elimination
q = q.view(bsz, seqlen, self.n_local_heads, self.head_dim)
k = k.view(bsz, seqlen, self.n_local_kv_heads, self.head_dim)
Expand Down Expand Up @@ -268,6 +324,7 @@ def forward(

mask = self.mask[:seqlen, :seqlen]

# Somehow, kv become floats.
output = F.scaled_dot_product_attention(q, k, v, attn_mask=mask, dropout_p=0.0)

output = output.transpose(1, 2).contiguous().view(bsz, seqlen, -1)
Expand Down
37 changes: 34 additions & 3 deletions examples/models/llama/export_llama_lib.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,8 +26,9 @@

from executorch.backends.vulkan._passes.remove_asserts import remove_asserts
from executorch.devtools.backend_debug import print_delegation_info

from executorch.devtools.etrecord import generate_etrecord

from executorch.examples.models.llama.attention import ForwardOptions
from executorch.examples.models.llama.hf_download import (
download_and_convert_hf_checkpoint,
)
Expand Down Expand Up @@ -455,6 +456,18 @@ def build_args_parser() -> argparse.ArgumentParser:
help="Whether the checkpoin is pre-quantized with QAT or not.",
)

parser.add_argument(
"--adapter",
default=None,
help="Adapter path",
)

parser.add_argument(
"--adapter_config",
default=None,
help="Adapter config path",
)

parser.add_argument(
"-lora",
"--use_lora",
Expand Down Expand Up @@ -591,6 +604,7 @@ def _prepare_for_llama_export(args) -> LLMEdgeManager:
checkpoint_dir = (
canonical_path(args.checkpoint_dir) if args.checkpoint_dir else None
)
adapter_path = canonical_path(args.adapter) if args.adapter else None
params_path = canonical_path(args.params) if args.params else None
output_dir_path = canonical_path(args.output_dir, dir=True)
weight_type = WeightType.FAIRSEQ2 if args.fairseq2 else WeightType.LLAMA
Expand All @@ -602,6 +616,7 @@ def _prepare_for_llama_export(args) -> LLMEdgeManager:
args.model,
checkpoint=checkpoint_path,
checkpoint_dir=checkpoint_dir,
adapter=adapter_path,
params_path=params_path,
use_kv_cache=args.use_kv_cache,
use_sdpa_with_kv_cache=args.use_sdpa_with_kv_cache,
Expand Down Expand Up @@ -641,8 +656,8 @@ def _prepare_for_llama_export(args) -> LLMEdgeManager:
logging.warning(
f"Checkpoint dtype {checkpoint_dtype} precision is higher than dtype override {dtype_override.to_torch_dtype()}."
)

edge_manager.model = edge_manager.model.to(dtype=dtype_override.to_torch_dtype())
breakpoint()

# We want to quantize (in the source transforms) the weights of the model
# in the checkpoint dtype.
Expand All @@ -656,10 +671,12 @@ def _prepare_for_llama_export(args) -> LLMEdgeManager:
)
)

breakpoint()

return edge_manager


def get_quantizer_and_quant_params(args):
def get_quantizer_and_quant_params(args):c
pt2e_quant_params = get_pt2e_quantization_params(
args.pt2e_quantize, args.quantization_mode
)
Expand Down Expand Up @@ -948,6 +965,11 @@ def _export_llama(args) -> LLMEdgeManager: # noqa: C901
args,
)
else:
from executorch.examples.models.llama.attention import ForwardOptions
eg = torch.tensor([[2, 3, 4]], dtype=torch.int64)
fw = ForwardOptions(input_pos=torch.tensor([0], dtype=torch.long))
breakpoint()

builder = _to_edge_and_lower_llama(
builder_exported,
modelname,
Expand All @@ -958,6 +980,7 @@ def _export_llama(args) -> LLMEdgeManager: # noqa: C901
args,
)

breakpoint()
if args.profile_memory:
generate_memory_trace(builder.export_program, "memory_profile.json")

Expand Down Expand Up @@ -1020,6 +1043,7 @@ def _load_llama_model(
*,
checkpoint: Optional[str] = None,
checkpoint_dir: Optional[str] = None,
adapter: Optional[str] = None,
params_path: Optional[str] = None,
use_kv_cache: bool = False,
use_sdpa_with_kv_cache: bool = False,
Expand Down Expand Up @@ -1067,6 +1091,7 @@ def _load_llama_model(
model_class_name,
checkpoint=checkpoint,
checkpoint_dir=checkpoint_dir,
adapter=adapter,
params=params_path,
use_kv_cache=use_kv_cache,
use_sdpa_with_kv_cache=use_sdpa_with_kv_cache,
Expand All @@ -1081,6 +1106,11 @@ def _load_llama_model(
args=args,
)
)
eg = torch.tensor([[13347]], dtype=torch.long)
ip = torch.tensor([0], dtype=torch.long)
fw = ForwardOptions(input_pos=ip)
# breakpoint()
# model.forward(eg, fw)

return LLMEdgeManager(
model=model,
Expand Down Expand Up @@ -1206,6 +1236,7 @@ def _get_source_transforms( # noqa
transforms.append(materialze_broadcast_of_rope_freq_cis)

if args.use_sdpa_with_kv_cache:
# here.
transforms.append(replace_kv_cache_with_custom_kv_cache)
transforms.append(replace_sdpa_with_custom_op)

Expand Down
17 changes: 17 additions & 0 deletions examples/models/llama/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,8 @@

from executorch.examples.models.llama.model_args import ModelArgs

from torchtune.models import convert_weights

try:
from .fairseq2 import convert_to_llama_checkpoint

Expand All @@ -45,6 +47,9 @@ def __init__(self, **kwargs):
# Params file.
params_path = kwargs.get("params", None)

# Adapter file.
adapter_path = kwargs.get("adapter", None)

self.use_kv_cache = kwargs.get("use_kv_cache", False)
self.use_sdpa_with_kv_cache_op = kwargs.get("use_sdpa_with_kv_cache", False)
self.generate_full_logits = kwargs.get("generate_full_logits", False)
Expand Down Expand Up @@ -96,6 +101,15 @@ def __init__(self, **kwargs):
elif checkpoint_path:
checkpoint = torch.load(checkpoint_path, map_location=device, mmap=True)

# Load adapter.
if adapter_path:
print("Loading adapter from: ", adapter_path)
adapter = torch.load(adapter_path, map_location=device, mmap=True)
adapter = convert_weights.tune_to_meta(adapter)
# Convert from tune to meta.
# breakpoint()
checkpoint.update(adapter)

# If given checkpoint is fairseq, convert to llama checkpoint.
fairseq2_checkpoint = kwargs.get("fairseq2", False)
if fairseq2_checkpoint:
Expand Down Expand Up @@ -174,8 +188,10 @@ def __init__(self, **kwargs):
with torch.device("meta"):
# Model itself is loaded in default dtype, fp32.
self.model_ = Transformer(model_args)

# Get checkpoint dtype.
if checkpoint:
# breakpoint()
self.model_.checkpoint_dtype = get_checkpoint_dtype(checkpoint)
else:
self.model_.checkpoint_dtype = torch.float32
Expand Down Expand Up @@ -252,6 +268,7 @@ def __init__(self, **kwargs):
# by default initialized to fp32. This is fine because every other supported type
# losslessly converts to fp32, so we don't lose precision here.
if checkpoint:
# breakpoint()
missing, unexpected = self.model_.load_state_dict(
checkpoint,
strict=False,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -283,6 +283,7 @@ def replace_kv_cache_with_custom_kv_cache(module):
def _replace_kv_cache_with_custom_kv_cache(module):
for name, child in module.named_children():
if isinstance(child, KVCache):
# breakpoint()
cache_shape = child.k_cache.shape
cache_dtype = child.k_cache.dtype
max_batch_size, n_heads, max_context_length, head_dim = cache_shape
Expand Down
1 change: 1 addition & 0 deletions extension/llm/export/builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -165,6 +165,7 @@ def source_transform(
list of source transforms.
"""
for transform in transforms:
breakpoint()
self.model = transform(self.model)
self.applied_source_transforms.extend(transforms)

Expand Down
Loading