From a6f7ace29e1f7cfc48fccce4cd8869ffbb9dbf84 Mon Sep 17 00:00:00 2001 From: Vinayak Baddi Date: Tue, 6 May 2025 07:52:43 +0000 Subject: [PATCH 1/7] Add Gemma3 Signed-off-by: vbaddi Signed-off-by: Mohit Soni --- .../transformers/models/gemma3/__init__.py | 6 + .../models/gemma3/modeling_gemma3.py | 551 ++++++++++++++++++ .../transformers/models/pytorch_transforms.py | 20 + examples/gemma3_text.py | 66 +++ 4 files changed, 643 insertions(+) create mode 100644 QEfficient/transformers/models/gemma3/__init__.py create mode 100644 QEfficient/transformers/models/gemma3/modeling_gemma3.py create mode 100644 examples/gemma3_text.py diff --git a/QEfficient/transformers/models/gemma3/__init__.py b/QEfficient/transformers/models/gemma3/__init__.py new file mode 100644 index 000000000..72ba36c8a --- /dev/null +++ b/QEfficient/transformers/models/gemma3/__init__.py @@ -0,0 +1,6 @@ +# ----------------------------------------------------------------------------- +# +# Copyright (c) 2025 Qualcomm Innovation Center, Inc. All rights reserved. +# SPDX-License-Identifier: BSD-3-Clause +# +# ----------------------------------------------------------------------------- diff --git a/QEfficient/transformers/models/gemma3/modeling_gemma3.py b/QEfficient/transformers/models/gemma3/modeling_gemma3.py new file mode 100644 index 000000000..0d05acea8 --- /dev/null +++ b/QEfficient/transformers/models/gemma3/modeling_gemma3.py @@ -0,0 +1,551 @@ +# ----------------------------------------------------------------------------- +# +# Copyright (c) 2025 Qualcomm Innovation Center, Inc. All rights reserved. +# SPDX-License-Identifier: BSD-3-Clause +# +# ----------------------------------------------------------------------------- + +import copy +from typing import Optional, Tuple, Union + +import torch +from torch import nn +from transformers.cache_utils import Cache, HybridCache +from transformers.modeling_outputs import ( + BaseModelOutputWithPast, + CausalLMOutputWithPast, +) +from transformers.models.gemma3.modeling_gemma3 import ( + Gemma3Attention, + Gemma3Config, + Gemma3DecoderLayer, + Gemma3ForCausalLM, + Gemma3TextModel, + logger, + repeat_kv, + rotate_half, +) + +from QEfficient.customop.rms_norm import CustomRMSNorm +from QEfficient.transformers.cache_utils import QEffDynamicCache +from QEfficient.transformers.modeling_attn_mask_utils import _create_causal_mask + + +class GemmaRMSNormFunc(torch.autograd.Function): + @staticmethod + def forward(hidden_states: torch.Tensor, weight: torch.Tensor, epsilon: float): + hidden_states = hidden_states.to(torch.float32) + div_first = hidden_states * torch.rsqrt(torch.tensor(hidden_states.shape[-1], dtype=torch.float32)) + variance = div_first.pow(2).sum(-1, keepdim=True) + hidden_states = hidden_states * torch.rsqrt(variance + epsilon) + return weight * hidden_states + + @staticmethod + def setup_context(ctx, inputs, outputs): + pass + + @staticmethod + def symbolic(g: torch.Graph, hidden_states: torch.Value, weight: torch.Value, epsilon: torch.Value) -> torch.Value: + return g.onnxscript_op(CustomRMSNorm, hidden_states, weight, epsilon_f=epsilon).setTypeAs(hidden_states) + + +class QEffGemma3CustomRMSNormAIC(nn.Module): + """ + RMSNorm module that works by replacing the current module with compiler known custom-op. + """ + + def forward(self, hidden_states): + return GemmaRMSNormFunc.apply( + hidden_states, + self.weight.float() + 1.0, + self.variance_epsilon if hasattr(self, "variance_epsilon") else self.eps, + ) + + +class QEffGemma3RotaryEmbedding(nn.Module): + """ + Copied from Gemma2RotaryEmbedding: https://github.com/huggingface/transformers/blob/main/src/transformers/models/gemma2/modeling_gemma2.py + The only differences are: + - Add static sin/cos computations. + """ + + def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None): + super().__init__() + self.dim = dim + self.max_position_embeddings = max_position_embeddings + self.base = base + inv_freq = 1.0 / (self.base ** (torch.arange(0, self.dim, 2, dtype=torch.int64).float().to(device) / self.dim)) + self.register_buffer("inv_freq", inv_freq, persistent=False) + + # Build here to make `torch.jit.trace` work. + self._set_cos_sin_cache( + seq_len=max_position_embeddings, device=self.inv_freq.device, dtype=torch.get_default_dtype() + ) + + def _set_cos_sin_cache(self, seq_len, device, dtype): + self.max_seq_len_cached = seq_len + t = torch.arange(self.max_seq_len_cached, device=device, dtype=torch.int64).type_as(self.inv_freq) + + freqs = torch.outer(t, self.inv_freq) + + emb = torch.cat((freqs, freqs), dim=-1) + self.register_buffer("cos_cached", emb.cos().to(dtype), persistent=False) + self.register_buffer("sin_cached", emb.sin().to(dtype), persistent=False) + + def forward(self, x, seq_len=None): + # x: [bs, num_attention_heads, seq_len, head_size] + if seq_len > self.max_seq_len_cached: + self._set_cos_sin_cache(seq_len=seq_len, device=x.device, dtype=x.dtype) + + return ( + self.cos_cached[:seq_len].to(dtype=x.dtype), + self.sin_cached[:seq_len].to(dtype=x.dtype), + ) + + +def qeff_apply_rotary_pos_emb(q, k, cos, sin, position_ids, unsqueeze_dim=1): + """Applies Rotary Position Embedding to the query and key tensors. + + Args: + q (`torch.Tensor`): The query tensor. + k (`torch.Tensor`): The key tensor. + cos (`torch.Tensor`): The cosine part of the rotary embedding. + sin (`torch.Tensor`): The sine part of the rotary embedding. + position_ids (`torch.Tensor`): + The position indices of the tokens corresponding to the query and key tensors. For example, this can be + used to pass offsetted position ids when working with a KV-cache. + unsqueeze_dim (`int`, *optional*, defaults to 1): + The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and + sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note + that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and + k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes + cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have + the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2. + Returns: + `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding. + """ + cos = cos[position_ids].unsqueeze(unsqueeze_dim) + sin = sin[position_ids].unsqueeze(unsqueeze_dim) + + # Apply rotation + q_embed = (q * cos) + (rotate_half(q) * sin) + k_embed = (k * cos) + (rotate_half(k) * sin) + # Cast back to original dtype + return q_embed.to(q.dtype), k_embed.to(k.dtype) + + +def eager_attention_forward( + module: nn.Module, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + attention_mask: Optional[torch.Tensor], + dropout: float = 0.0, + scaling: Optional[float] = None, + softcap: Optional[float] = None, + **kwargs, +) -> Tuple[torch.Tensor, torch.Tensor]: + key_states = repeat_kv(key, module.num_key_value_groups) + value_states = repeat_kv(value, module.num_key_value_groups) + + attn_weights = torch.matmul(query, key_states.transpose(2, 3)) * scaling + if softcap is not None: + attn_weights = attn_weights / softcap + attn_weights = torch.tanh(attn_weights) + attn_weights = attn_weights * softcap + + if attention_mask is not None: + attn_weights = torch.where(attention_mask, torch.tensor(-10000.0, dtype=torch.float32), attn_weights) + + attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype) + attn_output = torch.matmul(attn_weights, value_states) + attn_output = attn_output.transpose(1, 2).contiguous() + + return attn_output, attn_weights + + +class QEffGemma3Attention(Gemma3Attention): + """Multi-headed attention from 'Attention Is All You Need' paper""" + + def __init__(self, config: Gemma3Config, layer_idx: Optional[int] = None): + super().__init__(config, layer_idx) + # Define the general __qeff_init__() for any changes in the init calls + # Set the init in the module mapping pytorch transforms + self.__qeff_init__() + + def __qeff_init__(self): + self.rotary_emb = QEffGemma3RotaryEmbedding( + self.head_dim, + max_position_embeddings=self.config.max_position_embeddings, + base=self.config.rope_theta, + ) + + config = copy.deepcopy(self.config) + config.rope_theta = config.rope_local_base_freq + config.rope_scaling = {"rope_type": "default"} + + self.rotary_emb_local = QEffGemma3RotaryEmbedding( + self.head_dim, + max_position_embeddings=config.max_position_embeddings, + base=config.rope_theta, + ) + + def forward( + self, + hidden_states: torch.Tensor, + position_embeddings: torch.Tensor, + attention_mask: Optional[torch.Tensor], + position_ids: Optional[torch.LongTensor] = None, + past_key_value: Optional[Cache] = None, + batch_index: Optional[torch.LongTensor] = None, + cache_position: Optional[torch.LongTensor] = None, + **kwargs, + ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: + bsz, q_len, _ = hidden_states.size() + + input_shape = hidden_states.shape[:-1] + hidden_shape = (*input_shape, -1, self.head_dim) + + query_states = self.q_proj(hidden_states).view(hidden_shape).transpose(1, 2) + key_states = self.k_proj(hidden_states).view(hidden_shape).transpose(1, 2) + value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2) + + query_states = self.q_norm(query_states) + key_states = self.k_norm(key_states) + + kv_seq_len = key_states.shape[-2] + if past_key_value is not None: + if self.layer_idx is None: + raise ValueError( + f"The cache structure has changed since version v4.36. If you are using {self.__class__.__name__} " + "for auto-regressive decoding with k/v caching, please make sure to initialize the attention class " + "with a layer index." + ) + kv_seq_len = past_key_value.get_usable_length(kv_seq_len, self.layer_idx) + + if self.is_sliding: + cos, sin = self.rotary_emb_local(value_states, seq_len=kv_seq_len) + else: + cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len) + + query_states, key_states = qeff_apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids) + + if past_key_value is not None: + # sin and cos are specific to RoPE models; cache_position needed for the static cache + cache_kwargs = {"sin": sin, "cos": cos, "batch_index": batch_index, "position_ids": position_ids} + key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) + + key_states = repeat_kv(key_states, self.num_key_value_groups) + value_states = repeat_kv(value_states, self.num_key_value_groups) + + attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) * self.scaling + + # import ipdb; ipdb.set_trace() + if self.config.attn_logit_softcapping is not None: + attn_weights = attn_weights / self.config.attn_logit_softcapping + attn_weights = torch.tanh(attn_weights) + attn_weights = attn_weights * self.config.attn_logit_softcapping + + if attention_mask is not None: # no matter the length, we just slice it + attn_weights = torch.where(attention_mask.bool(), torch.tensor(-10000.0, dtype=torch.float32), attn_weights) + + # upcast attention to fp32 + attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype) + attn_weights = nn.functional.dropout(attn_weights, p=self.attention_dropout, training=self.training) + attn_output = torch.matmul(attn_weights, value_states) + + if attn_output.size() != (bsz, self.config.num_attention_heads, q_len, self.head_dim): + raise ValueError( + f"`attn_output` should be of size {(bsz, self.config.num_attention_heads, q_len, self.head_dim)}, but is" + f" {attn_output.size()}" + ) + + attn_output = attn_output.transpose(1, 2).contiguous() + + attn_output = attn_output.reshape(bsz, q_len, -1) + attn_output = self.o_proj(attn_output) + return attn_output, attn_weights + + +class QEffGemma3DecoderLayer(Gemma3DecoderLayer): + def forward( + self, + hidden_states: torch.Tensor, + position_embeddings_global: torch.Tensor, + position_embeddings_local: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_value: Optional[Cache] = None, + batch_index: Optional[torch.LongTensor] = None, + output_attentions: Optional[bool] = False, + use_cache: Optional[bool] = False, + cache_position: Optional[torch.LongTensor] = None, + last_cache_position: int = 0, + **kwargs, + ) -> tuple[torch.FloatTensor, Optional[tuple[torch.FloatTensor, torch.FloatTensor]]]: + residual = hidden_states + hidden_states = self.input_layernorm(hidden_states) + + # apply global RoPE to non-sliding layer only + if self.self_attn.is_sliding: + position_embeddings = position_embeddings_local + else: + position_embeddings = position_embeddings_global + + hidden_states, self_attn_weights = self.self_attn( + hidden_states=hidden_states, + position_embeddings=position_embeddings, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_value=past_key_value, + batch_index=batch_index, + output_attentions=output_attentions, + use_cache=use_cache, + cache_position=cache_position, + **kwargs, + ) + + hidden_states = self.post_attention_layernorm(hidden_states) + + hidden_states = residual + hidden_states + + residual = hidden_states + hidden_states = self.pre_feedforward_layernorm(hidden_states) + hidden_states = self.mlp(hidden_states) + + hidden_states = self.post_feedforward_layernorm(hidden_states) + hidden_states = residual + hidden_states + + outputs = (hidden_states,) + + if output_attentions: + outputs += (self_attn_weights,) + + return outputs + + +class QEffGemma3TextModel(Gemma3TextModel): + """ + Transformer decoder consisting of *config.num_hidden_layers* layers. Each layer is a [`Gemma3TextDecoderLayer`] + + Args: + config: Gemma3TextConfig + """ + + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[HybridCache] = None, + batch_index: Optional[torch.LongTensor] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + cache_position: Optional[torch.LongTensor] = None, + last_cache_position: Optional[int] = None, + **flash_attn_kwargs, + ) -> Union[Tuple, BaseModelOutputWithPast]: + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + use_cache = use_cache if use_cache is not None else self.config.use_cache + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + if (input_ids is None) ^ (inputs_embeds is not None): + raise ValueError("You must specify exactly one of input_ids or inputs_embeds") + + if self.gradient_checkpointing and self.training and use_cache: + logger.warning_once( + "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`." + ) + use_cache = False + + if inputs_embeds is None: + inputs_embeds = self.embed_tokens(input_ids) + + if use_cache and not isinstance(past_key_values, Cache): # kept for BC (non `Cache` `past_key_values` inputs) + # return_legacy_cache = True + past_key_values = QEffDynamicCache.from_legacy_cache(past_key_values) + + if cache_position is None: + past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 + cache_position = torch.arange( + past_seen_tokens, + past_seen_tokens + inputs_embeds.shape[1], + device=inputs_embeds.device, + ) + + if position_ids is None: + position_ids = cache_position.unsqueeze(0) + + causal_mask = _create_causal_mask( + position_ids=position_ids, target_length=past_seen_tokens, sliding_window=self.config.sliding_window + ) + + # embed positions + hidden_states = inputs_embeds + + # create position embeddings to be shared across the decoder layers + position_embeddings_global = self.rotary_emb(hidden_states, position_ids) + position_embeddings_local = self.rotary_emb_local(hidden_states, position_ids) + + # decoder layers + all_hidden_states = () if output_hidden_states else None + all_self_attns = () if output_attentions else None + for decoder_layer in self.layers[: self.config.num_hidden_layers]: + if output_hidden_states: + all_hidden_states += (hidden_states,) + + if self.gradient_checkpointing and self.training: + layer_outputs = self._gradient_checkpointing_func( + decoder_layer.__call__, + hidden_states, + position_embeddings_global, + position_embeddings_local, + causal_mask, + position_ids, + past_key_values, + output_attentions, + use_cache, + cache_position, + last_cache_position, + ) + else: + layer_outputs = decoder_layer( + hidden_states, + position_embeddings_global=position_embeddings_global, + position_embeddings_local=position_embeddings_local, + attention_mask=causal_mask, + position_ids=position_ids, + past_key_value=past_key_values, + batch_index=batch_index, + output_attentions=output_attentions, + use_cache=use_cache, + cache_position=cache_position, + last_cache_position=last_cache_position, + **flash_attn_kwargs, + ) + + hidden_states = layer_outputs[0] + + if output_attentions: + all_self_attns += (layer_outputs[1],) + + hidden_states = self.norm(hidden_states) + + if output_hidden_states: + all_hidden_states += (hidden_states,) + + if use_cache: + next_cache = past_key_values.to_legacy_cache() + + output = BaseModelOutputWithPast( + last_hidden_state=hidden_states, + past_key_values=next_cache, + hidden_states=all_hidden_states, + attentions=all_self_attns, + ) + return output if return_dict else output.to_tuple() + + +class QEffGemma3ForCausalLMModel(Gemma3ForCausalLM): + def forward( + self, + input_ids: torch.LongTensor = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[HybridCache] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + cache_position: Optional[torch.LongTensor] = None, + logits_to_keep: Union[int, torch.Tensor] = 0, + **loss_kwargs, + ) -> Union[Tuple, CausalLMOutputWithPast]: + r""" + labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Labels for computing the masked language modeling loss. Indices should either be in `[0, ..., + config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored + (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`. + + logits_to_keep (`int` or `torch.Tensor`, *optional*): + If an `int`, compute logits for the last `logits_to_keep` tokens. If `0`, calculate logits for all + `input_ids` (special case). Only last token logits are needed for generation, and calculating them only for that + token can save memory, which becomes pretty significant for long sequences or large vocabulary size. + If a `torch.Tensor`, must be 1D corresponding to the indices to keep in the sequence length dimension. + This is useful when using packed tensor format (single dimension for batch and sequence length). + + Returns: + + Example: + + ```python + >>> from transformers import AutoTokenizer, Gemma3ForCausalLM + + >>> model = Gemma3ForCausalLM.from_pretrained("google/gemma-2-9b") + >>> tokenizer = AutoTokenizer.from_pretrained("google/gemma-2-9b") + + >>> prompt = "What is your favorite condiment?" + >>> inputs = tokenizer(prompt, return_tensors="pt") + + >>> # Generate + >>> generate_ids = model.generate(inputs.input_ids, max_length=30) + >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0] + "What is your favorite condiment?" + ```""" + + if self.training and self.config._attn_implementation != "eager": + logger.warning_once( + "It is strongly recommended to train Gemma3 models with the `eager` attention implementation " + f"instead of `{self.config._attn_implementation}`. Use `eager` with `AutoModelForCausalLM.from_pretrained('', attn_implementation='eager')`." + ) + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + outputs = self.model( + input_ids=input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + cache_position=cache_position, + **loss_kwargs, + ) + + logit_index = position_ids.to(torch.int32).argmax(1, keepdim=True) + hidden_states = outputs[0][torch.arange(position_ids.shape[0]).view(-1, 1), logit_index] + logits = self.lm_head(hidden_states) + if self.config.final_logit_softcapping is not None: + logits = logits / self.config.final_logit_softcapping + logits = torch.tanh(logits) + logits = logits * self.config.final_logit_softcapping + + loss = None + if labels is not None: + loss = self.loss_function(logits, labels, self.vocab_size, **loss_kwargs) + + if not return_dict: + output = (logits,) + outputs[1:] + return (loss,) + output if loss is not None else output + + return CausalLMOutputWithPast( + loss=loss, + logits=logits, + past_key_values=outputs.past_key_values, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) diff --git a/QEfficient/transformers/models/pytorch_transforms.py b/QEfficient/transformers/models/pytorch_transforms.py index 333c734ba..0529c237c 100644 --- a/QEfficient/transformers/models/pytorch_transforms.py +++ b/QEfficient/transformers/models/pytorch_transforms.py @@ -35,6 +35,13 @@ Gemma2Model, Gemma2RMSNorm, ) +from transformers.models.gemma3.modeling_gemma3 import ( + Gemma3Attention, + Gemma3DecoderLayer, + Gemma3ForCausalLM, + Gemma3RMSNorm, + Gemma3TextModel, +) from transformers.models.gpt2.modeling_gpt2 import GPT2Attention, GPT2Block, GPT2LMHeadModel, GPT2Model from transformers.models.gpt_bigcode.modeling_gpt_bigcode import ( GPTBigCodeAttention, @@ -157,6 +164,13 @@ QEffGemma2ForCausalLM, QEffGemma2Model, ) +from QEfficient.transformers.models.gemma3.modeling_gemma3 import ( + QEffGemma3Attention, + QEffGemma3CustomRMSNormAIC, + QEffGemma3DecoderLayer, + QEffGemma3ForCausalLMModel, + QEffGemma3TextModel, +) from QEfficient.transformers.models.gpt2.modeling_gpt2 import ( QEffGPT2Attention, QEffGPT2Block, @@ -284,11 +298,17 @@ class CustomOpsTransform(ModuleMappingTransform): MllamaTextRMSNorm: CustomRMSNormAIC, GraniteRMSNorm: CustomRMSNormAIC, GraniteMoeRMSNorm: CustomRMSNormAIC, + Gemma3RMSNorm: QEffGemma3CustomRMSNormAIC, } class KVCacheTransform(ModuleMappingTransform): _module_mapping = { + # Gemma3 + Gemma3Attention: QEffGemma3Attention, + Gemma3DecoderLayer: QEffGemma3DecoderLayer, + Gemma3TextModel: QEffGemma3TextModel, + Gemma3ForCausalLM: QEffGemma3ForCausalLMModel, # CodeGen CodeGenAttention: QEffCodeGenAttention, CodeGenBlock: QeffCodeGenBlock, diff --git a/examples/gemma3_text.py b/examples/gemma3_text.py new file mode 100644 index 000000000..07aafa88b --- /dev/null +++ b/examples/gemma3_text.py @@ -0,0 +1,66 @@ +# ----------------------------------------------------------------------------- +# +# Copyright (c) 2025 Qualcomm Innovation Center, Inc. All rights reserved. +# SPDX-License-Identifier: BSD-3-Clause +# +# ----------------------------------------------------------------------------- + +import torch +from transformers import Gemma3ForCausalLM +from transformers.models.gemma3.modeling_gemma3 import Gemma3RMSNorm + +from QEfficient import QEFFAutoModelForCausalLM +from QEfficient.utils._utils import load_hf_tokenizer +from QEfficient.utils.constants import Constants +from QEfficient.utils.run_utils import ApiRunner + + +def add_named_scopes(model): + for name, module in model.named_modules(): + if isinstance(module, Gemma3RMSNorm): + module._onnx_scope_name = f"/{name}" + + +torch.manual_seed(42) +model_id = "google/gemma-3-4b-it" +model = Gemma3ForCausalLM.from_pretrained( + model_id, torch_dtype=torch.float32, use_cache=True, attn_implementation="eager" +) +model.eval() + +tokenizer = load_hf_tokenizer(pretrained_model_name_or_path=model_id) +config = model.config +batch_size = len(Constants.INPUT_STR) +api_runner = ApiRunner( + batch_size, + tokenizer, + config, + Constants.INPUT_STR, + Constants.PROMPT_LEN, + Constants.CTX_LEN, +) +pytorch_hf_tokens = api_runner.run_hf_model_on_pytorch(model) +qeff_model = QEFFAutoModelForCausalLM(model) +pytorch_kv_tokens = api_runner.run_kv_model_on_pytorch(qeff_model.model) +assert ( + pytorch_hf_tokens == pytorch_kv_tokens +).all(), "Tokens don't match for HF PyTorch model output and KV PyTorch model output" + +# add_named_scopes(qeff_model.model) +onnx_model_path = qeff_model.export() +ort_tokens = api_runner.run_kv_model_on_ort(onnx_model_path, is_tlm=False) +assert (pytorch_kv_tokens == ort_tokens).all(), "Tokens don't match for ONNXRT output and PyTorch output." + +qpc_path = qeff_model.compile( + prefill_seq_len=Constants.PROMPT_LEN, + ctx_len=Constants.CTX_LEN, + num_cores=16, + mxfp6_matmul=False, + mxint8_kv_cache=False, + num_devices=1, + mos=1, + aic_enable_depth_first=True, + num_speculative_tokens=None, +) +print(f"qpc path is {qpc_path}") +exec_info = qeff_model.generate(tokenizer, prompts=Constants.INPUT_STR, device_ids=[0]) From 73e31f2e180a10aad12595a0a57c70a9a00f076e Mon Sep 17 00:00:00 2001 From: Vinayak Baddi Date: Thu, 8 May 2025 06:10:24 +0000 Subject: [PATCH 2/7] nit: update example script with node precision file Signed-off-by: vbaddi Signed-off-by: Mohit Soni --- .../fp32_nodes_gemma3_text.yaml | 275 ++++++++++++++++++ .../{ => causal_lm_examples}/gemma3_text.py | 20 +- 2 files changed, 276 insertions(+), 19 deletions(-) create mode 100644 examples/causal_lm_examples/fp32_nodes_gemma3_text.yaml rename examples/{ => causal_lm_examples}/gemma3_text.py (66%) diff --git a/examples/causal_lm_examples/fp32_nodes_gemma3_text.yaml b/examples/causal_lm_examples/fp32_nodes_gemma3_text.yaml new file mode 100644 index 000000000..494486e68 --- /dev/null +++ b/examples/causal_lm_examples/fp32_nodes_gemma3_text.yaml @@ -0,0 +1,275 @@ +FP32NodeInstanceNames: + - /model/layers.0/Add_1_output_0 + - /model/layers.0/Add_output_0 + - /model/layers.1/Add_1_output_0 + - /model/layers.1/Add_output_0 + - /model/layers.10/Add_1_output_0 + - /model/layers.10/Add_output_0 + - /model/layers.11/Add_1_output_0 + - /model/layers.11/Add_output_0 + - /model/layers.12/Add_1_output_0 + - /model/layers.12/Add_output_0 + - /model/layers.13/Add_1_output_0 + - /model/layers.13/Add_output_0 + - /model/layers.14/Add_1_output_0 + - /model/layers.14/Add_output_0 + - /model/layers.15/Add_1_output_0 + - /model/layers.15/Add_output_0 + - /model/layers.16/Add_1_output_0 + - /model/layers.16/Add_output_0 + - /model/layers.17/Add_1_output_0 + - /model/layers.17/Add_output_0 + - /model/layers.18/Add_1_output_0 + - /model/layers.18/Add_output_0 + - /model/layers.19/Add_1_output_0 + - /model/layers.19/Add_output_0 + - /model/layers.2/Add_1_output_0 + - /model/layers.2/Add_output_0 + - /model/layers.20/Add_1_output_0 + - /model/layers.20/Add_output_0 + - /model/layers.21/Add_1_output_0 + - /model/layers.21/Add_output_0 + - /model/layers.22/Add_1_output_0 + - /model/layers.22/Add_output_0 + - /model/layers.23/Add_1_output_0 + - /model/layers.23/Add_output_0 + - /model/layers.24/Add_1_output_0 + - /model/layers.24/Add_output_0 + - /model/layers.25/Add_1_output_0 + - /model/layers.25/Add_output_0 + - /model/layers.26/Add_1_output_0 + - /model/layers.26/Add_output_0 + - /model/layers.27/Add_1_output_0 + - /model/layers.27/Add_output_0 + - /model/layers.28/Add_1_output_0 + - /model/layers.28/Add_output_0 + - /model/layers.29/Add_1_output_0 + - /model/layers.29/Add_output_0 + - /model/layers.3/Add_1_output_0 + - /model/layers.3/Add_output_0 + - /model/layers.30/Add_1_output_0 + - /model/layers.30/Add_output_0 + - /model/layers.31/Add_1_output_0 + - /model/layers.31/Add_output_0 + - /model/layers.32/Add_1_output_0 + - /model/layers.32/Add_output_0 + - /model/layers.33/Add_1_output_0 + - /model/layers.33/Add_output_0 + - /model/layers.4/Add_1_output_0 + - /model/layers.4/Add_output_0 + - /model/layers.5/Add_1_output_0 + - /model/layers.5/Add_output_0 + - /model/layers.6/Add_1_output_0 + - /model/layers.6/Add_output_0 + - /model/layers.7/Add_1_output_0 + - /model/layers.7/Add_output_0 + - /model/layers.8/Add_1_output_0 + - /model/layers.8/Add_output_0 + - /model/layers.9/Add_1_output_0 + - /model/layers.9/Add_output_0 + - /model/norm/Add_output_0 + - /model/layers.0/input_layernorm/CustomRMSNorm_output_0 + - /model/layers.0/post_attention_layernorm/CustomRMSNorm_output_0 + - /model/layers.0/post_feedforward_layernorm/CustomRMSNorm_output_0 + - /model/layers.0/pre_feedforward_layernorm/CustomRMSNorm_output_0 + - /model/layers.0/self_attn/k_norm/CustomRMSNorm_output_0 + - /model/layers.0/self_attn/q_norm/CustomRMSNorm_output_0 + - /model/layers.1/input_layernorm/CustomRMSNorm_output_0 + - /model/layers.1/post_attention_layernorm/CustomRMSNorm_output_0 + - /model/layers.1/post_feedforward_layernorm/CustomRMSNorm_output_0 + - /model/layers.1/pre_feedforward_layernorm/CustomRMSNorm_output_0 + - /model/layers.1/self_attn/k_norm/CustomRMSNorm_output_0 + - /model/layers.1/self_attn/q_norm/CustomRMSNorm_output_0 + - /model/layers.10/input_layernorm/CustomRMSNorm_output_0 + - /model/layers.10/post_attention_layernorm/CustomRMSNorm_output_0 + - /model/layers.10/post_feedforward_layernorm/CustomRMSNorm_output_0 + - /model/layers.10/pre_feedforward_layernorm/CustomRMSNorm_output_0 + - /model/layers.10/self_attn/k_norm/CustomRMSNorm_output_0 + - /model/layers.10/self_attn/q_norm/CustomRMSNorm_output_0 + - /model/layers.11/input_layernorm/CustomRMSNorm_output_0 + - /model/layers.11/post_attention_layernorm/CustomRMSNorm_output_0 + - /model/layers.11/post_feedforward_layernorm/CustomRMSNorm_output_0 + - /model/layers.11/pre_feedforward_layernorm/CustomRMSNorm_output_0 + - /model/layers.11/self_attn/k_norm/CustomRMSNorm_output_0 + - /model/layers.11/self_attn/q_norm/CustomRMSNorm_output_0 + - /model/layers.12/input_layernorm/CustomRMSNorm_output_0 + - /model/layers.12/post_attention_layernorm/CustomRMSNorm_output_0 + - /model/layers.12/post_feedforward_layernorm/CustomRMSNorm_output_0 + - /model/layers.12/pre_feedforward_layernorm/CustomRMSNorm_output_0 + - /model/layers.12/self_attn/k_norm/CustomRMSNorm_output_0 + - /model/layers.12/self_attn/q_norm/CustomRMSNorm_output_0 + - /model/layers.13/input_layernorm/CustomRMSNorm_output_0 + - /model/layers.13/post_attention_layernorm/CustomRMSNorm_output_0 + - /model/layers.13/post_feedforward_layernorm/CustomRMSNorm_output_0 + - /model/layers.13/pre_feedforward_layernorm/CustomRMSNorm_output_0 + - /model/layers.13/self_attn/k_norm/CustomRMSNorm_output_0 + - /model/layers.13/self_attn/q_norm/CustomRMSNorm_output_0 + - /model/layers.14/input_layernorm/CustomRMSNorm_output_0 + - /model/layers.14/post_attention_layernorm/CustomRMSNorm_output_0 + - /model/layers.14/post_feedforward_layernorm/CustomRMSNorm_output_0 + - /model/layers.14/pre_feedforward_layernorm/CustomRMSNorm_output_0 + - /model/layers.14/self_attn/k_norm/CustomRMSNorm_output_0 + - /model/layers.14/self_attn/q_norm/CustomRMSNorm_output_0 + - /model/layers.15/input_layernorm/CustomRMSNorm_output_0 + - /model/layers.15/post_attention_layernorm/CustomRMSNorm_output_0 + - /model/layers.15/post_feedforward_layernorm/CustomRMSNorm_output_0 + - /model/layers.15/pre_feedforward_layernorm/CustomRMSNorm_output_0 + - /model/layers.15/self_attn/k_norm/CustomRMSNorm_output_0 + - /model/layers.15/self_attn/q_norm/CustomRMSNorm_output_0 + - /model/layers.16/input_layernorm/CustomRMSNorm_output_0 + - /model/layers.16/post_attention_layernorm/CustomRMSNorm_output_0 + - /model/layers.16/post_feedforward_layernorm/CustomRMSNorm_output_0 + - /model/layers.16/pre_feedforward_layernorm/CustomRMSNorm_output_0 + - /model/layers.16/self_attn/k_norm/CustomRMSNorm_output_0 + - /model/layers.16/self_attn/q_norm/CustomRMSNorm_output_0 + - /model/layers.17/input_layernorm/CustomRMSNorm_output_0 + - /model/layers.17/post_attention_layernorm/CustomRMSNorm_output_0 + - /model/layers.17/post_feedforward_layernorm/CustomRMSNorm_output_0 + - /model/layers.17/pre_feedforward_layernorm/CustomRMSNorm_output_0 + - /model/layers.17/self_attn/k_norm/CustomRMSNorm_output_0 + - /model/layers.17/self_attn/q_norm/CustomRMSNorm_output_0 + - /model/layers.18/input_layernorm/CustomRMSNorm_output_0 + - /model/layers.18/post_attention_layernorm/CustomRMSNorm_output_0 + - /model/layers.18/post_feedforward_layernorm/CustomRMSNorm_output_0 + - /model/layers.18/pre_feedforward_layernorm/CustomRMSNorm_output_0 + - /model/layers.18/self_attn/k_norm/CustomRMSNorm_output_0 + - /model/layers.18/self_attn/q_norm/CustomRMSNorm_output_0 + - /model/layers.19/input_layernorm/CustomRMSNorm_output_0 + - /model/layers.19/post_attention_layernorm/CustomRMSNorm_output_0 + - /model/layers.19/post_feedforward_layernorm/CustomRMSNorm_output_0 + - /model/layers.19/pre_feedforward_layernorm/CustomRMSNorm_output_0 + - /model/layers.19/self_attn/k_norm/CustomRMSNorm_output_0 + - /model/layers.19/self_attn/q_norm/CustomRMSNorm_output_0 + - /model/layers.2/input_layernorm/CustomRMSNorm_output_0 + - /model/layers.2/post_attention_layernorm/CustomRMSNorm_output_0 + - /model/layers.2/post_feedforward_layernorm/CustomRMSNorm_output_0 + - /model/layers.2/pre_feedforward_layernorm/CustomRMSNorm_output_0 + - /model/layers.2/self_attn/k_norm/CustomRMSNorm_output_0 + - /model/layers.2/self_attn/q_norm/CustomRMSNorm_output_0 + - /model/layers.20/input_layernorm/CustomRMSNorm_output_0 + - /model/layers.20/post_attention_layernorm/CustomRMSNorm_output_0 + - /model/layers.20/post_feedforward_layernorm/CustomRMSNorm_output_0 + - /model/layers.20/pre_feedforward_layernorm/CustomRMSNorm_output_0 + - /model/layers.20/self_attn/k_norm/CustomRMSNorm_output_0 + - /model/layers.20/self_attn/q_norm/CustomRMSNorm_output_0 + - /model/layers.21/input_layernorm/CustomRMSNorm_output_0 + - /model/layers.21/post_attention_layernorm/CustomRMSNorm_output_0 + - /model/layers.21/post_feedforward_layernorm/CustomRMSNorm_output_0 + - /model/layers.21/pre_feedforward_layernorm/CustomRMSNorm_output_0 + - /model/layers.21/self_attn/k_norm/CustomRMSNorm_output_0 + - /model/layers.21/self_attn/q_norm/CustomRMSNorm_output_0 + - /model/layers.22/input_layernorm/CustomRMSNorm_output_0 + - /model/layers.22/post_attention_layernorm/CustomRMSNorm_output_0 + - /model/layers.22/post_feedforward_layernorm/CustomRMSNorm_output_0 + - /model/layers.22/pre_feedforward_layernorm/CustomRMSNorm_output_0 + - /model/layers.22/self_attn/k_norm/CustomRMSNorm_output_0 + - /model/layers.22/self_attn/q_norm/CustomRMSNorm_output_0 + - /model/layers.23/input_layernorm/CustomRMSNorm_output_0 + - /model/layers.23/post_attention_layernorm/CustomRMSNorm_output_0 + - /model/layers.23/post_feedforward_layernorm/CustomRMSNorm_output_0 + - /model/layers.23/pre_feedforward_layernorm/CustomRMSNorm_output_0 + - /model/layers.23/self_attn/k_norm/CustomRMSNorm_output_0 + - /model/layers.23/self_attn/q_norm/CustomRMSNorm_output_0 + - /model/layers.24/input_layernorm/CustomRMSNorm_output_0 + - /model/layers.24/post_attention_layernorm/CustomRMSNorm_output_0 + - /model/layers.24/post_feedforward_layernorm/CustomRMSNorm_output_0 + - /model/layers.24/pre_feedforward_layernorm/CustomRMSNorm_output_0 + - /model/layers.24/self_attn/k_norm/CustomRMSNorm_output_0 + - /model/layers.24/self_attn/q_norm/CustomRMSNorm_output_0 + - /model/layers.25/input_layernorm/CustomRMSNorm_output_0 + - /model/layers.25/post_attention_layernorm/CustomRMSNorm_output_0 + - /model/layers.25/post_feedforward_layernorm/CustomRMSNorm_output_0 + - /model/layers.25/pre_feedforward_layernorm/CustomRMSNorm_output_0 + - /model/layers.25/self_attn/k_norm/CustomRMSNorm_output_0 + - /model/layers.25/self_attn/q_norm/CustomRMSNorm_output_0 + - /model/layers.26/input_layernorm/CustomRMSNorm_output_0 + - /model/layers.26/post_attention_layernorm/CustomRMSNorm_output_0 + - /model/layers.26/post_feedforward_layernorm/CustomRMSNorm_output_0 + - /model/layers.26/pre_feedforward_layernorm/CustomRMSNorm_output_0 + - /model/layers.26/self_attn/k_norm/CustomRMSNorm_output_0 + - /model/layers.26/self_attn/q_norm/CustomRMSNorm_output_0 + - /model/layers.27/input_layernorm/CustomRMSNorm_output_0 + - /model/layers.27/post_attention_layernorm/CustomRMSNorm_output_0 + - /model/layers.27/post_feedforward_layernorm/CustomRMSNorm_output_0 + - /model/layers.27/pre_feedforward_layernorm/CustomRMSNorm_output_0 + - /model/layers.27/self_attn/k_norm/CustomRMSNorm_output_0 + - /model/layers.27/self_attn/q_norm/CustomRMSNorm_output_0 + - /model/layers.28/input_layernorm/CustomRMSNorm_output_0 + - /model/layers.28/post_attention_layernorm/CustomRMSNorm_output_0 + - /model/layers.28/post_feedforward_layernorm/CustomRMSNorm_output_0 + - /model/layers.28/pre_feedforward_layernorm/CustomRMSNorm_output_0 + - /model/layers.28/self_attn/k_norm/CustomRMSNorm_output_0 + - /model/layers.28/self_attn/q_norm/CustomRMSNorm_output_0 + - /model/layers.29/input_layernorm/CustomRMSNorm_output_0 + - /model/layers.29/post_attention_layernorm/CustomRMSNorm_output_0 + - /model/layers.29/post_feedforward_layernorm/CustomRMSNorm_output_0 + - /model/layers.29/pre_feedforward_layernorm/CustomRMSNorm_output_0 + - /model/layers.29/self_attn/k_norm/CustomRMSNorm_output_0 + - /model/layers.29/self_attn/q_norm/CustomRMSNorm_output_0 + - /model/layers.3/input_layernorm/CustomRMSNorm_output_0 + - /model/layers.3/post_attention_layernorm/CustomRMSNorm_output_0 + - /model/layers.3/post_feedforward_layernorm/CustomRMSNorm_output_0 + - /model/layers.3/pre_feedforward_layernorm/CustomRMSNorm_output_0 + - /model/layers.3/self_attn/k_norm/CustomRMSNorm_output_0 + - /model/layers.3/self_attn/q_norm/CustomRMSNorm_output_0 + - /model/layers.30/input_layernorm/CustomRMSNorm_output_0 + - /model/layers.30/post_attention_layernorm/CustomRMSNorm_output_0 + - /model/layers.30/post_feedforward_layernorm/CustomRMSNorm_output_0 + - /model/layers.30/pre_feedforward_layernorm/CustomRMSNorm_output_0 + - /model/layers.30/self_attn/k_norm/CustomRMSNorm_output_0 + - /model/layers.30/self_attn/q_norm/CustomRMSNorm_output_0 + - /model/layers.31/input_layernorm/CustomRMSNorm_output_0 + - /model/layers.31/post_attention_layernorm/CustomRMSNorm_output_0 + - /model/layers.31/post_feedforward_layernorm/CustomRMSNorm_output_0 + - /model/layers.31/pre_feedforward_layernorm/CustomRMSNorm_output_0 + - /model/layers.31/self_attn/k_norm/CustomRMSNorm_output_0 + - /model/layers.31/self_attn/q_norm/CustomRMSNorm_output_0 + - /model/layers.32/input_layernorm/CustomRMSNorm_output_0 + - /model/layers.32/post_attention_layernorm/CustomRMSNorm_output_0 + - /model/layers.32/post_feedforward_layernorm/CustomRMSNorm_output_0 + - /model/layers.32/pre_feedforward_layernorm/CustomRMSNorm_output_0 + - /model/layers.32/self_attn/k_norm/CustomRMSNorm_output_0 + - /model/layers.32/self_attn/q_norm/CustomRMSNorm_output_0 + - /model/layers.33/input_layernorm/CustomRMSNorm_output_0 + - /model/layers.33/post_attention_layernorm/CustomRMSNorm_output_0 + - /model/layers.33/post_feedforward_layernorm/CustomRMSNorm_output_0 + - /model/layers.33/pre_feedforward_layernorm/CustomRMSNorm_output_0 + - /model/layers.33/self_attn/k_norm/CustomRMSNorm_output_0 + - /model/layers.33/self_attn/q_norm/CustomRMSNorm_output_0 + - /model/layers.4/input_layernorm/CustomRMSNorm_output_0 + - /model/layers.4/post_attention_layernorm/CustomRMSNorm_output_0 + - /model/layers.4/post_feedforward_layernorm/CustomRMSNorm_output_0 + - /model/layers.4/pre_feedforward_layernorm/CustomRMSNorm_output_0 + - /model/layers.4/self_attn/k_norm/CustomRMSNorm_output_0 + - /model/layers.4/self_attn/q_norm/CustomRMSNorm_output_0 + - /model/layers.5/input_layernorm/CustomRMSNorm_output_0 + - /model/layers.5/post_attention_layernorm/CustomRMSNorm_output_0 + - /model/layers.5/post_feedforward_layernorm/CustomRMSNorm_output_0 + - /model/layers.5/pre_feedforward_layernorm/CustomRMSNorm_output_0 + - /model/layers.5/self_attn/k_norm/CustomRMSNorm_output_0 + - /model/layers.5/self_attn/q_norm/CustomRMSNorm_output_0 + - /model/layers.6/input_layernorm/CustomRMSNorm_output_0 + - /model/layers.6/post_attention_layernorm/CustomRMSNorm_output_0 + - /model/layers.6/post_feedforward_layernorm/CustomRMSNorm_output_0 + - /model/layers.6/pre_feedforward_layernorm/CustomRMSNorm_output_0 + - /model/layers.6/self_attn/k_norm/CustomRMSNorm_output_0 + - /model/layers.6/self_attn/q_norm/CustomRMSNorm_output_0 + - /model/layers.7/input_layernorm/CustomRMSNorm_output_0 + - /model/layers.7/post_attention_layernorm/CustomRMSNorm_output_0 + - /model/layers.7/post_feedforward_layernorm/CustomRMSNorm_output_0 + - /model/layers.7/pre_feedforward_layernorm/CustomRMSNorm_output_0 + - /model/layers.7/self_attn/k_norm/CustomRMSNorm_output_0 + - /model/layers.7/self_attn/q_norm/CustomRMSNorm_output_0 + - /model/layers.8/input_layernorm/CustomRMSNorm_output_0 + - /model/layers.8/post_attention_layernorm/CustomRMSNorm_output_0 + - /model/layers.8/post_feedforward_layernorm/CustomRMSNorm_output_0 + - /model/layers.8/pre_feedforward_layernorm/CustomRMSNorm_output_0 + - /model/layers.8/self_attn/k_norm/CustomRMSNorm_output_0 + - /model/layers.8/self_attn/q_norm/CustomRMSNorm_output_0 + - /model/layers.9/input_layernorm/CustomRMSNorm_output_0 + - /model/layers.9/post_attention_layernorm/CustomRMSNorm_output_0 + - /model/layers.9/post_feedforward_layernorm/CustomRMSNorm_output_0 + - /model/layers.9/pre_feedforward_layernorm/CustomRMSNorm_output_0 + - /model/layers.9/self_attn/k_norm/CustomRMSNorm_output_0 + - /model/layers.9/self_attn/q_norm/CustomRMSNorm_output_0 + - /model/norm/CustomRMSNorm_output_0 \ No newline at end of file diff --git a/examples/gemma3_text.py b/examples/causal_lm_examples/gemma3_text.py similarity index 66% rename from examples/gemma3_text.py rename to examples/causal_lm_examples/gemma3_text.py index 07aafa88b..278a2a28a 100644 --- a/examples/gemma3_text.py +++ b/examples/causal_lm_examples/gemma3_text.py @@ -12,7 +12,6 @@ from QEfficient import QEFFAutoModelForCausalLM from QEfficient.utils._utils import load_hf_tokenizer from QEfficient.utils.constants import Constants -from QEfficient.utils.run_utils import ApiRunner def add_named_scopes(model): @@ -29,27 +28,9 @@ def add_named_scopes(model): model.eval() tokenizer = load_hf_tokenizer(pretrained_model_name_or_path=model_id) -config = model.config -batch_size = len(Constants.INPUT_STR) -api_runner = ApiRunner( - batch_size, - tokenizer, - config, - Constants.INPUT_STR, - Constants.PROMPT_LEN, - Constants.CTX_LEN, -) -pytorch_hf_tokens = api_runner.run_hf_model_on_pytorch(model) qeff_model = QEFFAutoModelForCausalLM(model) -pytorch_kv_tokens = api_runner.run_kv_model_on_pytorch(qeff_model.model) -assert ( - pytorch_hf_tokens == pytorch_kv_tokens -).all(), "Tokens don't match for HF PyTorch model output and KV PyTorch model output" -# add_named_scopes(qeff_model.model) onnx_model_path = qeff_model.export() -ort_tokens = api_runner.run_kv_model_on_ort(onnx_model_path, is_tlm=False) -assert (pytorch_kv_tokens == ort_tokens).all(), "Tokens don't match for ONNXRT output and PyTorch output." qpc_path = qeff_model.compile( prefill_seq_len=Constants.PROMPT_LEN, @@ -61,6 +42,7 @@ def add_named_scopes(model): mos=1, aic_enable_depth_first=True, num_speculative_tokens=None, + node_precision_info="fp32_nodes_gemma3_text.yaml", ) print(f"qpc path is {qpc_path}") exec_info = qeff_model.generate(tokenizer, prompts=Constants.INPUT_STR, device_ids=[0]) From d19ee2b37e53605b5e405a9eda58070b5f9ea593 Mon Sep 17 00:00:00 2001 From: Vinayak Baddi Date: Thu, 8 May 2025 10:36:51 +0000 Subject: [PATCH 3/7] nit: add multi modal modeling changes Signed-off-by: vbaddi Signed-off-by: Mohit Soni --- .../models/gemma3/modeling_gemma3.py | 213 ++++++++++++++++++ .../transformers/models/pytorch_transforms.py | 13 +- 2 files changed, 221 insertions(+), 5 deletions(-) diff --git a/QEfficient/transformers/models/gemma3/modeling_gemma3.py b/QEfficient/transformers/models/gemma3/modeling_gemma3.py index 0d05acea8..4bd18f311 100644 --- a/QEfficient/transformers/models/gemma3/modeling_gemma3.py +++ b/QEfficient/transformers/models/gemma3/modeling_gemma3.py @@ -20,6 +20,7 @@ Gemma3Config, Gemma3DecoderLayer, Gemma3ForCausalLM, + Gemma3ForConditionalGeneration, Gemma3TextModel, logger, repeat_kv, @@ -29,6 +30,8 @@ from QEfficient.customop.rms_norm import CustomRMSNorm from QEfficient.transformers.cache_utils import QEffDynamicCache from QEfficient.transformers.modeling_attn_mask_utils import _create_causal_mask +from QEfficient.utils import constants +from QEfficient.utils._utils import IOInfo, get_padding_shape_from_config class GemmaRMSNormFunc(torch.autograd.Function): @@ -549,3 +552,213 @@ def forward( hidden_states=outputs.hidden_states, attentions=outputs.attentions, ) + + +class QEffGemma3EncoderWrapper(nn.Module): + def __init__(self, model): + super().__init__() + self.model = model + self.model.vision_model = self.model.vision_tower + + def forward(self, input_ids, pixel_values): + inputs_embeds = self.model.get_input_embeddings()(input_ids) + B, N, C = inputs_embeds.shape + image_features = self.model.get_image_features(pixel_values=pixel_values) + selected = input_ids == self.model.config.image_token_index + indices1 = selected.to(torch.int64).cumsum(1) - 1 + indices0 = torch.arange(selected.unsqueeze(0).shape[0]).view(-1, 1) + image_features_expanded = image_features.reshape(-1, C).unsqueeze(0)[indices0, indices1] + image_input_embeds = torch.where(selected.unsqueeze(-1), image_features_expanded, inputs_embeds) + return image_input_embeds + + +class QEffGemma3DecoderWrapper(nn.Module): + def __init__(self, model): + super().__init__() + self.model = model + self.language_model = self.model.language_model + self.config = self.model.config + + def forward(self, input_ids, vision_embeds, position_ids, past_key_values): + image_embeds = vision_embeds[:, : input_ids.shape[1], :] + inputs_embeds = self.model.language_model.get_input_embeddings()(input_ids) + inputs_embeds = torch.where(input_ids.shape[1] == torch.tensor(1), inputs_embeds, image_embeds) + outputs = self.model.language_model( + inputs_embeds=inputs_embeds, position_ids=position_ids, past_key_values=past_key_values, use_cache=True + ) + return outputs.logits, vision_embeds, outputs.past_key_values + + +class QEffGemma3ForConditionalGeneration(Gemma3ForConditionalGeneration): + def get_qeff_vision_encoder(self): + return QEffGemma3EncoderWrapper(self) + + def get_qeff_language_decoder(self): + return QEffGemma3DecoderWrapper(self) + + def get_specializations( + self, + batch_size: int, + prefill_seq_len: int, + ctx_len: int, + img_size: int, + kv_offload: bool = False, + **compiler_options, + ): + vision_seq_len = compiler_options.pop("vision_seq_len", None) + if vision_seq_len is None: + # TODO: Check properly for Gemma3, Not verified yet. + vision_seq_len = 2560 # for Gemma3 Vision feature shape is (1, 4096, 1152) --> 1152 is hidden size) + + prefill_seq_len = prefill_seq_len if prefill_seq_len else 32 + ctx_len = ctx_len if ctx_len else constants.INTERN_CTX_LEN + if img_size is None and hasattr(self.config.vision_config, "image_size"): + img_size = getattr(self.config.vision_config, "image_size") + elif img_size is None: + img_size = 896 # FIXME based on gemma3 Image size + logger.warning("Setting img_size to be 336, as it was neither passed nor found in vision_config") + + vision = [ + { + "batch_size": batch_size, + "img_size": img_size, + "seq_len": vision_seq_len, + "ctx_len": ctx_len, + } + ] + lang = [ + { + "batch_size": batch_size, + "seq_len": prefill_seq_len, + "ctx_len": ctx_len, + "img_size": img_size, + "chunk_length": prefill_seq_len, + }, + { + "batch_size": batch_size, + "seq_len": "1", + "ctx_len": ctx_len, + "img_size": img_size, + "chunk_length": prefill_seq_len, + }, + ] + + specializations = {} + + if kv_offload: + specializations["vision"] = vision + specializations["lang"] = lang + return specializations, compiler_options + else: + return lang, compiler_options + + def get_onnx_dynamic_axes(self, kv_offload: bool = False): + # Define dynamic axes + vision_dynamic_axes = {} + lang_dynamic_axes = {} + lang_dynamic_axes["input_ids"] = {0: "batch_size", 1: "seq_len"} + lang_dynamic_axes["position_ids"] = {0: "batch_size", 1: "seq_len"} + lang_dynamic_axes["vision_embeds"] = {0: "batch_size", 1: "chunk_length"} + vision_dynamic_axes["pixel_values"] = {0: "batch_size", 2: "img_size", 3: "img_size"} + vision_dynamic_axes["input_ids"] = {0: "batch_size", 1: "seq_len"} + + pkv_dynamic_axes = {0: "batch_size", 2: "ctx_len"} + for i in range(self.language_model.config.num_hidden_layers): + for kv in ["key", "value"]: + lang_dynamic_axes[f"past_{kv}.{i}"] = pkv_dynamic_axes + + dynamic_axes = {} + if kv_offload: + dynamic_axes["vision"] = vision_dynamic_axes + dynamic_axes["lang"] = lang_dynamic_axes + else: + dynamic_axes = {**vision_dynamic_axes, **lang_dynamic_axes} + return dynamic_axes + + def get_output_names(self, kv_offload: bool = False): + vision_output_names = ["vision_embeds"] + lang_output_names = ["logits"] + for i in range(self.language_model.config.num_hidden_layers): + for kv in ["key", "value"]: + lang_output_names.append(f"past_{kv}.{i}_RetainedState") + + output_names = {} + if kv_offload: + lang_output_names.insert(1, "vision_embeds_RetainedState") + output_names["vision"] = vision_output_names + output_names["lang"] = lang_output_names + else: + lang_output_names.insert(1, "pixel_values_RetainedState") + return lang_output_names + return output_names + + def get_dummy_inputs(self, kv_offload: bool = False): + if vis_cfg := getattr(self.config, "vision_config", None): + img_size = getattr(vis_cfg, "image_size", 896) + else: + img_size = 896 + + # Define shapes + inputs_shapes = {} + inputs_shapes["input_ids"] = (constants.ONNX_EXPORT_EXAMPLE_BATCH_SIZE, constants.ONNX_EXPORT_EXAMPLE_SEQ_LEN) + inputs_shapes["vision_embeds"] = ( + 1, # constants.INTERN_NUM_PATCHES, + constants.ONNX_EXPORT_EXAMPLE_SEQ_LEN, # constants.INTERN_FEATURE_SIZE, + self.language_model.config.hidden_size, # 5120 + ) + inputs_shapes["position_ids"] = ( + constants.ONNX_EXPORT_EXAMPLE_BATCH_SIZE, + constants.ONNX_EXPORT_EXAMPLE_SEQ_LEN, + ) + inputs_shapes["pixel_values"] = ( + constants.ONNX_EXPORT_EXAMPLE_BATCH_SIZE, + constants.INTERN_NUM_CHANNELS, + img_size, + img_size, + ) + + # Define inputs + vision_inputs = {} + lang_inputs = {} + vision_inputs["pixel_values"] = torch.zeros((inputs_shapes["pixel_values"]), dtype=torch.float32) + vision_inputs["input_ids"] = torch.zeros((inputs_shapes["input_ids"]), dtype=torch.int64) + lang_inputs["input_ids"] = torch.zeros((inputs_shapes["input_ids"]), dtype=torch.int64) + lang_inputs["vision_embeds"] = torch.zeros((inputs_shapes["vision_embeds"]), dtype=torch.float32) + lang_inputs["position_ids"] = ( + torch.arange(constants.ONNX_EXPORT_EXAMPLE_SEQ_LEN, dtype=torch.int64) + .view(1, constants.ONNX_EXPORT_EXAMPLE_SEQ_LEN) + .repeat(constants.ONNX_EXPORT_EXAMPLE_BATCH_SIZE, 1) + ) + + # Add data for KV + kv_cache_shape = get_padding_shape_from_config( + config=self.language_model.config, + batch_size=constants.ONNX_EXPORT_EXAMPLE_BATCH_SIZE, + seq_len=constants.ONNX_EXPORT_EXAMPLE_SEQ_LEN, + ) + + lang_inputs["past_key_values"] = [[] for _ in range(self.language_model.config.num_hidden_layers)] + for i in range(self.language_model.config.num_hidden_layers): + for kv in ["key", "value"]: + lang_inputs["past_key_values"][i].append(torch.zeros(kv_cache_shape, dtype=torch.float32)) + + inputs = {} + if kv_offload: + inputs["vision"] = vision_inputs + inputs["lang"] = lang_inputs + else: + lang_inputs.pop("vision_embeds") + inputs = {**vision_inputs, **lang_inputs} + + return inputs + + def get_inputs_info(self): + return [ + IOInfo(name="input_ids", datatype=torch.int64, shape=("batch_size", "seq_len")), + IOInfo(name="attention_mask", datatype=torch.int64, shape=("batch_size", "seq_len")), + IOInfo( + name="pixel_values", + datatype=torch.float32, + shape=("batch_size", 3, "img_size", "img_size"), + ), + ] diff --git a/QEfficient/transformers/models/pytorch_transforms.py b/QEfficient/transformers/models/pytorch_transforms.py index 0529c237c..b088f28c3 100644 --- a/QEfficient/transformers/models/pytorch_transforms.py +++ b/QEfficient/transformers/models/pytorch_transforms.py @@ -39,6 +39,7 @@ Gemma3Attention, Gemma3DecoderLayer, Gemma3ForCausalLM, + Gemma3ForConditionalGeneration, Gemma3RMSNorm, Gemma3TextModel, ) @@ -169,6 +170,7 @@ QEffGemma3CustomRMSNormAIC, QEffGemma3DecoderLayer, QEffGemma3ForCausalLMModel, + QEffGemma3ForConditionalGeneration, QEffGemma3TextModel, ) from QEfficient.transformers.models.gpt2.modeling_gpt2 import ( @@ -304,11 +306,6 @@ class CustomOpsTransform(ModuleMappingTransform): class KVCacheTransform(ModuleMappingTransform): _module_mapping = { - # Gemma3 - Gemma3Attention: QEffGemma3Attention, - Gemma3DecoderLayer: QEffGemma3DecoderLayer, - Gemma3TextModel: QEffGemma3TextModel, - Gemma3ForCausalLM: QEffGemma3ForCausalLMModel, # CodeGen CodeGenAttention: QEffCodeGenAttention, CodeGenBlock: QeffCodeGenBlock, @@ -348,6 +345,12 @@ class KVCacheTransform(ModuleMappingTransform): Gemma2DecoderLayer: QEffGemma2DecoderLayer, Gemma2Model: QEffGemma2Model, Gemma2ForCausalLM: QEffGemma2ForCausalLM, + # Gemma3 + Gemma3Attention: QEffGemma3Attention, + Gemma3DecoderLayer: QEffGemma3DecoderLayer, + Gemma3TextModel: QEffGemma3TextModel, + Gemma3ForCausalLM: QEffGemma3ForCausalLMModel, + Gemma3ForConditionalGeneration: QEffGemma3ForConditionalGeneration, # Granite GraniteModel: QEffGraniteModel, GraniteForCausalLM: QEffGraniteForCausalLM, From 80ef2cacfafe293287a7c1ae5fa7b81b95626047 Mon Sep 17 00:00:00 2001 From: Mohit Soni Date: Fri, 9 May 2025 14:07:53 +0530 Subject: [PATCH 4/7] Updating Chunking method (#398) Signed-off-by: Mohit Soni Signed-off-by: Mohit Soni --- .../models/gemma3/modeling_gemma3.py | 2 +- .../transformers/models/modeling_auto.py | 22 +++++++++++++------ 2 files changed, 16 insertions(+), 8 deletions(-) diff --git a/QEfficient/transformers/models/gemma3/modeling_gemma3.py b/QEfficient/transformers/models/gemma3/modeling_gemma3.py index 4bd18f311..58b837e9c 100644 --- a/QEfficient/transformers/models/gemma3/modeling_gemma3.py +++ b/QEfficient/transformers/models/gemma3/modeling_gemma3.py @@ -608,7 +608,7 @@ def get_specializations( vision_seq_len = compiler_options.pop("vision_seq_len", None) if vision_seq_len is None: # TODO: Check properly for Gemma3, Not verified yet. - vision_seq_len = 2560 # for Gemma3 Vision feature shape is (1, 4096, 1152) --> 1152 is hidden size) + vision_seq_len = 512 # for Gemma3 Vision feature shape is (1, 4096, 1152) --> 1152 is hidden size) prefill_seq_len = prefill_seq_len if prefill_seq_len else 32 ctx_len = ctx_len if ctx_len else constants.INTERN_CTX_LEN diff --git a/QEfficient/transformers/models/modeling_auto.py b/QEfficient/transformers/models/modeling_auto.py index 6b5deb8db..1a9610187 100644 --- a/QEfficient/transformers/models/modeling_auto.py +++ b/QEfficient/transformers/models/modeling_auto.py @@ -751,8 +751,8 @@ def kv_offload_generate( input_len = inputs["attention_mask"].sum(1, keepdims=True) input_ids_length = inputs["input_ids"].shape[1] num_chunks = -(input_ids_length // -prefill_seq_len) # ceil divide without float - padded_len = num_chunks * prefill_seq_len # Convert to a multiple of prompt_len - + # padded_len = num_chunks * prefill_seq_len # Convert to a multiple of prompt_len + padded_len = vision_session.bindings[vision_session.binding_index_map["input_ids"]].dims[1] if generation_len is None: generation_len = ctx_len - input_len.max() assert generation_len > 0, "generation length should be greater than zero" @@ -783,18 +783,22 @@ def kv_offload_generate( } vision_inputs["pixel_values"] = vision_inputs["pixel_values"].astype("float16") + vision_inputs["input_ids"] = inputs["input_ids"] + vision_start = perf_counter() vision_outputs = vision_session.run(vision_inputs) + vision_end = perf_counter() lang_inputs = {k: v for k, v in inputs.items() if k not in vision_inputs} + lang_inputs["input_ids"] = inputs["input_ids"] lang_inputs["position_ids"] = np.where( lang_inputs.pop("attention_mask"), np.arange(padded_len), -1 ) # Need to use -1 as position_ids for invalid tokens vision_session.deactivate() lang_session.activate() - - lang_session.set_buffers(vision_outputs) - + lang_inputs["vision_embeds"] = vision_outputs["vision_embeds"] + # lang_session.set_buffers(vision_outputs) + prefill_start = perf_counter() # Run prefill for i in range(num_chunks): chunk_inputs = lang_inputs.copy() @@ -802,9 +806,13 @@ def kv_offload_generate( chunk_inputs["position_ids"] = lang_inputs["position_ids"][ :, i * prefill_seq_len : (i + 1) * prefill_seq_len ] + chunk_inputs["vision_embeds"] = lang_inputs["vision_embeds"][ + :, i * prefill_seq_len : (i + 1) * prefill_seq_len + ] outputs = lang_session.run(chunk_inputs) - prefill_time = perf_counter() - prefill_start + prefill_time = perf_counter() - prefill_start + vision_end - vision_start + lang_inputs["vision_embeds"] = lang_inputs["vision_embeds"][:, :prefill_seq_len] # Skip inputs/outputs again lang_session.skip_buffers( [x for x in lang_session.input_names + lang_session.output_names if x.startswith("past_")] @@ -838,7 +846,7 @@ def kv_offload_generate( streamer.end() decode_perf = (num_token - 1) / (decode_end - decode_start) - total_time = decode_end - prefill_start + total_time = decode_end - decode_start + prefill_time total_perf = num_token / total_time return CloudAI100ExecInfoNew( From 7c0b51c2698e850d82aacdad2bc0c950adc39061 Mon Sep 17 00:00:00 2001 From: Mohit Soni Date: Wed, 14 May 2025 14:03:24 +0530 Subject: [PATCH 5/7] Gemma3 Adding Merging and Chunking in DecoderWrapper (#402) Signed-off-by: Rishin Raj Signed-off-by: Mohit Soni Signed-off-by: Abukhoyer Shaik Signed-off-by: Asmita Goswami Signed-off-by: vbaddi Signed-off-by: Meet Patel Co-authored-by: Rishin Raj Co-authored-by: Abukhoyer Shaik Co-authored-by: asmigosw Co-authored-by: Vinayak Baddi <68580231+vbaddi@users.noreply.github.com> Co-authored-by: Meet Patel Signed-off-by: Mohit Soni --- QEfficient/cloud/compile.py | 30 ++- QEfficient/cloud/finetune.py | 229 +++++++++++++----- QEfficient/cloud/infer.py | 4 + QEfficient/compile/compile_helper.py | 19 +- QEfficient/finetune/configs/peft_config.py | 21 +- QEfficient/finetune/configs/training.py | 48 +++- QEfficient/finetune/eval.py | 5 +- QEfficient/finetune/utils/config_utils.py | 178 ++++++++++++-- QEfficient/finetune/utils/train_utils.py | 21 +- .../models/gemma3/modeling_gemma3.py | 51 ++-- .../transformers/models/modeling_auto.py | 27 ++- QEfficient/utils/_utils.py | 124 ++++++---- QEfficient/utils/constants.py | 5 +- scripts/Jenkinsfile | 2 +- scripts/finetune/run_ft_model.py | 4 +- tests/finetune/test_finetune.py | 46 +++- tests/transformers/spd/test_pld_inference.py | 4 +- tests/transformers/spd/test_spd_inference.py | 6 +- 18 files changed, 596 insertions(+), 228 deletions(-) diff --git a/QEfficient/cloud/compile.py b/QEfficient/cloud/compile.py index 8b6da5b0b..5f0b9140c 100644 --- a/QEfficient/cloud/compile.py +++ b/QEfficient/cloud/compile.py @@ -85,17 +85,29 @@ parser.add_argument( "--enable_qnn", "--enable-qnn", - action="store_true", + nargs="?", + const=True, + type=str, default=False, help="Enables QNN. Optionally, a configuration file can be provided with [--enable_qnn CONFIG_FILE].\ If not provided, the default configuration will be used.\ Sample Config: QEfficient/compile/qnn_config.json", ) - parser.add_argument( - "qnn_config", - nargs="?", - type=str, - ) - # FIXME(ochougul): Allow extra compilation arguments - args = parser.parse_args() - QEfficient.compile(**vars(args)) + + args, compiler_options = parser.parse_known_args() + + if isinstance(args.enable_qnn, str): + args.qnn_config = args.enable_qnn + args.enable_qnn = True + + compiler_options_dict = {} + for i in range(0, len(compiler_options)): + if compiler_options[i].startswith("--"): + key = compiler_options[i].lstrip("-").replace("-", "_") + value = ( + compiler_options[i + 1] + if i + 1 < len(compiler_options) and not compiler_options[i + 1].startswith("-") + else True + ) + compiler_options_dict[key] = value + QEfficient.compile(**args.__dict__, **compiler_options_dict) diff --git a/QEfficient/cloud/finetune.py b/QEfficient/cloud/finetune.py index f312d00cb..c440e73c0 100644 --- a/QEfficient/cloud/finetune.py +++ b/QEfficient/cloud/finetune.py @@ -7,6 +7,7 @@ import random import warnings +from typing import Any, Dict, Optional, Union import fire import numpy as np @@ -17,8 +18,9 @@ import torch.utils.data from peft import PeftModel, get_peft_model from torch.optim.lr_scheduler import StepLR +from transformers import AutoModel, AutoModelForCausalLM, AutoTokenizer -from QEfficient.finetune.configs.training import train_config as TRAIN_CONFIG +from QEfficient.finetune.configs.training import TrainConfig from QEfficient.finetune.utils.config_utils import ( generate_dataset_config, generate_peft_config, @@ -32,52 +34,81 @@ from QEfficient.finetune.utils.train_utils import get_longest_seq_length, print_model_size, train from QEfficient.utils._utils import login_and_download_hf_lm +# Try importing QAIC-specific module, proceed without it if unavailable try: import torch_qaic # noqa: F401 except ImportError as e: - print(f"Warning: {e}. Moving ahead without these qaic modules.") + print(f"Warning: {e}. Proceeding without QAIC modules.") -from transformers import AutoModelForCausalLM, AutoModelForSequenceClassification, AutoTokenizer +from transformers import AutoModelForSequenceClassification # Suppress all warnings warnings.filterwarnings("ignore") -def main(**kwargs): - """ - Helper function to finetune the model on QAic. +def setup_distributed_training(train_config: TrainConfig) -> None: + """Initialize distributed training environment if enabled. - .. code-block:: bash + Args: + train_config (TrainConfig): Training configuration object. - python -m QEfficient.cloud.finetune OPTIONS + Notes: + - If distributed data parallel (DDP) is disabled, this function does nothing. + - Ensures the device is not CPU and does not specify an index for DDP compatibility. + - Initializes the process group using the specified distributed backend. + Raises: + AssertionError: If device is CPU or includes an index with DDP enabled. """ - # update the configuration for the training process - train_config = TRAIN_CONFIG() - update_config(train_config, **kwargs) - dataset_config = generate_dataset_config(train_config, kwargs) - device = train_config.device + if not train_config.enable_ddp: + return - # dist init - if train_config.enable_ddp: - # TODO: may have to init qccl backend, next try run with torchrun command - torch_device = torch.device(device) - assert torch_device.type != "cpu", "Host doesn't support single-node DDP" - assert torch_device.index is None, ( - f"DDP requires specification of device type only, however provided device index as well: {torch_device}" - ) - dist.init_process_group(backend=train_config.dist_backend) - # from here onward "qaic/cuda" will automatically map to "qaic:i/cuda:i", where i = process rank - getattr(torch, torch_device.type).set_device(dist.get_rank()) + torch_device = torch.device(train_config.device) + assert torch_device.type != "cpu", "Host doesn't support single-node DDP" + assert torch_device.index is None, f"DDP requires only device type, got: {torch_device}" + + dist.init_process_group(backend=train_config.dist_backend) + # from here onward "qaic/cuda" will automatically map to "qaic:i/cuda:i", where i = process rank + getattr(torch, torch_device.type).set_device(dist.get_rank()) - # Set the seeds for reproducibility - torch.manual_seed(train_config.seed) - random.seed(train_config.seed) - np.random.seed(train_config.seed) - # Load the pre-trained model and setup its configuration - # config = AutoConfig.from_pretrained(train_config.model_name) +def setup_seeds(seed: int) -> None: + """Set random seeds across libraries for reproducibility. + + Args: + seed (int): Seed value to set for random number generators. + + Notes: + - Sets seeds for PyTorch, Python's random module, and NumPy. + """ + torch.manual_seed(seed) + random.seed(seed) + np.random.seed(seed) + + +def load_model_and_tokenizer( + train_config: TrainConfig, dataset_config: Any, peft_config_file: str, **kwargs +) -> tuple[AutoModelForCausalLM, AutoTokenizer]: + """Load the pre-trained model and tokenizer from Hugging Face. + + Args: + config (TrainConfig): Training configuration object containing model and tokenizer names. + dataset_config (Any): A dataclass object representing dataset configuration. + peft_config_file (str): Path to PEFT config file used for PEFT finetuning. + kwargs: Additional arguments to override PEFT config. + + Returns: + tuple: A tuple of two values. + - Model with pretrained weights loaded. + - Model's tokenizer (AutoTokenizer). + + Notes: + - Downloads the model if not already cached using login_and_download_hf_lm. + - Configures the model with FP16 precision and disables caching for training. + - Resizes model embeddings if tokenizer vocab size exceeds model embedding size. + - Sets pad_token_id to eos_token_id if not defined in the tokenizer. + """ pretrained_model_path = login_and_download_hf_lm(train_config.model_name) if train_config.task_type == "seq_classification": model = AutoModelForSequenceClassification.from_pretrained( @@ -104,7 +135,6 @@ def main(**kwargs): torch_dtype=torch.float16, ) - # Load the tokenizer and add special tokens tokenizer = AutoTokenizer.from_pretrained( train_config.model_name if train_config.tokenizer_name is None else train_config.tokenizer_name ) @@ -114,14 +144,12 @@ def main(**kwargs): # If there is a mismatch between tokenizer vocab size and embedding matrix, # throw a warning and then expand the embedding matrix if len(tokenizer) > model.get_input_embeddings().weight.shape[0]: - print("WARNING: Resizing the embedding matrix to match the tokenizer vocab size.") + print("WARNING: Resizing embedding matrix to match tokenizer vocab size.") model.resize_token_embeddings(len(tokenizer)) + # FIXME (Meet): Cover below line inside the logger once it is implemented. print_model_size(model, train_config) - # print the datatype of the model parameters - # print(get_parameter_dtypes(model)) - # Note: Need to call this before calling PeftModel.from_pretrained or get_peft_model. # Because, both makes model.is_gradient_checkpointing = True which is used in peft library to # apply gradient checkpointing related hooks to the input embeddings. Without this we will get @@ -134,17 +162,70 @@ def main(**kwargs): else: raise RuntimeError("Given model doesn't support gradient checkpointing. Please disable it and run it.") - if train_config.use_peft: - # Load the pre-trained peft model checkpoint and setup its configuration - if train_config.from_peft_checkpoint: - model = PeftModel.from_pretrained(model, train_config.from_peft_checkpoint, is_trainable=True) - peft_config = model.peft_config - # Generate the peft config and start fine-tuning from original model - else: - peft_config = generate_peft_config(train_config, kwargs) - model = get_peft_model(model, peft_config) - model.print_trainable_parameters() + model = apply_peft(model, train_config, peft_config_file, **kwargs) + + return model, tokenizer + + +def apply_peft( + model: AutoModel, train_config: TrainConfig, peft_config_file: Dict, **kwargs +) -> Union[AutoModel, PeftModel]: + """Apply Parameter-Efficient Fine-Tuning (PEFT) to the model if enabled. + + Args: + model (AutoModel): Huggingface model. + train_config (TrainConfig): Training configuration object. + peft_config_file (str, optional): Path to YAML/JSON file containing + PEFT (LoRA) config. Defaults to None. + kwargs: Additional arguments to override PEFT config params. + Returns: + Union[AutoModel, PeftModel]: If the use_peft in train_config is True + then PeftModel object is returned else original model object + (AutoModel) is returned. + """ + if not train_config.use_peft: + return model + + # Load the pre-trained peft model checkpoint and setup its configuration + if train_config.from_peft_checkpoint: + model = PeftModel.from_pretrained(model, train_config.from_peft_checkpoint, is_trainable=True) + peft_config = model.peft_config + # Generate the peft config and start fine-tuning from original model + else: + peft_config = generate_peft_config(train_config, peft_config_file, **kwargs) + model = get_peft_model(model, peft_config) + model.print_trainable_parameters() + + return model + + +def setup_dataloaders( + train_config: TrainConfig, + dataset_config: Any, + tokenizer: AutoTokenizer, +) -> tuple[torch.utils.data.DataLoader, Optional[torch.utils.data.DataLoader], int]: + """Set up training and validation DataLoaders. + + Args: + train_config (TrainConfig): Training configuration object. + dataset_config (Any): Configuration for the dataset (generated from train_config). + tokenizer (AutoTokenizer): Tokenizer for preprocessing data. + + Returns: + tuple: A tuple of three values. + - First value represents train_dataloader + - Second value represents eval_dataloader. It is None if + validation is disabled. + - Length of longest sequence in the dataset. + + Raises: + ValueError: If validation is enabled but the validation set is too small. + + Notes: + - Applies a custom data collator if provided by get_custom_data_collator. + - Configures DataLoader kwargs using get_dataloader_kwargs for train and val splits. + """ # Get the dataset utils dataset_processer = tokenizer @@ -164,6 +245,8 @@ def main(**kwargs): ## train_dl_kwargs = get_dataloader_kwargs(train_config, dataset_train, dataset_processer, "train") print("length of dataset_train", len(dataset_train)) + + # FIXME (Meet): Add custom data collator registration from the outside by the user. custom_data_collator = get_custom_data_collator(dataset_processer, dataset_config) if custom_data_collator: print("custom_data_collator is used") @@ -208,40 +291,66 @@ def main(**kwargs): else: longest_seq_length, _ = get_longest_seq_length(train_dataloader.dataset) + return train_dataloader, eval_dataloader, longest_seq_length + + +def main(peft_config_file: str = None, **kwargs) -> None: + """ + Fine-tune a model on QAIC hardware with configurable training and LoRA parameters. + + Args: + peft_config_file (str, optional): Path to YAML/JSON file containing PEFT (LoRA) config. Defaults to None. + kwargs: Additional arguments to override TrainConfig. + + Example: + .. code-block:: bash + + # Using a YAML config file for PEFT + python -m QEfficient.cloud.finetune \\ + --model_name "meta-llama/Llama-3.2-1B" \\ + --lr 5e-4 \\ + --peft_config_file "lora_config.yaml" + + # Using default LoRA config + python -m QEfficient.cloud.finetune \\ + --model_name "meta-llama/Llama-3.2-1B" \\ + --lr 5e-4 + """ + train_config = TrainConfig() + update_config(train_config, **kwargs) + dataset_config = generate_dataset_config(train_config.dataset) + update_config(dataset_config, **kwargs) + + setup_distributed_training(train_config) + setup_seeds(train_config.seed) + model, tokenizer = load_model_and_tokenizer(train_config, dataset_config, peft_config_file, **kwargs) + + # Create DataLoaders for the training and validation dataset + train_dataloader, eval_dataloader, longest_seq_length = setup_dataloaders(train_config, dataset_config, tokenizer) print( f"The longest sequence length in the train data is {longest_seq_length}, " f"passed context length is {train_config.context_length} and overall model's context length is " f"{model.config.max_position_embeddings}" ) + model.to(train_config.device) - optimizer = optim.AdamW( - model.parameters(), - lr=train_config.lr, - weight_decay=train_config.weight_decay, - ) + optimizer = optim.AdamW(model.parameters(), lr=train_config.lr, weight_decay=train_config.weight_decay) scheduler = StepLR(optimizer, step_size=1, gamma=train_config.gamma) - - # wrap model with DDP if train_config.enable_ddp: model = nn.parallel.DistributedDataParallel(model, device_ids=[dist.get_rank()]) - - _ = train( + results = train( model, + tokenizer, train_dataloader, eval_dataloader, - tokenizer, optimizer, scheduler, - train_config.gradient_accumulation_steps, train_config, - train_config.device, dist.get_rank() if train_config.enable_ddp else None, - None, ) - - # finalize torch distributed if train_config.enable_ddp: dist.destroy_process_group() + return results if __name__ == "__main__": diff --git a/QEfficient/cloud/infer.py b/QEfficient/cloud/infer.py index 68be72fa8..30e67344a 100644 --- a/QEfficient/cloud/infer.py +++ b/QEfficient/cloud/infer.py @@ -197,6 +197,10 @@ def main( **kwargs, ) + # If the io-encrypt flag is passed we will exit after QPC generation. + if kwargs.get("io_encrypt", None): + exit() + ######### # Execute ######### diff --git a/QEfficient/compile/compile_helper.py b/QEfficient/compile/compile_helper.py index 5ce22bed9..70a912cd7 100644 --- a/QEfficient/compile/compile_helper.py +++ b/QEfficient/compile/compile_helper.py @@ -64,9 +64,6 @@ def compile_kv_model_on_cloud_ai_100( DeprecationWarning, stacklevel=2, ) - if kwargs: - # FIXME - raise NotImplementedError("Can't handle extra compilation args now!") aic_binary_dir = os.path.join(base_path, "qpcs") if os.path.isdir(aic_binary_dir): @@ -111,6 +108,13 @@ def compile_kv_model_on_cloud_ai_100( with open(mdp_ts_config_path, "w") as file: json.dump(mdp_ts_config, file, indent=4) command.append(f"-mdp-load-partition-config={mdp_ts_config_path}") + for key, value in kwargs.items(): + option = "-" + key.replace("_", "-") + if isinstance(value, bool): + if value: + command.append(option) + continue + command.append(f"{option}={value}") print("Running AI 100 compiler:", " ".join(command)) result = subprocess.run(command, capture_output=True, text=True) if result.returncode != 0: @@ -221,6 +225,13 @@ def compile( allow_mxint8_mdp_io=allow_mxint8_mdp_io, mos=mos, device_group=device_group, + **kwargs, ) - logger.info(f"Compiled QPC files can be found here: {qpc_path}") + if kwargs.get("io_encrypt", None): + logger.warning( + f"Compilation for IO-Encrypt has been successfully completed at path: {qpc_path}. However, Efficient-Transformers do not support IO-Encrypt execution. Please run the execution separately" + ) + else: + logger.info(f"Compiled QPC files can be found here: {qpc_path}") + return qpc_path diff --git a/QEfficient/finetune/configs/peft_config.py b/QEfficient/finetune/configs/peft_config.py index e2d018f05..a47774500 100644 --- a/QEfficient/finetune/configs/peft_config.py +++ b/QEfficient/finetune/configs/peft_config.py @@ -9,15 +9,24 @@ from typing import List -# Currently, the support is for Lora Configs only -# In future, we can expand to llama_adapters and prefix tuning -# TODO: vbaddi: Check back once FSDP is enabled @dataclass -class lora_config: +class LoraConfig: + """LoRA-specific configuration for parameter-efficient fine-tuning. + + Attributes: + r (int): LoRA rank (default: 8). + lora_alpha (int): LoRA scaling factor (default: 32). + target_modules (List[str]): Modules to apply LoRA to (default: ["q_proj", "v_proj"]). + bias (str): Bias handling in LoRA (default: "none"). + task_type (str): Task type for LoRA (default: "CAUSAL_LM"). + lora_dropout (float): Dropout rate for LoRA (default: 0.0). + inference_mode (bool): Whether model is in inference mode (default: False). + """ + r: int = 8 lora_alpha: int = 32 target_modules: List[str] = field(default_factory=lambda: ["q_proj", "v_proj"]) - bias = "none" + bias: str = "none" task_type: str = "CAUSAL_LM" lora_dropout: float = 0.05 inference_mode: bool = False # should be False for finetuning @@ -25,6 +34,6 @@ class lora_config: # CAUTION prefix tuning is currently not supported @dataclass -class prefix_config: +class PrefixConfig: num_virtual_tokens: int = 30 task_type: str = "CAUSAL_LM" diff --git a/QEfficient/finetune/configs/training.py b/QEfficient/finetune/configs/training.py index c50954c4c..69b083b6a 100644 --- a/QEfficient/finetune/configs/training.py +++ b/QEfficient/finetune/configs/training.py @@ -7,8 +7,54 @@ from dataclasses import dataclass +# Configuration Classes @dataclass -class train_config: +class TrainConfig: + """Training configuration for model fine-tuning. + + Attributes: + model_name (str): Name of the pre-trained model to fine-tune (default: "meta-llama/Llama-3.2-1B"). + tokenizer_name (str): Name of the tokenizer (defaults to model_name if None). + run_validation (bool): Whether to run validation during training (default: True). + batch_size_training (int): Batch size for training (default: 1). + context_length (Optional[int]): Maximum sequence length for inputs (default: None). + gradient_accumulation_steps (int): Steps for gradient accumulation (default: 4). + gradient checkpointing (bool): Enable gradient checkpointing to save the memory by compromising the speed. (default: False). + num_epochs (int): Number of training epochs (default: 1). + max_train_step (int): Maximum training steps (default: 0, unlimited if 0). + max_eval_step (int): Maximum evaluation steps (default: 0, unlimited if 0). + device (str): Device to train on (default: "qaic"). + num_workers_dataloader (int): Number of workers for data loading (default: 1). + lr (float): Learning rate (default: 3e-4). + weight_decay (float): Weight decay for optimizer (default: 0.0). + gamma (float): Learning rate decay factor (default: 0.85). + seed (int): Random seed for reproducibility (default: 42). + use_fp16 (bool): Use mixed precision training (default: True). + use_autocast (bool): Use autocast for mixed precision (default: True). + val_batch_size (int): Batch size for validation (default: 1). + dataset (str): Dataset name for training (default: "samsum_dataset"). + task_type (str): Type of task for which the finetuning is to be done. Options: "generation" and "seq_classification". (default: "generation") + peft_method (str): Parameter-efficient fine-tuning method (default: "lora"). + use_peft (bool): Whether to use PEFT (default: True). + from_peft_checkpoint (str): Path to PEFT checkpoint (default: ""). + output_dir (str): Directory to save outputs (default: "meta-llama-samsum"). + num_freeze_layers (int): Number of layers to freeze (default: 1). + one_qaic (bool): Use single QAIC device (default: False). + save_model (bool): Save the trained model (default: True). + save_metrics (bool): Save training metrics (default: True). + intermediate_step_save (int): Steps between intermediate saves (default: 1000). + batching_strategy (str): Batching strategy (default: "packing"). + enable_sorting_for_ddp (bool): Sort data for DDP (default: True). + convergence_counter (int): Steps to check convergence (default: 5). + convergence_loss (float): Loss threshold for convergence (default: 1e-4). + use_profiler (bool): Enable profiling (default: False). + enable_ddp (bool): Enable distributed data parallel (default: False). + dist_backend (str): Backend for distributed training (default: "cpu:gloo,qaic:qccl,cuda:gloo"). + grad_scaler (bool): Use gradient scaler (default: True). + dump_root_dir (str): Directory for mismatch dumps (default: "meta-llama-samsum-mismatches/step_"). + opByOpVerifier (bool): Enable operation-by-operation verification (default: False). + """ + model_name: str = "meta-llama/Llama-3.2-1B" tokenizer_name: str = None # if not passed as an argument, it uses the value of model_name run_validation: bool = True diff --git a/QEfficient/finetune/eval.py b/QEfficient/finetune/eval.py index 918230554..3fe6e0d81 100644 --- a/QEfficient/finetune/eval.py +++ b/QEfficient/finetune/eval.py @@ -11,7 +11,6 @@ import fire import numpy as np import torch -from configs.training import train_config as TRAIN_CONFIG from peft import AutoPeftModelForCausalLM from transformers import AutoModelForCausalLM, AutoTokenizer from utils.config_utils import ( @@ -25,6 +24,8 @@ ) from utils.train_utils import evaluation, print_model_size +from QEfficient.finetune.configs.training import TrainConfig + try: import torch_qaic # noqa: F401 @@ -39,7 +40,7 @@ def main(**kwargs): # update the configuration for the training process - train_config = TRAIN_CONFIG() + train_config = TrainConfig() update_config(train_config, **kwargs) # Set the seeds for reproducibility diff --git a/QEfficient/finetune/utils/config_utils.py b/QEfficient/finetune/utils/config_utils.py index e979961d6..c5c7fe615 100644 --- a/QEfficient/finetune/utils/config_utils.py +++ b/QEfficient/finetune/utils/config_utils.py @@ -4,27 +4,39 @@ # SPDX-License-Identifier: BSD-3-Clause # # ----------------------------------------------------------------------------- - import inspect +import json +import os from dataclasses import asdict +from typing import Any, Dict import torch.distributed as dist import torch.utils.data as data_utils +import yaml from peft import ( AdaptionPromptConfig, - LoraConfig, PrefixTuningConfig, ) +from peft import LoraConfig as PeftLoraConfig from transformers.data import DataCollatorForSeq2Seq import QEfficient.finetune.configs.dataset_config as datasets -from QEfficient.finetune.configs.peft_config import lora_config, prefix_config -from QEfficient.finetune.configs.training import train_config +from QEfficient.finetune.configs.peft_config import LoraConfig, PrefixConfig +from QEfficient.finetune.configs.training import TrainConfig from QEfficient.finetune.data.sampler import DistributedLengthBasedBatchSampler from QEfficient.finetune.dataset.dataset_config import DATASET_PREPROC def update_config(config, **kwargs): + """Update the attributes of a config object based on provided keyword arguments. + + Args: + config: The configuration object (e.g., TrainConfig, LoraConfig) or a list/tuple of such objects. + **kwargs: Keyword arguments representing attributes to update. + + Raises: + ValueError: If an unknown parameter is provided and the config type doesn't support nested updates. + """ if isinstance(config, (tuple, list)): for c in config: update_config(c, **kwargs) @@ -33,40 +45,73 @@ def update_config(config, **kwargs): if hasattr(config, k): setattr(config, k, v) elif "." in k: - # allow --some_config.some_param=True - config_name, param_name = k.split(".") - if type(config).__name__ == config_name: + config_name, param_name = k.split(".", 1) + if type(config).__name__.lower() == config_name.lower(): if hasattr(config, param_name): setattr(config, param_name, v) else: - # In case of specialized config we can warn user - assert False, f"Warning: {config_name} does not accept parameter: {k}" - elif isinstance(config, train_config): - assert False, f"Warning: unknown parameter {k}" + raise ValueError(f"Config '{config_name}' does not have parameter: '{param_name}'") + else: + config_type = type(config).__name__ + # FIXME (Meet): Once logger is available put this in debug level. + print(f"[WARNING]: Unknown parameter '{k}' for config type '{config_type}'") -def generate_peft_config(train_config, kwargs): - configs = (lora_config, prefix_config) - peft_configs = (LoraConfig, AdaptionPromptConfig, PrefixTuningConfig) - names = tuple(c.__name__.rstrip("_config") for c in configs) +def generate_peft_config(train_config: TrainConfig, peft_config_file: str = None, **kwargs) -> Any: + """Generate a PEFT-compatible configuration from a custom config based on peft_method. - if train_config.peft_method not in names: - raise RuntimeError(f"Peft config not found: {train_config.peft_method}") + Args: + train_config (TrainConfig): Training configuration with peft_method. + custom_config: Custom configuration object (e.g., LoraConfig). - config = configs[names.index(train_config.peft_method)]() + Returns: + Any: A PEFT-specific configuration object (e.g., PeftLoraConfig). - update_config(config, **kwargs) - params = asdict(config) - peft_config = peft_configs[names.index(train_config.peft_method)](**params) + Raises: + RuntimeError: If the peft_method is not supported. + """ + if peft_config_file: + peft_config_data = load_config_file(peft_config_file) + validate_config(peft_config_data, config_type="lora") + peft_config = PeftLoraConfig(**peft_config_data) + else: + config_map = { + "lora": (LoraConfig, PeftLoraConfig), + "prefix": (PrefixConfig, PrefixTuningConfig), + "adaption_prompt": (None, AdaptionPromptConfig), + } + + if train_config.peft_method not in config_map: + raise RuntimeError(f"Peft config not found: {train_config.peft_method}") + + config_cls, peft_config_cls = config_map[train_config.peft_method] + if config_cls is None: + params = kwargs + else: + config = config_cls() + update_config(config, **kwargs) + params = asdict(config) + peft_config = peft_config_cls(**params) return peft_config -def generate_dataset_config(train_config, kwargs): - names = tuple(DATASET_PREPROC.keys()) - assert train_config.dataset in names, f"Unknown dataset: {train_config.dataset}" - dataset_config = {k: v for k, v in inspect.getmembers(datasets)}[train_config.dataset]() - update_config(dataset_config, **kwargs) +def generate_dataset_config(dataset_name: str) -> Any: + """Generate a dataset configuration based on the specified dataset. + + Args: + dataset_name (str): Name of the dataset to be used for finetuning. + + Returns: + Any: A dataset configuration object. + + Raises: + AssertionError: If the dataset name is not recognized. + """ + supported_datasets = DATASET_PREPROC.keys() + assert dataset_name in supported_datasets, f"Given dataset '{dataset_name}' is not supported." + # FIXME (Meet): Replace below logic by creating using auto registry of datasets. + dataset_config = {k: v for k, v in inspect.getmembers(datasets)}[dataset_name]() return dataset_config @@ -98,3 +143,84 @@ def get_dataloader_kwargs(train_config, dataset, dataset_processer, mode): kwargs["drop_last"] = True kwargs["collate_fn"] = DataCollatorForSeq2Seq(dataset_processer) return kwargs + + +def validate_config(config_data: Dict[str, Any], config_type: str = "lora") -> None: + """Validate the provided YAML/JSON configuration for required fields and types. + + Args: + config_data (Dict[str, Any]): The configuration dictionary loaded from YAML/JSON. + config_type (str): Type of config to validate ("lora" for LoraConfig, default: "lora"). + + Raises: + ValueError: If required fields are missing or have incorrect types. + FileNotFoundError: If the config file path is invalid (handled upstream). + + Notes: + - Validates required fields for LoraConfig: r, lora_alpha, target_modules. + - Ensures types match expected values (int, float, list, etc.). + """ + if config_type.lower() != "lora": + raise ValueError(f"Unsupported config_type: {config_type}. Only 'lora' is supported.") + + required_fields = { + "r": int, + "lora_alpha": int, + "target_modules": list, + } + optional_fields = { + "bias": str, + "task_type": str, + "lora_dropout": float, + "inference_mode": bool, + } + + # Check for missing required fields + missing_fields = [field for field in required_fields if field not in config_data] + if missing_fields: + raise ValueError(f"Missing required fields in {config_type} config: {missing_fields}") + + # Validate types of required fields + for field, expected_type in required_fields.items(): + if not isinstance(config_data[field], expected_type): + raise ValueError( + f"Field '{field}' in {config_type} config must be of type {expected_type.__name__}, " + f"got {type(config_data[field]).__name__}" + ) + + # Validate target_modules contains strings + if not all(isinstance(mod, str) for mod in config_data["target_modules"]): + raise ValueError("All elements in 'target_modules' must be strings") + + # Validate types of optional fields if present + for field, expected_type in optional_fields.items(): + if field in config_data and not isinstance(config_data[field], expected_type): + raise ValueError( + f"Field '{field}' in {config_type} config must be of type {expected_type.__name__}, " + f"got {type(config_data[field]).__name__}" + ) + + +def load_config_file(config_path: str) -> Dict[str, Any]: + """Load a configuration from a YAML or JSON file. + + Args: + config_path (str): Path to the YAML or JSON file. + + Returns: + Dict[str, Any]: The loaded configuration as a dictionary. + + Raises: + FileNotFoundError: If the file does not exist. + ValueError: If the file format is unsupported. + """ + if not os.path.exists(config_path): + raise FileNotFoundError(f"Config file not found: {config_path}") + + with open(config_path, "r") as f: + if config_path.endswith(".yaml") or config_path.endswith(".yml"): + return yaml.safe_load(f) + elif config_path.endswith(".json"): + return json.load(f) + else: + raise ValueError("Unsupported config file format. Use .yaml, .yml, or .json") diff --git a/QEfficient/finetune/utils/train_utils.py b/QEfficient/finetune/utils/train_utils.py index 2bc701008..8693ae32d 100644 --- a/QEfficient/finetune/utils/train_utils.py +++ b/QEfficient/finetune/utils/train_utils.py @@ -18,7 +18,7 @@ from torch.utils.tensorboard import SummaryWriter from tqdm import tqdm -from QEfficient.finetune.configs.training import train_config as TRAIN_CONFIG +from QEfficient.finetune.configs.training import TrainConfig try: import torch_qaic # noqa: F401 @@ -34,34 +34,31 @@ def train( model, + tokenizer, train_dataloader, eval_dataloader, - tokenizer, optimizer, lr_scheduler, - gradient_accumulation_steps, - train_config: TRAIN_CONFIG, - device, + train_config: TrainConfig, local_rank=None, - rank=None, ): """ Trains the model on the given dataloader Args: model: The model to be trained + tokenizer: tokenizer used in the eval for decoding the predicitons train_dataloader: The dataloader containing the training data + eval_dataloader: The dataloader containing the eval data optimizer: The optimizer used for training lr_scheduler: The learning rate scheduler - gradient_accumulation_steps: The number of steps to accumulate gradients before performing a backward/update operation - num_epochs: The number of epochs to train for - local_rank: The rank of the current node in a distributed setting train_config: The training configuration - eval_dataloader: The dataloader containing the eval data - tokenizer: tokenizer used in the eval for decoding the predicitons + local_rank: The rank of the current node in a distributed setting Returns: results dictionary containing average training and validation perplexity and loss """ + device = train_config.device + train_metric = [] train_loss = [] val_metric = [] @@ -461,7 +458,7 @@ def evaluation_helper(model, train_config, eval_dataloader, device): # Print evaluation metrics print(f" {eval_metric.detach().cpu()=} {eval_epoch_loss.detach().cpu()=}") - return eval_metric, eval_epoch_loss, val_step_loss, val_step_metric + return eval_epoch_loss, eval_metric, val_step_loss, val_step_metric def get_longest_seq_length(data: List[Dict]) -> Tuple[int, int]: diff --git a/QEfficient/transformers/models/gemma3/modeling_gemma3.py b/QEfficient/transformers/models/gemma3/modeling_gemma3.py index 58b837e9c..70601489d 100644 --- a/QEfficient/transformers/models/gemma3/modeling_gemma3.py +++ b/QEfficient/transformers/models/gemma3/modeling_gemma3.py @@ -560,16 +560,9 @@ def __init__(self, model): self.model = model self.model.vision_model = self.model.vision_tower - def forward(self, input_ids, pixel_values): - inputs_embeds = self.model.get_input_embeddings()(input_ids) - B, N, C = inputs_embeds.shape + def forward(self, pixel_values): image_features = self.model.get_image_features(pixel_values=pixel_values) - selected = input_ids == self.model.config.image_token_index - indices1 = selected.to(torch.int64).cumsum(1) - 1 - indices0 = torch.arange(selected.unsqueeze(0).shape[0]).view(-1, 1) - image_features_expanded = image_features.reshape(-1, C).unsqueeze(0)[indices0, indices1] - image_input_embeds = torch.where(selected.unsqueeze(-1), image_features_expanded, inputs_embeds) - return image_input_embeds + return image_features class QEffGemma3DecoderWrapper(nn.Module): @@ -579,14 +572,21 @@ def __init__(self, model): self.language_model = self.model.language_model self.config = self.model.config - def forward(self, input_ids, vision_embeds, position_ids, past_key_values): - image_embeds = vision_embeds[:, : input_ids.shape[1], :] - inputs_embeds = self.model.language_model.get_input_embeddings()(input_ids) - inputs_embeds = torch.where(input_ids.shape[1] == torch.tensor(1), inputs_embeds, image_embeds) + def forward(self, input_ids, vision_embeds, position_ids, index, past_key_values): + inputs_embeds = self.model.get_input_embeddings()(input_ids) + B, N, C = inputs_embeds.shape + selected = input_ids == self.model.config.image_token_index + indices1 = selected.to(torch.int64).cumsum(1) - 1 + indices1 = torch.where(indices1 != -1, indices1 + index, indices1) + indices0 = torch.arange(selected.unsqueeze(0).shape[0]).view(-1, 1) + image_features_expanded = vision_embeds.reshape(-1, C).unsqueeze(0)[indices0, indices1] + image_input_embeds = torch.where(selected.unsqueeze(-1), image_features_expanded, inputs_embeds) + inputs_embeds = torch.where(input_ids.shape[1] == torch.tensor(1), inputs_embeds, image_input_embeds) outputs = self.model.language_model( inputs_embeds=inputs_embeds, position_ids=position_ids, past_key_values=past_key_values, use_cache=True ) - return outputs.logits, vision_embeds, outputs.past_key_values + index = (indices1.max() + 1).unsqueeze(0).unsqueeze(0) + return outputs.logits, vision_embeds, index, outputs.past_key_values class QEffGemma3ForConditionalGeneration(Gemma3ForConditionalGeneration): @@ -605,11 +605,6 @@ def get_specializations( kv_offload: bool = False, **compiler_options, ): - vision_seq_len = compiler_options.pop("vision_seq_len", None) - if vision_seq_len is None: - # TODO: Check properly for Gemma3, Not verified yet. - vision_seq_len = 512 # for Gemma3 Vision feature shape is (1, 4096, 1152) --> 1152 is hidden size) - prefill_seq_len = prefill_seq_len if prefill_seq_len else 32 ctx_len = ctx_len if ctx_len else constants.INTERN_CTX_LEN if img_size is None and hasattr(self.config.vision_config, "image_size"): @@ -617,12 +612,13 @@ def get_specializations( elif img_size is None: img_size = 896 # FIXME based on gemma3 Image size logger.warning("Setting img_size to be 336, as it was neither passed nor found in vision_config") + mm_tokens_per_image = getattr(self.config, "mm_tokens_per_image", 256) vision = [ { "batch_size": batch_size, "img_size": img_size, - "seq_len": vision_seq_len, + "seq_len": prefill_seq_len, "ctx_len": ctx_len, } ] @@ -632,14 +628,14 @@ def get_specializations( "seq_len": prefill_seq_len, "ctx_len": ctx_len, "img_size": img_size, - "chunk_length": prefill_seq_len, + "mm_tokens_per_image": mm_tokens_per_image, }, { "batch_size": batch_size, "seq_len": "1", "ctx_len": ctx_len, "img_size": img_size, - "chunk_length": prefill_seq_len, + "mm_tokens_per_image": mm_tokens_per_image, }, ] @@ -658,9 +654,8 @@ def get_onnx_dynamic_axes(self, kv_offload: bool = False): lang_dynamic_axes = {} lang_dynamic_axes["input_ids"] = {0: "batch_size", 1: "seq_len"} lang_dynamic_axes["position_ids"] = {0: "batch_size", 1: "seq_len"} - lang_dynamic_axes["vision_embeds"] = {0: "batch_size", 1: "chunk_length"} + lang_dynamic_axes["vision_embeds"] = {0: "batch_size", 1: "mm_tokens_per_image"} vision_dynamic_axes["pixel_values"] = {0: "batch_size", 2: "img_size", 3: "img_size"} - vision_dynamic_axes["input_ids"] = {0: "batch_size", 1: "seq_len"} pkv_dynamic_axes = {0: "batch_size", 2: "ctx_len"} for i in range(self.language_model.config.num_hidden_layers): @@ -685,6 +680,7 @@ def get_output_names(self, kv_offload: bool = False): output_names = {} if kv_offload: lang_output_names.insert(1, "vision_embeds_RetainedState") + lang_output_names.insert(2, "index_output") output_names["vision"] = vision_output_names output_names["lang"] = lang_output_names else: @@ -698,12 +694,13 @@ def get_dummy_inputs(self, kv_offload: bool = False): else: img_size = 896 + mm_tokens_per_image = getattr(self.config, "mm_tokens_per_image", 256) # Define shapes inputs_shapes = {} inputs_shapes["input_ids"] = (constants.ONNX_EXPORT_EXAMPLE_BATCH_SIZE, constants.ONNX_EXPORT_EXAMPLE_SEQ_LEN) inputs_shapes["vision_embeds"] = ( 1, # constants.INTERN_NUM_PATCHES, - constants.ONNX_EXPORT_EXAMPLE_SEQ_LEN, # constants.INTERN_FEATURE_SIZE, + mm_tokens_per_image, # constants.INTERN_FEATURE_SIZE, self.language_model.config.hidden_size, # 5120 ) inputs_shapes["position_ids"] = ( @@ -716,12 +713,12 @@ def get_dummy_inputs(self, kv_offload: bool = False): img_size, img_size, ) + inputs_shapes["index"] = (1, 1) # Define inputs vision_inputs = {} lang_inputs = {} vision_inputs["pixel_values"] = torch.zeros((inputs_shapes["pixel_values"]), dtype=torch.float32) - vision_inputs["input_ids"] = torch.zeros((inputs_shapes["input_ids"]), dtype=torch.int64) lang_inputs["input_ids"] = torch.zeros((inputs_shapes["input_ids"]), dtype=torch.int64) lang_inputs["vision_embeds"] = torch.zeros((inputs_shapes["vision_embeds"]), dtype=torch.float32) lang_inputs["position_ids"] = ( @@ -729,7 +726,7 @@ def get_dummy_inputs(self, kv_offload: bool = False): .view(1, constants.ONNX_EXPORT_EXAMPLE_SEQ_LEN) .repeat(constants.ONNX_EXPORT_EXAMPLE_BATCH_SIZE, 1) ) - + lang_inputs["index"] = torch.zeros((inputs_shapes["index"]), dtype=torch.int64) # Add data for KV kv_cache_shape = get_padding_shape_from_config( config=self.language_model.config, diff --git a/QEfficient/transformers/models/modeling_auto.py b/QEfficient/transformers/models/modeling_auto.py index 1a9610187..ebfd529cc 100644 --- a/QEfficient/transformers/models/modeling_auto.py +++ b/QEfficient/transformers/models/modeling_auto.py @@ -751,8 +751,8 @@ def kv_offload_generate( input_len = inputs["attention_mask"].sum(1, keepdims=True) input_ids_length = inputs["input_ids"].shape[1] num_chunks = -(input_ids_length // -prefill_seq_len) # ceil divide without float - # padded_len = num_chunks * prefill_seq_len # Convert to a multiple of prompt_len - padded_len = vision_session.bindings[vision_session.binding_index_map["input_ids"]].dims[1] + padded_len = num_chunks * prefill_seq_len # Convert to a multiple of prompt_len + if generation_len is None: generation_len = ctx_len - input_len.max() assert generation_len > 0, "generation length should be greater than zero" @@ -783,13 +783,11 @@ def kv_offload_generate( } vision_inputs["pixel_values"] = vision_inputs["pixel_values"].astype("float16") - vision_inputs["input_ids"] = inputs["input_ids"] vision_start = perf_counter() vision_outputs = vision_session.run(vision_inputs) vision_end = perf_counter() lang_inputs = {k: v for k, v in inputs.items() if k not in vision_inputs} - lang_inputs["input_ids"] = inputs["input_ids"] lang_inputs["position_ids"] = np.where( lang_inputs.pop("attention_mask"), np.arange(padded_len), -1 ) # Need to use -1 as position_ids for invalid tokens @@ -797,25 +795,27 @@ def kv_offload_generate( vision_session.deactivate() lang_session.activate() lang_inputs["vision_embeds"] = vision_outputs["vision_embeds"] - # lang_session.set_buffers(vision_outputs) + lang_session.set_buffers(vision_outputs) prefill_start = perf_counter() # Run prefill + chunk_inputs = lang_inputs.copy() + chunk_inputs["index"] = np.array([[0]]) for i in range(num_chunks): - chunk_inputs = lang_inputs.copy() chunk_inputs["input_ids"] = lang_inputs["input_ids"][:, i * prefill_seq_len : (i + 1) * prefill_seq_len] chunk_inputs["position_ids"] = lang_inputs["position_ids"][ :, i * prefill_seq_len : (i + 1) * prefill_seq_len ] - chunk_inputs["vision_embeds"] = lang_inputs["vision_embeds"][ - :, i * prefill_seq_len : (i + 1) * prefill_seq_len - ] outputs = lang_session.run(chunk_inputs) + chunk_inputs["index"] = outputs["index_output"] prefill_time = perf_counter() - prefill_start + vision_end - vision_start - lang_inputs["vision_embeds"] = lang_inputs["vision_embeds"][:, :prefill_seq_len] # Skip inputs/outputs again lang_session.skip_buffers( - [x for x in lang_session.input_names + lang_session.output_names if x.startswith("past_")] + [ + x + for x in lang_session.input_names + lang_session.output_names + if x.startswith("past_") or x.endswith("_RetainedState") + ] ) # Get first token @@ -1643,6 +1643,11 @@ def compile( **compiler_options, ) + if compiler_options.get("io_encrypt", None): + logger.warning( + "Compilation for IO-Encrypt has been successfully completed. However, Efficient-Transformers do not support IO-Encrypt execution. Please run the execution separately with QPC compiled without io-encrypt." + ) + return qpc_path # FIXME: Update this method to match with transformers AutoModelForCausalLM.generate diff --git a/QEfficient/utils/_utils.py b/QEfficient/utils/_utils.py index b6af66be5..564bdd94d 100644 --- a/QEfficient/utils/_utils.py +++ b/QEfficient/utils/_utils.py @@ -521,27 +521,57 @@ def __repr__(self): def dump_qconfig(func): def wrapper(self, *args, **kwargs): result = func(self, *args, **kwargs) - create_and_dump_qconfigs( - self.qpc_path, - self.onnx_path, - self.get_model_config, - [cls.__name__ for cls in self._pytorch_transforms], - [cls.__name__ for cls in self._onnx_transforms], - kwargs.get("specializations"), - kwargs.get("mdp_ts_num_devices", 1), - kwargs.get("num_speculative_tokens"), - **{ - k: v - for k, v in kwargs.items() - if k - not in ["specializations", "mdp_ts_num_devices", "num_speculative_tokens", "custom_io", "onnx_path"] - }, - ) + try: + create_and_dump_qconfigs( + self.qpc_path, + self.onnx_path, + self.get_model_config, + [cls.__name__ for cls in self._pytorch_transforms], + [cls.__name__ for cls in self._onnx_transforms], + kwargs.get("specializations"), + kwargs.get("mdp_ts_num_devices", 1), + kwargs.get("num_speculative_tokens"), + **{ + k: v + for k, v in kwargs.items() + if k + not in ["specializations", "mdp_ts_num_devices", "num_speculative_tokens", "custom_io", "onnx_path"] + }, + ) + except Exception as e: + print(f"An unexpected error occurred while dumping the qconfig: {e}") return result return wrapper +def get_qaic_sdk_version(qaic_sdk_xml_path: str) -> Optional[str]: + """ + Extracts the QAIC SDK version from the given SDK XML file. + + Args: + qaic_sdk_xml_path (str): Path to the SDK XML file. + Returns: + The SDK version as a string if found, otherwise None. + """ + qaic_sdk_version = None + + # Check and extract version from the given SDK XML file + if os.path.exists(qaic_sdk_xml_path): + try: + tree = ET.parse(qaic_sdk_xml_path) + root = tree.getroot() + base_version_element = root.find(".//base_version") + if base_version_element is not None: + qaic_sdk_version = base_version_element.text + except ET.ParseError as e: + print(f"Error parsing XML file {qaic_sdk_xml_path}: {e}") + except Exception as e: + print(f"An unexpected error occurred while processing {qaic_sdk_xml_path}: {e}") + + return qaic_sdk_version + + def create_and_dump_qconfigs( qpc_path, onnx_path, @@ -558,29 +588,12 @@ def create_and_dump_qconfigs( Such as huggingface configs, QEff transforms, QAIC sdk version, QNN sdk, compilation dir, qpc dir and many other compilation options. """ - qnn_config = compiler_options["qnn_config"] if "qnn_config" in compiler_options else None - enable_qnn = True if "qnn_config" in compiler_options else None - + enable_qnn = compiler_options.get("enable_qnn", False) + qnn_config_path = compiler_options.get("qnn_config", None) qconfig_file_path = os.path.join(os.path.dirname(qpc_path), "qconfig.json") onnx_path = str(onnx_path) specializations_file_path = str(os.path.join(os.path.dirname(qpc_path), "specializations.json")) compile_dir = str(os.path.dirname(qpc_path)) - qnn_config_path = ( - (qnn_config if qnn_config is not None else "QEfficient/compile/qnn_config.json") if enable_qnn else None - ) - - # Extract QAIC SDK Apps Version from SDK XML file - tree = ET.parse(Constants.SDK_APPS_XML) - root = tree.getroot() - qaic_version = root.find(".//base_version").text - - # Extract QNN SDK details from YAML file if the environment variable is set - qnn_sdk_details = None - qnn_sdk_path = os.getenv(QnnConstants.QNN_SDK_PATH_ENV_VAR_NAME) - if enable_qnn and qnn_sdk_path: - qnn_sdk_yaml_path = os.path.join(qnn_sdk_path, QnnConstants.QNN_SDK_YAML) - with open(qnn_sdk_yaml_path, "r") as file: - qnn_sdk_details = yaml.safe_load(file) # Ensure all objects in the configs dictionary are JSON serializable def make_serializable(obj): @@ -602,29 +615,38 @@ def make_serializable(obj): "onnx_transforms": make_serializable(onnx_transforms), "onnx_path": onnx_path, }, + "compiler_config": { + "enable_qnn": enable_qnn, + "compile_dir": compile_dir, + "specializations_file_path": specializations_file_path, + "specializations": make_serializable(specializations), + "mdp_ts_num_devices": mdp_ts_num_devices, + "num_speculative_tokens": num_speculative_tokens, + **compiler_options, + }, + "aic_sdk_config": { + "qaic_apps_version": get_qaic_sdk_version(Constants.SDK_APPS_XML), + "qaic_platform_version": get_qaic_sdk_version(Constants.SDK_PLATFORM_XML), + }, }, } - aic_compiler_config = { - "apps_sdk_version": qaic_version, - "compile_dir": compile_dir, - "specializations_file_path": specializations_file_path, - "specializations": make_serializable(specializations), - "mdp_ts_num_devices": mdp_ts_num_devices, - "num_speculative_tokens": num_speculative_tokens, - **compiler_options, - } - qnn_config = { - "enable_qnn": enable_qnn, - "qnn_config_path": qnn_config_path, - } - # Put AIC or qnn details. if enable_qnn: + qnn_sdk_path = os.getenv(QnnConstants.QNN_SDK_PATH_ENV_VAR_NAME) + if not qnn_sdk_path: + raise EnvironmentError( + f"QNN_SDK_PATH {qnn_sdk_path} is not set. Please set {QnnConstants.QNN_SDK_PATH_ENV_VAR_NAME}" + ) + qnn_sdk_yaml_path = os.path.join(qnn_sdk_path, QnnConstants.QNN_SDK_YAML) + qnn_sdk_details = load_yaml( + qnn_sdk_yaml_path + ) # Extract QNN SDK details from YAML file if the environment variable is set + qnn_config = { + "qnn_config_path": qnn_config_path, + } qconfigs["qpc_config"]["qnn_config"] = qnn_config if qnn_sdk_details: qconfigs["qpc_config"]["qnn_config"].update(qnn_sdk_details) - else: - qconfigs["qpc_config"]["aic_compiler_config"] = aic_compiler_config create_json(qconfig_file_path, qconfigs) diff --git a/QEfficient/utils/constants.py b/QEfficient/utils/constants.py index b1ff9701e..c8f74907a 100644 --- a/QEfficient/utils/constants.py +++ b/QEfficient/utils/constants.py @@ -97,7 +97,10 @@ class Constants: MAX_QPC_LIMIT = 30 MAX_RETRIES = 10 # This constant will be used set the maximum number of retry attempts for downloading a model using huggingface_hub snapshot_download NUM_SPECULATIVE_TOKENS = 2 - SDK_APPS_XML = "/opt/qti-aic/versions/apps.xml" # This xml file is parsed to find out the SDK version. + SDK_APPS_XML = "/opt/qti-aic/versions/apps.xml" # This xml file is parsed to find out the SDK apps version. + SDK_PLATFORM_XML = ( + "/opt/qti-aic/versions/platform.xml" # This xml file is parsed to find out the SDK platform version. + ) @dataclass diff --git a/scripts/Jenkinsfile b/scripts/Jenkinsfile index fcd2fece5..7036d6f6d 100644 --- a/scripts/Jenkinsfile +++ b/scripts/Jenkinsfile @@ -171,4 +171,4 @@ pipeline { deleteDir() } } -} +} \ No newline at end of file diff --git a/scripts/finetune/run_ft_model.py b/scripts/finetune/run_ft_model.py index 5e88db641..ef014923b 100644 --- a/scripts/finetune/run_ft_model.py +++ b/scripts/finetune/run_ft_model.py @@ -12,7 +12,7 @@ from peft import AutoPeftModelForCausalLM from transformers import AutoModelForCausalLM, AutoTokenizer -from QEfficient.finetune.configs.training import train_config as TRAIN_CONFIG +from QEfficient.finetune.configs.training import TrainConfig # Suppress all warnings warnings.filterwarnings("ignore") @@ -25,7 +25,7 @@ print(f"Warning: {e}. Moving ahead without these qaic modules.") device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") -train_config = TRAIN_CONFIG() +train_config = TrainConfig() model = AutoModelForCausalLM.from_pretrained( train_config.model_name, use_cache=False, diff --git a/tests/finetune/test_finetune.py b/tests/finetune/test_finetune.py index 45330cad6..fb4a84dc0 100644 --- a/tests/finetune/test_finetune.py +++ b/tests/finetune/test_finetune.py @@ -8,6 +8,7 @@ import os import shutil +import numpy as np import pytest import torch.optim as optim from torch.utils.data import DataLoader @@ -22,12 +23,25 @@ def clean_up(path): shutil.rmtree(path) -configs = [pytest.param("meta-llama/Llama-3.2-1B", 1, 1, 1, None, True, True, "cpu", id="llama_config")] +configs = [ + pytest.param( + "meta-llama/Llama-3.2-1B", # model_name + 10, # max_eval_step + 20, # max_train_step + 1, # intermediate_step_save + None, # context_length + True, # run_validation + True, # use_peft + "qaic", # device + id="llama_config", # config name + ) +] -# TODO:enable this once docker is available +@pytest.mark.skip(reason="Currently CI is broken. Once it is fixed we will enable this test.") +@pytest.mark.cli @pytest.mark.on_qaic -@pytest.mark.skip(reason="eager docker not available in sdk") +@pytest.mark.finetune @pytest.mark.parametrize( "model_name,max_eval_step,max_train_step,intermediate_step_save,context_length,run_validation,use_peft,device", configs, @@ -43,7 +57,7 @@ def test_finetune( device, mocker, ): - train_config_spy = mocker.spy(QEfficient.cloud.finetune, "TRAIN_CONFIG") + train_config_spy = mocker.spy(QEfficient.cloud.finetune, "TrainConfig") generate_dataset_config_spy = mocker.spy(QEfficient.cloud.finetune, "generate_dataset_config") generate_peft_config_spy = mocker.spy(QEfficient.cloud.finetune, "generate_peft_config") get_dataloader_kwargs_spy = mocker.spy(QEfficient.cloud.finetune, "get_dataloader_kwargs") @@ -65,23 +79,28 @@ def test_finetune( "device": device, } - finetune(**kwargs) + results = finetune(**kwargs) + assert np.allclose(results["avg_train_loss"], 0.00232327, atol=1e-5), "Train loss is not matching." + assert np.allclose(results["avg_train_metric"], 1.002326, atol=1e-5), "Train metric is not matching." + assert np.allclose(results["avg_eval_loss"], 0.0206124, atol=1e-5), "Eval loss is not matching." + assert np.allclose(results["avg_eval_metric"], 1.020826, atol=1e-5), "Eval metric is not matching." + assert results["avg_epoch_time"] < 60, "Training should complete within 60 seconds." train_config_spy.assert_called_once() generate_dataset_config_spy.assert_called_once() generate_peft_config_spy.assert_called_once() - update_config_spy.assert_called_once() get_custom_data_collator_spy.assert_called_once() get_longest_seq_length_spy.assert_called_once() print_model_size_spy.assert_called_once() train_spy.assert_called_once() + assert update_config_spy.call_count == 2 assert get_dataloader_kwargs_spy.call_count == 2 assert get_preprocessed_dataset_spy.call_count == 2 args, kwargs = train_spy.call_args - train_dataloader = args[1] - eval_dataloader = args[2] + train_dataloader = args[2] + eval_dataloader = args[3] optimizer = args[4] batch = next(iter(train_dataloader)) @@ -97,12 +116,19 @@ def test_finetune( else: assert eval_dataloader is None - args, kwargs = update_config_spy.call_args + args, kwargs = update_config_spy.call_args_list[0] train_config = args[0] + assert max_train_step >= train_config.gradient_accumulation_steps, ( + "Total training step should be more than " + f"{train_config.gradient_accumulation_steps} which is gradient accumulation steps." + ) - saved_file = os.path.join(train_config.output_dir, "adapter_model.safetensors") + saved_file = os.path.join(train_config.output_dir, "complete_epoch_1/adapter_model.safetensors") assert os.path.isfile(saved_file) clean_up(train_config.output_dir) clean_up("runs") clean_up(train_config.dump_root_dir) + + +# TODO (Meet): Add seperate tests for BERT FT and LLama FT diff --git a/tests/transformers/spd/test_pld_inference.py b/tests/transformers/spd/test_pld_inference.py index c80fe5969..71b4e01cd 100644 --- a/tests/transformers/spd/test_pld_inference.py +++ b/tests/transformers/spd/test_pld_inference.py @@ -262,7 +262,7 @@ def test_pld_spec_decode_inference( num_speculative_tokens=num_speculative_tokens, ) # init qaic session - target_model_session = QAICInferenceSession(target_model_qpc_path, device_ids=device_group) + target_model_session = QAICInferenceSession(target_model_qpc_path) draft_model_session = None # skip inputs/outputs buffers @@ -453,7 +453,7 @@ def test_pld_spec_decode_inference( del draft_model_session generated_ids = np.asarray(generated_ids[0]).flatten() gen_len = generated_ids.shape[0] - exec_info = target_model.generate(tokenizer, Constants.INPUT_STR, device_group) + exec_info = target_model.generate(tokenizer, Constants.INPUT_STR) cloud_ai_100_tokens = exec_info.generated_ids[0][ :gen_len ] # Because we always run for single input and single batch size diff --git a/tests/transformers/spd/test_spd_inference.py b/tests/transformers/spd/test_spd_inference.py index 6f6bdb268..e87c51d5f 100644 --- a/tests/transformers/spd/test_spd_inference.py +++ b/tests/transformers/spd/test_spd_inference.py @@ -157,8 +157,8 @@ def test_spec_decode_inference( full_batch_size=full_batch_size, ) # init qaic session - target_model_session = QAICInferenceSession(target_model_qpc_path, device_ids=device_group) - draft_model_session = QAICInferenceSession(draft_model_qpc_path, device_ids=device_group) + target_model_session = QAICInferenceSession(target_model_qpc_path) + draft_model_session = QAICInferenceSession(draft_model_qpc_path) # skip inputs/outputs buffers target_model_session.skip_buffers(set([x for x in target_model_session.input_names if x.startswith("past_")])) @@ -341,7 +341,7 @@ def test_spec_decode_inference( del draft_model_session generated_ids = np.asarray(generated_ids[0]).flatten() gen_len = generated_ids.shape[0] - exec_info = draft_model.generate(tokenizer, Constants.INPUT_STR, device_group) + exec_info = draft_model.generate(tokenizer, Constants.INPUT_STR) cloud_ai_100_tokens = exec_info.generated_ids[0][ :gen_len ] # Because we always run for single input and single batch size From e532d74a185b13dc5df783e9db1465a1563f016a Mon Sep 17 00:00:00 2001 From: vbaddi Date: Wed, 14 May 2025 09:29:59 +0000 Subject: [PATCH 6/7] Revert "Gemma3 Adding Merging and Chunking in DecoderWrapper (#402)" This reverts commit 70ae12f06782c868438024cd22f900c871d64dca. Signed-off-by: Mohit Soni --- QEfficient/cloud/compile.py | 30 +-- QEfficient/cloud/finetune.py | 229 +++++------------- QEfficient/cloud/infer.py | 4 - QEfficient/compile/compile_helper.py | 19 +- QEfficient/finetune/configs/peft_config.py | 21 +- QEfficient/finetune/configs/training.py | 48 +--- QEfficient/finetune/eval.py | 5 +- QEfficient/finetune/utils/config_utils.py | 178 ++------------ QEfficient/finetune/utils/train_utils.py | 21 +- .../models/gemma3/modeling_gemma3.py | 51 ++-- .../transformers/models/modeling_auto.py | 27 +-- QEfficient/utils/_utils.py | 124 ++++------ QEfficient/utils/constants.py | 5 +- scripts/Jenkinsfile | 2 +- scripts/finetune/run_ft_model.py | 4 +- tests/finetune/test_finetune.py | 46 +--- tests/transformers/spd/test_pld_inference.py | 4 +- tests/transformers/spd/test_spd_inference.py | 6 +- 18 files changed, 228 insertions(+), 596 deletions(-) diff --git a/QEfficient/cloud/compile.py b/QEfficient/cloud/compile.py index 5f0b9140c..8b6da5b0b 100644 --- a/QEfficient/cloud/compile.py +++ b/QEfficient/cloud/compile.py @@ -85,29 +85,17 @@ parser.add_argument( "--enable_qnn", "--enable-qnn", - nargs="?", - const=True, - type=str, + action="store_true", default=False, help="Enables QNN. Optionally, a configuration file can be provided with [--enable_qnn CONFIG_FILE].\ If not provided, the default configuration will be used.\ Sample Config: QEfficient/compile/qnn_config.json", ) - - args, compiler_options = parser.parse_known_args() - - if isinstance(args.enable_qnn, str): - args.qnn_config = args.enable_qnn - args.enable_qnn = True - - compiler_options_dict = {} - for i in range(0, len(compiler_options)): - if compiler_options[i].startswith("--"): - key = compiler_options[i].lstrip("-").replace("-", "_") - value = ( - compiler_options[i + 1] - if i + 1 < len(compiler_options) and not compiler_options[i + 1].startswith("-") - else True - ) - compiler_options_dict[key] = value - QEfficient.compile(**args.__dict__, **compiler_options_dict) + parser.add_argument( + "qnn_config", + nargs="?", + type=str, + ) + # FIXME(ochougul): Allow extra compilation arguments + args = parser.parse_args() + QEfficient.compile(**vars(args)) diff --git a/QEfficient/cloud/finetune.py b/QEfficient/cloud/finetune.py index c440e73c0..f312d00cb 100644 --- a/QEfficient/cloud/finetune.py +++ b/QEfficient/cloud/finetune.py @@ -7,7 +7,6 @@ import random import warnings -from typing import Any, Dict, Optional, Union import fire import numpy as np @@ -18,9 +17,8 @@ import torch.utils.data from peft import PeftModel, get_peft_model from torch.optim.lr_scheduler import StepLR -from transformers import AutoModel, AutoModelForCausalLM, AutoTokenizer -from QEfficient.finetune.configs.training import TrainConfig +from QEfficient.finetune.configs.training import train_config as TRAIN_CONFIG from QEfficient.finetune.utils.config_utils import ( generate_dataset_config, generate_peft_config, @@ -34,81 +32,52 @@ from QEfficient.finetune.utils.train_utils import get_longest_seq_length, print_model_size, train from QEfficient.utils._utils import login_and_download_hf_lm -# Try importing QAIC-specific module, proceed without it if unavailable try: import torch_qaic # noqa: F401 except ImportError as e: - print(f"Warning: {e}. Proceeding without QAIC modules.") + print(f"Warning: {e}. Moving ahead without these qaic modules.") -from transformers import AutoModelForSequenceClassification +from transformers import AutoModelForCausalLM, AutoModelForSequenceClassification, AutoTokenizer # Suppress all warnings warnings.filterwarnings("ignore") -def setup_distributed_training(train_config: TrainConfig) -> None: - """Initialize distributed training environment if enabled. - - Args: - train_config (TrainConfig): Training configuration object. - - Notes: - - If distributed data parallel (DDP) is disabled, this function does nothing. - - Ensures the device is not CPU and does not specify an index for DDP compatibility. - - Initializes the process group using the specified distributed backend. - - Raises: - AssertionError: If device is CPU or includes an index with DDP enabled. +def main(**kwargs): """ - if not train_config.enable_ddp: - return + Helper function to finetune the model on QAic. - torch_device = torch.device(train_config.device) - assert torch_device.type != "cpu", "Host doesn't support single-node DDP" - assert torch_device.index is None, f"DDP requires only device type, got: {torch_device}" + .. code-block:: bash - dist.init_process_group(backend=train_config.dist_backend) - # from here onward "qaic/cuda" will automatically map to "qaic:i/cuda:i", where i = process rank - getattr(torch, torch_device.type).set_device(dist.get_rank()) + python -m QEfficient.cloud.finetune OPTIONS + """ + # update the configuration for the training process + train_config = TRAIN_CONFIG() + update_config(train_config, **kwargs) + dataset_config = generate_dataset_config(train_config, kwargs) + device = train_config.device -def setup_seeds(seed: int) -> None: - """Set random seeds across libraries for reproducibility. + # dist init + if train_config.enable_ddp: + # TODO: may have to init qccl backend, next try run with torchrun command + torch_device = torch.device(device) + assert torch_device.type != "cpu", "Host doesn't support single-node DDP" + assert torch_device.index is None, ( + f"DDP requires specification of device type only, however provided device index as well: {torch_device}" + ) + dist.init_process_group(backend=train_config.dist_backend) + # from here onward "qaic/cuda" will automatically map to "qaic:i/cuda:i", where i = process rank + getattr(torch, torch_device.type).set_device(dist.get_rank()) - Args: - seed (int): Seed value to set for random number generators. + # Set the seeds for reproducibility + torch.manual_seed(train_config.seed) + random.seed(train_config.seed) + np.random.seed(train_config.seed) - Notes: - - Sets seeds for PyTorch, Python's random module, and NumPy. - """ - torch.manual_seed(seed) - random.seed(seed) - np.random.seed(seed) - - -def load_model_and_tokenizer( - train_config: TrainConfig, dataset_config: Any, peft_config_file: str, **kwargs -) -> tuple[AutoModelForCausalLM, AutoTokenizer]: - """Load the pre-trained model and tokenizer from Hugging Face. - - Args: - config (TrainConfig): Training configuration object containing model and tokenizer names. - dataset_config (Any): A dataclass object representing dataset configuration. - peft_config_file (str): Path to PEFT config file used for PEFT finetuning. - kwargs: Additional arguments to override PEFT config. - - Returns: - tuple: A tuple of two values. - - Model with pretrained weights loaded. - - Model's tokenizer (AutoTokenizer). - - Notes: - - Downloads the model if not already cached using login_and_download_hf_lm. - - Configures the model with FP16 precision and disables caching for training. - - Resizes model embeddings if tokenizer vocab size exceeds model embedding size. - - Sets pad_token_id to eos_token_id if not defined in the tokenizer. - """ + # Load the pre-trained model and setup its configuration + # config = AutoConfig.from_pretrained(train_config.model_name) pretrained_model_path = login_and_download_hf_lm(train_config.model_name) if train_config.task_type == "seq_classification": model = AutoModelForSequenceClassification.from_pretrained( @@ -135,6 +104,7 @@ def load_model_and_tokenizer( torch_dtype=torch.float16, ) + # Load the tokenizer and add special tokens tokenizer = AutoTokenizer.from_pretrained( train_config.model_name if train_config.tokenizer_name is None else train_config.tokenizer_name ) @@ -144,12 +114,14 @@ def load_model_and_tokenizer( # If there is a mismatch between tokenizer vocab size and embedding matrix, # throw a warning and then expand the embedding matrix if len(tokenizer) > model.get_input_embeddings().weight.shape[0]: - print("WARNING: Resizing embedding matrix to match tokenizer vocab size.") + print("WARNING: Resizing the embedding matrix to match the tokenizer vocab size.") model.resize_token_embeddings(len(tokenizer)) - # FIXME (Meet): Cover below line inside the logger once it is implemented. print_model_size(model, train_config) + # print the datatype of the model parameters + # print(get_parameter_dtypes(model)) + # Note: Need to call this before calling PeftModel.from_pretrained or get_peft_model. # Because, both makes model.is_gradient_checkpointing = True which is used in peft library to # apply gradient checkpointing related hooks to the input embeddings. Without this we will get @@ -162,70 +134,17 @@ def load_model_and_tokenizer( else: raise RuntimeError("Given model doesn't support gradient checkpointing. Please disable it and run it.") - model = apply_peft(model, train_config, peft_config_file, **kwargs) - - return model, tokenizer - - -def apply_peft( - model: AutoModel, train_config: TrainConfig, peft_config_file: Dict, **kwargs -) -> Union[AutoModel, PeftModel]: - """Apply Parameter-Efficient Fine-Tuning (PEFT) to the model if enabled. - - Args: - model (AutoModel): Huggingface model. - train_config (TrainConfig): Training configuration object. - peft_config_file (str, optional): Path to YAML/JSON file containing - PEFT (LoRA) config. Defaults to None. - kwargs: Additional arguments to override PEFT config params. + if train_config.use_peft: + # Load the pre-trained peft model checkpoint and setup its configuration + if train_config.from_peft_checkpoint: + model = PeftModel.from_pretrained(model, train_config.from_peft_checkpoint, is_trainable=True) + peft_config = model.peft_config + # Generate the peft config and start fine-tuning from original model + else: + peft_config = generate_peft_config(train_config, kwargs) + model = get_peft_model(model, peft_config) + model.print_trainable_parameters() - Returns: - Union[AutoModel, PeftModel]: If the use_peft in train_config is True - then PeftModel object is returned else original model object - (AutoModel) is returned. - """ - if not train_config.use_peft: - return model - - # Load the pre-trained peft model checkpoint and setup its configuration - if train_config.from_peft_checkpoint: - model = PeftModel.from_pretrained(model, train_config.from_peft_checkpoint, is_trainable=True) - peft_config = model.peft_config - # Generate the peft config and start fine-tuning from original model - else: - peft_config = generate_peft_config(train_config, peft_config_file, **kwargs) - model = get_peft_model(model, peft_config) - model.print_trainable_parameters() - - return model - - -def setup_dataloaders( - train_config: TrainConfig, - dataset_config: Any, - tokenizer: AutoTokenizer, -) -> tuple[torch.utils.data.DataLoader, Optional[torch.utils.data.DataLoader], int]: - """Set up training and validation DataLoaders. - - Args: - train_config (TrainConfig): Training configuration object. - dataset_config (Any): Configuration for the dataset (generated from train_config). - tokenizer (AutoTokenizer): Tokenizer for preprocessing data. - - Returns: - tuple: A tuple of three values. - - First value represents train_dataloader - - Second value represents eval_dataloader. It is None if - validation is disabled. - - Length of longest sequence in the dataset. - - Raises: - ValueError: If validation is enabled but the validation set is too small. - - Notes: - - Applies a custom data collator if provided by get_custom_data_collator. - - Configures DataLoader kwargs using get_dataloader_kwargs for train and val splits. - """ # Get the dataset utils dataset_processer = tokenizer @@ -245,8 +164,6 @@ def setup_dataloaders( ## train_dl_kwargs = get_dataloader_kwargs(train_config, dataset_train, dataset_processer, "train") print("length of dataset_train", len(dataset_train)) - - # FIXME (Meet): Add custom data collator registration from the outside by the user. custom_data_collator = get_custom_data_collator(dataset_processer, dataset_config) if custom_data_collator: print("custom_data_collator is used") @@ -291,66 +208,40 @@ def setup_dataloaders( else: longest_seq_length, _ = get_longest_seq_length(train_dataloader.dataset) - return train_dataloader, eval_dataloader, longest_seq_length - - -def main(peft_config_file: str = None, **kwargs) -> None: - """ - Fine-tune a model on QAIC hardware with configurable training and LoRA parameters. - - Args: - peft_config_file (str, optional): Path to YAML/JSON file containing PEFT (LoRA) config. Defaults to None. - kwargs: Additional arguments to override TrainConfig. - - Example: - .. code-block:: bash - - # Using a YAML config file for PEFT - python -m QEfficient.cloud.finetune \\ - --model_name "meta-llama/Llama-3.2-1B" \\ - --lr 5e-4 \\ - --peft_config_file "lora_config.yaml" - - # Using default LoRA config - python -m QEfficient.cloud.finetune \\ - --model_name "meta-llama/Llama-3.2-1B" \\ - --lr 5e-4 - """ - train_config = TrainConfig() - update_config(train_config, **kwargs) - dataset_config = generate_dataset_config(train_config.dataset) - update_config(dataset_config, **kwargs) - - setup_distributed_training(train_config) - setup_seeds(train_config.seed) - model, tokenizer = load_model_and_tokenizer(train_config, dataset_config, peft_config_file, **kwargs) - - # Create DataLoaders for the training and validation dataset - train_dataloader, eval_dataloader, longest_seq_length = setup_dataloaders(train_config, dataset_config, tokenizer) print( f"The longest sequence length in the train data is {longest_seq_length}, " f"passed context length is {train_config.context_length} and overall model's context length is " f"{model.config.max_position_embeddings}" ) - model.to(train_config.device) - optimizer = optim.AdamW(model.parameters(), lr=train_config.lr, weight_decay=train_config.weight_decay) + optimizer = optim.AdamW( + model.parameters(), + lr=train_config.lr, + weight_decay=train_config.weight_decay, + ) scheduler = StepLR(optimizer, step_size=1, gamma=train_config.gamma) + + # wrap model with DDP if train_config.enable_ddp: model = nn.parallel.DistributedDataParallel(model, device_ids=[dist.get_rank()]) - results = train( + + _ = train( model, - tokenizer, train_dataloader, eval_dataloader, + tokenizer, optimizer, scheduler, + train_config.gradient_accumulation_steps, train_config, + train_config.device, dist.get_rank() if train_config.enable_ddp else None, + None, ) + + # finalize torch distributed if train_config.enable_ddp: dist.destroy_process_group() - return results if __name__ == "__main__": diff --git a/QEfficient/cloud/infer.py b/QEfficient/cloud/infer.py index 30e67344a..68be72fa8 100644 --- a/QEfficient/cloud/infer.py +++ b/QEfficient/cloud/infer.py @@ -197,10 +197,6 @@ def main( **kwargs, ) - # If the io-encrypt flag is passed we will exit after QPC generation. - if kwargs.get("io_encrypt", None): - exit() - ######### # Execute ######### diff --git a/QEfficient/compile/compile_helper.py b/QEfficient/compile/compile_helper.py index 70a912cd7..5ce22bed9 100644 --- a/QEfficient/compile/compile_helper.py +++ b/QEfficient/compile/compile_helper.py @@ -64,6 +64,9 @@ def compile_kv_model_on_cloud_ai_100( DeprecationWarning, stacklevel=2, ) + if kwargs: + # FIXME + raise NotImplementedError("Can't handle extra compilation args now!") aic_binary_dir = os.path.join(base_path, "qpcs") if os.path.isdir(aic_binary_dir): @@ -108,13 +111,6 @@ def compile_kv_model_on_cloud_ai_100( with open(mdp_ts_config_path, "w") as file: json.dump(mdp_ts_config, file, indent=4) command.append(f"-mdp-load-partition-config={mdp_ts_config_path}") - for key, value in kwargs.items(): - option = "-" + key.replace("_", "-") - if isinstance(value, bool): - if value: - command.append(option) - continue - command.append(f"{option}={value}") print("Running AI 100 compiler:", " ".join(command)) result = subprocess.run(command, capture_output=True, text=True) if result.returncode != 0: @@ -225,13 +221,6 @@ def compile( allow_mxint8_mdp_io=allow_mxint8_mdp_io, mos=mos, device_group=device_group, - **kwargs, ) - if kwargs.get("io_encrypt", None): - logger.warning( - f"Compilation for IO-Encrypt has been successfully completed at path: {qpc_path}. However, Efficient-Transformers do not support IO-Encrypt execution. Please run the execution separately" - ) - else: - logger.info(f"Compiled QPC files can be found here: {qpc_path}") - + logger.info(f"Compiled QPC files can be found here: {qpc_path}") return qpc_path diff --git a/QEfficient/finetune/configs/peft_config.py b/QEfficient/finetune/configs/peft_config.py index a47774500..e2d018f05 100644 --- a/QEfficient/finetune/configs/peft_config.py +++ b/QEfficient/finetune/configs/peft_config.py @@ -9,24 +9,15 @@ from typing import List +# Currently, the support is for Lora Configs only +# In future, we can expand to llama_adapters and prefix tuning +# TODO: vbaddi: Check back once FSDP is enabled @dataclass -class LoraConfig: - """LoRA-specific configuration for parameter-efficient fine-tuning. - - Attributes: - r (int): LoRA rank (default: 8). - lora_alpha (int): LoRA scaling factor (default: 32). - target_modules (List[str]): Modules to apply LoRA to (default: ["q_proj", "v_proj"]). - bias (str): Bias handling in LoRA (default: "none"). - task_type (str): Task type for LoRA (default: "CAUSAL_LM"). - lora_dropout (float): Dropout rate for LoRA (default: 0.0). - inference_mode (bool): Whether model is in inference mode (default: False). - """ - +class lora_config: r: int = 8 lora_alpha: int = 32 target_modules: List[str] = field(default_factory=lambda: ["q_proj", "v_proj"]) - bias: str = "none" + bias = "none" task_type: str = "CAUSAL_LM" lora_dropout: float = 0.05 inference_mode: bool = False # should be False for finetuning @@ -34,6 +25,6 @@ class LoraConfig: # CAUTION prefix tuning is currently not supported @dataclass -class PrefixConfig: +class prefix_config: num_virtual_tokens: int = 30 task_type: str = "CAUSAL_LM" diff --git a/QEfficient/finetune/configs/training.py b/QEfficient/finetune/configs/training.py index 69b083b6a..c50954c4c 100644 --- a/QEfficient/finetune/configs/training.py +++ b/QEfficient/finetune/configs/training.py @@ -7,54 +7,8 @@ from dataclasses import dataclass -# Configuration Classes @dataclass -class TrainConfig: - """Training configuration for model fine-tuning. - - Attributes: - model_name (str): Name of the pre-trained model to fine-tune (default: "meta-llama/Llama-3.2-1B"). - tokenizer_name (str): Name of the tokenizer (defaults to model_name if None). - run_validation (bool): Whether to run validation during training (default: True). - batch_size_training (int): Batch size for training (default: 1). - context_length (Optional[int]): Maximum sequence length for inputs (default: None). - gradient_accumulation_steps (int): Steps for gradient accumulation (default: 4). - gradient checkpointing (bool): Enable gradient checkpointing to save the memory by compromising the speed. (default: False). - num_epochs (int): Number of training epochs (default: 1). - max_train_step (int): Maximum training steps (default: 0, unlimited if 0). - max_eval_step (int): Maximum evaluation steps (default: 0, unlimited if 0). - device (str): Device to train on (default: "qaic"). - num_workers_dataloader (int): Number of workers for data loading (default: 1). - lr (float): Learning rate (default: 3e-4). - weight_decay (float): Weight decay for optimizer (default: 0.0). - gamma (float): Learning rate decay factor (default: 0.85). - seed (int): Random seed for reproducibility (default: 42). - use_fp16 (bool): Use mixed precision training (default: True). - use_autocast (bool): Use autocast for mixed precision (default: True). - val_batch_size (int): Batch size for validation (default: 1). - dataset (str): Dataset name for training (default: "samsum_dataset"). - task_type (str): Type of task for which the finetuning is to be done. Options: "generation" and "seq_classification". (default: "generation") - peft_method (str): Parameter-efficient fine-tuning method (default: "lora"). - use_peft (bool): Whether to use PEFT (default: True). - from_peft_checkpoint (str): Path to PEFT checkpoint (default: ""). - output_dir (str): Directory to save outputs (default: "meta-llama-samsum"). - num_freeze_layers (int): Number of layers to freeze (default: 1). - one_qaic (bool): Use single QAIC device (default: False). - save_model (bool): Save the trained model (default: True). - save_metrics (bool): Save training metrics (default: True). - intermediate_step_save (int): Steps between intermediate saves (default: 1000). - batching_strategy (str): Batching strategy (default: "packing"). - enable_sorting_for_ddp (bool): Sort data for DDP (default: True). - convergence_counter (int): Steps to check convergence (default: 5). - convergence_loss (float): Loss threshold for convergence (default: 1e-4). - use_profiler (bool): Enable profiling (default: False). - enable_ddp (bool): Enable distributed data parallel (default: False). - dist_backend (str): Backend for distributed training (default: "cpu:gloo,qaic:qccl,cuda:gloo"). - grad_scaler (bool): Use gradient scaler (default: True). - dump_root_dir (str): Directory for mismatch dumps (default: "meta-llama-samsum-mismatches/step_"). - opByOpVerifier (bool): Enable operation-by-operation verification (default: False). - """ - +class train_config: model_name: str = "meta-llama/Llama-3.2-1B" tokenizer_name: str = None # if not passed as an argument, it uses the value of model_name run_validation: bool = True diff --git a/QEfficient/finetune/eval.py b/QEfficient/finetune/eval.py index 3fe6e0d81..918230554 100644 --- a/QEfficient/finetune/eval.py +++ b/QEfficient/finetune/eval.py @@ -11,6 +11,7 @@ import fire import numpy as np import torch +from configs.training import train_config as TRAIN_CONFIG from peft import AutoPeftModelForCausalLM from transformers import AutoModelForCausalLM, AutoTokenizer from utils.config_utils import ( @@ -24,8 +25,6 @@ ) from utils.train_utils import evaluation, print_model_size -from QEfficient.finetune.configs.training import TrainConfig - try: import torch_qaic # noqa: F401 @@ -40,7 +39,7 @@ def main(**kwargs): # update the configuration for the training process - train_config = TrainConfig() + train_config = TRAIN_CONFIG() update_config(train_config, **kwargs) # Set the seeds for reproducibility diff --git a/QEfficient/finetune/utils/config_utils.py b/QEfficient/finetune/utils/config_utils.py index c5c7fe615..e979961d6 100644 --- a/QEfficient/finetune/utils/config_utils.py +++ b/QEfficient/finetune/utils/config_utils.py @@ -4,39 +4,27 @@ # SPDX-License-Identifier: BSD-3-Clause # # ----------------------------------------------------------------------------- + import inspect -import json -import os from dataclasses import asdict -from typing import Any, Dict import torch.distributed as dist import torch.utils.data as data_utils -import yaml from peft import ( AdaptionPromptConfig, + LoraConfig, PrefixTuningConfig, ) -from peft import LoraConfig as PeftLoraConfig from transformers.data import DataCollatorForSeq2Seq import QEfficient.finetune.configs.dataset_config as datasets -from QEfficient.finetune.configs.peft_config import LoraConfig, PrefixConfig -from QEfficient.finetune.configs.training import TrainConfig +from QEfficient.finetune.configs.peft_config import lora_config, prefix_config +from QEfficient.finetune.configs.training import train_config from QEfficient.finetune.data.sampler import DistributedLengthBasedBatchSampler from QEfficient.finetune.dataset.dataset_config import DATASET_PREPROC def update_config(config, **kwargs): - """Update the attributes of a config object based on provided keyword arguments. - - Args: - config: The configuration object (e.g., TrainConfig, LoraConfig) or a list/tuple of such objects. - **kwargs: Keyword arguments representing attributes to update. - - Raises: - ValueError: If an unknown parameter is provided and the config type doesn't support nested updates. - """ if isinstance(config, (tuple, list)): for c in config: update_config(c, **kwargs) @@ -45,73 +33,40 @@ def update_config(config, **kwargs): if hasattr(config, k): setattr(config, k, v) elif "." in k: - config_name, param_name = k.split(".", 1) - if type(config).__name__.lower() == config_name.lower(): + # allow --some_config.some_param=True + config_name, param_name = k.split(".") + if type(config).__name__ == config_name: if hasattr(config, param_name): setattr(config, param_name, v) else: - raise ValueError(f"Config '{config_name}' does not have parameter: '{param_name}'") - else: - config_type = type(config).__name__ - # FIXME (Meet): Once logger is available put this in debug level. - print(f"[WARNING]: Unknown parameter '{k}' for config type '{config_type}'") + # In case of specialized config we can warn user + assert False, f"Warning: {config_name} does not accept parameter: {k}" + elif isinstance(config, train_config): + assert False, f"Warning: unknown parameter {k}" -def generate_peft_config(train_config: TrainConfig, peft_config_file: str = None, **kwargs) -> Any: - """Generate a PEFT-compatible configuration from a custom config based on peft_method. +def generate_peft_config(train_config, kwargs): + configs = (lora_config, prefix_config) + peft_configs = (LoraConfig, AdaptionPromptConfig, PrefixTuningConfig) + names = tuple(c.__name__.rstrip("_config") for c in configs) - Args: - train_config (TrainConfig): Training configuration with peft_method. - custom_config: Custom configuration object (e.g., LoraConfig). + if train_config.peft_method not in names: + raise RuntimeError(f"Peft config not found: {train_config.peft_method}") - Returns: - Any: A PEFT-specific configuration object (e.g., PeftLoraConfig). + config = configs[names.index(train_config.peft_method)]() - Raises: - RuntimeError: If the peft_method is not supported. - """ - if peft_config_file: - peft_config_data = load_config_file(peft_config_file) - validate_config(peft_config_data, config_type="lora") - peft_config = PeftLoraConfig(**peft_config_data) - else: - config_map = { - "lora": (LoraConfig, PeftLoraConfig), - "prefix": (PrefixConfig, PrefixTuningConfig), - "adaption_prompt": (None, AdaptionPromptConfig), - } - - if train_config.peft_method not in config_map: - raise RuntimeError(f"Peft config not found: {train_config.peft_method}") - - config_cls, peft_config_cls = config_map[train_config.peft_method] - if config_cls is None: - params = kwargs - else: - config = config_cls() - update_config(config, **kwargs) - params = asdict(config) + update_config(config, **kwargs) + params = asdict(config) + peft_config = peft_configs[names.index(train_config.peft_method)](**params) - peft_config = peft_config_cls(**params) return peft_config -def generate_dataset_config(dataset_name: str) -> Any: - """Generate a dataset configuration based on the specified dataset. - - Args: - dataset_name (str): Name of the dataset to be used for finetuning. - - Returns: - Any: A dataset configuration object. - - Raises: - AssertionError: If the dataset name is not recognized. - """ - supported_datasets = DATASET_PREPROC.keys() - assert dataset_name in supported_datasets, f"Given dataset '{dataset_name}' is not supported." - # FIXME (Meet): Replace below logic by creating using auto registry of datasets. - dataset_config = {k: v for k, v in inspect.getmembers(datasets)}[dataset_name]() +def generate_dataset_config(train_config, kwargs): + names = tuple(DATASET_PREPROC.keys()) + assert train_config.dataset in names, f"Unknown dataset: {train_config.dataset}" + dataset_config = {k: v for k, v in inspect.getmembers(datasets)}[train_config.dataset]() + update_config(dataset_config, **kwargs) return dataset_config @@ -143,84 +98,3 @@ def get_dataloader_kwargs(train_config, dataset, dataset_processer, mode): kwargs["drop_last"] = True kwargs["collate_fn"] = DataCollatorForSeq2Seq(dataset_processer) return kwargs - - -def validate_config(config_data: Dict[str, Any], config_type: str = "lora") -> None: - """Validate the provided YAML/JSON configuration for required fields and types. - - Args: - config_data (Dict[str, Any]): The configuration dictionary loaded from YAML/JSON. - config_type (str): Type of config to validate ("lora" for LoraConfig, default: "lora"). - - Raises: - ValueError: If required fields are missing or have incorrect types. - FileNotFoundError: If the config file path is invalid (handled upstream). - - Notes: - - Validates required fields for LoraConfig: r, lora_alpha, target_modules. - - Ensures types match expected values (int, float, list, etc.). - """ - if config_type.lower() != "lora": - raise ValueError(f"Unsupported config_type: {config_type}. Only 'lora' is supported.") - - required_fields = { - "r": int, - "lora_alpha": int, - "target_modules": list, - } - optional_fields = { - "bias": str, - "task_type": str, - "lora_dropout": float, - "inference_mode": bool, - } - - # Check for missing required fields - missing_fields = [field for field in required_fields if field not in config_data] - if missing_fields: - raise ValueError(f"Missing required fields in {config_type} config: {missing_fields}") - - # Validate types of required fields - for field, expected_type in required_fields.items(): - if not isinstance(config_data[field], expected_type): - raise ValueError( - f"Field '{field}' in {config_type} config must be of type {expected_type.__name__}, " - f"got {type(config_data[field]).__name__}" - ) - - # Validate target_modules contains strings - if not all(isinstance(mod, str) for mod in config_data["target_modules"]): - raise ValueError("All elements in 'target_modules' must be strings") - - # Validate types of optional fields if present - for field, expected_type in optional_fields.items(): - if field in config_data and not isinstance(config_data[field], expected_type): - raise ValueError( - f"Field '{field}' in {config_type} config must be of type {expected_type.__name__}, " - f"got {type(config_data[field]).__name__}" - ) - - -def load_config_file(config_path: str) -> Dict[str, Any]: - """Load a configuration from a YAML or JSON file. - - Args: - config_path (str): Path to the YAML or JSON file. - - Returns: - Dict[str, Any]: The loaded configuration as a dictionary. - - Raises: - FileNotFoundError: If the file does not exist. - ValueError: If the file format is unsupported. - """ - if not os.path.exists(config_path): - raise FileNotFoundError(f"Config file not found: {config_path}") - - with open(config_path, "r") as f: - if config_path.endswith(".yaml") or config_path.endswith(".yml"): - return yaml.safe_load(f) - elif config_path.endswith(".json"): - return json.load(f) - else: - raise ValueError("Unsupported config file format. Use .yaml, .yml, or .json") diff --git a/QEfficient/finetune/utils/train_utils.py b/QEfficient/finetune/utils/train_utils.py index 8693ae32d..2bc701008 100644 --- a/QEfficient/finetune/utils/train_utils.py +++ b/QEfficient/finetune/utils/train_utils.py @@ -18,7 +18,7 @@ from torch.utils.tensorboard import SummaryWriter from tqdm import tqdm -from QEfficient.finetune.configs.training import TrainConfig +from QEfficient.finetune.configs.training import train_config as TRAIN_CONFIG try: import torch_qaic # noqa: F401 @@ -34,31 +34,34 @@ def train( model, - tokenizer, train_dataloader, eval_dataloader, + tokenizer, optimizer, lr_scheduler, - train_config: TrainConfig, + gradient_accumulation_steps, + train_config: TRAIN_CONFIG, + device, local_rank=None, + rank=None, ): """ Trains the model on the given dataloader Args: model: The model to be trained - tokenizer: tokenizer used in the eval for decoding the predicitons train_dataloader: The dataloader containing the training data - eval_dataloader: The dataloader containing the eval data optimizer: The optimizer used for training lr_scheduler: The learning rate scheduler - train_config: The training configuration + gradient_accumulation_steps: The number of steps to accumulate gradients before performing a backward/update operation + num_epochs: The number of epochs to train for local_rank: The rank of the current node in a distributed setting + train_config: The training configuration + eval_dataloader: The dataloader containing the eval data + tokenizer: tokenizer used in the eval for decoding the predicitons Returns: results dictionary containing average training and validation perplexity and loss """ - device = train_config.device - train_metric = [] train_loss = [] val_metric = [] @@ -458,7 +461,7 @@ def evaluation_helper(model, train_config, eval_dataloader, device): # Print evaluation metrics print(f" {eval_metric.detach().cpu()=} {eval_epoch_loss.detach().cpu()=}") - return eval_epoch_loss, eval_metric, val_step_loss, val_step_metric + return eval_metric, eval_epoch_loss, val_step_loss, val_step_metric def get_longest_seq_length(data: List[Dict]) -> Tuple[int, int]: diff --git a/QEfficient/transformers/models/gemma3/modeling_gemma3.py b/QEfficient/transformers/models/gemma3/modeling_gemma3.py index 70601489d..58b837e9c 100644 --- a/QEfficient/transformers/models/gemma3/modeling_gemma3.py +++ b/QEfficient/transformers/models/gemma3/modeling_gemma3.py @@ -560,9 +560,16 @@ def __init__(self, model): self.model = model self.model.vision_model = self.model.vision_tower - def forward(self, pixel_values): + def forward(self, input_ids, pixel_values): + inputs_embeds = self.model.get_input_embeddings()(input_ids) + B, N, C = inputs_embeds.shape image_features = self.model.get_image_features(pixel_values=pixel_values) - return image_features + selected = input_ids == self.model.config.image_token_index + indices1 = selected.to(torch.int64).cumsum(1) - 1 + indices0 = torch.arange(selected.unsqueeze(0).shape[0]).view(-1, 1) + image_features_expanded = image_features.reshape(-1, C).unsqueeze(0)[indices0, indices1] + image_input_embeds = torch.where(selected.unsqueeze(-1), image_features_expanded, inputs_embeds) + return image_input_embeds class QEffGemma3DecoderWrapper(nn.Module): @@ -572,21 +579,14 @@ def __init__(self, model): self.language_model = self.model.language_model self.config = self.model.config - def forward(self, input_ids, vision_embeds, position_ids, index, past_key_values): - inputs_embeds = self.model.get_input_embeddings()(input_ids) - B, N, C = inputs_embeds.shape - selected = input_ids == self.model.config.image_token_index - indices1 = selected.to(torch.int64).cumsum(1) - 1 - indices1 = torch.where(indices1 != -1, indices1 + index, indices1) - indices0 = torch.arange(selected.unsqueeze(0).shape[0]).view(-1, 1) - image_features_expanded = vision_embeds.reshape(-1, C).unsqueeze(0)[indices0, indices1] - image_input_embeds = torch.where(selected.unsqueeze(-1), image_features_expanded, inputs_embeds) - inputs_embeds = torch.where(input_ids.shape[1] == torch.tensor(1), inputs_embeds, image_input_embeds) + def forward(self, input_ids, vision_embeds, position_ids, past_key_values): + image_embeds = vision_embeds[:, : input_ids.shape[1], :] + inputs_embeds = self.model.language_model.get_input_embeddings()(input_ids) + inputs_embeds = torch.where(input_ids.shape[1] == torch.tensor(1), inputs_embeds, image_embeds) outputs = self.model.language_model( inputs_embeds=inputs_embeds, position_ids=position_ids, past_key_values=past_key_values, use_cache=True ) - index = (indices1.max() + 1).unsqueeze(0).unsqueeze(0) - return outputs.logits, vision_embeds, index, outputs.past_key_values + return outputs.logits, vision_embeds, outputs.past_key_values class QEffGemma3ForConditionalGeneration(Gemma3ForConditionalGeneration): @@ -605,6 +605,11 @@ def get_specializations( kv_offload: bool = False, **compiler_options, ): + vision_seq_len = compiler_options.pop("vision_seq_len", None) + if vision_seq_len is None: + # TODO: Check properly for Gemma3, Not verified yet. + vision_seq_len = 512 # for Gemma3 Vision feature shape is (1, 4096, 1152) --> 1152 is hidden size) + prefill_seq_len = prefill_seq_len if prefill_seq_len else 32 ctx_len = ctx_len if ctx_len else constants.INTERN_CTX_LEN if img_size is None and hasattr(self.config.vision_config, "image_size"): @@ -612,13 +617,12 @@ def get_specializations( elif img_size is None: img_size = 896 # FIXME based on gemma3 Image size logger.warning("Setting img_size to be 336, as it was neither passed nor found in vision_config") - mm_tokens_per_image = getattr(self.config, "mm_tokens_per_image", 256) vision = [ { "batch_size": batch_size, "img_size": img_size, - "seq_len": prefill_seq_len, + "seq_len": vision_seq_len, "ctx_len": ctx_len, } ] @@ -628,14 +632,14 @@ def get_specializations( "seq_len": prefill_seq_len, "ctx_len": ctx_len, "img_size": img_size, - "mm_tokens_per_image": mm_tokens_per_image, + "chunk_length": prefill_seq_len, }, { "batch_size": batch_size, "seq_len": "1", "ctx_len": ctx_len, "img_size": img_size, - "mm_tokens_per_image": mm_tokens_per_image, + "chunk_length": prefill_seq_len, }, ] @@ -654,8 +658,9 @@ def get_onnx_dynamic_axes(self, kv_offload: bool = False): lang_dynamic_axes = {} lang_dynamic_axes["input_ids"] = {0: "batch_size", 1: "seq_len"} lang_dynamic_axes["position_ids"] = {0: "batch_size", 1: "seq_len"} - lang_dynamic_axes["vision_embeds"] = {0: "batch_size", 1: "mm_tokens_per_image"} + lang_dynamic_axes["vision_embeds"] = {0: "batch_size", 1: "chunk_length"} vision_dynamic_axes["pixel_values"] = {0: "batch_size", 2: "img_size", 3: "img_size"} + vision_dynamic_axes["input_ids"] = {0: "batch_size", 1: "seq_len"} pkv_dynamic_axes = {0: "batch_size", 2: "ctx_len"} for i in range(self.language_model.config.num_hidden_layers): @@ -680,7 +685,6 @@ def get_output_names(self, kv_offload: bool = False): output_names = {} if kv_offload: lang_output_names.insert(1, "vision_embeds_RetainedState") - lang_output_names.insert(2, "index_output") output_names["vision"] = vision_output_names output_names["lang"] = lang_output_names else: @@ -694,13 +698,12 @@ def get_dummy_inputs(self, kv_offload: bool = False): else: img_size = 896 - mm_tokens_per_image = getattr(self.config, "mm_tokens_per_image", 256) # Define shapes inputs_shapes = {} inputs_shapes["input_ids"] = (constants.ONNX_EXPORT_EXAMPLE_BATCH_SIZE, constants.ONNX_EXPORT_EXAMPLE_SEQ_LEN) inputs_shapes["vision_embeds"] = ( 1, # constants.INTERN_NUM_PATCHES, - mm_tokens_per_image, # constants.INTERN_FEATURE_SIZE, + constants.ONNX_EXPORT_EXAMPLE_SEQ_LEN, # constants.INTERN_FEATURE_SIZE, self.language_model.config.hidden_size, # 5120 ) inputs_shapes["position_ids"] = ( @@ -713,12 +716,12 @@ def get_dummy_inputs(self, kv_offload: bool = False): img_size, img_size, ) - inputs_shapes["index"] = (1, 1) # Define inputs vision_inputs = {} lang_inputs = {} vision_inputs["pixel_values"] = torch.zeros((inputs_shapes["pixel_values"]), dtype=torch.float32) + vision_inputs["input_ids"] = torch.zeros((inputs_shapes["input_ids"]), dtype=torch.int64) lang_inputs["input_ids"] = torch.zeros((inputs_shapes["input_ids"]), dtype=torch.int64) lang_inputs["vision_embeds"] = torch.zeros((inputs_shapes["vision_embeds"]), dtype=torch.float32) lang_inputs["position_ids"] = ( @@ -726,7 +729,7 @@ def get_dummy_inputs(self, kv_offload: bool = False): .view(1, constants.ONNX_EXPORT_EXAMPLE_SEQ_LEN) .repeat(constants.ONNX_EXPORT_EXAMPLE_BATCH_SIZE, 1) ) - lang_inputs["index"] = torch.zeros((inputs_shapes["index"]), dtype=torch.int64) + # Add data for KV kv_cache_shape = get_padding_shape_from_config( config=self.language_model.config, diff --git a/QEfficient/transformers/models/modeling_auto.py b/QEfficient/transformers/models/modeling_auto.py index ebfd529cc..1a9610187 100644 --- a/QEfficient/transformers/models/modeling_auto.py +++ b/QEfficient/transformers/models/modeling_auto.py @@ -751,8 +751,8 @@ def kv_offload_generate( input_len = inputs["attention_mask"].sum(1, keepdims=True) input_ids_length = inputs["input_ids"].shape[1] num_chunks = -(input_ids_length // -prefill_seq_len) # ceil divide without float - padded_len = num_chunks * prefill_seq_len # Convert to a multiple of prompt_len - + # padded_len = num_chunks * prefill_seq_len # Convert to a multiple of prompt_len + padded_len = vision_session.bindings[vision_session.binding_index_map["input_ids"]].dims[1] if generation_len is None: generation_len = ctx_len - input_len.max() assert generation_len > 0, "generation length should be greater than zero" @@ -783,11 +783,13 @@ def kv_offload_generate( } vision_inputs["pixel_values"] = vision_inputs["pixel_values"].astype("float16") + vision_inputs["input_ids"] = inputs["input_ids"] vision_start = perf_counter() vision_outputs = vision_session.run(vision_inputs) vision_end = perf_counter() lang_inputs = {k: v for k, v in inputs.items() if k not in vision_inputs} + lang_inputs["input_ids"] = inputs["input_ids"] lang_inputs["position_ids"] = np.where( lang_inputs.pop("attention_mask"), np.arange(padded_len), -1 ) # Need to use -1 as position_ids for invalid tokens @@ -795,27 +797,25 @@ def kv_offload_generate( vision_session.deactivate() lang_session.activate() lang_inputs["vision_embeds"] = vision_outputs["vision_embeds"] - lang_session.set_buffers(vision_outputs) + # lang_session.set_buffers(vision_outputs) prefill_start = perf_counter() # Run prefill - chunk_inputs = lang_inputs.copy() - chunk_inputs["index"] = np.array([[0]]) for i in range(num_chunks): + chunk_inputs = lang_inputs.copy() chunk_inputs["input_ids"] = lang_inputs["input_ids"][:, i * prefill_seq_len : (i + 1) * prefill_seq_len] chunk_inputs["position_ids"] = lang_inputs["position_ids"][ :, i * prefill_seq_len : (i + 1) * prefill_seq_len ] + chunk_inputs["vision_embeds"] = lang_inputs["vision_embeds"][ + :, i * prefill_seq_len : (i + 1) * prefill_seq_len + ] outputs = lang_session.run(chunk_inputs) - chunk_inputs["index"] = outputs["index_output"] prefill_time = perf_counter() - prefill_start + vision_end - vision_start + lang_inputs["vision_embeds"] = lang_inputs["vision_embeds"][:, :prefill_seq_len] # Skip inputs/outputs again lang_session.skip_buffers( - [ - x - for x in lang_session.input_names + lang_session.output_names - if x.startswith("past_") or x.endswith("_RetainedState") - ] + [x for x in lang_session.input_names + lang_session.output_names if x.startswith("past_")] ) # Get first token @@ -1643,11 +1643,6 @@ def compile( **compiler_options, ) - if compiler_options.get("io_encrypt", None): - logger.warning( - "Compilation for IO-Encrypt has been successfully completed. However, Efficient-Transformers do not support IO-Encrypt execution. Please run the execution separately with QPC compiled without io-encrypt." - ) - return qpc_path # FIXME: Update this method to match with transformers AutoModelForCausalLM.generate diff --git a/QEfficient/utils/_utils.py b/QEfficient/utils/_utils.py index 564bdd94d..b6af66be5 100644 --- a/QEfficient/utils/_utils.py +++ b/QEfficient/utils/_utils.py @@ -521,57 +521,27 @@ def __repr__(self): def dump_qconfig(func): def wrapper(self, *args, **kwargs): result = func(self, *args, **kwargs) - try: - create_and_dump_qconfigs( - self.qpc_path, - self.onnx_path, - self.get_model_config, - [cls.__name__ for cls in self._pytorch_transforms], - [cls.__name__ for cls in self._onnx_transforms], - kwargs.get("specializations"), - kwargs.get("mdp_ts_num_devices", 1), - kwargs.get("num_speculative_tokens"), - **{ - k: v - for k, v in kwargs.items() - if k - not in ["specializations", "mdp_ts_num_devices", "num_speculative_tokens", "custom_io", "onnx_path"] - }, - ) - except Exception as e: - print(f"An unexpected error occurred while dumping the qconfig: {e}") + create_and_dump_qconfigs( + self.qpc_path, + self.onnx_path, + self.get_model_config, + [cls.__name__ for cls in self._pytorch_transforms], + [cls.__name__ for cls in self._onnx_transforms], + kwargs.get("specializations"), + kwargs.get("mdp_ts_num_devices", 1), + kwargs.get("num_speculative_tokens"), + **{ + k: v + for k, v in kwargs.items() + if k + not in ["specializations", "mdp_ts_num_devices", "num_speculative_tokens", "custom_io", "onnx_path"] + }, + ) return result return wrapper -def get_qaic_sdk_version(qaic_sdk_xml_path: str) -> Optional[str]: - """ - Extracts the QAIC SDK version from the given SDK XML file. - - Args: - qaic_sdk_xml_path (str): Path to the SDK XML file. - Returns: - The SDK version as a string if found, otherwise None. - """ - qaic_sdk_version = None - - # Check and extract version from the given SDK XML file - if os.path.exists(qaic_sdk_xml_path): - try: - tree = ET.parse(qaic_sdk_xml_path) - root = tree.getroot() - base_version_element = root.find(".//base_version") - if base_version_element is not None: - qaic_sdk_version = base_version_element.text - except ET.ParseError as e: - print(f"Error parsing XML file {qaic_sdk_xml_path}: {e}") - except Exception as e: - print(f"An unexpected error occurred while processing {qaic_sdk_xml_path}: {e}") - - return qaic_sdk_version - - def create_and_dump_qconfigs( qpc_path, onnx_path, @@ -588,12 +558,29 @@ def create_and_dump_qconfigs( Such as huggingface configs, QEff transforms, QAIC sdk version, QNN sdk, compilation dir, qpc dir and many other compilation options. """ - enable_qnn = compiler_options.get("enable_qnn", False) - qnn_config_path = compiler_options.get("qnn_config", None) + qnn_config = compiler_options["qnn_config"] if "qnn_config" in compiler_options else None + enable_qnn = True if "qnn_config" in compiler_options else None + qconfig_file_path = os.path.join(os.path.dirname(qpc_path), "qconfig.json") onnx_path = str(onnx_path) specializations_file_path = str(os.path.join(os.path.dirname(qpc_path), "specializations.json")) compile_dir = str(os.path.dirname(qpc_path)) + qnn_config_path = ( + (qnn_config if qnn_config is not None else "QEfficient/compile/qnn_config.json") if enable_qnn else None + ) + + # Extract QAIC SDK Apps Version from SDK XML file + tree = ET.parse(Constants.SDK_APPS_XML) + root = tree.getroot() + qaic_version = root.find(".//base_version").text + + # Extract QNN SDK details from YAML file if the environment variable is set + qnn_sdk_details = None + qnn_sdk_path = os.getenv(QnnConstants.QNN_SDK_PATH_ENV_VAR_NAME) + if enable_qnn and qnn_sdk_path: + qnn_sdk_yaml_path = os.path.join(qnn_sdk_path, QnnConstants.QNN_SDK_YAML) + with open(qnn_sdk_yaml_path, "r") as file: + qnn_sdk_details = yaml.safe_load(file) # Ensure all objects in the configs dictionary are JSON serializable def make_serializable(obj): @@ -615,38 +602,29 @@ def make_serializable(obj): "onnx_transforms": make_serializable(onnx_transforms), "onnx_path": onnx_path, }, - "compiler_config": { - "enable_qnn": enable_qnn, - "compile_dir": compile_dir, - "specializations_file_path": specializations_file_path, - "specializations": make_serializable(specializations), - "mdp_ts_num_devices": mdp_ts_num_devices, - "num_speculative_tokens": num_speculative_tokens, - **compiler_options, - }, - "aic_sdk_config": { - "qaic_apps_version": get_qaic_sdk_version(Constants.SDK_APPS_XML), - "qaic_platform_version": get_qaic_sdk_version(Constants.SDK_PLATFORM_XML), - }, }, } + aic_compiler_config = { + "apps_sdk_version": qaic_version, + "compile_dir": compile_dir, + "specializations_file_path": specializations_file_path, + "specializations": make_serializable(specializations), + "mdp_ts_num_devices": mdp_ts_num_devices, + "num_speculative_tokens": num_speculative_tokens, + **compiler_options, + } + qnn_config = { + "enable_qnn": enable_qnn, + "qnn_config_path": qnn_config_path, + } + # Put AIC or qnn details. if enable_qnn: - qnn_sdk_path = os.getenv(QnnConstants.QNN_SDK_PATH_ENV_VAR_NAME) - if not qnn_sdk_path: - raise EnvironmentError( - f"QNN_SDK_PATH {qnn_sdk_path} is not set. Please set {QnnConstants.QNN_SDK_PATH_ENV_VAR_NAME}" - ) - qnn_sdk_yaml_path = os.path.join(qnn_sdk_path, QnnConstants.QNN_SDK_YAML) - qnn_sdk_details = load_yaml( - qnn_sdk_yaml_path - ) # Extract QNN SDK details from YAML file if the environment variable is set - qnn_config = { - "qnn_config_path": qnn_config_path, - } qconfigs["qpc_config"]["qnn_config"] = qnn_config if qnn_sdk_details: qconfigs["qpc_config"]["qnn_config"].update(qnn_sdk_details) + else: + qconfigs["qpc_config"]["aic_compiler_config"] = aic_compiler_config create_json(qconfig_file_path, qconfigs) diff --git a/QEfficient/utils/constants.py b/QEfficient/utils/constants.py index c8f74907a..b1ff9701e 100644 --- a/QEfficient/utils/constants.py +++ b/QEfficient/utils/constants.py @@ -97,10 +97,7 @@ class Constants: MAX_QPC_LIMIT = 30 MAX_RETRIES = 10 # This constant will be used set the maximum number of retry attempts for downloading a model using huggingface_hub snapshot_download NUM_SPECULATIVE_TOKENS = 2 - SDK_APPS_XML = "/opt/qti-aic/versions/apps.xml" # This xml file is parsed to find out the SDK apps version. - SDK_PLATFORM_XML = ( - "/opt/qti-aic/versions/platform.xml" # This xml file is parsed to find out the SDK platform version. - ) + SDK_APPS_XML = "/opt/qti-aic/versions/apps.xml" # This xml file is parsed to find out the SDK version. @dataclass diff --git a/scripts/Jenkinsfile b/scripts/Jenkinsfile index 7036d6f6d..fcd2fece5 100644 --- a/scripts/Jenkinsfile +++ b/scripts/Jenkinsfile @@ -171,4 +171,4 @@ pipeline { deleteDir() } } -} \ No newline at end of file +} diff --git a/scripts/finetune/run_ft_model.py b/scripts/finetune/run_ft_model.py index ef014923b..5e88db641 100644 --- a/scripts/finetune/run_ft_model.py +++ b/scripts/finetune/run_ft_model.py @@ -12,7 +12,7 @@ from peft import AutoPeftModelForCausalLM from transformers import AutoModelForCausalLM, AutoTokenizer -from QEfficient.finetune.configs.training import TrainConfig +from QEfficient.finetune.configs.training import train_config as TRAIN_CONFIG # Suppress all warnings warnings.filterwarnings("ignore") @@ -25,7 +25,7 @@ print(f"Warning: {e}. Moving ahead without these qaic modules.") device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") -train_config = TrainConfig() +train_config = TRAIN_CONFIG() model = AutoModelForCausalLM.from_pretrained( train_config.model_name, use_cache=False, diff --git a/tests/finetune/test_finetune.py b/tests/finetune/test_finetune.py index fb4a84dc0..45330cad6 100644 --- a/tests/finetune/test_finetune.py +++ b/tests/finetune/test_finetune.py @@ -8,7 +8,6 @@ import os import shutil -import numpy as np import pytest import torch.optim as optim from torch.utils.data import DataLoader @@ -23,25 +22,12 @@ def clean_up(path): shutil.rmtree(path) -configs = [ - pytest.param( - "meta-llama/Llama-3.2-1B", # model_name - 10, # max_eval_step - 20, # max_train_step - 1, # intermediate_step_save - None, # context_length - True, # run_validation - True, # use_peft - "qaic", # device - id="llama_config", # config name - ) -] +configs = [pytest.param("meta-llama/Llama-3.2-1B", 1, 1, 1, None, True, True, "cpu", id="llama_config")] -@pytest.mark.skip(reason="Currently CI is broken. Once it is fixed we will enable this test.") -@pytest.mark.cli +# TODO:enable this once docker is available @pytest.mark.on_qaic -@pytest.mark.finetune +@pytest.mark.skip(reason="eager docker not available in sdk") @pytest.mark.parametrize( "model_name,max_eval_step,max_train_step,intermediate_step_save,context_length,run_validation,use_peft,device", configs, @@ -57,7 +43,7 @@ def test_finetune( device, mocker, ): - train_config_spy = mocker.spy(QEfficient.cloud.finetune, "TrainConfig") + train_config_spy = mocker.spy(QEfficient.cloud.finetune, "TRAIN_CONFIG") generate_dataset_config_spy = mocker.spy(QEfficient.cloud.finetune, "generate_dataset_config") generate_peft_config_spy = mocker.spy(QEfficient.cloud.finetune, "generate_peft_config") get_dataloader_kwargs_spy = mocker.spy(QEfficient.cloud.finetune, "get_dataloader_kwargs") @@ -79,28 +65,23 @@ def test_finetune( "device": device, } - results = finetune(**kwargs) - assert np.allclose(results["avg_train_loss"], 0.00232327, atol=1e-5), "Train loss is not matching." - assert np.allclose(results["avg_train_metric"], 1.002326, atol=1e-5), "Train metric is not matching." - assert np.allclose(results["avg_eval_loss"], 0.0206124, atol=1e-5), "Eval loss is not matching." - assert np.allclose(results["avg_eval_metric"], 1.020826, atol=1e-5), "Eval metric is not matching." - assert results["avg_epoch_time"] < 60, "Training should complete within 60 seconds." + finetune(**kwargs) train_config_spy.assert_called_once() generate_dataset_config_spy.assert_called_once() generate_peft_config_spy.assert_called_once() + update_config_spy.assert_called_once() get_custom_data_collator_spy.assert_called_once() get_longest_seq_length_spy.assert_called_once() print_model_size_spy.assert_called_once() train_spy.assert_called_once() - assert update_config_spy.call_count == 2 assert get_dataloader_kwargs_spy.call_count == 2 assert get_preprocessed_dataset_spy.call_count == 2 args, kwargs = train_spy.call_args - train_dataloader = args[2] - eval_dataloader = args[3] + train_dataloader = args[1] + eval_dataloader = args[2] optimizer = args[4] batch = next(iter(train_dataloader)) @@ -116,19 +97,12 @@ def test_finetune( else: assert eval_dataloader is None - args, kwargs = update_config_spy.call_args_list[0] + args, kwargs = update_config_spy.call_args train_config = args[0] - assert max_train_step >= train_config.gradient_accumulation_steps, ( - "Total training step should be more than " - f"{train_config.gradient_accumulation_steps} which is gradient accumulation steps." - ) - saved_file = os.path.join(train_config.output_dir, "complete_epoch_1/adapter_model.safetensors") + saved_file = os.path.join(train_config.output_dir, "adapter_model.safetensors") assert os.path.isfile(saved_file) clean_up(train_config.output_dir) clean_up("runs") clean_up(train_config.dump_root_dir) - - -# TODO (Meet): Add seperate tests for BERT FT and LLama FT diff --git a/tests/transformers/spd/test_pld_inference.py b/tests/transformers/spd/test_pld_inference.py index 71b4e01cd..c80fe5969 100644 --- a/tests/transformers/spd/test_pld_inference.py +++ b/tests/transformers/spd/test_pld_inference.py @@ -262,7 +262,7 @@ def test_pld_spec_decode_inference( num_speculative_tokens=num_speculative_tokens, ) # init qaic session - target_model_session = QAICInferenceSession(target_model_qpc_path) + target_model_session = QAICInferenceSession(target_model_qpc_path, device_ids=device_group) draft_model_session = None # skip inputs/outputs buffers @@ -453,7 +453,7 @@ def test_pld_spec_decode_inference( del draft_model_session generated_ids = np.asarray(generated_ids[0]).flatten() gen_len = generated_ids.shape[0] - exec_info = target_model.generate(tokenizer, Constants.INPUT_STR) + exec_info = target_model.generate(tokenizer, Constants.INPUT_STR, device_group) cloud_ai_100_tokens = exec_info.generated_ids[0][ :gen_len ] # Because we always run for single input and single batch size diff --git a/tests/transformers/spd/test_spd_inference.py b/tests/transformers/spd/test_spd_inference.py index e87c51d5f..6f6bdb268 100644 --- a/tests/transformers/spd/test_spd_inference.py +++ b/tests/transformers/spd/test_spd_inference.py @@ -157,8 +157,8 @@ def test_spec_decode_inference( full_batch_size=full_batch_size, ) # init qaic session - target_model_session = QAICInferenceSession(target_model_qpc_path) - draft_model_session = QAICInferenceSession(draft_model_qpc_path) + target_model_session = QAICInferenceSession(target_model_qpc_path, device_ids=device_group) + draft_model_session = QAICInferenceSession(draft_model_qpc_path, device_ids=device_group) # skip inputs/outputs buffers target_model_session.skip_buffers(set([x for x in target_model_session.input_names if x.startswith("past_")])) @@ -341,7 +341,7 @@ def test_spec_decode_inference( del draft_model_session generated_ids = np.asarray(generated_ids[0]).flatten() gen_len = generated_ids.shape[0] - exec_info = draft_model.generate(tokenizer, Constants.INPUT_STR) + exec_info = draft_model.generate(tokenizer, Constants.INPUT_STR, device_group) cloud_ai_100_tokens = exec_info.generated_ids[0][ :gen_len ] # Because we always run for single input and single batch size From f434ea32bf060bae5ae1dda48964e5010ac2bc24 Mon Sep 17 00:00:00 2001 From: Mohit Soni Date: Wed, 14 May 2025 15:20:43 +0530 Subject: [PATCH 7/7] Updating Wrappers for Merging and Chunking in DecoderWrapper (#404) Signed-off-by: Mohit Soni Signed-off-by: Mohit Soni --- .../models/gemma3/modeling_gemma3.py | 51 +++++++++---------- .../transformers/models/modeling_auto.py | 22 ++++---- 2 files changed, 35 insertions(+), 38 deletions(-) diff --git a/QEfficient/transformers/models/gemma3/modeling_gemma3.py b/QEfficient/transformers/models/gemma3/modeling_gemma3.py index 58b837e9c..70601489d 100644 --- a/QEfficient/transformers/models/gemma3/modeling_gemma3.py +++ b/QEfficient/transformers/models/gemma3/modeling_gemma3.py @@ -560,16 +560,9 @@ def __init__(self, model): self.model = model self.model.vision_model = self.model.vision_tower - def forward(self, input_ids, pixel_values): - inputs_embeds = self.model.get_input_embeddings()(input_ids) - B, N, C = inputs_embeds.shape + def forward(self, pixel_values): image_features = self.model.get_image_features(pixel_values=pixel_values) - selected = input_ids == self.model.config.image_token_index - indices1 = selected.to(torch.int64).cumsum(1) - 1 - indices0 = torch.arange(selected.unsqueeze(0).shape[0]).view(-1, 1) - image_features_expanded = image_features.reshape(-1, C).unsqueeze(0)[indices0, indices1] - image_input_embeds = torch.where(selected.unsqueeze(-1), image_features_expanded, inputs_embeds) - return image_input_embeds + return image_features class QEffGemma3DecoderWrapper(nn.Module): @@ -579,14 +572,21 @@ def __init__(self, model): self.language_model = self.model.language_model self.config = self.model.config - def forward(self, input_ids, vision_embeds, position_ids, past_key_values): - image_embeds = vision_embeds[:, : input_ids.shape[1], :] - inputs_embeds = self.model.language_model.get_input_embeddings()(input_ids) - inputs_embeds = torch.where(input_ids.shape[1] == torch.tensor(1), inputs_embeds, image_embeds) + def forward(self, input_ids, vision_embeds, position_ids, index, past_key_values): + inputs_embeds = self.model.get_input_embeddings()(input_ids) + B, N, C = inputs_embeds.shape + selected = input_ids == self.model.config.image_token_index + indices1 = selected.to(torch.int64).cumsum(1) - 1 + indices1 = torch.where(indices1 != -1, indices1 + index, indices1) + indices0 = torch.arange(selected.unsqueeze(0).shape[0]).view(-1, 1) + image_features_expanded = vision_embeds.reshape(-1, C).unsqueeze(0)[indices0, indices1] + image_input_embeds = torch.where(selected.unsqueeze(-1), image_features_expanded, inputs_embeds) + inputs_embeds = torch.where(input_ids.shape[1] == torch.tensor(1), inputs_embeds, image_input_embeds) outputs = self.model.language_model( inputs_embeds=inputs_embeds, position_ids=position_ids, past_key_values=past_key_values, use_cache=True ) - return outputs.logits, vision_embeds, outputs.past_key_values + index = (indices1.max() + 1).unsqueeze(0).unsqueeze(0) + return outputs.logits, vision_embeds, index, outputs.past_key_values class QEffGemma3ForConditionalGeneration(Gemma3ForConditionalGeneration): @@ -605,11 +605,6 @@ def get_specializations( kv_offload: bool = False, **compiler_options, ): - vision_seq_len = compiler_options.pop("vision_seq_len", None) - if vision_seq_len is None: - # TODO: Check properly for Gemma3, Not verified yet. - vision_seq_len = 512 # for Gemma3 Vision feature shape is (1, 4096, 1152) --> 1152 is hidden size) - prefill_seq_len = prefill_seq_len if prefill_seq_len else 32 ctx_len = ctx_len if ctx_len else constants.INTERN_CTX_LEN if img_size is None and hasattr(self.config.vision_config, "image_size"): @@ -617,12 +612,13 @@ def get_specializations( elif img_size is None: img_size = 896 # FIXME based on gemma3 Image size logger.warning("Setting img_size to be 336, as it was neither passed nor found in vision_config") + mm_tokens_per_image = getattr(self.config, "mm_tokens_per_image", 256) vision = [ { "batch_size": batch_size, "img_size": img_size, - "seq_len": vision_seq_len, + "seq_len": prefill_seq_len, "ctx_len": ctx_len, } ] @@ -632,14 +628,14 @@ def get_specializations( "seq_len": prefill_seq_len, "ctx_len": ctx_len, "img_size": img_size, - "chunk_length": prefill_seq_len, + "mm_tokens_per_image": mm_tokens_per_image, }, { "batch_size": batch_size, "seq_len": "1", "ctx_len": ctx_len, "img_size": img_size, - "chunk_length": prefill_seq_len, + "mm_tokens_per_image": mm_tokens_per_image, }, ] @@ -658,9 +654,8 @@ def get_onnx_dynamic_axes(self, kv_offload: bool = False): lang_dynamic_axes = {} lang_dynamic_axes["input_ids"] = {0: "batch_size", 1: "seq_len"} lang_dynamic_axes["position_ids"] = {0: "batch_size", 1: "seq_len"} - lang_dynamic_axes["vision_embeds"] = {0: "batch_size", 1: "chunk_length"} + lang_dynamic_axes["vision_embeds"] = {0: "batch_size", 1: "mm_tokens_per_image"} vision_dynamic_axes["pixel_values"] = {0: "batch_size", 2: "img_size", 3: "img_size"} - vision_dynamic_axes["input_ids"] = {0: "batch_size", 1: "seq_len"} pkv_dynamic_axes = {0: "batch_size", 2: "ctx_len"} for i in range(self.language_model.config.num_hidden_layers): @@ -685,6 +680,7 @@ def get_output_names(self, kv_offload: bool = False): output_names = {} if kv_offload: lang_output_names.insert(1, "vision_embeds_RetainedState") + lang_output_names.insert(2, "index_output") output_names["vision"] = vision_output_names output_names["lang"] = lang_output_names else: @@ -698,12 +694,13 @@ def get_dummy_inputs(self, kv_offload: bool = False): else: img_size = 896 + mm_tokens_per_image = getattr(self.config, "mm_tokens_per_image", 256) # Define shapes inputs_shapes = {} inputs_shapes["input_ids"] = (constants.ONNX_EXPORT_EXAMPLE_BATCH_SIZE, constants.ONNX_EXPORT_EXAMPLE_SEQ_LEN) inputs_shapes["vision_embeds"] = ( 1, # constants.INTERN_NUM_PATCHES, - constants.ONNX_EXPORT_EXAMPLE_SEQ_LEN, # constants.INTERN_FEATURE_SIZE, + mm_tokens_per_image, # constants.INTERN_FEATURE_SIZE, self.language_model.config.hidden_size, # 5120 ) inputs_shapes["position_ids"] = ( @@ -716,12 +713,12 @@ def get_dummy_inputs(self, kv_offload: bool = False): img_size, img_size, ) + inputs_shapes["index"] = (1, 1) # Define inputs vision_inputs = {} lang_inputs = {} vision_inputs["pixel_values"] = torch.zeros((inputs_shapes["pixel_values"]), dtype=torch.float32) - vision_inputs["input_ids"] = torch.zeros((inputs_shapes["input_ids"]), dtype=torch.int64) lang_inputs["input_ids"] = torch.zeros((inputs_shapes["input_ids"]), dtype=torch.int64) lang_inputs["vision_embeds"] = torch.zeros((inputs_shapes["vision_embeds"]), dtype=torch.float32) lang_inputs["position_ids"] = ( @@ -729,7 +726,7 @@ def get_dummy_inputs(self, kv_offload: bool = False): .view(1, constants.ONNX_EXPORT_EXAMPLE_SEQ_LEN) .repeat(constants.ONNX_EXPORT_EXAMPLE_BATCH_SIZE, 1) ) - + lang_inputs["index"] = torch.zeros((inputs_shapes["index"]), dtype=torch.int64) # Add data for KV kv_cache_shape = get_padding_shape_from_config( config=self.language_model.config, diff --git a/QEfficient/transformers/models/modeling_auto.py b/QEfficient/transformers/models/modeling_auto.py index 1a9610187..ac40352d0 100644 --- a/QEfficient/transformers/models/modeling_auto.py +++ b/QEfficient/transformers/models/modeling_auto.py @@ -751,8 +751,8 @@ def kv_offload_generate( input_len = inputs["attention_mask"].sum(1, keepdims=True) input_ids_length = inputs["input_ids"].shape[1] num_chunks = -(input_ids_length // -prefill_seq_len) # ceil divide without float - # padded_len = num_chunks * prefill_seq_len # Convert to a multiple of prompt_len - padded_len = vision_session.bindings[vision_session.binding_index_map["input_ids"]].dims[1] + padded_len = num_chunks * prefill_seq_len # Convert to a multiple of prompt_len + if generation_len is None: generation_len = ctx_len - input_len.max() assert generation_len > 0, "generation length should be greater than zero" @@ -783,13 +783,11 @@ def kv_offload_generate( } vision_inputs["pixel_values"] = vision_inputs["pixel_values"].astype("float16") - vision_inputs["input_ids"] = inputs["input_ids"] vision_start = perf_counter() vision_outputs = vision_session.run(vision_inputs) vision_end = perf_counter() lang_inputs = {k: v for k, v in inputs.items() if k not in vision_inputs} - lang_inputs["input_ids"] = inputs["input_ids"] lang_inputs["position_ids"] = np.where( lang_inputs.pop("attention_mask"), np.arange(padded_len), -1 ) # Need to use -1 as position_ids for invalid tokens @@ -797,25 +795,27 @@ def kv_offload_generate( vision_session.deactivate() lang_session.activate() lang_inputs["vision_embeds"] = vision_outputs["vision_embeds"] - # lang_session.set_buffers(vision_outputs) + lang_session.set_buffers(vision_outputs) prefill_start = perf_counter() # Run prefill + chunk_inputs = lang_inputs.copy() + chunk_inputs["index"] = np.array([[0]]) for i in range(num_chunks): - chunk_inputs = lang_inputs.copy() chunk_inputs["input_ids"] = lang_inputs["input_ids"][:, i * prefill_seq_len : (i + 1) * prefill_seq_len] chunk_inputs["position_ids"] = lang_inputs["position_ids"][ :, i * prefill_seq_len : (i + 1) * prefill_seq_len ] - chunk_inputs["vision_embeds"] = lang_inputs["vision_embeds"][ - :, i * prefill_seq_len : (i + 1) * prefill_seq_len - ] outputs = lang_session.run(chunk_inputs) + chunk_inputs["index"] = outputs["index_output"] prefill_time = perf_counter() - prefill_start + vision_end - vision_start - lang_inputs["vision_embeds"] = lang_inputs["vision_embeds"][:, :prefill_seq_len] # Skip inputs/outputs again lang_session.skip_buffers( - [x for x in lang_session.input_names + lang_session.output_names if x.startswith("past_")] + [ + x + for x in lang_session.input_names + lang_session.output_names + if x.startswith("past_") or x.endswith("_RetainedState") + ] ) # Get first token