From acdaf5f0f16b07d0d60bb4bb27e924370d06ce45 Mon Sep 17 00:00:00 2001 From: Amit Raj Date: Fri, 21 Mar 2025 08:05:19 +0000 Subject: [PATCH 01/16] Llama unabled for this model Signed-off-by: Amit Raj --- .../models/llama/modeling_llama.py | 334 +++++++++++++++++- 1 file changed, 327 insertions(+), 7 deletions(-) diff --git a/QEfficient/transformers/models/llama/modeling_llama.py b/QEfficient/transformers/models/llama/modeling_llama.py index dae783361..b7c773307 100644 --- a/QEfficient/transformers/models/llama/modeling_llama.py +++ b/QEfficient/transformers/models/llama/modeling_llama.py @@ -5,7 +5,12 @@ # # ----------------------------------------------------------------------------- +<<<<<<< HEAD from typing import Callable, List, Optional, Tuple, Union +======= +import math +from typing import Callable, Dict, List, Optional, Tuple, Union +>>>>>>> bb7e949 (Llama unabled for this model) import torch from torch import nn @@ -28,7 +33,6 @@ from QEfficient.transformers.cache_utils import QEffDynamicCache from QEfficient.transformers.modeling_attn_mask_utils import _create_causal_mask - class QEffLlamaRotaryEmbedding(LlamaRotaryEmbedding): """ Copied from LlamaForCausalLM: https://github.com/huggingface/transformers/blob/main/src/transformers/models/llama/modeling_llama.py @@ -42,7 +46,7 @@ def __init__(self, config: LlamaConfig, device=None): self._set_cos_sin_cache( seq_len=self.original_max_seq_len, device=self.inv_freq.device, dtype=torch.get_default_dtype() ) - + def _set_cos_sin_cache(self, seq_len, device, dtype): self.max_seq_len_cached = seq_len t = torch.arange(self.max_seq_len_cached, device=device, dtype=torch.int64).type_as(self.inv_freq) @@ -94,7 +98,19 @@ def qeff_apply_rotary_pos_emb(q, k, cos, sin, position_ids, unsqueeze_dim=1): # Cast back to original dtype return q_embed.to(q.dtype), k_embed.to(k.dtype) +def eager_attention_forward( + module: nn.Module, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + attention_mask: Optional[torch.Tensor], + scaling: float, + **kwargs, +): + key_states = repeat_kv(key, module.num_key_value_groups) + value_states = repeat_kv(value, module.num_key_value_groups) +<<<<<<< HEAD def eager_attention_forward( module: nn.Module, query: torch.Tensor, @@ -107,6 +123,8 @@ def eager_attention_forward( key_states = repeat_kv(key, module.num_key_value_groups) value_states = repeat_kv(value, module.num_key_value_groups) +======= +>>>>>>> bb7e949 (Llama unabled for this model) attn_weights = torch.matmul(query, key_states.transpose(2, 3)) * scaling if attention_mask is not None: attn_weights = torch.where(attention_mask, torch.tensor(-10000.0, dtype=torch.float32), attn_weights) @@ -116,10 +134,21 @@ def eager_attention_forward( attn_output = attn_output.transpose(1, 2).contiguous() return attn_output, attn_weights +<<<<<<< HEAD class QEffLlamaAttention(LlamaAttention): """Multi-headed attention from 'Attention Is All You Need' paper""" +======= +class QEffLlamaAttention(LlamaAttention): + """Multi-headed attention from 'Attention Is All You Need' paper""" + def __init__(self, config: LlamaConfig, layer_idx: Optional[int] = None): + super().__init__(config, layer_idx) + # Define the general __qeff_init__() for any changes in the init calls + # Set the init in the module mapping pytorch transforms + self.config = config + self.__qeff_init__() +>>>>>>> bb7e949 (Llama unabled for this model) def __qeff_init__(self): self.rotary_emb = QEffLlamaRotaryEmbedding(config=self.config) @@ -127,18 +156,21 @@ def __qeff_init__(self): def forward( self, hidden_states: torch.Tensor, +<<<<<<< HEAD +======= + position_embeddings: Tuple[torch.Tensor, torch.Tensor], +>>>>>>> bb7e949 (Llama unabled for this model) attention_mask: Optional[torch.Tensor], position_ids: Optional[torch.LongTensor] = None, past_key_value: Optional[Cache] = None, batch_index: Optional[torch.LongTensor] = None, - output_attentions: bool = False, - use_cache: bool = False, cache_position: Optional[torch.LongTensor] = None, **kwargs, ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: input_shape = hidden_states.shape[:-1] hidden_shape = (*input_shape, -1, self.head_dim) +<<<<<<< HEAD query_states = self.q_proj(hidden_states, **kwargs) key_states = self.k_proj(hidden_states, **kwargs) value_states = self.v_proj(hidden_states, **kwargs) @@ -151,11 +183,27 @@ def forward( kv_seq_len = past_key_value.get_usable_length(kv_seq_len, self.layer_idx) cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len) +======= + query_states = self.q_proj(hidden_states).view(hidden_shape).transpose(1, 2) + key_states = self.k_proj(hidden_states).view(hidden_shape).transpose(1, 2) + value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2) + + kv_seq_len = key_states.shape[-2] + if past_key_value is not None: + if self.layer_idx is None: + raise ValueError( + f"The cache structure has changed since version v4.36. If you are using {self.__class__.__name__} " + "for auto-regressive decoding with k/v caching, please make sure to initialize the attention class " + "with a layer index." + ) + kv_seq_len = past_key_value.get_usable_length(kv_seq_len, self.layer_idx) + cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len) +>>>>>>> bb7e949 (Llama unabled for this model) query_states, key_states = qeff_apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids) if past_key_value is not None: # sin and cos are specific to RoPE models; cache_position needed for the static cache - cache_kwargs = {"sin": sin, "cos": cos, "batch_index": batch_index, "position_ids": position_ids} + cache_kwargs = {"sin": sin, "cos": cos, "batch_index": batch_index, "position_ids": position_ids} key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) attention_interface: Callable = eager_attention_forward @@ -169,6 +217,7 @@ def forward( scaling=self.scaling, **kwargs, ) +<<<<<<< HEAD attn_output = attn_output.reshape(*input_shape, -1).contiguous() attn_output = self.o_proj(attn_output) @@ -324,13 +373,19 @@ def forward( return output if return_dict else output.to_tuple() +======= + + attn_output = attn_output.reshape(*input_shape, -1).contiguous() + attn_output = self.o_proj(attn_output) + return attn_output, attn_weights, past_key_value + +>>>>>>> bb7e949 (Llama unabled for this model) class QEffLlamaForCausalLM(LlamaForCausalLM): """ Copied from LlamaForCausalLM: https://github.com/huggingface/transformers/blob/main/src/transformers/models/llama/modeling_llama.py The only differences are: - add new args cache idx for the kv retention """ - def forward( self, input_ids: torch.LongTensor = None, @@ -348,6 +403,10 @@ def forward( logits_to_keep: Union[int, torch.Tensor] = 0, **kwargs, ) -> Union[Tuple, CausalLMOutputWithPast]: +<<<<<<< HEAD +======= + +>>>>>>> bb7e949 (Llama unabled for this model) output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions output_hidden_states = ( output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states @@ -369,13 +428,31 @@ def forward( cache_position=cache_position, **kwargs, ) - + # Cast to INT32 to avoid issue while running in ONNXRT logit_index = position_ids.to(torch.int32).argmax(1, keepdim=True) hidden_states = outputs[0][torch.arange(position_ids.shape[0]).view(-1, 1), logit_index] +<<<<<<< HEAD + logits = self.lm_head(hidden_states) + logits = logits.float() +======= + # # hidden_states = outputs[0] + # # Only compute necessary logits, and do not upcast them to float if we are not computing the loss + # slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep + # logits = self.lm_head(hidden_states[:, slice_indices, :]) + logits = self.lm_head(hidden_states) logits = logits.float() + + loss = None + if labels is not None: + loss = self.loss_function(logits=logits, labels=labels, vocab_size=self.config.vocab_size, **kwargs) + + if not return_dict: + output = (logits,) + outputs[1:] + return (loss,) + output if loss is not None else output +>>>>>>> bb7e949 (Llama unabled for this model) return CausalLMOutputWithPast( loss=None, @@ -384,3 +461,246 @@ def forward( hidden_states=outputs.hidden_states, attentions=outputs.attentions, ) +<<<<<<< HEAD +======= + +class QEffLlamaDecoderLayer(LlamaDecoderLayer): + """ + Copied from LlamaForCausalLM: https://github.com/huggingface/transformers/blob/main/src/transformers/models/llama/modeling_llama.py + The only differences are: + - add new args batch idx for the CB models + """ + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_value: Optional[Cache] = None, + batch_index: Optional[torch.LongTensor] = None, + output_attentions: Optional[bool] = False, + use_cache: Optional[bool] = False, + cache_position: Optional[torch.LongTensor] = None, + position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # necessary, but kept here for BC + **kwargs, + ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]: + residual = hidden_states + + hidden_states = self.input_layernorm(hidden_states) + + # Self Attention + hidden_states, self_attn_weights, present_key_value = self.self_attn( + hidden_states=hidden_states, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_value=past_key_value, + batch_index=batch_index, + output_attentions=output_attentions, + use_cache=use_cache, + cache_position=cache_position, + position_embeddings=position_embeddings, + **kwargs, + ) + hidden_states = residual + hidden_states + + # Fully Connected + residual = hidden_states + hidden_states = self.post_attention_layernorm(hidden_states) + hidden_states = self.mlp(hidden_states) + hidden_states = residual + hidden_states + + outputs = (hidden_states,) + + if output_attentions: + outputs += (self_attn_weights,) + if use_cache: + outputs += (present_key_value,) + + return outputs + + +class QEffLlamaModel(LlamaModel): + """ + Copied from LlamaForCausalLM: https://github.com/huggingface/transformers/blob/main/src/transformers/models/llama/modeling_llama.py + The only differences are: + - add new args cache idx for the kv retention + """ + + def forward( + self, + input_ids: torch.LongTensor = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[Cache] = None, + batch_index: Optional[torch.LongTensor] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + cache_position: Optional[torch.LongTensor] = None, + **kwargs, + ) -> Union[Tuple, BaseModelOutputWithPast]: + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + use_cache = use_cache if use_cache is not None else self.config.use_cache + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + if (input_ids is None) ^ (inputs_embeds is not None): + raise ValueError("You must specify exactly one of input_ids or inputs_embeds") + + if self.gradient_checkpointing and self.training and use_cache: + logger.warning_once( + "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`." + ) + use_cache = False + + if inputs_embeds is None: + inputs_embeds = self.embed_tokens(input_ids) + + # if use_cache and past_key_values is None: + # past_key_values = DynamicCache() + # kept for BC (non `Cache` `past_key_values` inputs) + return_legacy_cache = False + if use_cache and not isinstance(past_key_values, Cache): + return_legacy_cache = True + if past_key_values is None: + past_key_values = DynamicCache() + else: + past_key_values = DynamicCache.from_legacy_cache(past_key_values) + logger.warning_once( + "We detected that you are passing `past_key_values` as a tuple of tuples. This is deprecated and " + "will be removed in v4.47. Please convert your cache or use an appropriate `Cache` class " + "(https://huggingface.co/docs/transformers/kv_cache#legacy-cache-format)" + ) + + if cache_position is None: + past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 + cache_position = torch.arange( + past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device + ) + if position_ids is None: + position_ids = cache_position.unsqueeze(0) + + causal_mask = self._update_causal_mask( + attention_mask, inputs_embeds, cache_position, position_ids, past_key_values, output_attentions + ) + + # embed positions + hidden_states = inputs_embeds + + # create position embeddings to be shared across the decoder layers + position_embeddings = self.rotary_emb(hidden_states, position_ids) + + # decoder layers + all_hidden_states = () if output_hidden_states else None + all_self_attns = () if output_attentions else None + + for decoder_layer in self.layers[: self.config.num_hidden_layers]: + if output_hidden_states: + all_hidden_states += (hidden_states,) + + if batch_index is not None: + layer_outputs = decoder_layer( + hidden_states, + attention_mask=causal_mask, + position_ids=position_ids, + past_key_value=past_key_values, + batch_index=batch_index, + output_attentions=output_attentions, + use_cache=use_cache, + cache_position=cache_position, + **kwargs, + ) + else: + layer_outputs = decoder_layer( + hidden_states, + attention_mask=causal_mask, + position_ids=position_ids, + past_key_value=past_key_values, + output_attentions=output_attentions, + use_cache=use_cache, + cache_position=cache_position, + position_embeddings=position_embeddings, + **kwargs, + ) + + hidden_states = layer_outputs[0] + + if output_attentions: + all_self_attns += (layer_outputs[1],) + + hidden_states = self.norm(hidden_states) + + # add hidden states from the last decoder layer + if output_hidden_states: + all_hidden_states += (hidden_states,) + + if return_legacy_cache: + past_key_values=past_key_values.to_legacy_cache() + + output = BaseModelOutputWithPast( + last_hidden_state=hidden_states, + past_key_values=past_key_values if use_cache else None, + hidden_states=all_hidden_states, + attentions=all_self_attns, + ) + return output if return_dict else output.to_tuple() + + def _update_causal_mask( + self, + attention_mask: torch.Tensor, + input_tensor: torch.Tensor, + cache_position: torch.Tensor, + position_ids: torch.Tensor, + past_key_values: Cache, + output_attentions: bool, + ): + past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 + using_static_cache = isinstance(past_key_values, StaticCache) + + dtype, device = input_tensor.dtype, input_tensor.device + min_dtype = torch.finfo(dtype).min + sequence_length = input_tensor.shape[1] + if using_static_cache: + target_length = past_key_values.get_max_length() + else: + target_length = attention_mask.shape[-1] if isinstance(attention_mask, torch.Tensor) else past_seen_tokens + + if attention_mask is not None and attention_mask.dim() == 4: + # in this case we assume that the mask comes already in inverted form and requires no inversion or slicing + if attention_mask.max() != 0: + raise ValueError("Custom 4D attention mask should be passed in inverted form with max==0`") + causal_mask = attention_mask + else: + causal_mask = torch.full((sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=device) + if sequence_length != 1: + causal_mask = torch.triu(causal_mask, diagonal=1) + causal_mask *= torch.arange(target_length, device=device) > cache_position.reshape(-1, 1) + causal_mask = causal_mask[None, None, :, :].expand(input_tensor.shape[0], 1, -1, -1) + if attention_mask is not None: + causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit + mask_length = attention_mask.shape[-1] + padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :] + padding_mask = padding_mask == 0 + causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill( + padding_mask, min_dtype + ) + else: + causal_mask = _create_causal_mask(position_ids=position_ids, target_length=target_length) + + if ( + self.config._attn_implementation == "sdpa" + and attention_mask is not None + and attention_mask.device.type == "cuda" + and not output_attentions + ): + # Attend to all tokens in fully masked rows in the causal_mask, for example the relevant first rows when + # using left padding. This is required by F.scaled_dot_product_attention memory-efficient attention path. + # Details: https://github.com/pytorch/pytorch/issues/110213 + causal_mask = AttentionMaskConverter._unmask_unattended(causal_mask, min_dtype) + + return causal_mask +>>>>>>> bb7e949 (Llama unabled for this model) From 2b275bae2d0a684bbbee0a7c930141f03246728b Mon Sep 17 00:00:00 2001 From: Amit Raj Date: Mon, 24 Mar 2025 05:36:46 +0000 Subject: [PATCH 02/16] tf version 4.50 enanbled for llama, mpt and gemma Signed-off-by: Amit Raj --- .../models/gemma/modeling_gemma.py | 28 +++++++++++++++++-- 1 file changed, 26 insertions(+), 2 deletions(-) diff --git a/QEfficient/transformers/models/gemma/modeling_gemma.py b/QEfficient/transformers/models/gemma/modeling_gemma.py index 44142f4e0..369c3d5c8 100644 --- a/QEfficient/transformers/models/gemma/modeling_gemma.py +++ b/QEfficient/transformers/models/gemma/modeling_gemma.py @@ -28,6 +28,27 @@ from QEfficient.transformers.cache_utils import QEffDynamicCache from QEfficient.transformers.modeling_attn_mask_utils import _create_causal_mask +def eager_attention_forward( + module: nn.Module, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + attention_mask: Optional[torch.Tensor], + scaling: float, + **kwargs, +): + key_states = repeat_kv(key, module.num_key_value_groups) + value_states = repeat_kv(value, module.num_key_value_groups) + + attn_weights = torch.matmul(query, key_states.transpose(2, 3)) * scaling + if attention_mask is not None: + attn_weights = torch.where(attention_mask, torch.tensor(-10000.0, dtype=torch.float32), attn_weights) + + attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype) + attn_output = torch.matmul(attn_weights, value_states) + attn_output = attn_output.transpose(1, 2).contiguous() + + return attn_output, attn_weights class QEffGemmaRotaryEmbedding(GemmaRotaryEmbedding): """ @@ -86,8 +107,8 @@ def qeff_apply_rotary_pos_emb(q, k, cos, sin, position_ids, unsqueeze_dim=1): Returns: `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding. """ - cos = cos[position_ids].unsqueeze(unsqueeze_dim) - sin = sin[position_ids].unsqueeze(unsqueeze_dim) + cos = cos.unsqueeze(unsqueeze_dim) + sin = sin.unsqueeze(unsqueeze_dim) # Apply rotation q_embed = (q * cos) + (rotate_half(q) * sin) @@ -298,6 +319,9 @@ def forward( # embed positions hidden_states = inputs_embeds + + # create position embeddings to be shared across the decoder layers + position_embeddings = self.rotary_emb(hidden_states, position_ids) # normalized # Gemma downcasts the below to float16, causing sqrt(3072)=55.4256 to become 55.5 From d3cce265ba0441d7d08da3787c366e78d2b78ee2 Mon Sep 17 00:00:00 2001 From: Amit Raj Date: Mon, 24 Mar 2025 10:02:07 +0000 Subject: [PATCH 03/16] 'Enabled gemma2, codegen, falcon Signed-off-by: Amit Raj --- .../models/falcon/modeling_falcon.py | 4 +- .../models/gemma/modeling_gemma.py | 4 +- .../models/gemma2/modeling_gemma2.py | 7 +- .../models/llama/modeling_llama.py | 327 +----------------- .../transformers/models/mpt/modeling_mpt.py | 2 + 5 files changed, 20 insertions(+), 324 deletions(-) diff --git a/QEfficient/transformers/models/falcon/modeling_falcon.py b/QEfficient/transformers/models/falcon/modeling_falcon.py index 7f6862ffb..e4ef30b86 100644 --- a/QEfficient/transformers/models/falcon/modeling_falcon.py +++ b/QEfficient/transformers/models/falcon/modeling_falcon.py @@ -89,8 +89,8 @@ def qeff_apply_rotary_pos_emb(q, k, cos, sin, position_ids, unsqueeze_dim=1): Returns: `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding. """ - cos = cos[position_ids].unsqueeze(unsqueeze_dim) - sin = sin[position_ids].unsqueeze(unsqueeze_dim) + cos = cos.unsqueeze(unsqueeze_dim) + sin = sin.unsqueeze(unsqueeze_dim) # Apply rotation q_embed = (q * cos) + (rotate_half(q) * sin) diff --git a/QEfficient/transformers/models/gemma/modeling_gemma.py b/QEfficient/transformers/models/gemma/modeling_gemma.py index 369c3d5c8..3d5680655 100644 --- a/QEfficient/transformers/models/gemma/modeling_gemma.py +++ b/QEfficient/transformers/models/gemma/modeling_gemma.py @@ -28,6 +28,7 @@ from QEfficient.transformers.cache_utils import QEffDynamicCache from QEfficient.transformers.modeling_attn_mask_utils import _create_causal_mask + def eager_attention_forward( module: nn.Module, query: torch.Tensor, @@ -50,6 +51,7 @@ def eager_attention_forward( return attn_output, attn_weights + class QEffGemmaRotaryEmbedding(GemmaRotaryEmbedding): """ Copied from GemmaForCausalLM: https://github.com/huggingface/transformers/blob/main/src/transformers/models/gemma/modeling_gemma.py @@ -319,7 +321,7 @@ def forward( # embed positions hidden_states = inputs_embeds - + # create position embeddings to be shared across the decoder layers position_embeddings = self.rotary_emb(hidden_states, position_ids) diff --git a/QEfficient/transformers/models/gemma2/modeling_gemma2.py b/QEfficient/transformers/models/gemma2/modeling_gemma2.py index e73d759d8..40b45773e 100644 --- a/QEfficient/transformers/models/gemma2/modeling_gemma2.py +++ b/QEfficient/transformers/models/gemma2/modeling_gemma2.py @@ -89,8 +89,8 @@ def qeff_apply_rotary_pos_emb(q, k, cos, sin, position_ids, unsqueeze_dim=1): Returns: `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding. """ - cos = cos[position_ids].unsqueeze(unsqueeze_dim) - sin = sin[position_ids].unsqueeze(unsqueeze_dim) + cos = cos.unsqueeze(unsqueeze_dim) + sin = sin.unsqueeze(unsqueeze_dim) # Apply rotation q_embed = (q * cos) + (rotate_half(q) * sin) @@ -318,6 +318,9 @@ def forward( # embed positions hidden_states = inputs_embeds + # create position embeddings to be shared across the decoder layers + position_embeddings = self.rotary_emb(hidden_states, position_ids) + # normalized # Gemma2 downcasts the below to float16, causing sqrt(3072)=55.4256 to become 55.5 # See https://github.com/huggingface/transformers/pull/29402 diff --git a/QEfficient/transformers/models/llama/modeling_llama.py b/QEfficient/transformers/models/llama/modeling_llama.py index b7c773307..5ce7a3173 100644 --- a/QEfficient/transformers/models/llama/modeling_llama.py +++ b/QEfficient/transformers/models/llama/modeling_llama.py @@ -5,12 +5,7 @@ # # ----------------------------------------------------------------------------- -<<<<<<< HEAD from typing import Callable, List, Optional, Tuple, Union -======= -import math -from typing import Callable, Dict, List, Optional, Tuple, Union ->>>>>>> bb7e949 (Llama unabled for this model) import torch from torch import nn @@ -33,7 +28,8 @@ from QEfficient.transformers.cache_utils import QEffDynamicCache from QEfficient.transformers.modeling_attn_mask_utils import _create_causal_mask -class QEffLlamaRotaryEmbedding(LlamaRotaryEmbedding): + +class QEffLlamaRotaryEmbedding(LlamaRotaryEmbedding, nn.Module): """ Copied from LlamaForCausalLM: https://github.com/huggingface/transformers/blob/main/src/transformers/models/llama/modeling_llama.py The only differences are: @@ -46,7 +42,7 @@ def __init__(self, config: LlamaConfig, device=None): self._set_cos_sin_cache( seq_len=self.original_max_seq_len, device=self.inv_freq.device, dtype=torch.get_default_dtype() ) - + def _set_cos_sin_cache(self, seq_len, device, dtype): self.max_seq_len_cached = seq_len t = torch.arange(self.max_seq_len_cached, device=device, dtype=torch.int64).type_as(self.inv_freq) @@ -89,8 +85,8 @@ def qeff_apply_rotary_pos_emb(q, k, cos, sin, position_ids, unsqueeze_dim=1): Returns: `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding. """ - cos = cos[position_ids].unsqueeze(unsqueeze_dim) - sin = sin[position_ids].unsqueeze(unsqueeze_dim) + cos = cos.unsqueeze(unsqueeze_dim) + sin = sin.unsqueeze(unsqueeze_dim) # Apply rotation q_embed = (q * cos) + (rotate_half(q) * sin) @@ -110,7 +106,6 @@ def eager_attention_forward( key_states = repeat_kv(key, module.num_key_value_groups) value_states = repeat_kv(value, module.num_key_value_groups) -<<<<<<< HEAD def eager_attention_forward( module: nn.Module, query: torch.Tensor, @@ -123,8 +118,6 @@ def eager_attention_forward( key_states = repeat_kv(key, module.num_key_value_groups) value_states = repeat_kv(value, module.num_key_value_groups) -======= ->>>>>>> bb7e949 (Llama unabled for this model) attn_weights = torch.matmul(query, key_states.transpose(2, 3)) * scaling if attention_mask is not None: attn_weights = torch.where(attention_mask, torch.tensor(-10000.0, dtype=torch.float32), attn_weights) @@ -134,21 +127,10 @@ def eager_attention_forward( attn_output = attn_output.transpose(1, 2).contiguous() return attn_output, attn_weights -<<<<<<< HEAD class QEffLlamaAttention(LlamaAttention): """Multi-headed attention from 'Attention Is All You Need' paper""" -======= -class QEffLlamaAttention(LlamaAttention): - """Multi-headed attention from 'Attention Is All You Need' paper""" - def __init__(self, config: LlamaConfig, layer_idx: Optional[int] = None): - super().__init__(config, layer_idx) - # Define the general __qeff_init__() for any changes in the init calls - # Set the init in the module mapping pytorch transforms - self.config = config - self.__qeff_init__() ->>>>>>> bb7e949 (Llama unabled for this model) def __qeff_init__(self): self.rotary_emb = QEffLlamaRotaryEmbedding(config=self.config) @@ -156,10 +138,6 @@ def __qeff_init__(self): def forward( self, hidden_states: torch.Tensor, -<<<<<<< HEAD -======= - position_embeddings: Tuple[torch.Tensor, torch.Tensor], ->>>>>>> bb7e949 (Llama unabled for this model) attention_mask: Optional[torch.Tensor], position_ids: Optional[torch.LongTensor] = None, past_key_value: Optional[Cache] = None, @@ -170,7 +148,6 @@ def forward( input_shape = hidden_states.shape[:-1] hidden_shape = (*input_shape, -1, self.head_dim) -<<<<<<< HEAD query_states = self.q_proj(hidden_states, **kwargs) key_states = self.k_proj(hidden_states, **kwargs) value_states = self.v_proj(hidden_states, **kwargs) @@ -183,27 +160,11 @@ def forward( kv_seq_len = past_key_value.get_usable_length(kv_seq_len, self.layer_idx) cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len) -======= - query_states = self.q_proj(hidden_states).view(hidden_shape).transpose(1, 2) - key_states = self.k_proj(hidden_states).view(hidden_shape).transpose(1, 2) - value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2) - - kv_seq_len = key_states.shape[-2] - if past_key_value is not None: - if self.layer_idx is None: - raise ValueError( - f"The cache structure has changed since version v4.36. If you are using {self.__class__.__name__} " - "for auto-regressive decoding with k/v caching, please make sure to initialize the attention class " - "with a layer index." - ) - kv_seq_len = past_key_value.get_usable_length(kv_seq_len, self.layer_idx) - cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len) ->>>>>>> bb7e949 (Llama unabled for this model) query_states, key_states = qeff_apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids) if past_key_value is not None: # sin and cos are specific to RoPE models; cache_position needed for the static cache - cache_kwargs = {"sin": sin, "cos": cos, "batch_index": batch_index, "position_ids": position_ids} + cache_kwargs = {"sin": sin, "cos": cos, "batch_index": batch_index, "position_ids": position_ids} key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) attention_interface: Callable = eager_attention_forward @@ -217,7 +178,6 @@ def forward( scaling=self.scaling, **kwargs, ) -<<<<<<< HEAD attn_output = attn_output.reshape(*input_shape, -1).contiguous() attn_output = self.o_proj(attn_output) @@ -373,19 +333,13 @@ def forward( return output if return_dict else output.to_tuple() -======= - - attn_output = attn_output.reshape(*input_shape, -1).contiguous() - attn_output = self.o_proj(attn_output) - return attn_output, attn_weights, past_key_value - ->>>>>>> bb7e949 (Llama unabled for this model) class QEffLlamaForCausalLM(LlamaForCausalLM): """ Copied from LlamaForCausalLM: https://github.com/huggingface/transformers/blob/main/src/transformers/models/llama/modeling_llama.py The only differences are: - add new args cache idx for the kv retention """ + def forward( self, input_ids: torch.LongTensor = None, @@ -403,10 +357,6 @@ def forward( logits_to_keep: Union[int, torch.Tensor] = 0, **kwargs, ) -> Union[Tuple, CausalLMOutputWithPast]: -<<<<<<< HEAD -======= - ->>>>>>> bb7e949 (Llama unabled for this model) output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions output_hidden_states = ( output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states @@ -428,31 +378,13 @@ def forward( cache_position=cache_position, **kwargs, ) - + # Cast to INT32 to avoid issue while running in ONNXRT logit_index = position_ids.to(torch.int32).argmax(1, keepdim=True) hidden_states = outputs[0][torch.arange(position_ids.shape[0]).view(-1, 1), logit_index] -<<<<<<< HEAD - logits = self.lm_head(hidden_states) - logits = logits.float() -======= - # # hidden_states = outputs[0] - # # Only compute necessary logits, and do not upcast them to float if we are not computing the loss - # slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep - # logits = self.lm_head(hidden_states[:, slice_indices, :]) - logits = self.lm_head(hidden_states) logits = logits.float() - - loss = None - if labels is not None: - loss = self.loss_function(logits=logits, labels=labels, vocab_size=self.config.vocab_size, **kwargs) - - if not return_dict: - output = (logits,) + outputs[1:] - return (loss,) + output if loss is not None else output ->>>>>>> bb7e949 (Llama unabled for this model) return CausalLMOutputWithPast( loss=None, @@ -461,246 +393,3 @@ def forward( hidden_states=outputs.hidden_states, attentions=outputs.attentions, ) -<<<<<<< HEAD -======= - -class QEffLlamaDecoderLayer(LlamaDecoderLayer): - """ - Copied from LlamaForCausalLM: https://github.com/huggingface/transformers/blob/main/src/transformers/models/llama/modeling_llama.py - The only differences are: - - add new args batch idx for the CB models - """ - - def forward( - self, - hidden_states: torch.Tensor, - attention_mask: Optional[torch.Tensor] = None, - position_ids: Optional[torch.LongTensor] = None, - past_key_value: Optional[Cache] = None, - batch_index: Optional[torch.LongTensor] = None, - output_attentions: Optional[bool] = False, - use_cache: Optional[bool] = False, - cache_position: Optional[torch.LongTensor] = None, - position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # necessary, but kept here for BC - **kwargs, - ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]: - residual = hidden_states - - hidden_states = self.input_layernorm(hidden_states) - - # Self Attention - hidden_states, self_attn_weights, present_key_value = self.self_attn( - hidden_states=hidden_states, - attention_mask=attention_mask, - position_ids=position_ids, - past_key_value=past_key_value, - batch_index=batch_index, - output_attentions=output_attentions, - use_cache=use_cache, - cache_position=cache_position, - position_embeddings=position_embeddings, - **kwargs, - ) - hidden_states = residual + hidden_states - - # Fully Connected - residual = hidden_states - hidden_states = self.post_attention_layernorm(hidden_states) - hidden_states = self.mlp(hidden_states) - hidden_states = residual + hidden_states - - outputs = (hidden_states,) - - if output_attentions: - outputs += (self_attn_weights,) - if use_cache: - outputs += (present_key_value,) - - return outputs - - -class QEffLlamaModel(LlamaModel): - """ - Copied from LlamaForCausalLM: https://github.com/huggingface/transformers/blob/main/src/transformers/models/llama/modeling_llama.py - The only differences are: - - add new args cache idx for the kv retention - """ - - def forward( - self, - input_ids: torch.LongTensor = None, - attention_mask: Optional[torch.Tensor] = None, - position_ids: Optional[torch.LongTensor] = None, - past_key_values: Optional[Cache] = None, - batch_index: Optional[torch.LongTensor] = None, - inputs_embeds: Optional[torch.FloatTensor] = None, - use_cache: Optional[bool] = None, - output_attentions: Optional[bool] = None, - output_hidden_states: Optional[bool] = None, - return_dict: Optional[bool] = None, - cache_position: Optional[torch.LongTensor] = None, - **kwargs, - ) -> Union[Tuple, BaseModelOutputWithPast]: - output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions - output_hidden_states = ( - output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states - ) - use_cache = use_cache if use_cache is not None else self.config.use_cache - return_dict = return_dict if return_dict is not None else self.config.use_return_dict - - if (input_ids is None) ^ (inputs_embeds is not None): - raise ValueError("You must specify exactly one of input_ids or inputs_embeds") - - if self.gradient_checkpointing and self.training and use_cache: - logger.warning_once( - "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`." - ) - use_cache = False - - if inputs_embeds is None: - inputs_embeds = self.embed_tokens(input_ids) - - # if use_cache and past_key_values is None: - # past_key_values = DynamicCache() - # kept for BC (non `Cache` `past_key_values` inputs) - return_legacy_cache = False - if use_cache and not isinstance(past_key_values, Cache): - return_legacy_cache = True - if past_key_values is None: - past_key_values = DynamicCache() - else: - past_key_values = DynamicCache.from_legacy_cache(past_key_values) - logger.warning_once( - "We detected that you are passing `past_key_values` as a tuple of tuples. This is deprecated and " - "will be removed in v4.47. Please convert your cache or use an appropriate `Cache` class " - "(https://huggingface.co/docs/transformers/kv_cache#legacy-cache-format)" - ) - - if cache_position is None: - past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 - cache_position = torch.arange( - past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device - ) - if position_ids is None: - position_ids = cache_position.unsqueeze(0) - - causal_mask = self._update_causal_mask( - attention_mask, inputs_embeds, cache_position, position_ids, past_key_values, output_attentions - ) - - # embed positions - hidden_states = inputs_embeds - - # create position embeddings to be shared across the decoder layers - position_embeddings = self.rotary_emb(hidden_states, position_ids) - - # decoder layers - all_hidden_states = () if output_hidden_states else None - all_self_attns = () if output_attentions else None - - for decoder_layer in self.layers[: self.config.num_hidden_layers]: - if output_hidden_states: - all_hidden_states += (hidden_states,) - - if batch_index is not None: - layer_outputs = decoder_layer( - hidden_states, - attention_mask=causal_mask, - position_ids=position_ids, - past_key_value=past_key_values, - batch_index=batch_index, - output_attentions=output_attentions, - use_cache=use_cache, - cache_position=cache_position, - **kwargs, - ) - else: - layer_outputs = decoder_layer( - hidden_states, - attention_mask=causal_mask, - position_ids=position_ids, - past_key_value=past_key_values, - output_attentions=output_attentions, - use_cache=use_cache, - cache_position=cache_position, - position_embeddings=position_embeddings, - **kwargs, - ) - - hidden_states = layer_outputs[0] - - if output_attentions: - all_self_attns += (layer_outputs[1],) - - hidden_states = self.norm(hidden_states) - - # add hidden states from the last decoder layer - if output_hidden_states: - all_hidden_states += (hidden_states,) - - if return_legacy_cache: - past_key_values=past_key_values.to_legacy_cache() - - output = BaseModelOutputWithPast( - last_hidden_state=hidden_states, - past_key_values=past_key_values if use_cache else None, - hidden_states=all_hidden_states, - attentions=all_self_attns, - ) - return output if return_dict else output.to_tuple() - - def _update_causal_mask( - self, - attention_mask: torch.Tensor, - input_tensor: torch.Tensor, - cache_position: torch.Tensor, - position_ids: torch.Tensor, - past_key_values: Cache, - output_attentions: bool, - ): - past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 - using_static_cache = isinstance(past_key_values, StaticCache) - - dtype, device = input_tensor.dtype, input_tensor.device - min_dtype = torch.finfo(dtype).min - sequence_length = input_tensor.shape[1] - if using_static_cache: - target_length = past_key_values.get_max_length() - else: - target_length = attention_mask.shape[-1] if isinstance(attention_mask, torch.Tensor) else past_seen_tokens - - if attention_mask is not None and attention_mask.dim() == 4: - # in this case we assume that the mask comes already in inverted form and requires no inversion or slicing - if attention_mask.max() != 0: - raise ValueError("Custom 4D attention mask should be passed in inverted form with max==0`") - causal_mask = attention_mask - else: - causal_mask = torch.full((sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=device) - if sequence_length != 1: - causal_mask = torch.triu(causal_mask, diagonal=1) - causal_mask *= torch.arange(target_length, device=device) > cache_position.reshape(-1, 1) - causal_mask = causal_mask[None, None, :, :].expand(input_tensor.shape[0], 1, -1, -1) - if attention_mask is not None: - causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit - mask_length = attention_mask.shape[-1] - padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :] - padding_mask = padding_mask == 0 - causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill( - padding_mask, min_dtype - ) - else: - causal_mask = _create_causal_mask(position_ids=position_ids, target_length=target_length) - - if ( - self.config._attn_implementation == "sdpa" - and attention_mask is not None - and attention_mask.device.type == "cuda" - and not output_attentions - ): - # Attend to all tokens in fully masked rows in the causal_mask, for example the relevant first rows when - # using left padding. This is required by F.scaled_dot_product_attention memory-efficient attention path. - # Details: https://github.com/pytorch/pytorch/issues/110213 - causal_mask = AttentionMaskConverter._unmask_unattended(causal_mask, min_dtype) - - return causal_mask ->>>>>>> bb7e949 (Llama unabled for this model) diff --git a/QEfficient/transformers/models/mpt/modeling_mpt.py b/QEfficient/transformers/models/mpt/modeling_mpt.py index 0d0fd9857..82c96e2dd 100644 --- a/QEfficient/transformers/models/mpt/modeling_mpt.py +++ b/QEfficient/transformers/models/mpt/modeling_mpt.py @@ -12,6 +12,8 @@ import torch import torch.utils.checkpoint from torch import nn +from torch.nn import CrossEntropyLoss +from transformers.cache_utils import DynamicCache from transformers.modeling_attn_mask_utils import _prepare_4d_causal_attention_mask from transformers.modeling_outputs import ( BaseModelOutputWithPastAndCrossAttentions, From c8c972c664cd7ff1e387b299b5c8657f883bbaa1 Mon Sep 17 00:00:00 2001 From: Amit Raj Date: Tue, 25 Mar 2025 09:49:53 +0000 Subject: [PATCH 04/16] Enabled gpt_bigcode, mistral, qwen, llava, internvl Signed-off-by: Amit Raj --- .../models/mistral/modeling_mistral.py | 9 ++++--- .../transformers/models/mpt/modeling_mpt.py | 2 -- .../models/qwen2/modeling_qwen2.py | 27 +++++++++++++++++-- 3 files changed, 30 insertions(+), 8 deletions(-) diff --git a/QEfficient/transformers/models/mistral/modeling_mistral.py b/QEfficient/transformers/models/mistral/modeling_mistral.py index 183b07b3a..6d075c442 100644 --- a/QEfficient/transformers/models/mistral/modeling_mistral.py +++ b/QEfficient/transformers/models/mistral/modeling_mistral.py @@ -17,6 +17,7 @@ BaseModelOutputWithPast, CausalLMOutputWithPast, ) +from transformers.modeling_rope_utils import ROPE_INIT_FUNCTIONS from transformers.models.mistral.modeling_mistral import ( MistralAttention, MistralConfig, @@ -90,8 +91,8 @@ def qeff_apply_rotary_pos_emb(q, k, cos, sin, position_ids, unsqueeze_dim=1): Returns: `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding. """ - cos = cos[position_ids].unsqueeze(unsqueeze_dim) - sin = sin[position_ids].unsqueeze(unsqueeze_dim) + cos = cos.unsqueeze(unsqueeze_dim) + sin = sin.unsqueeze(unsqueeze_dim) # Apply rotation q_embed = (q * cos) + (rotate_half(q) * sin) @@ -140,8 +141,6 @@ def forward( position_ids: Optional[torch.LongTensor] = None, past_key_value: Optional[Cache] = None, batch_index: Optional[torch.LongTensor] = None, - output_attentions: bool = False, - use_cache: bool = False, cache_position: Optional[torch.LongTensor] = None, **kwargs, ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: @@ -201,6 +200,7 @@ def forward( output_attentions: Optional[bool] = False, use_cache: Optional[bool] = False, cache_position: Optional[torch.LongTensor] = None, + position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # necessary, but kept here for BC **kwargs, ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]: """ @@ -231,6 +231,7 @@ def forward( output_attentions=output_attentions, use_cache=use_cache, cache_position=cache_position, + position_embeddings=position_embeddings, **kwargs, ) hidden_states = residual + hidden_states diff --git a/QEfficient/transformers/models/mpt/modeling_mpt.py b/QEfficient/transformers/models/mpt/modeling_mpt.py index 82c96e2dd..0d0fd9857 100644 --- a/QEfficient/transformers/models/mpt/modeling_mpt.py +++ b/QEfficient/transformers/models/mpt/modeling_mpt.py @@ -12,8 +12,6 @@ import torch import torch.utils.checkpoint from torch import nn -from torch.nn import CrossEntropyLoss -from transformers.cache_utils import DynamicCache from transformers.modeling_attn_mask_utils import _prepare_4d_causal_attention_mask from transformers.modeling_outputs import ( BaseModelOutputWithPastAndCrossAttentions, diff --git a/QEfficient/transformers/models/qwen2/modeling_qwen2.py b/QEfficient/transformers/models/qwen2/modeling_qwen2.py index 0ea22cead..20adfbf85 100644 --- a/QEfficient/transformers/models/qwen2/modeling_qwen2.py +++ b/QEfficient/transformers/models/qwen2/modeling_qwen2.py @@ -32,6 +32,29 @@ from QEfficient.transformers.modeling_attn_mask_utils import _create_causal_mask +def eager_attention_forward( + module: nn.Module, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + attention_mask: Optional[torch.Tensor], + scaling: float, + **kwargs, +): + key_states = repeat_kv(key, module.num_key_value_groups) + value_states = repeat_kv(value, module.num_key_value_groups) + + attn_weights = torch.matmul(query, key_states.transpose(2, 3)) * scaling + if attention_mask is not None: + attn_weights = torch.where(attention_mask, torch.tensor(-10000.0, dtype=torch.float32), attn_weights) + + attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype) + attn_output = torch.matmul(attn_weights, value_states) + attn_output = attn_output.transpose(1, 2).contiguous() + + return attn_output, attn_weights + + # Can be replaced with llama/modeling_llama.py::QEffLlamaRotaryEmbedding but keeping it following transformers ideology class QEffQwen2RotaryEmbedding(Qwen2RotaryEmbedding): """ @@ -101,8 +124,8 @@ def qeff_apply_rotary_pos_emb(q, k, cos, sin, position_ids, unsqueeze_dim=1): `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding. """ - cos = cos[position_ids].unsqueeze(unsqueeze_dim) - sin = sin[position_ids].unsqueeze(unsqueeze_dim) + cos = cos.unsqueeze(unsqueeze_dim) + sin = sin.unsqueeze(unsqueeze_dim) q_embed = (q * cos) + (rotate_half(q) * sin) k_embed = (k * cos) + (rotate_half(k) * sin) From 52f8b92b9613ee175679951a65e887522a06a48e Mon Sep 17 00:00:00 2001 From: Amit Raj Date: Fri, 28 Mar 2025 05:25:15 +0000 Subject: [PATCH 05/16] Code cleaning and formating Signed-off-by: Amit Raj --- .../models/gemma2/modeling_gemma2.py | 38 +++++++++++++++++++ .../models/llama/modeling_llama.py | 2 +- .../models/mistral/modeling_mistral.py | 1 - 3 files changed, 39 insertions(+), 2 deletions(-) diff --git a/QEfficient/transformers/models/gemma2/modeling_gemma2.py b/QEfficient/transformers/models/gemma2/modeling_gemma2.py index 40b45773e..812f8b8a5 100644 --- a/QEfficient/transformers/models/gemma2/modeling_gemma2.py +++ b/QEfficient/transformers/models/gemma2/modeling_gemma2.py @@ -40,7 +40,11 @@ class QEffGemma2RotaryEmbedding(Gemma2RotaryEmbedding): """ def __init__(self, config: Gemma2Config, device=None): +<<<<<<< HEAD super().__init__(config=config) +======= + Gemma2RotaryEmbedding.__init__(self, config=config) +>>>>>>> 6ba4c76 (Code cleaning and formating) # Build here to make `torch.jit.trace` work. self._set_cos_sin_cache( @@ -132,6 +136,16 @@ class QEffGemma2Attention(Gemma2Attention): - add new args cache idx for the kv retention """ +<<<<<<< HEAD +======= + def __init__(self, config: Gemma2Config, layer_idx: Optional[int] = None): + super().__init__(config, layer_idx) + # Define the general __qeff_init__() for any changes in the init calls + # Set the init in the module mapping pytorch transforms + self.config = config + self.__qeff_init__() + +>>>>>>> 6ba4c76 (Code cleaning and formating) def __qeff_init__(self): self.rotary_emb = QEffGemma2RotaryEmbedding(config=self.config) @@ -344,6 +358,10 @@ def forward( output_attentions=output_attentions, use_cache=use_cache, cache_position=cache_position, +<<<<<<< HEAD +======= + position_embeddings=position_embeddings, +>>>>>>> 6ba4c76 (Code cleaning and formating) **kwargs, ) @@ -369,6 +387,26 @@ def forward( ) return output if return_dict else output.to_tuple() +<<<<<<< HEAD +======= + def _update_causal_mask( + self, + attention_mask: torch.Tensor, + input_tensor: torch.Tensor, + cache_position: torch.Tensor, + position_ids: torch.Tensor, + past_key_values: Cache, + output_attentions: bool, + ): + past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 + target_length = attention_mask.shape[-1] if isinstance(attention_mask, torch.Tensor) else past_seen_tokens + causal_mask = _create_causal_mask( + position_ids=position_ids, target_length=target_length, sliding_window=self.config.sliding_window + ) + + return causal_mask + +>>>>>>> 6ba4c76 (Code cleaning and formating) class QEffGemma2ForCausalLM(Gemma2ForCausalLM, GenerationMixin): """ diff --git a/QEfficient/transformers/models/llama/modeling_llama.py b/QEfficient/transformers/models/llama/modeling_llama.py index 5ce7a3173..d20fcb8d5 100644 --- a/QEfficient/transformers/models/llama/modeling_llama.py +++ b/QEfficient/transformers/models/llama/modeling_llama.py @@ -29,7 +29,7 @@ from QEfficient.transformers.modeling_attn_mask_utils import _create_causal_mask -class QEffLlamaRotaryEmbedding(LlamaRotaryEmbedding, nn.Module): +class QEffLlamaRotaryEmbedding(LlamaRotaryEmbedding): """ Copied from LlamaForCausalLM: https://github.com/huggingface/transformers/blob/main/src/transformers/models/llama/modeling_llama.py The only differences are: diff --git a/QEfficient/transformers/models/mistral/modeling_mistral.py b/QEfficient/transformers/models/mistral/modeling_mistral.py index 6d075c442..046ec561e 100644 --- a/QEfficient/transformers/models/mistral/modeling_mistral.py +++ b/QEfficient/transformers/models/mistral/modeling_mistral.py @@ -17,7 +17,6 @@ BaseModelOutputWithPast, CausalLMOutputWithPast, ) -from transformers.modeling_rope_utils import ROPE_INIT_FUNCTIONS from transformers.models.mistral.modeling_mistral import ( MistralAttention, MistralConfig, From 42cfcbd0cde7922e5920d66048ce03d59e0f7369 Mon Sep 17 00:00:00 2001 From: Amit Raj Date: Fri, 28 Mar 2025 11:21:14 +0000 Subject: [PATCH 06/16] llama clean Signed-off-by: Amit Raj --- .../transformers/models/llama/modeling_llama.py | 15 ++------------- 1 file changed, 2 insertions(+), 13 deletions(-) diff --git a/QEfficient/transformers/models/llama/modeling_llama.py b/QEfficient/transformers/models/llama/modeling_llama.py index d20fcb8d5..2afd459b6 100644 --- a/QEfficient/transformers/models/llama/modeling_llama.py +++ b/QEfficient/transformers/models/llama/modeling_llama.py @@ -21,8 +21,8 @@ LlamaForCausalLM, LlamaModel, LlamaRotaryEmbedding, + apply_rotary_pos_emb, repeat_kv, - rotate_half, ) from QEfficient.transformers.cache_utils import QEffDynamicCache @@ -94,17 +94,6 @@ def qeff_apply_rotary_pos_emb(q, k, cos, sin, position_ids, unsqueeze_dim=1): # Cast back to original dtype return q_embed.to(q.dtype), k_embed.to(k.dtype) -def eager_attention_forward( - module: nn.Module, - query: torch.Tensor, - key: torch.Tensor, - value: torch.Tensor, - attention_mask: Optional[torch.Tensor], - scaling: float, - **kwargs, -): - key_states = repeat_kv(key, module.num_key_value_groups) - value_states = repeat_kv(value, module.num_key_value_groups) def eager_attention_forward( module: nn.Module, @@ -159,7 +148,7 @@ def forward( kv_seq_len = key_states.shape[-2] kv_seq_len = past_key_value.get_usable_length(kv_seq_len, self.layer_idx) - cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len) + cos, sin = position_embeddings query_states, key_states = qeff_apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids) if past_key_value is not None: From e7a5129efafad25a854755098853e78dac25a721 Mon Sep 17 00:00:00 2001 From: Amit Raj Date: Tue, 1 Apr 2025 04:58:01 +0000 Subject: [PATCH 07/16] falcon, gemma, gemma2, mistral, qwen clean up Signed-off-by: Amit Raj --- .../models/falcon/modeling_falcon.py | 32 ------------------ .../models/gemma/modeling_gemma.py | 1 + .../models/gemma2/modeling_gemma2.py | 33 +------------------ .../models/llama/modeling_llama.py | 13 +++++++- .../models/mistral/modeling_mistral.py | 32 +----------------- .../models/qwen2/modeling_qwen2.py | 26 ++------------- 6 files changed, 17 insertions(+), 120 deletions(-) diff --git a/QEfficient/transformers/models/falcon/modeling_falcon.py b/QEfficient/transformers/models/falcon/modeling_falcon.py index e4ef30b86..b237e46c9 100644 --- a/QEfficient/transformers/models/falcon/modeling_falcon.py +++ b/QEfficient/transformers/models/falcon/modeling_falcon.py @@ -68,37 +68,6 @@ def forward(self, x, seq_len=None): ) -def qeff_apply_rotary_pos_emb(q, k, cos, sin, position_ids, unsqueeze_dim=1): - """Applies Rotary Position Embedding to the query and key tensors. - - Args: - q (`torch.Tensor`): The query tensor. - k (`torch.Tensor`): The key tensor. - cos (`torch.Tensor`): The cosine part of the rotary embedding. - sin (`torch.Tensor`): The sine part of the rotary embedding. - position_ids (`torch.Tensor`): - The position indices of the tokens corresponding to the query and key tensors. For example, this can be - used to pass offsetted position ids when working with a KV-cache. - unsqueeze_dim (`int`, *optional*, defaults to 1): - The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and - sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note - that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and - k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes - cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have - the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2. - Returns: - `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding. - """ - cos = cos.unsqueeze(unsqueeze_dim) - sin = sin.unsqueeze(unsqueeze_dim) - - # Apply rotation - q_embed = (q * cos) + (rotate_half(q) * sin) - k_embed = (k * cos) + (rotate_half(k) * sin) - # Cast back to original dtype - return q_embed.to(q.dtype), k_embed.to(k.dtype) - - class QEffFalconAttention(FalconAttention): """ Copied from FalconAttention: https://github.com/huggingface/transformers/blob/main/src/transformers/models/falcon/modeling_falcon.py @@ -342,7 +311,6 @@ def forward( attentions=all_self_attentions, ) - class QEffFalconForCausalLM(FalconForCausalLM): """ Copied from FalconForCausalLM: https://github.com/huggingface/transformers/blob/main/src/transformers/models/falcon/modeling_falcon.py diff --git a/QEfficient/transformers/models/gemma/modeling_gemma.py b/QEfficient/transformers/models/gemma/modeling_gemma.py index 3d5680655..a069e7a5b 100644 --- a/QEfficient/transformers/models/gemma/modeling_gemma.py +++ b/QEfficient/transformers/models/gemma/modeling_gemma.py @@ -23,6 +23,7 @@ GemmaRotaryEmbedding, repeat_kv, rotate_half, + apply_rotary_pos_emb ) from QEfficient.transformers.cache_utils import QEffDynamicCache diff --git a/QEfficient/transformers/models/gemma2/modeling_gemma2.py b/QEfficient/transformers/models/gemma2/modeling_gemma2.py index 812f8b8a5..14b52df2f 100644 --- a/QEfficient/transformers/models/gemma2/modeling_gemma2.py +++ b/QEfficient/transformers/models/gemma2/modeling_gemma2.py @@ -24,6 +24,7 @@ Gemma2RotaryEmbedding, repeat_kv, rotate_half, + apply_rotary_pos_emb, ) from QEfficient.transformers.cache_utils import QEffDynamicCache @@ -71,38 +72,6 @@ def forward(self, x, seq_len=None): self.sin_cached[:seq_len].to(dtype=x.dtype), ) - -def qeff_apply_rotary_pos_emb(q, k, cos, sin, position_ids, unsqueeze_dim=1): - """Applies Rotary Position Embedding to the query and key tensors. - - Args: - q (`torch.Tensor`): The query tensor. - k (`torch.Tensor`): The key tensor. - cos (`torch.Tensor`): The cosine part of the rotary embedding. - sin (`torch.Tensor`): The sine part of the rotary embedding. - position_ids (`torch.Tensor`): - The position indices of the tokens corresponding to the query and key tensors. For example, this can be - used to pass offsetted position ids when working with a KV-cache. - unsqueeze_dim (`int`, *optional*, defaults to 1): - The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and - sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note - that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and - k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes - cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have - the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2. - Returns: - `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding. - """ - cos = cos.unsqueeze(unsqueeze_dim) - sin = sin.unsqueeze(unsqueeze_dim) - - # Apply rotation - q_embed = (q * cos) + (rotate_half(q) * sin) - k_embed = (k * cos) + (rotate_half(k) * sin) - # Cast back to original dtype - return q_embed.to(q.dtype), k_embed.to(k.dtype) - - def eager_attention_forward( module: nn.Module, query: torch.Tensor, diff --git a/QEfficient/transformers/models/llama/modeling_llama.py b/QEfficient/transformers/models/llama/modeling_llama.py index 2afd459b6..1ec550831 100644 --- a/QEfficient/transformers/models/llama/modeling_llama.py +++ b/QEfficient/transformers/models/llama/modeling_llama.py @@ -94,6 +94,17 @@ def qeff_apply_rotary_pos_emb(q, k, cos, sin, position_ids, unsqueeze_dim=1): # Cast back to original dtype return q_embed.to(q.dtype), k_embed.to(k.dtype) +def eager_attention_forward( + module: nn.Module, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + attention_mask: Optional[torch.Tensor], + scaling: float, + **kwargs, +): + key_states = repeat_kv(key, module.num_key_value_groups) + value_states = repeat_kv(value, module.num_key_value_groups) def eager_attention_forward( module: nn.Module, @@ -148,7 +159,7 @@ def forward( kv_seq_len = key_states.shape[-2] kv_seq_len = past_key_value.get_usable_length(kv_seq_len, self.layer_idx) - cos, sin = position_embeddings + cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len) query_states, key_states = qeff_apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids) if past_key_value is not None: diff --git a/QEfficient/transformers/models/mistral/modeling_mistral.py b/QEfficient/transformers/models/mistral/modeling_mistral.py index 046ec561e..f8f1c8415 100644 --- a/QEfficient/transformers/models/mistral/modeling_mistral.py +++ b/QEfficient/transformers/models/mistral/modeling_mistral.py @@ -26,6 +26,7 @@ MistralRotaryEmbedding, logger, repeat_kv, + apply_rotary_pos_emb, rotate_half, ) @@ -69,37 +70,6 @@ def forward(self, x, seq_len=None): ) -def qeff_apply_rotary_pos_emb(q, k, cos, sin, position_ids, unsqueeze_dim=1): - """Applies Rotary Position Embedding to the query and key tensors. - - Args: - q (`torch.Tensor`): The query tensor. - k (`torch.Tensor`): The key tensor. - cos (`torch.Tensor`): The cosine part of the rotary embedding. - sin (`torch.Tensor`): The sine part of the rotary embedding. - position_ids (`torch.Tensor`): - The position indices of the tokens corresponding to the query and key tensors. For example, this can be - used to pass offsetted position ids when working with a KV-cache. - unsqueeze_dim (`int`, *optional*, defaults to 1): - The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and - sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note - that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and - k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes - cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have - the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2. - Returns: - `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding. - """ - cos = cos.unsqueeze(unsqueeze_dim) - sin = sin.unsqueeze(unsqueeze_dim) - - # Apply rotation - q_embed = (q * cos) + (rotate_half(q) * sin) - k_embed = (k * cos) + (rotate_half(k) * sin) - # Cast back to original dtype - return q_embed.to(q.dtype), k_embed.to(k.dtype) - - def eager_attention_forward( module: nn.Module, query: torch.Tensor, diff --git a/QEfficient/transformers/models/qwen2/modeling_qwen2.py b/QEfficient/transformers/models/qwen2/modeling_qwen2.py index 20adfbf85..88617fcb0 100644 --- a/QEfficient/transformers/models/qwen2/modeling_qwen2.py +++ b/QEfficient/transformers/models/qwen2/modeling_qwen2.py @@ -26,6 +26,7 @@ Qwen2RotaryEmbedding, repeat_kv, rotate_half, + apply_rotary_pos_emb, ) from QEfficient.transformers.cache_utils import QEffDynamicCache @@ -133,29 +134,6 @@ def qeff_apply_rotary_pos_emb(q, k, cos, sin, position_ids, unsqueeze_dim=1): return q_embed.to(q.dtype), k_embed.to(k.dtype) -def eager_attention_forward( - module: nn.Module, - query: torch.Tensor, - key: torch.Tensor, - value: torch.Tensor, - attention_mask: Optional[torch.Tensor], - scaling: float, - **kwargs, -): - key_states = repeat_kv(key, module.num_key_value_groups) - value_states = repeat_kv(value, module.num_key_value_groups) - - attn_weights = torch.matmul(query, key_states.transpose(2, 3)) * scaling - if attention_mask is not None: - attn_weights = torch.where(attention_mask, torch.tensor(-10000.0, dtype=torch.float32), attn_weights) - - attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype) - attn_output = torch.matmul(attn_weights, value_states) - attn_output = attn_output.transpose(1, 2).contiguous() - - return attn_output, attn_weights - - class QEffQwen2Attention(Qwen2Attention): """ Copied from Qwen2Attention: https://github.com/huggingface/transformers/blob/main/src/transformers/models/qwen2/modeling_qwen2.py @@ -186,7 +164,7 @@ def forward( kv_seq_len = key_states.shape[-2] kv_seq_len = past_key_value.get_usable_length(kv_seq_len, self.layer_idx) - cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len) + cos, sin = position_embeddings query_states, key_states = qeff_apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids) if past_key_value is not None: From e530d851934c888cdc3d87813259681a14cf2717 Mon Sep 17 00:00:00 2001 From: Amit Raj Date: Tue, 1 Apr 2025 04:58:55 +0000 Subject: [PATCH 08/16] Ruff check and format Signed-off-by: Amit Raj --- QEfficient/transformers/models/falcon/modeling_falcon.py | 2 +- QEfficient/transformers/models/gemma/modeling_gemma.py | 3 +-- QEfficient/transformers/models/gemma2/modeling_gemma2.py | 4 ++-- QEfficient/transformers/models/mistral/modeling_mistral.py | 3 +-- QEfficient/transformers/models/qwen2/modeling_qwen2.py | 6 ++++-- 5 files changed, 9 insertions(+), 9 deletions(-) diff --git a/QEfficient/transformers/models/falcon/modeling_falcon.py b/QEfficient/transformers/models/falcon/modeling_falcon.py index b237e46c9..b6a437007 100644 --- a/QEfficient/transformers/models/falcon/modeling_falcon.py +++ b/QEfficient/transformers/models/falcon/modeling_falcon.py @@ -26,7 +26,6 @@ FalconModel, FalconRotaryEmbedding, dropout_add, - rotate_half, ) from QEfficient.transformers.cache_utils import QEffDynamicCache @@ -311,6 +310,7 @@ def forward( attentions=all_self_attentions, ) + class QEffFalconForCausalLM(FalconForCausalLM): """ Copied from FalconForCausalLM: https://github.com/huggingface/transformers/blob/main/src/transformers/models/falcon/modeling_falcon.py diff --git a/QEfficient/transformers/models/gemma/modeling_gemma.py b/QEfficient/transformers/models/gemma/modeling_gemma.py index a069e7a5b..7fc59079a 100644 --- a/QEfficient/transformers/models/gemma/modeling_gemma.py +++ b/QEfficient/transformers/models/gemma/modeling_gemma.py @@ -21,9 +21,8 @@ GemmaForCausalLM, GemmaModel, GemmaRotaryEmbedding, + apply_rotary_pos_emb, repeat_kv, - rotate_half, - apply_rotary_pos_emb ) from QEfficient.transformers.cache_utils import QEffDynamicCache diff --git a/QEfficient/transformers/models/gemma2/modeling_gemma2.py b/QEfficient/transformers/models/gemma2/modeling_gemma2.py index 14b52df2f..4a8fbeffa 100644 --- a/QEfficient/transformers/models/gemma2/modeling_gemma2.py +++ b/QEfficient/transformers/models/gemma2/modeling_gemma2.py @@ -22,9 +22,8 @@ Gemma2ForCausalLM, Gemma2Model, Gemma2RotaryEmbedding, - repeat_kv, - rotate_half, apply_rotary_pos_emb, + repeat_kv, ) from QEfficient.transformers.cache_utils import QEffDynamicCache @@ -72,6 +71,7 @@ def forward(self, x, seq_len=None): self.sin_cached[:seq_len].to(dtype=x.dtype), ) + def eager_attention_forward( module: nn.Module, query: torch.Tensor, diff --git a/QEfficient/transformers/models/mistral/modeling_mistral.py b/QEfficient/transformers/models/mistral/modeling_mistral.py index f8f1c8415..9d0d89c20 100644 --- a/QEfficient/transformers/models/mistral/modeling_mistral.py +++ b/QEfficient/transformers/models/mistral/modeling_mistral.py @@ -24,10 +24,9 @@ MistralForCausalLM, MistralModel, MistralRotaryEmbedding, + apply_rotary_pos_emb, logger, repeat_kv, - apply_rotary_pos_emb, - rotate_half, ) from QEfficient.transformers.cache_utils import QEffDynamicCache diff --git a/QEfficient/transformers/models/qwen2/modeling_qwen2.py b/QEfficient/transformers/models/qwen2/modeling_qwen2.py index 88617fcb0..a07aabf87 100644 --- a/QEfficient/transformers/models/qwen2/modeling_qwen2.py +++ b/QEfficient/transformers/models/qwen2/modeling_qwen2.py @@ -24,9 +24,8 @@ Qwen2ForCausalLM, Qwen2Model, Qwen2RotaryEmbedding, - repeat_kv, - rotate_half, apply_rotary_pos_emb, + repeat_kv, ) from QEfficient.transformers.cache_utils import QEffDynamicCache @@ -92,6 +91,7 @@ def forward(self, x, seq_len=None): ) +<<<<<<< HEAD def qeff_apply_rotary_pos_emb(q, k, cos, sin, position_ids, unsqueeze_dim=1): """Applies Rotary Position Embedding with Multimodal Sections to the query and key tensors (https://qwenlm.github.io/blog/qwen2-vl/). @@ -134,6 +134,8 @@ def qeff_apply_rotary_pos_emb(q, k, cos, sin, position_ids, unsqueeze_dim=1): return q_embed.to(q.dtype), k_embed.to(k.dtype) +======= +>>>>>>> d0f7ffd (Ruff check and format) class QEffQwen2Attention(Qwen2Attention): """ Copied from Qwen2Attention: https://github.com/huggingface/transformers/blob/main/src/transformers/models/qwen2/modeling_qwen2.py From e8eb268241d1fbff45adc32f85aaccf2320c0108 Mon Sep 17 00:00:00 2001 From: Amit Raj Date: Wed, 2 Apr 2025 04:44:09 +0000 Subject: [PATCH 09/16] Minor fixes Signed-off-by: Amit Raj --- .../models/falcon/modeling_falcon.py | 32 ++++++++++ .../models/gemma/modeling_gemma.py | 28 +-------- .../models/gemma2/modeling_gemma2.py | 36 +++++++++-- .../models/llama/modeling_llama.py | 2 +- .../models/mistral/modeling_mistral.py | 33 +++++++++- .../models/qwen2/modeling_qwen2.py | 63 ++++++++++++------- 6 files changed, 137 insertions(+), 57 deletions(-) diff --git a/QEfficient/transformers/models/falcon/modeling_falcon.py b/QEfficient/transformers/models/falcon/modeling_falcon.py index b6a437007..7f6862ffb 100644 --- a/QEfficient/transformers/models/falcon/modeling_falcon.py +++ b/QEfficient/transformers/models/falcon/modeling_falcon.py @@ -26,6 +26,7 @@ FalconModel, FalconRotaryEmbedding, dropout_add, + rotate_half, ) from QEfficient.transformers.cache_utils import QEffDynamicCache @@ -67,6 +68,37 @@ def forward(self, x, seq_len=None): ) +def qeff_apply_rotary_pos_emb(q, k, cos, sin, position_ids, unsqueeze_dim=1): + """Applies Rotary Position Embedding to the query and key tensors. + + Args: + q (`torch.Tensor`): The query tensor. + k (`torch.Tensor`): The key tensor. + cos (`torch.Tensor`): The cosine part of the rotary embedding. + sin (`torch.Tensor`): The sine part of the rotary embedding. + position_ids (`torch.Tensor`): + The position indices of the tokens corresponding to the query and key tensors. For example, this can be + used to pass offsetted position ids when working with a KV-cache. + unsqueeze_dim (`int`, *optional*, defaults to 1): + The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and + sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note + that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and + k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes + cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have + the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2. + Returns: + `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding. + """ + cos = cos[position_ids].unsqueeze(unsqueeze_dim) + sin = sin[position_ids].unsqueeze(unsqueeze_dim) + + # Apply rotation + q_embed = (q * cos) + (rotate_half(q) * sin) + k_embed = (k * cos) + (rotate_half(k) * sin) + # Cast back to original dtype + return q_embed.to(q.dtype), k_embed.to(k.dtype) + + class QEffFalconAttention(FalconAttention): """ Copied from FalconAttention: https://github.com/huggingface/transformers/blob/main/src/transformers/models/falcon/modeling_falcon.py diff --git a/QEfficient/transformers/models/gemma/modeling_gemma.py b/QEfficient/transformers/models/gemma/modeling_gemma.py index 7fc59079a..5282ae484 100644 --- a/QEfficient/transformers/models/gemma/modeling_gemma.py +++ b/QEfficient/transformers/models/gemma/modeling_gemma.py @@ -21,37 +21,14 @@ GemmaForCausalLM, GemmaModel, GemmaRotaryEmbedding, - apply_rotary_pos_emb, repeat_kv, + rotate_half, ) from QEfficient.transformers.cache_utils import QEffDynamicCache from QEfficient.transformers.modeling_attn_mask_utils import _create_causal_mask -def eager_attention_forward( - module: nn.Module, - query: torch.Tensor, - key: torch.Tensor, - value: torch.Tensor, - attention_mask: Optional[torch.Tensor], - scaling: float, - **kwargs, -): - key_states = repeat_kv(key, module.num_key_value_groups) - value_states = repeat_kv(value, module.num_key_value_groups) - - attn_weights = torch.matmul(query, key_states.transpose(2, 3)) * scaling - if attention_mask is not None: - attn_weights = torch.where(attention_mask, torch.tensor(-10000.0, dtype=torch.float32), attn_weights) - - attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype) - attn_output = torch.matmul(attn_weights, value_states) - attn_output = attn_output.transpose(1, 2).contiguous() - - return attn_output, attn_weights - - class QEffGemmaRotaryEmbedding(GemmaRotaryEmbedding): """ Copied from GemmaForCausalLM: https://github.com/huggingface/transformers/blob/main/src/transformers/models/gemma/modeling_gemma.py @@ -322,9 +299,6 @@ def forward( # embed positions hidden_states = inputs_embeds - # create position embeddings to be shared across the decoder layers - position_embeddings = self.rotary_emb(hidden_states, position_ids) - # normalized # Gemma downcasts the below to float16, causing sqrt(3072)=55.4256 to become 55.5 # See https://github.com/huggingface/transformers/pull/29402 diff --git a/QEfficient/transformers/models/gemma2/modeling_gemma2.py b/QEfficient/transformers/models/gemma2/modeling_gemma2.py index 4a8fbeffa..b7890b24e 100644 --- a/QEfficient/transformers/models/gemma2/modeling_gemma2.py +++ b/QEfficient/transformers/models/gemma2/modeling_gemma2.py @@ -22,8 +22,8 @@ Gemma2ForCausalLM, Gemma2Model, Gemma2RotaryEmbedding, - apply_rotary_pos_emb, repeat_kv, + rotate_half, ) from QEfficient.transformers.cache_utils import QEffDynamicCache @@ -72,6 +72,37 @@ def forward(self, x, seq_len=None): ) +def qeff_apply_rotary_pos_emb(q, k, cos, sin, position_ids, unsqueeze_dim=1): + """Applies Rotary Position Embedding to the query and key tensors. + + Args: + q (`torch.Tensor`): The query tensor. + k (`torch.Tensor`): The key tensor. + cos (`torch.Tensor`): The cosine part of the rotary embedding. + sin (`torch.Tensor`): The sine part of the rotary embedding. + position_ids (`torch.Tensor`): + The position indices of the tokens corresponding to the query and key tensors. For example, this can be + used to pass offsetted position ids when working with a KV-cache. + unsqueeze_dim (`int`, *optional*, defaults to 1): + The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and + sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note + that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and + k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes + cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have + the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2. + Returns: + `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding. + """ + cos = cos[position_ids].unsqueeze(unsqueeze_dim) + sin = sin[position_ids].unsqueeze(unsqueeze_dim) + + # Apply rotation + q_embed = (q * cos) + (rotate_half(q) * sin) + k_embed = (k * cos) + (rotate_half(k) * sin) + # Cast back to original dtype + return q_embed.to(q.dtype), k_embed.to(k.dtype) + + def eager_attention_forward( module: nn.Module, query: torch.Tensor, @@ -301,9 +332,6 @@ def forward( # embed positions hidden_states = inputs_embeds - # create position embeddings to be shared across the decoder layers - position_embeddings = self.rotary_emb(hidden_states, position_ids) - # normalized # Gemma2 downcasts the below to float16, causing sqrt(3072)=55.4256 to become 55.5 # See https://github.com/huggingface/transformers/pull/29402 diff --git a/QEfficient/transformers/models/llama/modeling_llama.py b/QEfficient/transformers/models/llama/modeling_llama.py index 1ec550831..d20fcb8d5 100644 --- a/QEfficient/transformers/models/llama/modeling_llama.py +++ b/QEfficient/transformers/models/llama/modeling_llama.py @@ -21,8 +21,8 @@ LlamaForCausalLM, LlamaModel, LlamaRotaryEmbedding, - apply_rotary_pos_emb, repeat_kv, + rotate_half, ) from QEfficient.transformers.cache_utils import QEffDynamicCache diff --git a/QEfficient/transformers/models/mistral/modeling_mistral.py b/QEfficient/transformers/models/mistral/modeling_mistral.py index 9d0d89c20..520488e1d 100644 --- a/QEfficient/transformers/models/mistral/modeling_mistral.py +++ b/QEfficient/transformers/models/mistral/modeling_mistral.py @@ -24,9 +24,9 @@ MistralForCausalLM, MistralModel, MistralRotaryEmbedding, - apply_rotary_pos_emb, logger, repeat_kv, + rotate_half, ) from QEfficient.transformers.cache_utils import QEffDynamicCache @@ -69,6 +69,37 @@ def forward(self, x, seq_len=None): ) +def qeff_apply_rotary_pos_emb(q, k, cos, sin, position_ids, unsqueeze_dim=1): + """Applies Rotary Position Embedding to the query and key tensors. + + Args: + q (`torch.Tensor`): The query tensor. + k (`torch.Tensor`): The key tensor. + cos (`torch.Tensor`): The cosine part of the rotary embedding. + sin (`torch.Tensor`): The sine part of the rotary embedding. + position_ids (`torch.Tensor`): + The position indices of the tokens corresponding to the query and key tensors. For example, this can be + used to pass offsetted position ids when working with a KV-cache. + unsqueeze_dim (`int`, *optional*, defaults to 1): + The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and + sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note + that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and + k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes + cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have + the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2. + Returns: + `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding. + """ + cos = cos[position_ids].unsqueeze(unsqueeze_dim) + sin = sin[position_ids].unsqueeze(unsqueeze_dim) + + # Apply rotation + q_embed = (q * cos) + (rotate_half(q) * sin) + k_embed = (k * cos) + (rotate_half(k) * sin) + # Cast back to original dtype + return q_embed.to(q.dtype), k_embed.to(k.dtype) + + def eager_attention_forward( module: nn.Module, query: torch.Tensor, diff --git a/QEfficient/transformers/models/qwen2/modeling_qwen2.py b/QEfficient/transformers/models/qwen2/modeling_qwen2.py index a07aabf87..ef2300ae6 100644 --- a/QEfficient/transformers/models/qwen2/modeling_qwen2.py +++ b/QEfficient/transformers/models/qwen2/modeling_qwen2.py @@ -24,37 +24,14 @@ Qwen2ForCausalLM, Qwen2Model, Qwen2RotaryEmbedding, - apply_rotary_pos_emb, repeat_kv, + rotate_half, ) from QEfficient.transformers.cache_utils import QEffDynamicCache from QEfficient.transformers.modeling_attn_mask_utils import _create_causal_mask -def eager_attention_forward( - module: nn.Module, - query: torch.Tensor, - key: torch.Tensor, - value: torch.Tensor, - attention_mask: Optional[torch.Tensor], - scaling: float, - **kwargs, -): - key_states = repeat_kv(key, module.num_key_value_groups) - value_states = repeat_kv(value, module.num_key_value_groups) - - attn_weights = torch.matmul(query, key_states.transpose(2, 3)) * scaling - if attention_mask is not None: - attn_weights = torch.where(attention_mask, torch.tensor(-10000.0, dtype=torch.float32), attn_weights) - - attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype) - attn_output = torch.matmul(attn_weights, value_states) - attn_output = attn_output.transpose(1, 2).contiguous() - - return attn_output, attn_weights - - # Can be replaced with llama/modeling_llama.py::QEffLlamaRotaryEmbedding but keeping it following transformers ideology class QEffQwen2RotaryEmbedding(Qwen2RotaryEmbedding): """ @@ -92,6 +69,9 @@ def forward(self, x, seq_len=None): <<<<<<< HEAD +<<<<<<< HEAD +======= +>>>>>>> e4503c5 (Minor fixes) def qeff_apply_rotary_pos_emb(q, k, cos, sin, position_ids, unsqueeze_dim=1): """Applies Rotary Position Embedding with Multimodal Sections to the query and key tensors (https://qwenlm.github.io/blog/qwen2-vl/). @@ -125,8 +105,13 @@ def qeff_apply_rotary_pos_emb(q, k, cos, sin, position_ids, unsqueeze_dim=1): `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding. """ +<<<<<<< HEAD cos = cos.unsqueeze(unsqueeze_dim) sin = sin.unsqueeze(unsqueeze_dim) +======= + cos = cos[position_ids].unsqueeze(unsqueeze_dim) + sin = sin[position_ids].unsqueeze(unsqueeze_dim) +>>>>>>> e4503c5 (Minor fixes) q_embed = (q * cos) + (rotate_half(q) * sin) k_embed = (k * cos) + (rotate_half(k) * sin) @@ -134,8 +119,34 @@ def qeff_apply_rotary_pos_emb(q, k, cos, sin, position_ids, unsqueeze_dim=1): return q_embed.to(q.dtype), k_embed.to(k.dtype) +<<<<<<< HEAD ======= >>>>>>> d0f7ffd (Ruff check and format) +======= +def eager_attention_forward( + module: nn.Module, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + attention_mask: Optional[torch.Tensor], + scaling: float, + **kwargs, +): + key_states = repeat_kv(key, module.num_key_value_groups) + value_states = repeat_kv(value, module.num_key_value_groups) + + attn_weights = torch.matmul(query, key_states.transpose(2, 3)) * scaling + if attention_mask is not None: + attn_weights = torch.where(attention_mask, torch.tensor(-10000.0, dtype=torch.float32), attn_weights) + + attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype) + attn_output = torch.matmul(attn_weights, value_states) + attn_output = attn_output.transpose(1, 2).contiguous() + + return attn_output, attn_weights + + +>>>>>>> e4503c5 (Minor fixes) class QEffQwen2Attention(Qwen2Attention): """ Copied from Qwen2Attention: https://github.com/huggingface/transformers/blob/main/src/transformers/models/qwen2/modeling_qwen2.py @@ -166,7 +177,11 @@ def forward( kv_seq_len = key_states.shape[-2] kv_seq_len = past_key_value.get_usable_length(kv_seq_len, self.layer_idx) +<<<<<<< HEAD cos, sin = position_embeddings +======= + cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len) +>>>>>>> e4503c5 (Minor fixes) query_states, key_states = qeff_apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids) if past_key_value is not None: From 355eb7b73431e2a9cdf9771ec3a014260207aa2a Mon Sep 17 00:00:00 2001 From: Dipankar Sarkar Date: Tue, 8 Apr 2025 09:34:58 +0530 Subject: [PATCH 10/16] Enabled Models for TF-4.50.0 (#340) Enabled Models for TF-4.50.0 Models Enabled are 1. GPT2 2. GPTJ 3. Granite 4. Phi 5. Phi3 6. Whisper This is the same PR that was raised as #334 which was closed due to rebasing issue Code Cleaned and requested changes done and marked for review QEff Dynamic Cache Added --------- Signed-off-by: Dipankar Sarkar Signed-off-by: Amit Raj --- QEfficient/transformers/models/whisper/modeling_whisper.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/QEfficient/transformers/models/whisper/modeling_whisper.py b/QEfficient/transformers/models/whisper/modeling_whisper.py index 335624b09..803b520c3 100644 --- a/QEfficient/transformers/models/whisper/modeling_whisper.py +++ b/QEfficient/transformers/models/whisper/modeling_whisper.py @@ -9,7 +9,8 @@ import torch from torch import nn -from transformers.cache_utils import Cache, StaticCache +from transformers.cache_utils import Cache, EncoderDecoderCache, StaticCache +from transformers.modeling_attn_mask_utils import AttentionMaskConverter from transformers.modeling_outputs import ( BaseModelOutputWithCrossAttentions, BaseModelOutputWithPastAndCrossAttentions, From 281ec013d666dcbfeb3b7b23ac05762ab59b9545 Mon Sep 17 00:00:00 2001 From: asmigosw Date: Tue, 8 Apr 2025 10:08:47 +0530 Subject: [PATCH 11/16] Tf version upgrade to 4.50 (#344) Tf version upgrade for mllama, starcoder2 and mistral_moe --------- Signed-off-by: Asmita Goswami Signed-off-by: Amit Raj --- .../models/mllama/modeling_mllama.py | 38 ------------------- .../models/whisper/modeling_whisper.py | 3 +- 2 files changed, 1 insertion(+), 40 deletions(-) diff --git a/QEfficient/transformers/models/mllama/modeling_mllama.py b/QEfficient/transformers/models/mllama/modeling_mllama.py index 6eef81bd3..9f02270fe 100644 --- a/QEfficient/transformers/models/mllama/modeling_mllama.py +++ b/QEfficient/transformers/models/mllama/modeling_mllama.py @@ -48,44 +48,6 @@ NUM_CHANNEL = 3 -class QEffMllamaRotaryEmbedding(MllamaRotaryEmbedding): - """ - Copied from MllamaForCausalLM: https://github.com/huggingface/transformers/blob/main/src/transformers/models/mllama/modeling_mllama.py - The only differences are: - - Add static sin/cos computations. - """ - - def __init__(self, config: MllamaConfig, device=None): - super().__init__(config=config) - - # Build here to make `torch.jit.trace` work. - self._set_cos_sin_cache( - seq_len=self.original_max_seq_len, - device=self.inv_freq.device, - dtype=torch.get_default_dtype(), - ) - - def _set_cos_sin_cache(self, seq_len, device, dtype): - self.max_seq_len_cached = seq_len - t = torch.arange(self.max_seq_len_cached, device=device, dtype=torch.int64).type_as(self.inv_freq) - - freqs = torch.outer(t, self.inv_freq) - - emb = torch.cat((freqs, freqs), dim=-1) - self.register_buffer("cos_cached", emb.cos().to(dtype), persistent=False) - self.register_buffer("sin_cached", emb.sin().to(dtype), persistent=False) - - def forward(self, x, seq_len=None): - # x: [bs, num_attention_heads, seq_len, head_size] - if seq_len > self.max_seq_len_cached: - self._set_cos_sin_cache(seq_len=seq_len, device=x.device, dtype=x.dtype) - - return ( - self.cos_cached[:seq_len].to(dtype=x.dtype) * self.attention_scaling, - self.sin_cached[:seq_len].to(dtype=x.dtype) * self.attention_scaling, - ) - - def qeff_apply_rotary_pos_emb(q, k, cos, sin, position_ids, unsqueeze_dim=1): """Applies Rotary Position Embedding to the query and key tensors. diff --git a/QEfficient/transformers/models/whisper/modeling_whisper.py b/QEfficient/transformers/models/whisper/modeling_whisper.py index 803b520c3..335624b09 100644 --- a/QEfficient/transformers/models/whisper/modeling_whisper.py +++ b/QEfficient/transformers/models/whisper/modeling_whisper.py @@ -9,8 +9,7 @@ import torch from torch import nn -from transformers.cache_utils import Cache, EncoderDecoderCache, StaticCache -from transformers.modeling_attn_mask_utils import AttentionMaskConverter +from transformers.cache_utils import Cache, StaticCache from transformers.modeling_outputs import ( BaseModelOutputWithCrossAttentions, BaseModelOutputWithPastAndCrossAttentions, From 20b82b200cccdbfa2161de5baf519f2da03478a9 Mon Sep 17 00:00:00 2001 From: Amit Raj Date: Tue, 8 Apr 2025 08:25:04 +0000 Subject: [PATCH 12/16] Fixed issue of quantizer_compressed Signed-off-by: Amit Raj --- .../models/mllama/modeling_mllama.py | 38 +++++++++++++++++++ .../quantizer_compressed_tensors.py | 8 +--- 2 files changed, 39 insertions(+), 7 deletions(-) diff --git a/QEfficient/transformers/models/mllama/modeling_mllama.py b/QEfficient/transformers/models/mllama/modeling_mllama.py index 9f02270fe..6eef81bd3 100644 --- a/QEfficient/transformers/models/mllama/modeling_mllama.py +++ b/QEfficient/transformers/models/mllama/modeling_mllama.py @@ -48,6 +48,44 @@ NUM_CHANNEL = 3 +class QEffMllamaRotaryEmbedding(MllamaRotaryEmbedding): + """ + Copied from MllamaForCausalLM: https://github.com/huggingface/transformers/blob/main/src/transformers/models/mllama/modeling_mllama.py + The only differences are: + - Add static sin/cos computations. + """ + + def __init__(self, config: MllamaConfig, device=None): + super().__init__(config=config) + + # Build here to make `torch.jit.trace` work. + self._set_cos_sin_cache( + seq_len=self.original_max_seq_len, + device=self.inv_freq.device, + dtype=torch.get_default_dtype(), + ) + + def _set_cos_sin_cache(self, seq_len, device, dtype): + self.max_seq_len_cached = seq_len + t = torch.arange(self.max_seq_len_cached, device=device, dtype=torch.int64).type_as(self.inv_freq) + + freqs = torch.outer(t, self.inv_freq) + + emb = torch.cat((freqs, freqs), dim=-1) + self.register_buffer("cos_cached", emb.cos().to(dtype), persistent=False) + self.register_buffer("sin_cached", emb.sin().to(dtype), persistent=False) + + def forward(self, x, seq_len=None): + # x: [bs, num_attention_heads, seq_len, head_size] + if seq_len > self.max_seq_len_cached: + self._set_cos_sin_cache(seq_len=seq_len, device=x.device, dtype=x.dtype) + + return ( + self.cos_cached[:seq_len].to(dtype=x.dtype) * self.attention_scaling, + self.sin_cached[:seq_len].to(dtype=x.dtype) * self.attention_scaling, + ) + + def qeff_apply_rotary_pos_emb(q, k, cos, sin, position_ids, unsqueeze_dim=1): """Applies Rotary Position Embedding to the query and key tensors. diff --git a/QEfficient/transformers/quantizers/quantizer_compressed_tensors.py b/QEfficient/transformers/quantizers/quantizer_compressed_tensors.py index 18e814b83..9c8a7166a 100644 --- a/QEfficient/transformers/quantizers/quantizer_compressed_tensors.py +++ b/QEfficient/transformers/quantizers/quantizer_compressed_tensors.py @@ -130,12 +130,7 @@ def forward(self, x): class QEffFP8Config(QuantizationConfigMixin): def __init__( - self, - quant_method: str, - activation_scheme: str, - ignored_layers: List[str] = None, - kv_cache_scheme: str = None, - run_compressed: bool = False, + self, quant_method: str, activation_scheme: str, ignored_layers: List[str] = None, kv_cache_scheme: str = None ): self.quant_method = quant_method self.activation_scheme = activation_scheme @@ -230,7 +225,6 @@ def __init__( ignore=None, sparsity_config=None, quant_method="compressed-tensors", - run_compressed: bool = False, **kwargs, ): self.config_groups = config_groups From 598b83f4ceb7e88d79a5abee8d16447c99bdc641 Mon Sep 17 00:00:00 2001 From: Amit Raj Date: Tue, 8 Apr 2025 17:22:09 +0000 Subject: [PATCH 13/16] adding classes Signed-off-by: Amit Raj --- .../models/mixtral_moe/modeling_mixtral.py | 47 ++++++++++--------- 1 file changed, 24 insertions(+), 23 deletions(-) diff --git a/QEfficient/transformers/models/mixtral_moe/modeling_mixtral.py b/QEfficient/transformers/models/mixtral_moe/modeling_mixtral.py index 36bffac31..37d8ccc33 100644 --- a/QEfficient/transformers/models/mixtral_moe/modeling_mixtral.py +++ b/QEfficient/transformers/models/mixtral_moe/modeling_mixtral.py @@ -206,37 +206,38 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: """ """ batch_size, sequence_length, hidden_dim = hidden_states.shape hidden_states = hidden_states.view(-1, hidden_dim) - # router_logits: (batch * sequence_length, n_experts) - router_logits = self.gate(hidden_states) - + + # Compute routing logits for selecting experts + router_logits = self.gate(hidden_states) # Shape: (batch * seq_len, num_experts) + + # Compute routing probabilities and select top-k experts per token routing_weights = F.softmax(router_logits, dim=1, dtype=torch.float) routing_weights, selected_experts = torch.topk(routing_weights, self.top_k, dim=-1) - routing_weights /= routing_weights.sum(dim=-1, keepdim=True) - # we cast back to the input dtype - routing_weights = routing_weights.to(hidden_states.dtype) + routing_weights /= routing_weights.sum(dim=-1, keepdim=True) # Normalize weights + routing_weights = routing_weights.to(hidden_states.dtype) # Ensure correct dtype + # Initialize final output tensor final_hidden_states = torch.zeros( (batch_size * sequence_length, hidden_dim), dtype=hidden_states.dtype, device=hidden_states.device ) + # One-hot encode selected experts (batch * seq_len, top_k, num_experts) + expert_mask = torch.nn.functional.one_hot(selected_experts, num_classes=self.num_experts) + expert_mask = expert_mask.to(hidden_states.dtype) # Ensure dtype matches for efficient computation + + # Compute all expert outputs in parallel (batch * seq_len, num_experts, hidden_dim) + expert_outputs = torch.stack([self.experts[i](hidden_states) for i in range(self.num_experts)], dim=1) + + # Efficient expert selection using matrix multiplication + selected_expert_outputs = torch.einsum("bte,beh->bth", expert_mask, expert_outputs) - # One hot encode the selected experts to create an expert mask - # this will be used to easily index which expert is going to be sollicitated - expert_mask = torch.nn.functional.one_hot(selected_experts, num_classes=self.num_experts).permute(2, 1, 0) - - # Loop over all available experts in the model and perform the computation on each expert - for expert_idx in range(self.num_experts): - expert_layer = self.experts[expert_idx] - expert_mask_tr = expert_mask[expert_idx].transpose(0, 1) - current_hidden_states = expert_layer(hidden_states) * (((routing_weights * expert_mask_tr).sum(1))[:, None]) - current_hidden_states = torch.where( - (routing_weights * expert_mask_tr).sum(1).to(torch.bool)[:, None], - current_hidden_states, - torch.tensor(0.0), - ) - final_hidden_states = final_hidden_states + current_hidden_states - final_hidden_states = final_hidden_states.reshape(batch_size, sequence_length, hidden_dim) - return final_hidden_states, router_logits + # Multiply by routing weights and sum over top_k experts + final_hidden_states = (selected_expert_outputs * routing_weights.unsqueeze(-1)).sum(dim=1) + + # Reshape back to original dimensions + final_hidden_states = final_hidden_states.view(batch_size, sequence_length, hidden_dim) + + return final_hidden_states, router_logits class QeffMixtralDecoderLayer(MixtralDecoderLayer): """ From fc8e7313cc85496e3aa5e9a6a27108e1e06c63e3 Mon Sep 17 00:00:00 2001 From: Amit Raj Date: Thu, 10 Apr 2025 06:19:50 +0000 Subject: [PATCH 14/16] Added suppport of Grok1 model Signed-off-by: Amit Raj --- QEfficient/base/pytorch_transforms.py | 4 + .../transformers/models/grok_1/__init__.py | 7 + .../models/grok_1/modeling_grok1.py | 377 ++++++++++++++++++ .../models/mixtral_moe/modeling_mixtral.py | 47 ++- .../transformers/models/modeling_auto.py | 1 + .../transformers/models/pytorch_transforms.py | 18 + .../quantizer_compressed_tensors.py | 8 +- 7 files changed, 437 insertions(+), 25 deletions(-) create mode 100644 QEfficient/transformers/models/grok_1/__init__.py create mode 100644 QEfficient/transformers/models/grok_1/modeling_grok1.py diff --git a/QEfficient/base/pytorch_transforms.py b/QEfficient/base/pytorch_transforms.py index abd19ed35..f97b51489 100644 --- a/QEfficient/base/pytorch_transforms.py +++ b/QEfficient/base/pytorch_transforms.py @@ -107,6 +107,10 @@ def apply(cls, model: nn.Module) -> Tuple[nn.Module, bool]: ): for orig_method_name, mapped_method in repl_method_map.items(): setattr(module, orig_method_name, MethodType(mapped_method, module)) + + if hasattr(module, "__qeff_init__"): + module.__qeff_init__() + transformed = True return model, transformed diff --git a/QEfficient/transformers/models/grok_1/__init__.py b/QEfficient/transformers/models/grok_1/__init__.py new file mode 100644 index 000000000..da26921c5 --- /dev/null +++ b/QEfficient/transformers/models/grok_1/__init__.py @@ -0,0 +1,7 @@ +# ----------------------------------------------------------------------------- +# +# Copyright (c) 2024 Qualcomm Innovation Center, Inc. All rights reserved. +# SPDX-License-Identifier: BSD-3-Clause +# +# ----------------------------------------------------------------------------- + diff --git a/QEfficient/transformers/models/grok_1/modeling_grok1.py b/QEfficient/transformers/models/grok_1/modeling_grok1.py new file mode 100644 index 000000000..07f8f1e3d --- /dev/null +++ b/QEfficient/transformers/models/grok_1/modeling_grok1.py @@ -0,0 +1,377 @@ +# ----------------------------------------------------------------------------- +# +# Copyright (c) 2024 Qualcomm Innovation Center, Inc. All rights reserved. +# SPDX-License-Identifier: BSD-3-Clause +# +# ----------------------------------------------------------------------------- +from typing import List, Optional, Tuple, Union + +import torch +import torch.nn as nn +import torch.nn.functional as F + +from QEfficient.transformers.cache_utils import QEffDynamicCache +from QEfficient.transformers.modeling_attn_mask_utils import _create_causal_mask + +from transformers.modeling_outputs import ( + MoeCausalLMOutputWithPast, + MoeModelOutputWithPast, +) + + +# Copied from transformers.models.llama.modeling_llama.repeat_kv +def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor: + """ + This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch, + num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim) + """ + batch, num_key_value_heads, slen, head_dim = hidden_states.shape + if n_rep == 1: + return hidden_states + hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim) + return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim) + + +# Copied from transformers.models.llama.modeling_llama.rotate_half +def rotate_half(x): + """Rotates half the hidden dims of the input.""" + x1 = x[..., : x.shape[-1] // 2] + x2 = x[..., x.shape[-1] // 2 :] + return torch.cat((-x2, x1), dim=-1) + + +# Copied from transformers.models.llama.modeling_llama.apply_rotary_pos_emb +def qeff_apply_rotary_pos_emb(q, k, cos, sin, position_ids, unsqueeze_dim=1): + """Applies Rotary Position Embedding to the query and key tensors. + + Args: + q (`torch.Tensor`): The query tensor. + k (`torch.Tensor`): The key tensor. + cos (`torch.Tensor`): The cosine part of the rotary embedding. + sin (`torch.Tensor`): The sine part of the rotary embedding. + position_ids (`torch.Tensor`): + The position indices of the tokens corresponding to the query and key tensors. For example, this can be + used to pass offsetted position ids when working with a KV-cache. + unsqueeze_dim (`int`, *optional*, defaults to 1): + The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and + sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note + that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and + k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes + cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have + the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2. + Returns: + `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding. + """ + cos = cos[position_ids].unsqueeze(unsqueeze_dim) + sin = sin[position_ids].unsqueeze(unsqueeze_dim) + q_embed = (q * cos) + (rotate_half(q) * sin) + k_embed = (k * cos) + (rotate_half(k) * sin) + return q_embed, k_embed + + +class QEffGrok1MultiHeadAttention(nn.Module): + def __qeff_init__(self): + self.layer_idx = 0 + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_value: Optional[Tuple[torch.Tensor]] = None, + batch_index: Optional[torch.LongTensor] = None, + output_attentions: bool = False, + use_cache: bool = False, + **kwargs, + ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: + bsz, q_len, _ = hidden_states.size() + + query_states = self.q_proj(hidden_states) + key_states = self.k_proj(hidden_states) + value_states = self.v_proj(hidden_states) + + query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) + key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) + value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) + + kv_seq_len = key_states.shape[-2] + if past_key_value is not None: + kv_seq_len = past_key_value.get_usable_length(kv_seq_len, self.layer_idx) + + cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len) + query_states, key_states = qeff_apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids) + + if past_key_value is not None: + cache_kwargs = { + "sin": sin, + "cos": cos, + "batch_index": batch_index, + "position_ids": position_ids, + } # Specific to RoPE models + key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) + + # repeat k/v heads if n_kv_heads < n_heads + key_states = repeat_kv(key_states, self.num_key_value_groups) + value_states = repeat_kv(value_states, self.num_key_value_groups) + + attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)).to(torch.float) + attn_weights = attn_weights * self.attn_output_multiplier + attn_weights = self.max_attn_val * F.tanh(attn_weights / self.max_attn_val) + + if attention_mask is not None: + attn_weights = torch.where(attention_mask, torch.tensor(-10000.0, dtype=torch.float32), attn_weights) + + attn_weights = F.softmax(attn_weights, dim=-1).to(query_states.dtype) + attn_output = torch.matmul(attn_weights, value_states) + + if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim): + raise ValueError( + f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is" + f" {attn_output.size()}" + ) + + attn_output = attn_output.transpose(1, 2).contiguous() + attn_output = attn_output.reshape(bsz, q_len, -1) + + attn_output = self.o_proj(attn_output) + + if not output_attentions: + attn_weights = None + + return attn_output, attn_weights, past_key_value + + +class QEffGrok1MoeBlock(nn.Module): + def forward(self, hidden_states: torch.Tensor) -> Tuple[torch.Tensor]: + batch_size, sequence_length, hidden_dim = hidden_states.shape + hidden_states = hidden_states.view(-1, hidden_dim) + # router_logits: (batch * sequence_length, n_experts) + router_logits = self.gate(hidden_states) + + routing_weights = F.softmax(router_logits, dim=1, dtype=torch.float) + routing_weights, selected_experts = torch.topk(routing_weights, self.top_k, dim=-1) + # we cast back to the input dtype + routing_weights = routing_weights.to(hidden_states.dtype) + + final_hidden_states = torch.zeros( + (batch_size * sequence_length, hidden_dim), + dtype=hidden_states.dtype, + device=hidden_states.device, + ) + # One hot encode the selected experts to create an expert mask + # this will be used to easily index which expert is going to be sollicitated + expert_mask = torch.nn.functional.one_hot(selected_experts, num_classes=self.num_experts).permute(2, 1, 0) + + # Loop over all available experts in the model and perform the computation on each expert + for expert_idx in range(self.num_experts): + expert_layer = self.experts[expert_idx] + expert_mask_tr = expert_mask[expert_idx].transpose(0, 1) + current_hidden_states = expert_layer(hidden_states) * (((routing_weights * expert_mask_tr).sum(1))[:, None]) + current_hidden_states = torch.where( + (routing_weights * expert_mask_tr).sum(1).to(torch.bool)[:, None], + current_hidden_states, + torch.tensor(0.0), + ) + final_hidden_states = final_hidden_states + current_hidden_states + final_hidden_states = final_hidden_states.reshape(batch_size, sequence_length, hidden_dim) + return final_hidden_states, router_logits + + +class QEffGrok1DecoderLayer(nn.Module): + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_value: Optional[Tuple[torch.Tensor]] = None, + batch_index: Optional[torch.LongTensor] = None, + output_attentions: Optional[bool] = False, + output_router_logits: Optional[bool] = False, + use_cache: Optional[bool] = False, + **kwargs, + ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]: + residual = hidden_states + hidden_states = self.pre_attn_norm(hidden_states) + hidden_states, attention_weights, present_key_value = self.attn( + hidden_states, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_value=past_key_value, + batch_index=batch_index, + output_attentions=output_attentions, + use_cache=use_cache, + ) + hidden_states = self.post_attn_norm(hidden_states) + hidden_states = residual + hidden_states + + residual = hidden_states + hidden_states = self.pre_moe_norm(hidden_states) + hidden_states, router_logits = self.moe_block(hidden_states) + hidden_states = self.post_moe_norm(hidden_states) + hidden_states = residual + hidden_states + + outputs = (hidden_states,) + if output_attentions: + outputs += (attention_weights,) + if use_cache: + outputs += (present_key_value,) + if output_router_logits: + outputs += (router_logits,) + return outputs + + +class QEffGrok1Model(nn.Module): + def forward( + self, + input_ids: torch.LongTensor = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[List[torch.FloatTensor]] = None, + batch_index: Optional[torch.LongTensor] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + output_router_logits: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple, MoeModelOutputWithPast]: + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + use_cache = use_cache if use_cache is not None else self.config.use_cache + + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + # retrieve input_ids and inputs_embeds + if input_ids is not None and inputs_embeds is not None: + raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time") + elif input_ids is not None: + batch_size, seq_length = input_ids.shape[:2] + elif inputs_embeds is not None: + batch_size, seq_length = inputs_embeds.shape[:2] + else: + raise ValueError("You have to specify either input_ids or inputs_embeds") + + seq_length_with_past = seq_length + past_key_values_length = 0 + if past_key_values is not None: + past_key_values_length = past_key_values[0][0].shape[2] + seq_length_with_past = seq_length_with_past + past_key_values_length + + past_key_values = QEffDynamicCache.from_legacy_cache(past_key_values) + + if inputs_embeds is None: + inputs_embeds = self.embed_tokens(input_ids) + inputs_embeds = inputs_embeds * self.embedding_multiplier_scale + + attention_mask = _create_causal_mask(position_ids=position_ids, target_length=past_key_values_length) + + # embed positions + hidden_states = inputs_embeds + + # decoder layers + all_hidden_states = () if output_hidden_states else None + all_self_attns = () if output_attentions else None + all_router_logits = () if output_router_logits else None + next_decoder_cache = () if use_cache else None + + for idx, decoder_layer in enumerate(self.layers): + if output_hidden_states: + all_hidden_states += (hidden_states,) + + layer_outputs = decoder_layer( + hidden_states, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_value=past_key_values, + batch_index=batch_index, + output_attentions=output_attentions, + use_cache=use_cache, + ) + + hidden_states = layer_outputs[0] + + if use_cache: + next_decoder_cache += (layer_outputs[2 if output_attentions else 1],) + + if output_attentions: + all_self_attns += (layer_outputs[1],) + + if output_router_logits: + all_router_logits += (layer_outputs[-1],) + + hidden_states = self.norm(hidden_states) + + # add hidden states from the last decoder layer + if output_hidden_states: + all_hidden_states += (hidden_states,) + + past_key_values = past_key_values.to_legacy_cache() + + return MoeModelOutputWithPast( + last_hidden_state=hidden_states, + past_key_values=past_key_values, + hidden_states=all_hidden_states, + attentions=all_self_attns, + router_logits=all_router_logits, + ) + + +class QEffGrok1ModelForCausalLM(nn.Module): + def forward( + self, + input_ids: torch.LongTensor = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[List[torch.FloatTensor]] = None, + batch_index: Optional[torch.LongTensor] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + output_router_logits: Optional[bool] = None, + return_dict: Optional[bool] = None, + **kwargs, + ): + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_router_logits = ( + output_router_logits if output_router_logits is not None else self.config.output_router_logits + ) + + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn) + outputs = self.model( + input_ids=input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + batch_index=batch_index, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + output_router_logits=output_router_logits, + return_dict=return_dict, + **kwargs, + ) + + # Cast to int32 to avoid ONNXRT issue + logit_idx = position_ids.to(torch.int32).argmax(1, keepdim=True) + hidden_states = outputs[0][torch.arange(position_ids.shape[0]).view(-1, 1), logit_idx] + logits = self.lm_head(hidden_states) + logits = logits * self.output_multiplier_scale + logits = logits.float() + + return MoeCausalLMOutputWithPast( + logits=logits, + past_key_values=outputs.past_key_values, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + router_logits=outputs.router_logits, + ) diff --git a/QEfficient/transformers/models/mixtral_moe/modeling_mixtral.py b/QEfficient/transformers/models/mixtral_moe/modeling_mixtral.py index 37d8ccc33..36bffac31 100644 --- a/QEfficient/transformers/models/mixtral_moe/modeling_mixtral.py +++ b/QEfficient/transformers/models/mixtral_moe/modeling_mixtral.py @@ -206,39 +206,38 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: """ """ batch_size, sequence_length, hidden_dim = hidden_states.shape hidden_states = hidden_states.view(-1, hidden_dim) - - # Compute routing logits for selecting experts - router_logits = self.gate(hidden_states) # Shape: (batch * seq_len, num_experts) - - # Compute routing probabilities and select top-k experts per token + # router_logits: (batch * sequence_length, n_experts) + router_logits = self.gate(hidden_states) + routing_weights = F.softmax(router_logits, dim=1, dtype=torch.float) routing_weights, selected_experts = torch.topk(routing_weights, self.top_k, dim=-1) - routing_weights /= routing_weights.sum(dim=-1, keepdim=True) # Normalize weights - routing_weights = routing_weights.to(hidden_states.dtype) # Ensure correct dtype + routing_weights /= routing_weights.sum(dim=-1, keepdim=True) + # we cast back to the input dtype + routing_weights = routing_weights.to(hidden_states.dtype) - # Initialize final output tensor final_hidden_states = torch.zeros( (batch_size * sequence_length, hidden_dim), dtype=hidden_states.dtype, device=hidden_states.device ) - # One-hot encode selected experts (batch * seq_len, top_k, num_experts) - expert_mask = torch.nn.functional.one_hot(selected_experts, num_classes=self.num_experts) - expert_mask = expert_mask.to(hidden_states.dtype) # Ensure dtype matches for efficient computation - - # Compute all expert outputs in parallel (batch * seq_len, num_experts, hidden_dim) - expert_outputs = torch.stack([self.experts[i](hidden_states) for i in range(self.num_experts)], dim=1) - - # Efficient expert selection using matrix multiplication - selected_expert_outputs = torch.einsum("bte,beh->bth", expert_mask, expert_outputs) - - - # Multiply by routing weights and sum over top_k experts - final_hidden_states = (selected_expert_outputs * routing_weights.unsqueeze(-1)).sum(dim=1) - - # Reshape back to original dimensions - final_hidden_states = final_hidden_states.view(batch_size, sequence_length, hidden_dim) + # One hot encode the selected experts to create an expert mask + # this will be used to easily index which expert is going to be sollicitated + expert_mask = torch.nn.functional.one_hot(selected_experts, num_classes=self.num_experts).permute(2, 1, 0) + + # Loop over all available experts in the model and perform the computation on each expert + for expert_idx in range(self.num_experts): + expert_layer = self.experts[expert_idx] + expert_mask_tr = expert_mask[expert_idx].transpose(0, 1) + current_hidden_states = expert_layer(hidden_states) * (((routing_weights * expert_mask_tr).sum(1))[:, None]) + current_hidden_states = torch.where( + (routing_weights * expert_mask_tr).sum(1).to(torch.bool)[:, None], + current_hidden_states, + torch.tensor(0.0), + ) + final_hidden_states = final_hidden_states + current_hidden_states + final_hidden_states = final_hidden_states.reshape(batch_size, sequence_length, hidden_dim) return final_hidden_states, router_logits + class QeffMixtralDecoderLayer(MixtralDecoderLayer): """ Copied from MixtralForCausalLM: https://github.com/huggingface/transformers/blob/main/src/transformers/models/mixtral/modeling_mixtral.py diff --git a/QEfficient/transformers/models/modeling_auto.py b/QEfficient/transformers/models/modeling_auto.py index 0531af7b8..298066504 100644 --- a/QEfficient/transformers/models/modeling_auto.py +++ b/QEfficient/transformers/models/modeling_auto.py @@ -1295,6 +1295,7 @@ class QEFFAutoModelForCausalLM(QEFFBaseModel): FP8DeQuantLinearToLinearTransform, CustomOpsTransform, KVCacheTransform, + KVCacheModuleMethodMapperTransform, ] _onnx_transforms = [FP16ClipTransform, SplitTensorsTransform] diff --git a/QEfficient/transformers/models/pytorch_transforms.py b/QEfficient/transformers/models/pytorch_transforms.py index 2e94908c8..2fd971942 100644 --- a/QEfficient/transformers/models/pytorch_transforms.py +++ b/QEfficient/transformers/models/pytorch_transforms.py @@ -189,6 +189,14 @@ QEffGraniteMoeRotaryEmbedding, QEffGraniteMoeTopKGating, ) +from QEfficient.transformers.models.grok_1.modeling_grok1 import ( + QEffGrok1DecoderLayer, + QEffGrok1Model, + QEffGrok1ModelForCausalLM, + QEffGrok1MoeBlock, + QEffGrok1MultiHeadAttention, +) + from QEfficient.transformers.models.internvl.modeling_internvl import QEffInternVisionEmbeddings, QEffInternVLModel from QEfficient.transformers.models.llama.modeling_llama import ( QEffLlamaAttention, @@ -468,5 +476,15 @@ class KVCacheModuleMethodMapperTransform(ModuleMethodMapperTransform): "get_qeff_language_decoder": QEffInternVLModel.get_qeff_language_decoder, }, "InternVisionEmbeddings": {"forward": QEffInternVisionEmbeddings.forward}, + # #Mapping for grok1 model + "Grok1ModelForCausalLM": {"forward": QEffGrok1ModelForCausalLM.forward}, + "Grok1Model": {"forward": QEffGrok1Model.forward}, + "DecoderLayer": {"forward": QEffGrok1DecoderLayer.forward}, + "MoeBlock": {"forward": QEffGrok1MoeBlock.forward}, + "MultiHeadAttention": { + "forward": QEffGrok1MultiHeadAttention.forward, + "__qeff_init__": QEffGrok1MultiHeadAttention.__qeff_init__, + }, } + _match_class_replace_method = {} diff --git a/QEfficient/transformers/quantizers/quantizer_compressed_tensors.py b/QEfficient/transformers/quantizers/quantizer_compressed_tensors.py index 9c8a7166a..18e814b83 100644 --- a/QEfficient/transformers/quantizers/quantizer_compressed_tensors.py +++ b/QEfficient/transformers/quantizers/quantizer_compressed_tensors.py @@ -130,7 +130,12 @@ def forward(self, x): class QEffFP8Config(QuantizationConfigMixin): def __init__( - self, quant_method: str, activation_scheme: str, ignored_layers: List[str] = None, kv_cache_scheme: str = None + self, + quant_method: str, + activation_scheme: str, + ignored_layers: List[str] = None, + kv_cache_scheme: str = None, + run_compressed: bool = False, ): self.quant_method = quant_method self.activation_scheme = activation_scheme @@ -225,6 +230,7 @@ def __init__( ignore=None, sparsity_config=None, quant_method="compressed-tensors", + run_compressed: bool = False, **kwargs, ): self.config_groups = config_groups From b0b50030c81315ec1360b25701833dd29d534235 Mon Sep 17 00:00:00 2001 From: Amit Raj Date: Fri, 11 Apr 2025 17:15:46 +0000 Subject: [PATCH 15/16] Minor improvement Signed-off-by: Amit Raj --- .../models/grok_1/modeling_grok1.py | 40 +++++-------------- .../transformers/models/pytorch_transforms.py | 6 ++- 2 files changed, 15 insertions(+), 31 deletions(-) diff --git a/QEfficient/transformers/models/grok_1/modeling_grok1.py b/QEfficient/transformers/models/grok_1/modeling_grok1.py index 07f8f1e3d..d489b2c42 100644 --- a/QEfficient/transformers/models/grok_1/modeling_grok1.py +++ b/QEfficient/transformers/models/grok_1/modeling_grok1.py @@ -9,35 +9,14 @@ import torch import torch.nn as nn import torch.nn.functional as F - -from QEfficient.transformers.cache_utils import QEffDynamicCache -from QEfficient.transformers.modeling_attn_mask_utils import _create_causal_mask - from transformers.modeling_outputs import ( MoeCausalLMOutputWithPast, MoeModelOutputWithPast, ) +from transformers.models.llama.modeling_llama import repeat_kv, rotate_half - -# Copied from transformers.models.llama.modeling_llama.repeat_kv -def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor: - """ - This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch, - num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim) - """ - batch, num_key_value_heads, slen, head_dim = hidden_states.shape - if n_rep == 1: - return hidden_states - hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim) - return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim) - - -# Copied from transformers.models.llama.modeling_llama.rotate_half -def rotate_half(x): - """Rotates half the hidden dims of the input.""" - x1 = x[..., : x.shape[-1] // 2] - x2 = x[..., x.shape[-1] // 2 :] - return torch.cat((-x2, x1), dim=-1) +from QEfficient.transformers.cache_utils import QEffDynamicCache +from QEfficient.transformers.modeling_attn_mask_utils import _create_causal_mask # Copied from transformers.models.llama.modeling_llama.apply_rotary_pos_emb @@ -70,12 +49,10 @@ def qeff_apply_rotary_pos_emb(q, k, cos, sin, position_ids, unsqueeze_dim=1): class QEffGrok1MultiHeadAttention(nn.Module): - def __qeff_init__(self): - self.layer_idx = 0 - def forward( self, hidden_states: torch.Tensor, + layer_idx: int, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, past_key_value: Optional[Tuple[torch.Tensor]] = None, @@ -96,7 +73,7 @@ def forward( kv_seq_len = key_states.shape[-2] if past_key_value is not None: - kv_seq_len = past_key_value.get_usable_length(kv_seq_len, self.layer_idx) + kv_seq_len = past_key_value.get_usable_length(kv_seq_len, layer_idx) cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len) query_states, key_states = qeff_apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids) @@ -108,7 +85,7 @@ def forward( "batch_index": batch_index, "position_ids": position_ids, } # Specific to RoPE models - key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) + key_states, value_states = past_key_value.update(key_states, value_states, layer_idx, cache_kwargs) # repeat k/v heads if n_kv_heads < n_heads key_states = repeat_kv(key_states, self.num_key_value_groups) @@ -194,6 +171,7 @@ def forward( hidden_states = self.pre_attn_norm(hidden_states) hidden_states, attention_weights, present_key_value = self.attn( hidden_states, + layer_idx=self.layer_idx, attention_mask=attention_mask, position_ids=position_ids, past_key_value=past_key_value, @@ -221,6 +199,10 @@ def forward( class QEffGrok1Model(nn.Module): + def __qeff_init__(self): + for idx, layer in enumerate(self.layers): + layer.layer_idx = idx + def forward( self, input_ids: torch.LongTensor = None, diff --git a/QEfficient/transformers/models/pytorch_transforms.py b/QEfficient/transformers/models/pytorch_transforms.py index 2fd971942..67bf94901 100644 --- a/QEfficient/transformers/models/pytorch_transforms.py +++ b/QEfficient/transformers/models/pytorch_transforms.py @@ -478,12 +478,14 @@ class KVCacheModuleMethodMapperTransform(ModuleMethodMapperTransform): "InternVisionEmbeddings": {"forward": QEffInternVisionEmbeddings.forward}, # #Mapping for grok1 model "Grok1ModelForCausalLM": {"forward": QEffGrok1ModelForCausalLM.forward}, - "Grok1Model": {"forward": QEffGrok1Model.forward}, + "Grok1Model": { + "forward": QEffGrok1Model.forward, + "__qeff_init__": QEffGrok1Model.__qeff_init__, + }, "DecoderLayer": {"forward": QEffGrok1DecoderLayer.forward}, "MoeBlock": {"forward": QEffGrok1MoeBlock.forward}, "MultiHeadAttention": { "forward": QEffGrok1MultiHeadAttention.forward, - "__qeff_init__": QEffGrok1MultiHeadAttention.__qeff_init__, }, } From 5a72baca7906083565e23c3ffa569c3cc9c73c8a Mon Sep 17 00:00:00 2001 From: Amit Raj Date: Mon, 21 Apr 2025 10:58:11 +0000 Subject: [PATCH 16/16] REsolved Conflicts Signed-off-by: Amit Raj --- .../models/gemma/modeling_gemma.py | 4 +- .../models/gemma2/modeling_gemma2.py | 38 ------------------- .../models/llama/modeling_llama.py | 17 ++------- .../models/mistral/modeling_mistral.py | 4 +- .../models/qwen2/modeling_qwen2.py | 18 --------- 5 files changed, 8 insertions(+), 73 deletions(-) diff --git a/QEfficient/transformers/models/gemma/modeling_gemma.py b/QEfficient/transformers/models/gemma/modeling_gemma.py index 5282ae484..44142f4e0 100644 --- a/QEfficient/transformers/models/gemma/modeling_gemma.py +++ b/QEfficient/transformers/models/gemma/modeling_gemma.py @@ -86,8 +86,8 @@ def qeff_apply_rotary_pos_emb(q, k, cos, sin, position_ids, unsqueeze_dim=1): Returns: `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding. """ - cos = cos.unsqueeze(unsqueeze_dim) - sin = sin.unsqueeze(unsqueeze_dim) + cos = cos[position_ids].unsqueeze(unsqueeze_dim) + sin = sin[position_ids].unsqueeze(unsqueeze_dim) # Apply rotation q_embed = (q * cos) + (rotate_half(q) * sin) diff --git a/QEfficient/transformers/models/gemma2/modeling_gemma2.py b/QEfficient/transformers/models/gemma2/modeling_gemma2.py index b7890b24e..e73d759d8 100644 --- a/QEfficient/transformers/models/gemma2/modeling_gemma2.py +++ b/QEfficient/transformers/models/gemma2/modeling_gemma2.py @@ -40,11 +40,7 @@ class QEffGemma2RotaryEmbedding(Gemma2RotaryEmbedding): """ def __init__(self, config: Gemma2Config, device=None): -<<<<<<< HEAD super().__init__(config=config) -======= - Gemma2RotaryEmbedding.__init__(self, config=config) ->>>>>>> 6ba4c76 (Code cleaning and formating) # Build here to make `torch.jit.trace` work. self._set_cos_sin_cache( @@ -136,16 +132,6 @@ class QEffGemma2Attention(Gemma2Attention): - add new args cache idx for the kv retention """ -<<<<<<< HEAD -======= - def __init__(self, config: Gemma2Config, layer_idx: Optional[int] = None): - super().__init__(config, layer_idx) - # Define the general __qeff_init__() for any changes in the init calls - # Set the init in the module mapping pytorch transforms - self.config = config - self.__qeff_init__() - ->>>>>>> 6ba4c76 (Code cleaning and formating) def __qeff_init__(self): self.rotary_emb = QEffGemma2RotaryEmbedding(config=self.config) @@ -355,10 +341,6 @@ def forward( output_attentions=output_attentions, use_cache=use_cache, cache_position=cache_position, -<<<<<<< HEAD -======= - position_embeddings=position_embeddings, ->>>>>>> 6ba4c76 (Code cleaning and formating) **kwargs, ) @@ -384,26 +366,6 @@ def forward( ) return output if return_dict else output.to_tuple() -<<<<<<< HEAD -======= - def _update_causal_mask( - self, - attention_mask: torch.Tensor, - input_tensor: torch.Tensor, - cache_position: torch.Tensor, - position_ids: torch.Tensor, - past_key_values: Cache, - output_attentions: bool, - ): - past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 - target_length = attention_mask.shape[-1] if isinstance(attention_mask, torch.Tensor) else past_seen_tokens - causal_mask = _create_causal_mask( - position_ids=position_ids, target_length=target_length, sliding_window=self.config.sliding_window - ) - - return causal_mask - ->>>>>>> 6ba4c76 (Code cleaning and formating) class QEffGemma2ForCausalLM(Gemma2ForCausalLM, GenerationMixin): """ diff --git a/QEfficient/transformers/models/llama/modeling_llama.py b/QEfficient/transformers/models/llama/modeling_llama.py index d20fcb8d5..dae783361 100644 --- a/QEfficient/transformers/models/llama/modeling_llama.py +++ b/QEfficient/transformers/models/llama/modeling_llama.py @@ -85,8 +85,8 @@ def qeff_apply_rotary_pos_emb(q, k, cos, sin, position_ids, unsqueeze_dim=1): Returns: `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding. """ - cos = cos.unsqueeze(unsqueeze_dim) - sin = sin.unsqueeze(unsqueeze_dim) + cos = cos[position_ids].unsqueeze(unsqueeze_dim) + sin = sin[position_ids].unsqueeze(unsqueeze_dim) # Apply rotation q_embed = (q * cos) + (rotate_half(q) * sin) @@ -94,17 +94,6 @@ def qeff_apply_rotary_pos_emb(q, k, cos, sin, position_ids, unsqueeze_dim=1): # Cast back to original dtype return q_embed.to(q.dtype), k_embed.to(k.dtype) -def eager_attention_forward( - module: nn.Module, - query: torch.Tensor, - key: torch.Tensor, - value: torch.Tensor, - attention_mask: Optional[torch.Tensor], - scaling: float, - **kwargs, -): - key_states = repeat_kv(key, module.num_key_value_groups) - value_states = repeat_kv(value, module.num_key_value_groups) def eager_attention_forward( module: nn.Module, @@ -142,6 +131,8 @@ def forward( position_ids: Optional[torch.LongTensor] = None, past_key_value: Optional[Cache] = None, batch_index: Optional[torch.LongTensor] = None, + output_attentions: bool = False, + use_cache: bool = False, cache_position: Optional[torch.LongTensor] = None, **kwargs, ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: diff --git a/QEfficient/transformers/models/mistral/modeling_mistral.py b/QEfficient/transformers/models/mistral/modeling_mistral.py index 520488e1d..183b07b3a 100644 --- a/QEfficient/transformers/models/mistral/modeling_mistral.py +++ b/QEfficient/transformers/models/mistral/modeling_mistral.py @@ -140,6 +140,8 @@ def forward( position_ids: Optional[torch.LongTensor] = None, past_key_value: Optional[Cache] = None, batch_index: Optional[torch.LongTensor] = None, + output_attentions: bool = False, + use_cache: bool = False, cache_position: Optional[torch.LongTensor] = None, **kwargs, ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: @@ -199,7 +201,6 @@ def forward( output_attentions: Optional[bool] = False, use_cache: Optional[bool] = False, cache_position: Optional[torch.LongTensor] = None, - position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # necessary, but kept here for BC **kwargs, ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]: """ @@ -230,7 +231,6 @@ def forward( output_attentions=output_attentions, use_cache=use_cache, cache_position=cache_position, - position_embeddings=position_embeddings, **kwargs, ) hidden_states = residual + hidden_states diff --git a/QEfficient/transformers/models/qwen2/modeling_qwen2.py b/QEfficient/transformers/models/qwen2/modeling_qwen2.py index ef2300ae6..0ea22cead 100644 --- a/QEfficient/transformers/models/qwen2/modeling_qwen2.py +++ b/QEfficient/transformers/models/qwen2/modeling_qwen2.py @@ -68,10 +68,6 @@ def forward(self, x, seq_len=None): ) -<<<<<<< HEAD -<<<<<<< HEAD -======= ->>>>>>> e4503c5 (Minor fixes) def qeff_apply_rotary_pos_emb(q, k, cos, sin, position_ids, unsqueeze_dim=1): """Applies Rotary Position Embedding with Multimodal Sections to the query and key tensors (https://qwenlm.github.io/blog/qwen2-vl/). @@ -105,13 +101,8 @@ def qeff_apply_rotary_pos_emb(q, k, cos, sin, position_ids, unsqueeze_dim=1): `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding. """ -<<<<<<< HEAD - cos = cos.unsqueeze(unsqueeze_dim) - sin = sin.unsqueeze(unsqueeze_dim) -======= cos = cos[position_ids].unsqueeze(unsqueeze_dim) sin = sin[position_ids].unsqueeze(unsqueeze_dim) ->>>>>>> e4503c5 (Minor fixes) q_embed = (q * cos) + (rotate_half(q) * sin) k_embed = (k * cos) + (rotate_half(k) * sin) @@ -119,10 +110,6 @@ def qeff_apply_rotary_pos_emb(q, k, cos, sin, position_ids, unsqueeze_dim=1): return q_embed.to(q.dtype), k_embed.to(k.dtype) -<<<<<<< HEAD -======= ->>>>>>> d0f7ffd (Ruff check and format) -======= def eager_attention_forward( module: nn.Module, query: torch.Tensor, @@ -146,7 +133,6 @@ def eager_attention_forward( return attn_output, attn_weights ->>>>>>> e4503c5 (Minor fixes) class QEffQwen2Attention(Qwen2Attention): """ Copied from Qwen2Attention: https://github.com/huggingface/transformers/blob/main/src/transformers/models/qwen2/modeling_qwen2.py @@ -177,11 +163,7 @@ def forward( kv_seq_len = key_states.shape[-2] kv_seq_len = past_key_value.get_usable_length(kv_seq_len, self.layer_idx) -<<<<<<< HEAD - cos, sin = position_embeddings -======= cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len) ->>>>>>> e4503c5 (Minor fixes) query_states, key_states = qeff_apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids) if past_key_value is not None: