From 10c9c85b4a5625b1eb2c587fc3d9f36dbce302aa Mon Sep 17 00:00:00 2001 From: lucylq Date: Fri, 4 Apr 2025 11:50:22 -0700 Subject: [PATCH] export llama with lora --- examples/models/llama/attention.py | 61 ++++++++++++++++++- examples/models/llama/export_llama_lib.py | 37 ++++++++++- examples/models/llama/model.py | 17 ++++++ .../quantized_kv_cache.py | 1 + extension/llm/export/builder.py | 1 + 5 files changed, 112 insertions(+), 5 deletions(-) diff --git a/examples/models/llama/attention.py b/examples/models/llama/attention.py index 54f738ba737..1fd17852d73 100644 --- a/examples/models/llama/attention.py +++ b/examples/models/llama/attention.py @@ -160,6 +160,48 @@ def forward( return y.transpose(1, 2).contiguous().view(bsz, seqlen, self.dim) +class LoRALinear(nn.Module): + def __init__( + self, + in_dim: int, + out_dim: int, + rank: int, + alpha: float, + dropout: float = 0.0, + use_bias: bool = False, + ): + super().__init__() + self.in_dim = in_dim + self.out_dim = out_dim + self.rank = rank + self.alpha = alpha + self.use_bias = use_bias + self.dropout = dropout + + # Setup weight and bias + # self.wq = nn.Linear( + # self.dim, self.n_heads * self.head_dim, bias=self.attention_qkv_bias + # ) + linear_q = nn.Linear(in_dim, out_dim, bias=use_bias) + weight = linear_q.weight + bias = linear_q.bias if self.use_bias else None + self.register_parameter("weight", nn.Parameter(weight)) + self.register_parameter( + "bias", nn.Parameter(bias) if bias is not None else None + ) + + self.dropout = nn.Dropout(p=dropout) if dropout > 0.0 else nn.Identity() + self.lora_a = nn.Linear(in_features=in_dim, out_features=rank, bias=False) + self.lora_b = nn.Linear(in_features=rank, out_features=out_dim, bias=False) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + out = torch.nn.functional.linear(x, self.weight, self.bias) + lora_out = self.lora_a(self.dropout(x)) + lora_out = (self.alpha / self.rank) * self.lora_b(lora_out) + + return out + lora_out + + @register_attention("mha") class AttentionMHA(Attention): def __init__(self, args: ModelArgs, layer_id: int, rope: Rope): @@ -185,9 +227,19 @@ def __init__(self, args: ModelArgs, layer_id: int, rope: Rope): self.q_norm_fn = RMSNorm(q_norm_dim, eps=args.norm_eps) self.k_norm_fn = RMSNorm(k_norm_dim, eps=args.norm_eps) - self.wq = nn.Linear( - self.dim, self.n_heads * self.head_dim, bias=self.attention_qkv_bias + # self.wq = nn.Linear( + # self.dim, self.n_heads * self.head_dim, bias=self.attention_qkv_bias + # ) + self.wq = LoRALinear( + in_dim=self.dim, + out_dim=self.n_heads * self.head_dim, + rank=8, + alpha=16.0, + dropout=0.0, + use_bias=self.attention_qkv_bias, ) + + # breakpoint() self.wk = nn.Linear( self.dim, self.n_kv_heads * self.head_dim, bias=self.attention_qkv_bias ) @@ -238,6 +290,10 @@ def forward( # QKV q, k, v = self.wq(x), self.wk(x), self.wv(x) + + # q_per_kv = self.num_heads // self.num_kv_heads + # q = q.view(b, s_x, self.num_kv_heads * q_per_kv, self.head_dim) + # We need view_copy elimination q = q.view(bsz, seqlen, self.n_local_heads, self.head_dim) k = k.view(bsz, seqlen, self.n_local_kv_heads, self.head_dim) @@ -268,6 +324,7 @@ def forward( mask = self.mask[:seqlen, :seqlen] + # Somehow, kv become floats. output = F.scaled_dot_product_attention(q, k, v, attn_mask=mask, dropout_p=0.0) output = output.transpose(1, 2).contiguous().view(bsz, seqlen, -1) diff --git a/examples/models/llama/export_llama_lib.py b/examples/models/llama/export_llama_lib.py index f1c5c3a73f1..739b6086286 100644 --- a/examples/models/llama/export_llama_lib.py +++ b/examples/models/llama/export_llama_lib.py @@ -26,8 +26,9 @@ from executorch.backends.vulkan._passes.remove_asserts import remove_asserts from executorch.devtools.backend_debug import print_delegation_info - from executorch.devtools.etrecord import generate_etrecord + +from executorch.examples.models.llama.attention import ForwardOptions from executorch.examples.models.llama.hf_download import ( download_and_convert_hf_checkpoint, ) @@ -455,6 +456,18 @@ def build_args_parser() -> argparse.ArgumentParser: help="Whether the checkpoin is pre-quantized with QAT or not.", ) + parser.add_argument( + "--adapter", + default=None, + help="Adapter path", + ) + + parser.add_argument( + "--adapter_config", + default=None, + help="Adapter config path", + ) + parser.add_argument( "-lora", "--use_lora", @@ -591,6 +604,7 @@ def _prepare_for_llama_export(args) -> LLMEdgeManager: checkpoint_dir = ( canonical_path(args.checkpoint_dir) if args.checkpoint_dir else None ) + adapter_path = canonical_path(args.adapter) if args.adapter else None params_path = canonical_path(args.params) if args.params else None output_dir_path = canonical_path(args.output_dir, dir=True) weight_type = WeightType.FAIRSEQ2 if args.fairseq2 else WeightType.LLAMA @@ -602,6 +616,7 @@ def _prepare_for_llama_export(args) -> LLMEdgeManager: args.model, checkpoint=checkpoint_path, checkpoint_dir=checkpoint_dir, + adapter=adapter_path, params_path=params_path, use_kv_cache=args.use_kv_cache, use_sdpa_with_kv_cache=args.use_sdpa_with_kv_cache, @@ -641,8 +656,8 @@ def _prepare_for_llama_export(args) -> LLMEdgeManager: logging.warning( f"Checkpoint dtype {checkpoint_dtype} precision is higher than dtype override {dtype_override.to_torch_dtype()}." ) - edge_manager.model = edge_manager.model.to(dtype=dtype_override.to_torch_dtype()) + breakpoint() # We want to quantize (in the source transforms) the weights of the model # in the checkpoint dtype. @@ -656,10 +671,12 @@ def _prepare_for_llama_export(args) -> LLMEdgeManager: ) ) + breakpoint() + return edge_manager -def get_quantizer_and_quant_params(args): +def get_quantizer_and_quant_params(args):c pt2e_quant_params = get_pt2e_quantization_params( args.pt2e_quantize, args.quantization_mode ) @@ -948,6 +965,11 @@ def _export_llama(args) -> LLMEdgeManager: # noqa: C901 args, ) else: + from executorch.examples.models.llama.attention import ForwardOptions + eg = torch.tensor([[2, 3, 4]], dtype=torch.int64) + fw = ForwardOptions(input_pos=torch.tensor([0], dtype=torch.long)) + breakpoint() + builder = _to_edge_and_lower_llama( builder_exported, modelname, @@ -958,6 +980,7 @@ def _export_llama(args) -> LLMEdgeManager: # noqa: C901 args, ) + breakpoint() if args.profile_memory: generate_memory_trace(builder.export_program, "memory_profile.json") @@ -1020,6 +1043,7 @@ def _load_llama_model( *, checkpoint: Optional[str] = None, checkpoint_dir: Optional[str] = None, + adapter: Optional[str] = None, params_path: Optional[str] = None, use_kv_cache: bool = False, use_sdpa_with_kv_cache: bool = False, @@ -1067,6 +1091,7 @@ def _load_llama_model( model_class_name, checkpoint=checkpoint, checkpoint_dir=checkpoint_dir, + adapter=adapter, params=params_path, use_kv_cache=use_kv_cache, use_sdpa_with_kv_cache=use_sdpa_with_kv_cache, @@ -1081,6 +1106,11 @@ def _load_llama_model( args=args, ) ) + eg = torch.tensor([[13347]], dtype=torch.long) + ip = torch.tensor([0], dtype=torch.long) + fw = ForwardOptions(input_pos=ip) + # breakpoint() + # model.forward(eg, fw) return LLMEdgeManager( model=model, @@ -1206,6 +1236,7 @@ def _get_source_transforms( # noqa transforms.append(materialze_broadcast_of_rope_freq_cis) if args.use_sdpa_with_kv_cache: + # here. transforms.append(replace_kv_cache_with_custom_kv_cache) transforms.append(replace_sdpa_with_custom_op) diff --git a/examples/models/llama/model.py b/examples/models/llama/model.py index 19829576482..2cc63cc5c2c 100644 --- a/examples/models/llama/model.py +++ b/examples/models/llama/model.py @@ -19,6 +19,8 @@ from executorch.examples.models.llama.model_args import ModelArgs +from torchtune.models import convert_weights + try: from .fairseq2 import convert_to_llama_checkpoint @@ -45,6 +47,9 @@ def __init__(self, **kwargs): # Params file. params_path = kwargs.get("params", None) + # Adapter file. + adapter_path = kwargs.get("adapter", None) + self.use_kv_cache = kwargs.get("use_kv_cache", False) self.use_sdpa_with_kv_cache_op = kwargs.get("use_sdpa_with_kv_cache", False) self.generate_full_logits = kwargs.get("generate_full_logits", False) @@ -96,6 +101,15 @@ def __init__(self, **kwargs): elif checkpoint_path: checkpoint = torch.load(checkpoint_path, map_location=device, mmap=True) + # Load adapter. + if adapter_path: + print("Loading adapter from: ", adapter_path) + adapter = torch.load(adapter_path, map_location=device, mmap=True) + adapter = convert_weights.tune_to_meta(adapter) + # Convert from tune to meta. + # breakpoint() + checkpoint.update(adapter) + # If given checkpoint is fairseq, convert to llama checkpoint. fairseq2_checkpoint = kwargs.get("fairseq2", False) if fairseq2_checkpoint: @@ -174,8 +188,10 @@ def __init__(self, **kwargs): with torch.device("meta"): # Model itself is loaded in default dtype, fp32. self.model_ = Transformer(model_args) + # Get checkpoint dtype. if checkpoint: + # breakpoint() self.model_.checkpoint_dtype = get_checkpoint_dtype(checkpoint) else: self.model_.checkpoint_dtype = torch.float32 @@ -252,6 +268,7 @@ def __init__(self, **kwargs): # by default initialized to fp32. This is fine because every other supported type # losslessly converts to fp32, so we don't lose precision here. if checkpoint: + # breakpoint() missing, unexpected = self.model_.load_state_dict( checkpoint, strict=False, diff --git a/examples/models/llama/source_transformation/quantized_kv_cache.py b/examples/models/llama/source_transformation/quantized_kv_cache.py index e7138622ed9..46a48fcb2f5 100644 --- a/examples/models/llama/source_transformation/quantized_kv_cache.py +++ b/examples/models/llama/source_transformation/quantized_kv_cache.py @@ -283,6 +283,7 @@ def replace_kv_cache_with_custom_kv_cache(module): def _replace_kv_cache_with_custom_kv_cache(module): for name, child in module.named_children(): if isinstance(child, KVCache): + # breakpoint() cache_shape = child.k_cache.shape cache_dtype = child.k_cache.dtype max_batch_size, n_heads, max_context_length, head_dim = cache_shape diff --git a/extension/llm/export/builder.py b/extension/llm/export/builder.py index cf3a1087cfb..f92d0b0d6ea 100644 --- a/extension/llm/export/builder.py +++ b/extension/llm/export/builder.py @@ -165,6 +165,7 @@ def source_transform( list of source transforms. """ for transform in transforms: + breakpoint() self.model = transform(self.model) self.applied_source_transforms.extend(transforms)