From 56c5030396a5c3d8fa4ae9da5548d30d06f95108 Mon Sep 17 00:00:00 2001 From: quic-shagun Date: Fri, 4 Apr 2025 10:09:12 -0700 Subject: [PATCH 1/7] Onboard Plamo model Signed-off-by: quic-shagun --- QEfficient/base/pytorch_transforms.py | 3 + .../transformers/models/modeling_auto.py | 1 + .../transformers/models/plamo/__init__.py | 6 + .../models/plamo/modeling_plamo.py | 742 ++++++++++++++++++ .../transformers/models/pytorch_transforms.py | 16 + QEfficient/utils/constants.py | 4 +- 6 files changed, 770 insertions(+), 2 deletions(-) create mode 100644 QEfficient/transformers/models/plamo/__init__.py create mode 100644 QEfficient/transformers/models/plamo/modeling_plamo.py diff --git a/QEfficient/base/pytorch_transforms.py b/QEfficient/base/pytorch_transforms.py index abd19ed35..6bb481dc4 100644 --- a/QEfficient/base/pytorch_transforms.py +++ b/QEfficient/base/pytorch_transforms.py @@ -107,6 +107,9 @@ def apply(cls, model: nn.Module) -> Tuple[nn.Module, bool]: ): for orig_method_name, mapped_method in repl_method_map.items(): setattr(module, orig_method_name, MethodType(mapped_method, module)) + # Handling the __init__ calls in the models + if hasattr(module, "__qeff_init__"): + module.__qeff_init__() transformed = True return model, transformed diff --git a/QEfficient/transformers/models/modeling_auto.py b/QEfficient/transformers/models/modeling_auto.py index eb3748bc5..ae648bd3e 100644 --- a/QEfficient/transformers/models/modeling_auto.py +++ b/QEfficient/transformers/models/modeling_auto.py @@ -1280,6 +1280,7 @@ class QEFFAutoModelForCausalLM(QEFFBaseModel): FP8DeQuantLinearToLinearTransform, CustomOpsTransform, KVCacheTransform, + KVCacheModuleMethodMapperTransform, ] _onnx_transforms = [FP16ClipTransform, SplitTensorsTransform] diff --git a/QEfficient/transformers/models/plamo/__init__.py b/QEfficient/transformers/models/plamo/__init__.py new file mode 100644 index 000000000..72ba36c8a --- /dev/null +++ b/QEfficient/transformers/models/plamo/__init__.py @@ -0,0 +1,6 @@ +# ----------------------------------------------------------------------------- +# +# Copyright (c) 2025 Qualcomm Innovation Center, Inc. All rights reserved. +# SPDX-License-Identifier: BSD-3-Clause +# +# ----------------------------------------------------------------------------- diff --git a/QEfficient/transformers/models/plamo/modeling_plamo.py b/QEfficient/transformers/models/plamo/modeling_plamo.py new file mode 100644 index 000000000..351d8b2df --- /dev/null +++ b/QEfficient/transformers/models/plamo/modeling_plamo.py @@ -0,0 +1,742 @@ +from asyncio.log import logger +import math +from typing import Any, Dict, List, NamedTuple, Optional, Tuple, Union + +import numpy as np +import torch +from torch import nn +from torch.nn import functional as F +from transformers import PretrainedConfig, PreTrainedModel +from transformers.modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast +from transformers.cache_utils import Cache, DynamicCache, StaticCache + +from QEfficient.transformers.modeling_attn_mask_utils import _create_causal_mask + + +class QEffPlamoConfig(PretrainedConfig): # type: ignore + model_type: str = "plamo" + + def __init__( + self, + vocab_size: int = 32000, + hidden_size: int = 4096, + intermediate_size: int = 13312, + num_hidden_layers: int = 32, + num_attention_heads: int = 32, + num_key_value_heads: Optional[int] = None, + max_position_embeddings: int = 2048, + initializer_range: float = 0.02, + rms_norm_eps: float = 1e-6, + use_cache: bool = True, + tokenizer_class: str = "PlamoTokenizer", + pad_token_id: Optional[int] = None, + bos_token_id: int = 1, + eos_token_id: int = 2, + n_shared_head: int = 8, + tie_word_embeddings: bool = False, + **kwargs: Any, + ) -> None: + self.vocab_size = vocab_size + self.max_position_embeddings = max_position_embeddings + self.hidden_size = hidden_size + self.intermediate_size = intermediate_size + self.num_hidden_layers = num_hidden_layers + self.num_attention_heads = num_attention_heads + + # for backward compatibility + if num_key_value_heads is None: + num_key_value_heads = num_attention_heads + + self.num_key_value_heads = num_key_value_heads + self.initializer_range = initializer_range + self.rms_norm_eps = rms_norm_eps + self.use_cache = use_cache + + self.n_shared_head = n_shared_head + + super().__init__( + tokenizer_class=tokenizer_class, + pad_token_id=pad_token_id, + bos_token_id=bos_token_id, + eos_token_id=eos_token_id, + tie_word_embeddings=tie_word_embeddings, + **kwargs, + ) + + +# Copied from transformers.models.bart.modeling_bart._make_causal_mask +def _make_causal_mask( + input_ids_shape: Tuple[int, int], dtype: torch.dtype, device: torch.device, past_key_values_length: int = 0 +) -> torch.Tensor: + """ + Make causal mask used for bi-directional self-attention. + """ + bsz, tgt_len = input_ids_shape + mask = torch.full((tgt_len, tgt_len), torch.finfo(dtype).min, device=device) + mask_cond = torch.arange(mask.size(-1), device=device) + mask.masked_fill_(mask_cond < (mask_cond + 1).view(mask.size(-1), 1), 0) + mask = mask.to(dtype) + + if past_key_values_length > 0: + mask = torch.cat([torch.zeros(tgt_len, past_key_values_length, dtype=dtype, device=device), mask], dim=-1) + return mask[None, None, :, :].expand(bsz, 1, tgt_len, tgt_len + past_key_values_length) + + +# Copied from transformers.models.bart.modeling_bart._expand_mask +def _expand_mask(mask: torch.Tensor, dtype: torch.dtype, tgt_len: Optional[int] = None) -> torch.Tensor: + """ + Expands attention_mask from `[bsz, seq_len]` to `[bsz, 1, tgt_seq_len, src_seq_len]`. + """ + bsz, src_len = mask.size() + tgt_len = tgt_len if tgt_len is not None else src_len + + expanded_mask = mask[:, None, None, :].expand(bsz, 1, tgt_len, src_len).to(dtype) + + inverted_mask = 1.0 - expanded_mask + + return inverted_mask.masked_fill(inverted_mask.to(torch.bool), torch.finfo(dtype).min) # type: ignore + + +class QEffPlamoRotaryEmbedding(torch.nn.Module): + def __init__( + self, dim: int, max_position_embeddings: int = 2048, base: int = 10000, device: Optional[torch.device] = None + ) -> None: + super().__init__() + + self.dim = dim + self.max_position_embeddings = max_position_embeddings + self.base = base + inv_freq = 1.0 / (self.base ** (torch.arange(0, self.dim, 2).float().to(device) / self.dim)) + self.register_buffer("inv_freq", inv_freq, persistent=False) + + # Build here to make `torch.jit.trace` work. + self._set_cos_sin_cache( + seq_len=max_position_embeddings, device=self.inv_freq.device, dtype=torch.get_default_dtype() + ) + + def _set_cos_sin_cache(self, seq_len: int, device: Any, dtype: Any) -> None: + self.max_seq_len_cached = seq_len + t = torch.arange(self.max_seq_len_cached, device=device, dtype=self.inv_freq.dtype) # type: ignore + + freqs = torch.einsum("i,j->ij", t, self.inv_freq) + # Different from paper, but it uses a different permutation in order to obtain the same calculation + emb = torch.cat((freqs, freqs), dim=-1) + self.register_buffer("cos_cached", emb.cos()[None, None, :, :].to(dtype), persistent=False) + self.register_buffer("sin_cached", emb.sin()[None, None, :, :].to(dtype), persistent=False) + + def forward(self, x: torch.Tensor, seq_len: int) -> Tuple[torch.Tensor, torch.Tensor]: + # x: [bs, num_attention_heads, seq_len, head_size] + if seq_len > self.max_seq_len_cached: + self._set_cos_sin_cache(seq_len=seq_len, device=x.device, dtype=x.dtype) + + return ( + self.cos_cached[:, :, :seq_len, ...].to(dtype=x.dtype), # type: ignore + self.sin_cached[:, :, :seq_len, ...].to(dtype=x.dtype), # type: ignore + ) + + +def _rotate_half(x: torch.Tensor) -> torch.Tensor: + """Rotates half the hidden dims of the input.""" + x1 = x[..., : x.shape[-1] // 2] + x2 = x[..., x.shape[-1] // 2 :] + return torch.cat((-x2, x1), dim=-1) + + +def _rotary_pos_emb(x: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor, position_ids: torch.Tensor) -> torch.Tensor: + # The first two dimensions of cos and sin are always 1, so we can `squeeze` them. + cos = cos.squeeze(1).squeeze(0) # [seq_len, dim] + sin = sin.squeeze(1).squeeze(0) # [seq_len, dim] + cos = cos[position_ids].unsqueeze(1) # [bs, 1, seq_len, dim] + sin = sin[position_ids].unsqueeze(1) # [bs, 1, seq_len, dim] + x_embed = (x * cos) + (_rotate_half(x) * sin) + return x_embed + + +class QEffPlamoRMSNorm(nn.Module): + def __init__(self, hidden_size: int, eps: float = 1e-6) -> None: + super().__init__() + self.weight = nn.Parameter(torch.ones(hidden_size)) + self.variance_epsilon = eps + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + input_dtype = hidden_states.dtype + hidden_states = hidden_states.to(torch.float32) + variance = hidden_states.pow(2).mean(-1, keepdim=True) + hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon) + return self.weight * hidden_states.to(input_dtype) + + +class QEffPlamoAttention(torch.nn.Module): + def __init__(self, config: QEffPlamoConfig, layer_idx: Optional[int] = None) -> None: + super().__init__() + self.config = config + self.hidden_size = config.hidden_size + head_dim = self.hidden_size // config.num_attention_heads + self.max_position_embeddings = config.max_position_embeddings + + self.q_num_heads = config.num_attention_heads + self.qk_dim = self.v_dim = head_dim + self.k_num_heads = self.v_num_heads = int(np.ceil(self.q_num_heads / config.n_shared_head)) + + self.q_proj = nn.Linear(self.hidden_size, self.q_num_heads * self.qk_dim, bias=False) + self.k_proj = nn.Linear(self.hidden_size, self.k_num_heads * self.qk_dim, bias=False) + self.v_proj = nn.Linear(self.hidden_size, self.v_num_heads * self.v_dim, bias=False) + self.o_proj = nn.Linear(self.q_num_heads * self.v_dim, self.hidden_size, bias=False) + self.rotary_emb = QEffPlamoRotaryEmbedding(self.qk_dim, max_position_embeddings=self.max_position_embeddings) + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.Tensor] = None, + past_key_value: Optional[Cache] = None, + output_attentions: bool = False, + use_cache: bool = False, + batch_index: Optional[torch.Tensor] = None, + layer_idx: Optional[int] = None, + cache_position: Optional[torch.LongTensor] = None, + ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor, torch.Tensor]]]: + bsz, q_len, _ = hidden_states.size() + + query_states = self.q_proj(hidden_states).view(bsz, q_len, self.q_num_heads, self.qk_dim).transpose(1, 2) + key_states = self.k_proj(hidden_states).view(bsz, q_len, self.k_num_heads, self.qk_dim).transpose(1, 2) + value_states = self.v_proj(hidden_states).view(bsz, q_len, self.v_num_heads, self.v_dim).transpose(1, 2) + + def _expand_kv(t: torch.Tensor, repeat: int, target: int) -> torch.Tensor: + return t.repeat(1, repeat, 1, 1)[:, :target] + + # expand shared kv + assert self.k_num_heads == self.v_num_heads + key_states = _expand_kv(key_states, self.config.n_shared_head, self.q_num_heads) + value_states = _expand_kv(value_states, self.config.n_shared_head, self.q_num_heads) + + kv_seq_len = key_states.shape[-2] + if past_key_value is not None: + if layer_idx is None: + raise ValueError( + f"The cache structure has changed since version v4.36. If you are using {self.__class__.__name__} " + "for auto-regressive decoding with k/v caching, please make sure to initialize the attention class " + "with a layer index." + ) + kv_seq_len = past_key_value.get_usable_length(kv_seq_len, layer_idx) + + cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len) + assert position_ids is not None + query_states = _rotary_pos_emb(query_states, cos, sin, position_ids) + key_states = _rotary_pos_emb(key_states, cos, sin, position_ids) + # [bsz, nh, t, hd] + + if past_key_value is not None: + # sin and cos are specific to RoPE models; cache_position needed for the static cache + cache_kwargs = {"sin": sin, "cos": cos, "batch_index": batch_index, "position_ids": position_ids} + key_states, value_states = past_key_value.update(key_states, value_states, layer_idx, cache_kwargs) + + attn_output = F.scaled_dot_product_attention(query_states, key_states, value_states, attn_mask=attention_mask) + attn_output = attn_output.transpose(1, 2) + + attn_output = attn_output.reshape(bsz, q_len, self.q_num_heads * self.v_dim) + attn_output = self.o_proj(attn_output) + + if not output_attentions: + attn_weights = None + + return attn_output, attn_weights, past_key_value + + +class MLP(nn.Module): + def __init__(self, config: QEffPlamoConfig) -> None: + super().__init__() + self.config = config + self.hidden_size = config.hidden_size + self.intermediate_size = config.intermediate_size + self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False) + self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False) + self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False) + self.act_fn = torch.nn.functional.silu + + def forward(self, x: torch.Tensor) -> torch.Tensor: + return self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x)) # type: ignore + + +class QEffPlamoDecoderLayer(torch.nn.Module): + def __init__(self, config: QEffPlamoConfig, layer_idx: int) -> None: + super().__init__() + self.config = config + self.hidden_size = config.hidden_size + self.self_attn = QEffPlamoAttention(config, layer_idx) + self.mlp = MLP(config) + self.norm = QEffPlamoRMSNorm(config.hidden_size, eps=config.rms_norm_eps) + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_value: Optional[Cache] = None, + output_attentions: Optional[bool] = False, + use_cache: Optional[bool] = False, + batch_index: Optional[torch.LongTensor] = None, + cache_position: Optional[torch.LongTensor] = None, + layer_idx: Optional[int] = None, + ) -> Tuple[Any, ...]: + # from LlamaDecoder + residual = hidden_states + + hidden_states = self.norm(hidden_states) + + # Self Attention + hidden_states_sa, self_attn_weights, present_key_value = self.self_attn( + hidden_states=hidden_states, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_value=past_key_value, + batch_index=batch_index, + output_attentions=output_attentions, + use_cache=use_cache, + cache_position=cache_position, + layer_idx=layer_idx, + ) + + # Fully Connected + hidden_states_mlp = self.mlp(hidden_states) + + # Residual + hidden_states = residual + hidden_states_sa + hidden_states_mlp + + outputs: Any = (hidden_states,) + + if output_attentions: + outputs += (self_attn_weights,) + + if use_cache: + outputs += (present_key_value,) + + return outputs # type: ignore + + +class QEffPlamoDecoder(torch.nn.Module): + def __init__(self, config: QEffPlamoConfig) -> None: + super().__init__() + self.layers = torch.nn.ModuleList([QEffPlamoDecoderLayer(config,layer_idx) for layer_idx in range(config.num_hidden_layers)]) + + def forward(self, + hidden_states: torch.Tensor, + position_ids: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + past_key_values: Optional[List[torch.FloatTensor]] = None, + output_hidden_states: Optional[bool] = False, + output_attentions: Optional[bool] = False, + use_cache: Optional[bool] = False, + gradient_checkpointing: bool = False, + batch_index: Optional[torch.LongTensor] = None, + cache_position: Optional[torch.LongTensor] = None, + ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]: + + all_hidden_states: Optional[Tuple[torch.Tensor, ...]] = () if output_hidden_states else None + all_self_attns: Optional[Tuple[torch.Tensor, ...]] = () if output_attentions else None + next_decoder_cache: Optional[Tuple[torch.Tensor, ...]] = () if use_cache else None + hidden_states = hidden_states + + for idx, decoder_layer in enumerate(self.layers): + if output_hidden_states: + assert all_hidden_states is not None + all_hidden_states += (hidden_states,) + + #past_key_value = x.past_key_values[idx] if x.past_key_values is not None else None + + if self.training and gradient_checkpointing: + + def create_custom_forward(module): # type: ignore + def custom_forward(*inputs): # type: ignore + # None for past_key_value + return module(*inputs, output_attentions, None) + + return custom_forward + + layer_outputs = torch.utils.checkpoint.checkpoint( + create_custom_forward(decoder_layer), # type: ignore + hidden_states, + attention_mask, + position_ids, + None, + ) + elif batch_index is not None: + layer_outputs = decoder_layer( + hidden_states, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_value=past_key_values, + output_attentions=output_attentions, + use_cache=use_cache, + layer_idx=idx, + batch_index=batch_index, + cache_position=cache_position, + ) + else: + layer_outputs = decoder_layer( + hidden_states, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_value=past_key_values, + output_attentions=output_attentions, + use_cache=use_cache, + layer_idx=idx, + cache_position=cache_position, + ) + + hidden_states = layer_outputs[0] + + if use_cache: + cache = layer_outputs[2 if output_attentions else 1] + assert cache is not None + assert next_decoder_cache is not None + next_decoder_cache = cache + + if output_attentions: + assert layer_outputs[1] is not None + assert all_self_attns is not None + all_self_attns += (layer_outputs[1],) + + return (hidden_states, all_hidden_states, all_self_attns, next_decoder_cache) + + +class QEffPlamoPreTrainedModel(PreTrainedModel): # type: ignore + config_class = QEffPlamoConfig + _no_split_modules: List[str] + base_model_prefix = "model" + supports_gradient_checkpointing = True + _no_split_modules = ["PlamoDecoderLayer"] + _skip_keys_device_placement = "past_key_values" + _keys_to_ignore_on_load_unexpected = [r"decoder\.version"] + + def _init_weights(self, module: torch.nn.Module) -> None: + std = self.config.initializer_range + if isinstance(module, nn.Linear): + module.weight.data.normal_(mean=0.0, std=std) + if module.bias is not None: + module.bias.data.zero_() + elif isinstance(module, nn.Embedding): + module.weight.data.normal_(mean=0.0, std=std) + if module.padding_idx is not None: + module.weight.data[module.padding_idx].zero_() + + def _set_gradient_checkpointing(self, module: torch.nn.Module, value: bool = False) -> None: + module.gradient_checkpointing = value # type: ignore + + +class QEffPlamoModel(QEffPlamoPreTrainedModel): + def __init__(self, config: QEffPlamoConfig): + super().__init__(config) + self.padding_idx = config.pad_token_id + self.vocab_size = config.vocab_size + + self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx) + self.layers = QEffPlamoDecoder(config) # type: ignore + self.norm = QEffPlamoRMSNorm(config.hidden_size, eps=config.rms_norm_eps) + + self.gradient_checkpointing = False + # Initialize weights and apply final processing + self.post_init() + + def get_input_embeddings(self) -> torch.nn.Embedding: + return self.embed_tokens + + def set_input_embeddings(self, value: torch.nn.Embedding) -> None: + self.embed_tokens = value + + # Copied from transformers.models.bart.modeling_bart.BartDecoder._prepare_decoder_attention_mask + def _prepare_decoder_attention_mask( + self, + attention_mask: torch.Tensor, + input_shape: Tuple[int, int], + inputs_embeds: Optional[torch.FloatTensor], + past_key_values_length: int, + ) -> Optional[torch.Tensor]: + # create causal mask + # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] + combined_attention_mask: Optional[torch.Tensor] = None + if input_shape[-1] > 1: + assert inputs_embeds is not None + combined_attention_mask = _make_causal_mask( + input_shape, + inputs_embeds.dtype, + device=inputs_embeds.device, + past_key_values_length=past_key_values_length, + ) + + if attention_mask is not None: + # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] + assert inputs_embeds is not None + expanded_attn_mask = _expand_mask(attention_mask, inputs_embeds.dtype, tgt_len=input_shape[-1]).to( + inputs_embeds.device + ) + combined_attention_mask = ( + expanded_attn_mask if combined_attention_mask is None else expanded_attn_mask + combined_attention_mask + ) + + return combined_attention_mask + + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.Tensor] = None, + past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None, + batch_index: Optional[torch.LongTensor] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + cache_position: Optional[torch.LongTensor] = None, + ) -> Union[Tuple, BaseModelOutputWithPast]: + assert input_ids is not None + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + use_cache = use_cache if use_cache is not None else self.config.use_cache + + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + # retrieve input_ids and inputs_embeds + if input_ids is not None and inputs_embeds is not None: + raise ValueError("You cannot specify both decoder_input_ids and decoder_inputs_embeds at the same time") + elif input_ids is not None: + batch_size, seq_length = input_ids.shape + else: + raise ValueError("You have to specify either decoder_input_ids or decoder_inputs_embeds") + + + if inputs_embeds is None: + inputs_embeds = self.embed_tokens(input_ids) + + # kept for BC (non `Cache` `past_key_values` inputs) + return_legacy_cache = False + if use_cache and not isinstance(past_key_values, Cache): + return_legacy_cache = True + if past_key_values is None: + past_key_values = DynamicCache() + else: + past_key_values = DynamicCache.from_legacy_cache(past_key_values) + logger.warning_once( + "We detected that you are passing `past_key_values` as a tuple of tuples. This is deprecated and " + "will be removed in v4.47. Please convert your cache or use an appropriate `Cache` class " + "(https://huggingface.co/docs/transformers/kv_cache#legacy-cache-format)" + ) + + if cache_position is None: + past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 + cache_position = torch.arange( + past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device + ) + + if position_ids is None: + position_ids = cache_position.unsqueeze(0) + + past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 + attention_mask = _create_causal_mask(position_ids=position_ids, target_length=past_seen_tokens) + + hidden_states = inputs_embeds + + if self.gradient_checkpointing and self.training: + if use_cache: + use_cache = False + + # decoder layers + layer_outputs = self.layers( + hidden_states=hidden_states, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + output_attentions=output_attentions, + use_cache=use_cache, + cache_position=cache_position, + batch_index=batch_index + ) + + hidden_states = layer_outputs[0] + all_hidden_states = layer_outputs[1] + all_self_attns = layer_outputs[2] + next_decoder_cache = layer_outputs[3] + + hidden_states = self.norm(hidden_states) + + # add hidden states from the last decoder layer + if output_hidden_states: + assert all_hidden_states is not None + all_hidden_states += (hidden_states,) + + next_cache = next_decoder_cache if use_cache else None + if not return_dict: + return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None) + + if return_legacy_cache: + next_cache = next_cache.to_legacy_cache() + + return BaseModelOutputWithPast( + last_hidden_state=hidden_states, + past_key_values=next_cache, + hidden_states=all_hidden_states, + attentions=all_self_attns, + ) + + +class QEffPlamoForCausalLM(QEffPlamoPreTrainedModel): + def __init__(self, config: PretrainedConfig) -> None: + super().__init__(config) + self.model = QEffPlamoModel(config) + + self.lm_head: torch.nn.Module = nn.Linear(config.hidden_size, config.vocab_size, bias=False) + + # Initialize weights and apply final processing + self.post_init() + + def get_input_embeddings(self) -> torch.nn.Embedding: + return self.model.embed_tokens + + def set_input_embeddings(self, value: torch.nn.Embedding) -> None: + self.model.embed_tokens = value + + def get_output_embeddings(self) -> torch.nn.Module: + return self.lm_head + + def set_output_embeddings(self, new_embeddings: torch.nn.Module) -> None: + self.lm_head = new_embeddings + + def set_decoder(self, decoder: QEffPlamoModel) -> None: + self.model = decoder + + def get_decoder(self) -> QEffPlamoModel: + return self.model + + def forward( # type: ignore + self, + input_ids: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.Tensor] = None, + past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None, + batch_index: Optional[torch.LongTensor] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + cache_position: Optional[torch.LongTensor] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple, CausalLMOutputWithPast]: + r""" + Args: + labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Labels for computing the masked language modeling loss. Indices should either be in `[0, ..., + config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored + (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`. + Returns: + Example: + ```python + >>> from transformers import AutoTokenizer, LlamaForCausalLM + >>> model = LlamaForCausalLM.from_pretrained(PATH_TO_CONVERTED_WEIGHTS) + >>> tokenizer = AutoTokenizer.from_pretrained(PATH_TO_CONVERTED_TOKENIZER) + >>> prompt = "Hey, are you consciours? Can you talk to me?" + >>> inputs = tokenizer(prompt, return_tensors="pt") + >>> # Generate + >>> generate_ids = model.generate(inputs.input_ids, max_length=30) + >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0] + "Hey, are you consciours? Can you talk to me?\nI'm not consciours, but I can talk to you." + ```""" + assert input_ids is not None + + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn) + outputs = self.model( + input_ids=input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + batch_index=batch_index, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + cache_position=cache_position, + ) + + # Cast to INT32 to avoid issue while running in ONNXRT + logit_index = position_ids.to(torch.int32).argmax(1, keepdim=True) + hidden_states = outputs[0][torch.arange(position_ids.shape[0]).view(-1, 1), logit_index] + + logits = self.lm_head(hidden_states) + logits = logits.float() + + loss = None + if labels is not None: + # Shift so that tokens < n predict n + shift_logits = logits[..., :-1, :].contiguous() + shift_labels = labels[..., 1:].contiguous() + # Flatten the tokens + loss_fct = nn.CrossEntropyLoss() + shift_logits = shift_logits.view(-1, self.config.vocab_size) + shift_labels = shift_labels.view(-1) + # Enable model parallelism + shift_labels = shift_labels.to(shift_logits.device) + loss = loss_fct(shift_logits, shift_labels) + + if not return_dict: + output = (logits,) + outputs[1:] + return (loss,) + output if loss is not None else output + + return CausalLMOutputWithPast( + loss=loss, + logits=logits, + past_key_values=outputs.past_key_values, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) + + def prepare_inputs_for_generation( + self, + input_ids: torch.Tensor, + past_key_values: Optional[Cache] = None, + attention_mask: Optional[torch.Tensor] = None, + inputs_embeds: Optional[torch.Tensor] = None, + **kwargs: Any, + ) -> Dict[str, Any]: + if past_key_values: + input_ids = input_ids[:, -1:] + + position_ids = kwargs.get("position_ids", None) + if attention_mask is not None and position_ids is None: + # create position_ids on the fly for batch generation + position_ids = attention_mask.long().cumsum(-1) - 1 + position_ids.masked_fill_(attention_mask == 0, 1) + if past_key_values: + position_ids = position_ids[:, -1].unsqueeze(-1) + + # if `inputs_embeds` are passed, we only want to use them in the 1st generation step + if inputs_embeds is not None and past_key_values is None: + model_inputs: Dict[str, Any] = {"inputs_embeds": inputs_embeds} + else: + model_inputs = {"input_ids": input_ids} + + model_inputs.update( + { + "position_ids": position_ids, + "past_key_values": past_key_values, + "use_cache": kwargs.get("use_cache"), + "attention_mask": attention_mask, + } + ) + return model_inputs + + @staticmethod + def _reorder_cache(past_key_values: List[torch.FloatTensor], beam_idx: int) -> Tuple[Any, ...]: + reordered_past: Tuple[Any, ...] = () + for layer_past in past_key_values: + reordered_past += (tuple(past_state.index_select(0, beam_idx) for past_state in layer_past),) + return reordered_past diff --git a/QEfficient/transformers/models/pytorch_transforms.py b/QEfficient/transformers/models/pytorch_transforms.py index bcedd4a27..5061bd7e0 100644 --- a/QEfficient/transformers/models/pytorch_transforms.py +++ b/QEfficient/transformers/models/pytorch_transforms.py @@ -225,6 +225,15 @@ QEffPhi3ForCausalLM, QEffPhi3Model, ) +from QEfficient.transformers.models.plamo.modeling_plamo import ( + QEffPlamoAttention, + QEffPlamoDecoder, + QEffPlamoDecoderLayer, + QEffPlamoForCausalLM, + QEffPlamoModel, + QEffPlamoRMSNorm, + QEffPlamoRotaryEmbedding +) from QEfficient.transformers.models.qwen2.modeling_qwen2 import ( QEffQwen2Attention, QEffQwen2DecoderLayer, @@ -439,5 +448,12 @@ class KVCacheModuleMethodMapperTransform(ModuleMethodMapperTransform): "get_qeff_language_decoder": QEffInternVLModel.get_qeff_language_decoder, }, "InternVisionEmbeddings": {"forward": QEffInternVisionEmbeddings.forward}, + "PlamoForCausalLM" : {"forward": QEffPlamoForCausalLM.forward}, + "PlamoModel": {"forward": QEffPlamoModel.forward}, + "PlamoDecoder":{ "forward": QEffPlamoDecoder.forward}, + "PlamoDecoderLayer":{ "forward": QEffPlamoDecoderLayer.forward}, + "Attention": {"forward": QEffPlamoAttention.forward}, + "RMSNorm":{"forward": QEffPlamoRMSNorm.forward}, + "RotaryEmbedding":{"forward": QEffPlamoRotaryEmbedding.forward}, } _match_class_replace_method = {} diff --git a/QEfficient/utils/constants.py b/QEfficient/utils/constants.py index 74313044a..20671cc71 100644 --- a/QEfficient/utils/constants.py +++ b/QEfficient/utils/constants.py @@ -17,7 +17,7 @@ ONNX_EXPORT_EXAMPLE_SEQ_LEN = 32 ONNX_EXPORT_EXAMPLE_FBS = 4 ONNX_EXPORT_EXAMPLE_NLK = 2 # Number of Logits to Keep -ONNX_EXPORT_OPSET = 13 +ONNX_EXPORT_OPSET = 14 ONNX_EXPORT_MAX_NUM_IMAGES = 1 ONNX_EXPORT_MAX_IMAGE_TILES = 4 ONNX_EXPORT_IMAGE_WIDTH = 560 @@ -60,7 +60,7 @@ def get_models_dir(): ONNX_EXPORT_EXAMPLE_SEQ_LEN = 32 ONNX_EXPORT_EXAMPLE_FBS = 4 ONNX_EXPORT_EXAMPLE_NLK = 2 # Number of Logits to Keep -ONNX_EXPORT_OPSET = 13 +ONNX_EXPORT_OPSET = 14 COMPILER = ["/opt/qti-aic/exec/qaic-exec", "-aic-hw", "-aic-hw-version=2.0"] From de15585b5bd2c645b98f2a9ef6dabbe8e1bb5f57 Mon Sep 17 00:00:00 2001 From: quic-shagun Date: Mon, 7 Apr 2025 12:04:16 -0700 Subject: [PATCH 2/7] Update Plamo modeling file to support Opset 13 Signed-off-by: quic-shagun --- .../models/plamo/modeling_plamo.py | 81 ++++++++++++++++--- QEfficient/utils/constants.py | 4 +- 2 files changed, 71 insertions(+), 14 deletions(-) diff --git a/QEfficient/transformers/models/plamo/modeling_plamo.py b/QEfficient/transformers/models/plamo/modeling_plamo.py index 351d8b2df..d99cf0f4e 100644 --- a/QEfficient/transformers/models/plamo/modeling_plamo.py +++ b/QEfficient/transformers/models/plamo/modeling_plamo.py @@ -1,14 +1,15 @@ from asyncio.log import logger import math from typing import Any, Dict, List, NamedTuple, Optional, Tuple, Union - import numpy as np import torch from torch import nn from torch.nn import functional as F from transformers import PretrainedConfig, PreTrainedModel from transformers.modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast -from transformers.cache_utils import Cache, DynamicCache, StaticCache +from transformers.cache_utils import Cache, StaticCache +from QEfficient.customop.rms_norm import CustomRMSNorm +from QEfficient.transformers.cache_utils import QEffDynamicCache from QEfficient.transformers.modeling_attn_mask_utils import _create_causal_mask @@ -151,6 +152,35 @@ def _rotary_pos_emb(x: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor, posit x_embed = (x * cos) + (_rotate_half(x) * sin) return x_embed +def qeff_apply_rotary_pos_emb(q, k, cos, sin, position_ids, unsqueeze_dim=1): + """Applies Rotary Position Embedding to the query and key tensors. + + Args: + q (`torch.Tensor`): The query tensor. + k (`torch.Tensor`): The key tensor. + cos (`torch.Tensor`): The cosine part of the rotary embedding. + sin (`torch.Tensor`): The sine part of the rotary embedding. + position_ids (`torch.Tensor`): + The position indices of the tokens corresponding to the query and key tensors. For example, this can be + used to pass offsetted position ids when working with a KV-cache. + unsqueeze_dim (`int`, *optional*, defaults to 1): + The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and + sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note + that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and + k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes + cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have + the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2. + Returns: + `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding. + """ + cos = cos[position_ids].unsqueeze(unsqueeze_dim) + sin = sin[position_ids].unsqueeze(unsqueeze_dim) + + # Apply rotation + q_embed = (q * cos) + (_rotate_half(q) * sin) + k_embed = (k * cos) + (_rotate_half(k) * sin) + # Cast back to original dtype + return q_embed.to(q.dtype), k_embed.to(k.dtype) class QEffPlamoRMSNorm(nn.Module): def __init__(self, hidden_size: int, eps: float = 1e-6) -> None: @@ -196,7 +226,7 @@ def forward( layer_idx: Optional[int] = None, cache_position: Optional[torch.LongTensor] = None, ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor, torch.Tensor]]]: - bsz, q_len, _ = hidden_states.size() + bsz, q_len, _ = hidden_states.size() query_states = self.q_proj(hidden_states).view(bsz, q_len, self.q_num_heads, self.qk_dim).transpose(1, 2) key_states = self.k_proj(hidden_states).view(bsz, q_len, self.k_num_heads, self.qk_dim).transpose(1, 2) @@ -219,22 +249,43 @@ def _expand_kv(t: torch.Tensor, repeat: int, target: int) -> torch.Tensor: "with a layer index." ) kv_seq_len = past_key_value.get_usable_length(kv_seq_len, layer_idx) - + cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len) - assert position_ids is not None + # query_states, key_states = qeff_apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids) query_states = _rotary_pos_emb(query_states, cos, sin, position_ids) key_states = _rotary_pos_emb(key_states, cos, sin, position_ids) - # [bsz, nh, t, hd] + cos = cos.squeeze(1).squeeze(0) # [seq_len, dim] + sin = sin.squeeze(1).squeeze(0) # [seq_len, dim] + if past_key_value is not None: # sin and cos are specific to RoPE models; cache_position needed for the static cache - cache_kwargs = {"sin": sin, "cos": cos, "batch_index": batch_index, "position_ids": position_ids} + cache_kwargs = {"sin": sin, "cos": cos, "batch_index": batch_index, "position_ids": position_ids} key_states, value_states = past_key_value.update(key_states, value_states, layer_idx, cache_kwargs) - attn_output = F.scaled_dot_product_attention(query_states, key_states, value_states, attn_mask=attention_mask) - attn_output = attn_output.transpose(1, 2) + attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.qk_dim) + + if attention_mask is not None: # no matter the length, we just slice it + causal_mask = attention_mask[:, :, :, : key_states.shape[-2]] + attn_weights = attn_weights + causal_mask - attn_output = attn_output.reshape(bsz, q_len, self.q_num_heads * self.v_dim) + # upcast attention to fp32 + attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype) + attn_output = torch.matmul(attn_weights, value_states) + + if attn_output.size() != (bsz, self.q_num_heads, q_len, self.qk_dim): + raise ValueError( + f"`attn_output` should be of size {(bsz, self.q_num_heads, q_len, self.qk_dim)}, but is" + f" {attn_output.size()}" + ) + + attn_output = attn_output.transpose(1, 2).contiguous() + attn_output = attn_output.reshape(bsz, q_len, -1) + + # attn_output = F.scaled_dot_product_attention(query_states, key_states, value_states, attn_mask=attention_mask) + # attn_output = attn_output.transpose(1, 2) + + #attn_output = attn_output.reshape(bsz, q_len, self.q_num_heads * self.v_dim) attn_output = self.o_proj(attn_output) if not output_attentions: @@ -267,6 +318,9 @@ def __init__(self, config: QEffPlamoConfig, layer_idx: int) -> None: self.mlp = MLP(config) self.norm = QEffPlamoRMSNorm(config.hidden_size, eps=config.rms_norm_eps) + def __qeff_init__(self,): + self.norm = CustomRMSNorm(self.config.hidden_size, eps=self.config.rms_norm_eps) + def forward( self, hidden_states: torch.Tensor, @@ -438,6 +492,9 @@ def __init__(self, config: QEffPlamoConfig): # Initialize weights and apply final processing self.post_init() + def __qeff_init__(self,): + self.norm = CustomRMSNorm(self.config.hidden_size, eps=self.config.rms_norm_eps) + def get_input_embeddings(self) -> torch.nn.Embedding: return self.embed_tokens @@ -516,9 +573,9 @@ def forward( if use_cache and not isinstance(past_key_values, Cache): return_legacy_cache = True if past_key_values is None: - past_key_values = DynamicCache() + past_key_values = QEffDynamicCache() else: - past_key_values = DynamicCache.from_legacy_cache(past_key_values) + past_key_values = QEffDynamicCache.from_legacy_cache(past_key_values) logger.warning_once( "We detected that you are passing `past_key_values` as a tuple of tuples. This is deprecated and " "will be removed in v4.47. Please convert your cache or use an appropriate `Cache` class " diff --git a/QEfficient/utils/constants.py b/QEfficient/utils/constants.py index 20671cc71..74313044a 100644 --- a/QEfficient/utils/constants.py +++ b/QEfficient/utils/constants.py @@ -17,7 +17,7 @@ ONNX_EXPORT_EXAMPLE_SEQ_LEN = 32 ONNX_EXPORT_EXAMPLE_FBS = 4 ONNX_EXPORT_EXAMPLE_NLK = 2 # Number of Logits to Keep -ONNX_EXPORT_OPSET = 14 +ONNX_EXPORT_OPSET = 13 ONNX_EXPORT_MAX_NUM_IMAGES = 1 ONNX_EXPORT_MAX_IMAGE_TILES = 4 ONNX_EXPORT_IMAGE_WIDTH = 560 @@ -60,7 +60,7 @@ def get_models_dir(): ONNX_EXPORT_EXAMPLE_SEQ_LEN = 32 ONNX_EXPORT_EXAMPLE_FBS = 4 ONNX_EXPORT_EXAMPLE_NLK = 2 # Number of Logits to Keep -ONNX_EXPORT_OPSET = 14 +ONNX_EXPORT_OPSET = 13 COMPILER = ["/opt/qti-aic/exec/qaic-exec", "-aic-hw", "-aic-hw-version=2.0"] From ed9d55d23e4f2d4e7a01f4541447afb83ebd4558 Mon Sep 17 00:00:00 2001 From: quic-shagun Date: Tue, 8 Apr 2025 22:14:48 -0700 Subject: [PATCH 3/7] Fix Plamo accuracy issues Signed-off-by: quic-shagun --- .../models/plamo/modeling_plamo.py | 138 +++++------------- 1 file changed, 36 insertions(+), 102 deletions(-) diff --git a/QEfficient/transformers/models/plamo/modeling_plamo.py b/QEfficient/transformers/models/plamo/modeling_plamo.py index d99cf0f4e..b90b631d3 100644 --- a/QEfficient/transformers/models/plamo/modeling_plamo.py +++ b/QEfficient/transformers/models/plamo/modeling_plamo.py @@ -1,16 +1,14 @@ from asyncio.log import logger import math -from typing import Any, Dict, List, NamedTuple, Optional, Tuple, Union +from typing import Any, Callable, Dict, List, Optional, Tuple, Union import numpy as np import torch from torch import nn -from torch.nn import functional as F from transformers import PretrainedConfig, PreTrainedModel from transformers.modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast -from transformers.cache_utils import Cache, StaticCache +from transformers.cache_utils import Cache from QEfficient.customop.rms_norm import CustomRMSNorm from QEfficient.transformers.cache_utils import QEffDynamicCache - from QEfficient.transformers.modeling_attn_mask_utils import _create_causal_mask @@ -195,6 +193,25 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon) return self.weight * hidden_states.to(input_dtype) +def eager_attention_forward( + module: nn.Module, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + attention_mask: Optional[torch.Tensor], + **kwargs, +): + attn_weights = torch.matmul(query, key.transpose(2, 3)) / math.sqrt(module.qk_dim) + + if attention_mask is not None: + attn_weights = torch.where(attention_mask, torch.tensor(-10000.0, dtype=torch.float32), attn_weights) + + # upcast attention to fp32 + attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype) + attn_output = torch.matmul(attn_weights, value) + attn_output = attn_output.transpose(1, 2).contiguous() + + return attn_output, attn_weights class QEffPlamoAttention(torch.nn.Module): def __init__(self, config: QEffPlamoConfig, layer_idx: Optional[int] = None) -> None: @@ -225,6 +242,7 @@ def forward( batch_index: Optional[torch.Tensor] = None, layer_idx: Optional[int] = None, cache_position: Optional[torch.LongTensor] = None, + **kwargs, ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor, torch.Tensor]]]: bsz, q_len, _ = hidden_states.size() @@ -254,38 +272,24 @@ def _expand_kv(t: torch.Tensor, repeat: int, target: int) -> torch.Tensor: # query_states, key_states = qeff_apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids) query_states = _rotary_pos_emb(query_states, cos, sin, position_ids) key_states = _rotary_pos_emb(key_states, cos, sin, position_ids) - - cos = cos.squeeze(1).squeeze(0) # [seq_len, dim] - sin = sin.squeeze(1).squeeze(0) # [seq_len, dim] if past_key_value is not None: # sin and cos are specific to RoPE models; cache_position needed for the static cache cache_kwargs = {"sin": sin, "cos": cos, "batch_index": batch_index, "position_ids": position_ids} key_states, value_states = past_key_value.update(key_states, value_states, layer_idx, cache_kwargs) - attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.qk_dim) - - if attention_mask is not None: # no matter the length, we just slice it - causal_mask = attention_mask[:, :, :, : key_states.shape[-2]] - attn_weights = attn_weights + causal_mask - - # upcast attention to fp32 - attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype) - attn_output = torch.matmul(attn_weights, value_states) + attention_interface: Callable = eager_attention_forward - if attn_output.size() != (bsz, self.q_num_heads, q_len, self.qk_dim): - raise ValueError( - f"`attn_output` should be of size {(bsz, self.q_num_heads, q_len, self.qk_dim)}, but is" - f" {attn_output.size()}" - ) - - attn_output = attn_output.transpose(1, 2).contiguous() + attn_output, attn_weights = attention_interface( + self, + query_states, + key_states, + value_states, + attention_mask, + **kwargs, + ) + attn_output = attn_output.reshape(bsz, q_len, -1) - - # attn_output = F.scaled_dot_product_attention(query_states, key_states, value_states, attn_mask=attention_mask) - # attn_output = attn_output.transpose(1, 2) - - #attn_output = attn_output.reshape(bsz, q_len, self.q_num_heads * self.v_dim) attn_output = self.o_proj(attn_output) if not output_attentions: @@ -293,7 +297,6 @@ def _expand_kv(t: torch.Tensor, repeat: int, target: int) -> torch.Tensor: return attn_output, attn_weights, past_key_value - class MLP(nn.Module): def __init__(self, config: QEffPlamoConfig) -> None: super().__init__() @@ -381,7 +384,6 @@ def forward(self, output_hidden_states: Optional[bool] = False, output_attentions: Optional[bool] = False, use_cache: Optional[bool] = False, - gradient_checkpointing: bool = False, batch_index: Optional[torch.LongTensor] = None, cache_position: Optional[torch.LongTensor] = None, ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]: @@ -396,25 +398,7 @@ def forward(self, assert all_hidden_states is not None all_hidden_states += (hidden_states,) - #past_key_value = x.past_key_values[idx] if x.past_key_values is not None else None - - if self.training and gradient_checkpointing: - - def create_custom_forward(module): # type: ignore - def custom_forward(*inputs): # type: ignore - # None for past_key_value - return module(*inputs, output_attentions, None) - - return custom_forward - - layer_outputs = torch.utils.checkpoint.checkpoint( - create_custom_forward(decoder_layer), # type: ignore - hidden_states, - attention_mask, - position_ids, - None, - ) - elif batch_index is not None: + if batch_index is not None: layer_outputs = decoder_layer( hidden_states, attention_mask=attention_mask, @@ -501,38 +485,6 @@ def get_input_embeddings(self) -> torch.nn.Embedding: def set_input_embeddings(self, value: torch.nn.Embedding) -> None: self.embed_tokens = value - # Copied from transformers.models.bart.modeling_bart.BartDecoder._prepare_decoder_attention_mask - def _prepare_decoder_attention_mask( - self, - attention_mask: torch.Tensor, - input_shape: Tuple[int, int], - inputs_embeds: Optional[torch.FloatTensor], - past_key_values_length: int, - ) -> Optional[torch.Tensor]: - # create causal mask - # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] - combined_attention_mask: Optional[torch.Tensor] = None - if input_shape[-1] > 1: - assert inputs_embeds is not None - combined_attention_mask = _make_causal_mask( - input_shape, - inputs_embeds.dtype, - device=inputs_embeds.device, - past_key_values_length=past_key_values_length, - ) - - if attention_mask is not None: - # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] - assert inputs_embeds is not None - expanded_attn_mask = _expand_mask(attention_mask, inputs_embeds.dtype, tgt_len=input_shape[-1]).to( - inputs_embeds.device - ) - combined_attention_mask = ( - expanded_attn_mask if combined_attention_mask is None else expanded_attn_mask + combined_attention_mask - ) - - return combined_attention_mask - def forward( self, input_ids: Optional[torch.LongTensor] = None, @@ -603,13 +555,14 @@ def forward( # decoder layers layer_outputs = self.layers( hidden_states=hidden_states, - attention_mask=attention_mask, position_ids=position_ids, + attention_mask=attention_mask, + output_hidden_states=output_hidden_states, past_key_values=past_key_values, output_attentions=output_attentions, use_cache=use_cache, cache_position=cache_position, - batch_index=batch_index + batch_index=batch_index, ) hidden_states = layer_outputs[0] @@ -682,25 +635,6 @@ def forward( # type: ignore cache_position: Optional[torch.LongTensor] = None, return_dict: Optional[bool] = None, ) -> Union[Tuple, CausalLMOutputWithPast]: - r""" - Args: - labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): - Labels for computing the masked language modeling loss. Indices should either be in `[0, ..., - config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored - (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`. - Returns: - Example: - ```python - >>> from transformers import AutoTokenizer, LlamaForCausalLM - >>> model = LlamaForCausalLM.from_pretrained(PATH_TO_CONVERTED_WEIGHTS) - >>> tokenizer = AutoTokenizer.from_pretrained(PATH_TO_CONVERTED_TOKENIZER) - >>> prompt = "Hey, are you consciours? Can you talk to me?" - >>> inputs = tokenizer(prompt, return_tensors="pt") - >>> # Generate - >>> generate_ids = model.generate(inputs.input_ids, max_length=30) - >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0] - "Hey, are you consciours? Can you talk to me?\nI'm not consciours, but I can talk to you." - ```""" assert input_ids is not None output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions From e426634c464483c44baa2b2b2de4003ada7d2a41 Mon Sep 17 00:00:00 2001 From: quic-shagun Date: Tue, 8 Apr 2025 23:23:37 -0700 Subject: [PATCH 4/7] Fix lint issues Signed-off-by: quic-shagun --- .../transformers/models/plamo/modeling_plamo.py | 6 ++++-- .../transformers/models/pytorch_transforms.py | 13 ++++++------- 2 files changed, 10 insertions(+), 9 deletions(-) diff --git a/QEfficient/transformers/models/plamo/modeling_plamo.py b/QEfficient/transformers/models/plamo/modeling_plamo.py index b90b631d3..a8a36869f 100644 --- a/QEfficient/transformers/models/plamo/modeling_plamo.py +++ b/QEfficient/transformers/models/plamo/modeling_plamo.py @@ -1,12 +1,14 @@ -from asyncio.log import logger import math +from asyncio.log import logger from typing import Any, Callable, Dict, List, Optional, Tuple, Union + import numpy as np import torch from torch import nn from transformers import PretrainedConfig, PreTrainedModel -from transformers.modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast from transformers.cache_utils import Cache +from transformers.modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast + from QEfficient.customop.rms_norm import CustomRMSNorm from QEfficient.transformers.cache_utils import QEffDynamicCache from QEfficient.transformers.modeling_attn_mask_utils import _create_causal_mask diff --git a/QEfficient/transformers/models/pytorch_transforms.py b/QEfficient/transformers/models/pytorch_transforms.py index 5061bd7e0..a17218d2a 100644 --- a/QEfficient/transformers/models/pytorch_transforms.py +++ b/QEfficient/transformers/models/pytorch_transforms.py @@ -227,12 +227,12 @@ ) from QEfficient.transformers.models.plamo.modeling_plamo import ( QEffPlamoAttention, - QEffPlamoDecoder, - QEffPlamoDecoderLayer, - QEffPlamoForCausalLM, - QEffPlamoModel, - QEffPlamoRMSNorm, - QEffPlamoRotaryEmbedding + QEffPlamoDecoder, + QEffPlamoDecoderLayer, + QEffPlamoForCausalLM, + QEffPlamoModel, + QEffPlamoRMSNorm, + QEffPlamoRotaryEmbedding, ) from QEfficient.transformers.models.qwen2.modeling_qwen2 import ( QEffQwen2Attention, @@ -434,7 +434,6 @@ class VlmNoKVOffloadTransform(ModuleMappingTransform): MllamaTextCrossAttention: QEffMllamaTextCrossAttentionSingleQPC, } - class KVCacheModuleMethodMapperTransform(ModuleMethodMapperTransform): _match_string_replace_method = { "InternVLChatModel": { From 537ba6f04371ff3bef768444b98fda29c74b5b63 Mon Sep 17 00:00:00 2001 From: quic-shagun Date: Tue, 8 Apr 2025 23:31:23 -0700 Subject: [PATCH 5/7] Add Qualcomm Signature in new modeling file Signed-off-by: quic-shagun --- .../models/plamo/modeling_plamo.py | 79 +++++++++++-------- .../transformers/models/pytorch_transforms.py | 11 +-- 2 files changed, 54 insertions(+), 36 deletions(-) diff --git a/QEfficient/transformers/models/plamo/modeling_plamo.py b/QEfficient/transformers/models/plamo/modeling_plamo.py index a8a36869f..bb4c42c90 100644 --- a/QEfficient/transformers/models/plamo/modeling_plamo.py +++ b/QEfficient/transformers/models/plamo/modeling_plamo.py @@ -1,3 +1,10 @@ +# ----------------------------------------------------------------------------- +# +# Copyright (c) 2025 Qualcomm Innovation Center, Inc. All rights reserved. +# SPDX-License-Identifier: BSD-3-Clause +# +# ----------------------------------------------------------------------------- + import math from asyncio.log import logger from typing import Any, Callable, Dict, List, Optional, Tuple, Union @@ -152,6 +159,7 @@ def _rotary_pos_emb(x: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor, posit x_embed = (x * cos) + (_rotate_half(x) * sin) return x_embed + def qeff_apply_rotary_pos_emb(q, k, cos, sin, position_ids, unsqueeze_dim=1): """Applies Rotary Position Embedding to the query and key tensors. @@ -182,6 +190,7 @@ def qeff_apply_rotary_pos_emb(q, k, cos, sin, position_ids, unsqueeze_dim=1): # Cast back to original dtype return q_embed.to(q.dtype), k_embed.to(k.dtype) + class QEffPlamoRMSNorm(nn.Module): def __init__(self, hidden_size: int, eps: float = 1e-6) -> None: super().__init__() @@ -195,6 +204,7 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon) return self.weight * hidden_states.to(input_dtype) + def eager_attention_forward( module: nn.Module, query: torch.Tensor, @@ -212,9 +222,10 @@ def eager_attention_forward( attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype) attn_output = torch.matmul(attn_weights, value) attn_output = attn_output.transpose(1, 2).contiguous() - + return attn_output, attn_weights + class QEffPlamoAttention(torch.nn.Module): def __init__(self, config: QEffPlamoConfig, layer_idx: Optional[int] = None) -> None: super().__init__() @@ -246,7 +257,7 @@ def forward( cache_position: Optional[torch.LongTensor] = None, **kwargs, ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor, torch.Tensor]]]: - bsz, q_len, _ = hidden_states.size() + bsz, q_len, _ = hidden_states.size() query_states = self.q_proj(hidden_states).view(bsz, q_len, self.q_num_heads, self.qk_dim).transpose(1, 2) key_states = self.k_proj(hidden_states).view(bsz, q_len, self.k_num_heads, self.qk_dim).transpose(1, 2) @@ -269,15 +280,15 @@ def _expand_kv(t: torch.Tensor, repeat: int, target: int) -> torch.Tensor: "with a layer index." ) kv_seq_len = past_key_value.get_usable_length(kv_seq_len, layer_idx) - + cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len) # query_states, key_states = qeff_apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids) query_states = _rotary_pos_emb(query_states, cos, sin, position_ids) key_states = _rotary_pos_emb(key_states, cos, sin, position_ids) - + if past_key_value is not None: # sin and cos are specific to RoPE models; cache_position needed for the static cache - cache_kwargs = {"sin": sin, "cos": cos, "batch_index": batch_index, "position_ids": position_ids} + cache_kwargs = {"sin": sin, "cos": cos, "batch_index": batch_index, "position_ids": position_ids} key_states, value_states = past_key_value.update(key_states, value_states, layer_idx, cache_kwargs) attention_interface: Callable = eager_attention_forward @@ -290,7 +301,7 @@ def _expand_kv(t: torch.Tensor, repeat: int, target: int) -> torch.Tensor: attention_mask, **kwargs, ) - + attn_output = attn_output.reshape(bsz, q_len, -1) attn_output = self.o_proj(attn_output) @@ -299,6 +310,7 @@ def _expand_kv(t: torch.Tensor, repeat: int, target: int) -> torch.Tensor: return attn_output, attn_weights, past_key_value + class MLP(nn.Module): def __init__(self, config: QEffPlamoConfig) -> None: super().__init__() @@ -323,9 +335,11 @@ def __init__(self, config: QEffPlamoConfig, layer_idx: int) -> None: self.mlp = MLP(config) self.norm = QEffPlamoRMSNorm(config.hidden_size, eps=config.rms_norm_eps) - def __qeff_init__(self,): + def __qeff_init__( + self, + ): self.norm = CustomRMSNorm(self.config.hidden_size, eps=self.config.rms_norm_eps) - + def forward( self, hidden_states: torch.Tensor, @@ -376,25 +390,27 @@ def forward( class QEffPlamoDecoder(torch.nn.Module): def __init__(self, config: QEffPlamoConfig) -> None: super().__init__() - self.layers = torch.nn.ModuleList([QEffPlamoDecoderLayer(config,layer_idx) for layer_idx in range(config.num_hidden_layers)]) - - def forward(self, - hidden_states: torch.Tensor, - position_ids: torch.Tensor, - attention_mask: Optional[torch.Tensor] = None, - past_key_values: Optional[List[torch.FloatTensor]] = None, - output_hidden_states: Optional[bool] = False, - output_attentions: Optional[bool] = False, - use_cache: Optional[bool] = False, - batch_index: Optional[torch.LongTensor] = None, - cache_position: Optional[torch.LongTensor] = None, - ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]: + self.layers = torch.nn.ModuleList( + [QEffPlamoDecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)] + ) + def forward( + self, + hidden_states: torch.Tensor, + position_ids: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + past_key_values: Optional[List[torch.FloatTensor]] = None, + output_hidden_states: Optional[bool] = False, + output_attentions: Optional[bool] = False, + use_cache: Optional[bool] = False, + batch_index: Optional[torch.LongTensor] = None, + cache_position: Optional[torch.LongTensor] = None, + ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]: all_hidden_states: Optional[Tuple[torch.Tensor, ...]] = () if output_hidden_states else None all_self_attns: Optional[Tuple[torch.Tensor, ...]] = () if output_attentions else None next_decoder_cache: Optional[Tuple[torch.Tensor, ...]] = () if use_cache else None hidden_states = hidden_states - + for idx, decoder_layer in enumerate(self.layers): if output_hidden_states: assert all_hidden_states is not None @@ -478,7 +494,9 @@ def __init__(self, config: QEffPlamoConfig): # Initialize weights and apply final processing self.post_init() - def __qeff_init__(self,): + def __qeff_init__( + self, + ): self.norm = CustomRMSNorm(self.config.hidden_size, eps=self.config.rms_norm_eps) def get_input_embeddings(self) -> torch.nn.Embedding: @@ -518,7 +536,6 @@ def forward( else: raise ValueError("You have to specify either decoder_input_ids or decoder_inputs_embeds") - if inputs_embeds is None: inputs_embeds = self.embed_tokens(input_ids) @@ -541,10 +558,10 @@ def forward( cache_position = torch.arange( past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device ) - + if position_ids is None: position_ids = cache_position.unsqueeze(0) - + past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 attention_mask = _create_causal_mask(position_ids=position_ids, target_length=past_seen_tokens) @@ -566,7 +583,7 @@ def forward( cache_position=cache_position, batch_index=batch_index, ) - + hidden_states = layer_outputs[0] all_hidden_states = layer_outputs[1] all_self_attns = layer_outputs[2] @@ -582,10 +599,10 @@ def forward( next_cache = next_decoder_cache if use_cache else None if not return_dict: return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None) - + if return_legacy_cache: next_cache = next_cache.to_legacy_cache() - + return BaseModelOutputWithPast( last_hidden_state=hidden_states, past_key_values=next_cache, @@ -644,7 +661,7 @@ def forward( # type: ignore output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states ) return_dict = return_dict if return_dict is not None else self.config.use_return_dict - + # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn) outputs = self.model( input_ids=input_ids, @@ -663,7 +680,7 @@ def forward( # type: ignore # Cast to INT32 to avoid issue while running in ONNXRT logit_index = position_ids.to(torch.int32).argmax(1, keepdim=True) hidden_states = outputs[0][torch.arange(position_ids.shape[0]).view(-1, 1), logit_index] - + logits = self.lm_head(hidden_states) logits = logits.float() diff --git a/QEfficient/transformers/models/pytorch_transforms.py b/QEfficient/transformers/models/pytorch_transforms.py index a17218d2a..1306de0c3 100644 --- a/QEfficient/transformers/models/pytorch_transforms.py +++ b/QEfficient/transformers/models/pytorch_transforms.py @@ -434,6 +434,7 @@ class VlmNoKVOffloadTransform(ModuleMappingTransform): MllamaTextCrossAttention: QEffMllamaTextCrossAttentionSingleQPC, } + class KVCacheModuleMethodMapperTransform(ModuleMethodMapperTransform): _match_string_replace_method = { "InternVLChatModel": { @@ -447,12 +448,12 @@ class KVCacheModuleMethodMapperTransform(ModuleMethodMapperTransform): "get_qeff_language_decoder": QEffInternVLModel.get_qeff_language_decoder, }, "InternVisionEmbeddings": {"forward": QEffInternVisionEmbeddings.forward}, - "PlamoForCausalLM" : {"forward": QEffPlamoForCausalLM.forward}, + "PlamoForCausalLM": {"forward": QEffPlamoForCausalLM.forward}, "PlamoModel": {"forward": QEffPlamoModel.forward}, - "PlamoDecoder":{ "forward": QEffPlamoDecoder.forward}, - "PlamoDecoderLayer":{ "forward": QEffPlamoDecoderLayer.forward}, + "PlamoDecoder": {"forward": QEffPlamoDecoder.forward}, + "PlamoDecoderLayer": {"forward": QEffPlamoDecoderLayer.forward}, "Attention": {"forward": QEffPlamoAttention.forward}, - "RMSNorm":{"forward": QEffPlamoRMSNorm.forward}, - "RotaryEmbedding":{"forward": QEffPlamoRotaryEmbedding.forward}, + "RMSNorm": {"forward": QEffPlamoRMSNorm.forward}, + "RotaryEmbedding": {"forward": QEffPlamoRotaryEmbedding.forward}, } _match_class_replace_method = {} From 1516dee23b8967afa1e90100b992caf20d063807 Mon Sep 17 00:00:00 2001 From: quic-shagun Date: Wed, 9 Apr 2025 02:27:13 -0700 Subject: [PATCH 6/7] nit: Add Plamo in test file and Update README Signed-off-by: quic-shagun --- README.md | 1 + tests/transformers/models/test_causal_lm_models.py | 1 + 2 files changed, 2 insertions(+) diff --git a/README.md b/README.md index 2185c9f64..bfce31771 100644 --- a/README.md +++ b/README.md @@ -19,6 +19,7 @@ - [09/2024] [AWQ](https://arxiv.org/abs/2306.00978)/[GPTQ](https://arxiv.org/abs/2210.17323) 4-bit quantized models are supported
- [09/2024] Now we support [PEFT](https://huggingface.co/docs/peft/index) models +- [04/2025] Added support for [PLaMo] (https://huggingface.co/pfnet/plamo-13b-instruct) - [01/2025] Added support for [Ibm-Granite] (https://huggingface.co/ibm-granite/granite-3.1-8b-instruct) - [01/2025] Added support for [Ibm-Granite-Guardian] (https://huggingface.co/ibm-granite/granite-guardian-3.1-8b) - [09/2024] Added support for [Gemma-2-Family](https://huggingface.co/collections/google/gemma-2-release-667d6600fd5220e7b967f315)
diff --git a/tests/transformers/models/test_causal_lm_models.py b/tests/transformers/models/test_causal_lm_models.py index 418386780..82b899db4 100644 --- a/tests/transformers/models/test_causal_lm_models.py +++ b/tests/transformers/models/test_causal_lm_models.py @@ -44,6 +44,7 @@ "neuralmagic/Qwen2-0.5B-Instruct-FP8", # fp8 quant method, static, with lm head ignored "ibm-granite/granite-3.1-2b-instruct", "ibm-granite/granite-guardian-3.1-2b", + "pfnet/plamo-13b-instruct", ] spd_test_models = [ From 5c100c5f0fe1ca18c3fc41fe0e17286b456ea4bd Mon Sep 17 00:00:00 2001 From: quic-shagun Date: Wed, 7 May 2025 22:06:23 -0700 Subject: [PATCH 7/7] Update modeling file as per latest guidelines Signed-off-by: quic-shagun --- .../models/plamo/modeling_plamo.py | 242 +----------------- 1 file changed, 13 insertions(+), 229 deletions(-) diff --git a/QEfficient/transformers/models/plamo/modeling_plamo.py b/QEfficient/transformers/models/plamo/modeling_plamo.py index bb4c42c90..17b3270c6 100644 --- a/QEfficient/transformers/models/plamo/modeling_plamo.py +++ b/QEfficient/transformers/models/plamo/modeling_plamo.py @@ -6,10 +6,8 @@ # ----------------------------------------------------------------------------- import math -from asyncio.log import logger from typing import Any, Callable, Dict, List, Optional, Tuple, Union -import numpy as np import torch from torch import nn from transformers import PretrainedConfig, PreTrainedModel @@ -72,56 +70,7 @@ def __init__( ) -# Copied from transformers.models.bart.modeling_bart._make_causal_mask -def _make_causal_mask( - input_ids_shape: Tuple[int, int], dtype: torch.dtype, device: torch.device, past_key_values_length: int = 0 -) -> torch.Tensor: - """ - Make causal mask used for bi-directional self-attention. - """ - bsz, tgt_len = input_ids_shape - mask = torch.full((tgt_len, tgt_len), torch.finfo(dtype).min, device=device) - mask_cond = torch.arange(mask.size(-1), device=device) - mask.masked_fill_(mask_cond < (mask_cond + 1).view(mask.size(-1), 1), 0) - mask = mask.to(dtype) - - if past_key_values_length > 0: - mask = torch.cat([torch.zeros(tgt_len, past_key_values_length, dtype=dtype, device=device), mask], dim=-1) - return mask[None, None, :, :].expand(bsz, 1, tgt_len, tgt_len + past_key_values_length) - - -# Copied from transformers.models.bart.modeling_bart._expand_mask -def _expand_mask(mask: torch.Tensor, dtype: torch.dtype, tgt_len: Optional[int] = None) -> torch.Tensor: - """ - Expands attention_mask from `[bsz, seq_len]` to `[bsz, 1, tgt_seq_len, src_seq_len]`. - """ - bsz, src_len = mask.size() - tgt_len = tgt_len if tgt_len is not None else src_len - - expanded_mask = mask[:, None, None, :].expand(bsz, 1, tgt_len, src_len).to(dtype) - - inverted_mask = 1.0 - expanded_mask - - return inverted_mask.masked_fill(inverted_mask.to(torch.bool), torch.finfo(dtype).min) # type: ignore - - class QEffPlamoRotaryEmbedding(torch.nn.Module): - def __init__( - self, dim: int, max_position_embeddings: int = 2048, base: int = 10000, device: Optional[torch.device] = None - ) -> None: - super().__init__() - - self.dim = dim - self.max_position_embeddings = max_position_embeddings - self.base = base - inv_freq = 1.0 / (self.base ** (torch.arange(0, self.dim, 2).float().to(device) / self.dim)) - self.register_buffer("inv_freq", inv_freq, persistent=False) - - # Build here to make `torch.jit.trace` work. - self._set_cos_sin_cache( - seq_len=max_position_embeddings, device=self.inv_freq.device, dtype=torch.get_default_dtype() - ) - def _set_cos_sin_cache(self, seq_len: int, device: Any, dtype: Any) -> None: self.max_seq_len_cached = seq_len t = torch.arange(self.max_seq_len_cached, device=device, dtype=self.inv_freq.dtype) # type: ignore @@ -160,43 +109,7 @@ def _rotary_pos_emb(x: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor, posit return x_embed -def qeff_apply_rotary_pos_emb(q, k, cos, sin, position_ids, unsqueeze_dim=1): - """Applies Rotary Position Embedding to the query and key tensors. - - Args: - q (`torch.Tensor`): The query tensor. - k (`torch.Tensor`): The key tensor. - cos (`torch.Tensor`): The cosine part of the rotary embedding. - sin (`torch.Tensor`): The sine part of the rotary embedding. - position_ids (`torch.Tensor`): - The position indices of the tokens corresponding to the query and key tensors. For example, this can be - used to pass offsetted position ids when working with a KV-cache. - unsqueeze_dim (`int`, *optional*, defaults to 1): - The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and - sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note - that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and - k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes - cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have - the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2. - Returns: - `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding. - """ - cos = cos[position_ids].unsqueeze(unsqueeze_dim) - sin = sin[position_ids].unsqueeze(unsqueeze_dim) - - # Apply rotation - q_embed = (q * cos) + (_rotate_half(q) * sin) - k_embed = (k * cos) + (_rotate_half(k) * sin) - # Cast back to original dtype - return q_embed.to(q.dtype), k_embed.to(k.dtype) - - class QEffPlamoRMSNorm(nn.Module): - def __init__(self, hidden_size: int, eps: float = 1e-6) -> None: - super().__init__() - self.weight = nn.Parameter(torch.ones(hidden_size)) - self.variance_epsilon = eps - def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: input_dtype = hidden_states.dtype hidden_states = hidden_states.to(torch.float32) @@ -227,23 +140,6 @@ def eager_attention_forward( class QEffPlamoAttention(torch.nn.Module): - def __init__(self, config: QEffPlamoConfig, layer_idx: Optional[int] = None) -> None: - super().__init__() - self.config = config - self.hidden_size = config.hidden_size - head_dim = self.hidden_size // config.num_attention_heads - self.max_position_embeddings = config.max_position_embeddings - - self.q_num_heads = config.num_attention_heads - self.qk_dim = self.v_dim = head_dim - self.k_num_heads = self.v_num_heads = int(np.ceil(self.q_num_heads / config.n_shared_head)) - - self.q_proj = nn.Linear(self.hidden_size, self.q_num_heads * self.qk_dim, bias=False) - self.k_proj = nn.Linear(self.hidden_size, self.k_num_heads * self.qk_dim, bias=False) - self.v_proj = nn.Linear(self.hidden_size, self.v_num_heads * self.v_dim, bias=False) - self.o_proj = nn.Linear(self.q_num_heads * self.v_dim, self.hidden_size, bias=False) - self.rotary_emb = QEffPlamoRotaryEmbedding(self.qk_dim, max_position_embeddings=self.max_position_embeddings) - def forward( self, hidden_states: torch.Tensor, @@ -312,29 +208,11 @@ def _expand_kv(t: torch.Tensor, repeat: int, target: int) -> torch.Tensor: class MLP(nn.Module): - def __init__(self, config: QEffPlamoConfig) -> None: - super().__init__() - self.config = config - self.hidden_size = config.hidden_size - self.intermediate_size = config.intermediate_size - self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False) - self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False) - self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False) - self.act_fn = torch.nn.functional.silu - def forward(self, x: torch.Tensor) -> torch.Tensor: return self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x)) # type: ignore class QEffPlamoDecoderLayer(torch.nn.Module): - def __init__(self, config: QEffPlamoConfig, layer_idx: int) -> None: - super().__init__() - self.config = config - self.hidden_size = config.hidden_size - self.self_attn = QEffPlamoAttention(config, layer_idx) - self.mlp = MLP(config) - self.norm = QEffPlamoRMSNorm(config.hidden_size, eps=config.rms_norm_eps) - def __qeff_init__( self, ): @@ -388,12 +266,6 @@ def forward( class QEffPlamoDecoder(torch.nn.Module): - def __init__(self, config: QEffPlamoConfig) -> None: - super().__init__() - self.layers = torch.nn.ModuleList( - [QEffPlamoDecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)] - ) - def forward( self, hidden_states: torch.Tensor, @@ -416,29 +288,16 @@ def forward( assert all_hidden_states is not None all_hidden_states += (hidden_states,) - if batch_index is not None: - layer_outputs = decoder_layer( - hidden_states, - attention_mask=attention_mask, - position_ids=position_ids, - past_key_value=past_key_values, - output_attentions=output_attentions, - use_cache=use_cache, - layer_idx=idx, - batch_index=batch_index, - cache_position=cache_position, - ) - else: - layer_outputs = decoder_layer( - hidden_states, - attention_mask=attention_mask, - position_ids=position_ids, - past_key_value=past_key_values, - output_attentions=output_attentions, - use_cache=use_cache, - layer_idx=idx, - cache_position=cache_position, - ) + layer_outputs = decoder_layer( + hidden_states, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_value=past_key_values, + output_attentions=output_attentions, + use_cache=use_cache, + layer_idx=idx, + cache_position=cache_position, + ) hidden_states = layer_outputs[0] @@ -481,30 +340,11 @@ def _set_gradient_checkpointing(self, module: torch.nn.Module, value: bool = Fal class QEffPlamoModel(QEffPlamoPreTrainedModel): - def __init__(self, config: QEffPlamoConfig): - super().__init__(config) - self.padding_idx = config.pad_token_id - self.vocab_size = config.vocab_size - - self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx) - self.layers = QEffPlamoDecoder(config) # type: ignore - self.norm = QEffPlamoRMSNorm(config.hidden_size, eps=config.rms_norm_eps) - - self.gradient_checkpointing = False - # Initialize weights and apply final processing - self.post_init() - def __qeff_init__( self, ): self.norm = CustomRMSNorm(self.config.hidden_size, eps=self.config.rms_norm_eps) - def get_input_embeddings(self) -> torch.nn.Embedding: - return self.embed_tokens - - def set_input_embeddings(self, value: torch.nn.Embedding) -> None: - self.embed_tokens = value - def forward( self, input_ids: Optional[torch.LongTensor] = None, @@ -529,12 +369,8 @@ def forward( return_dict = return_dict if return_dict is not None else self.config.use_return_dict # retrieve input_ids and inputs_embeds - if input_ids is not None and inputs_embeds is not None: + if (input_ids is None) ^ (inputs_embeds is not None): raise ValueError("You cannot specify both decoder_input_ids and decoder_inputs_embeds at the same time") - elif input_ids is not None: - batch_size, seq_length = input_ids.shape - else: - raise ValueError("You have to specify either decoder_input_ids or decoder_inputs_embeds") if inputs_embeds is None: inputs_embeds = self.embed_tokens(input_ids) @@ -543,15 +379,7 @@ def forward( return_legacy_cache = False if use_cache and not isinstance(past_key_values, Cache): return_legacy_cache = True - if past_key_values is None: - past_key_values = QEffDynamicCache() - else: - past_key_values = QEffDynamicCache.from_legacy_cache(past_key_values) - logger.warning_once( - "We detected that you are passing `past_key_values` as a tuple of tuples. This is deprecated and " - "will be removed in v4.47. Please convert your cache or use an appropriate `Cache` class " - "(https://huggingface.co/docs/transformers/kv_cache#legacy-cache-format)" - ) + past_key_values = QEffDynamicCache() if cache_position is None: past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 @@ -612,33 +440,6 @@ def forward( class QEffPlamoForCausalLM(QEffPlamoPreTrainedModel): - def __init__(self, config: PretrainedConfig) -> None: - super().__init__(config) - self.model = QEffPlamoModel(config) - - self.lm_head: torch.nn.Module = nn.Linear(config.hidden_size, config.vocab_size, bias=False) - - # Initialize weights and apply final processing - self.post_init() - - def get_input_embeddings(self) -> torch.nn.Embedding: - return self.model.embed_tokens - - def set_input_embeddings(self, value: torch.nn.Embedding) -> None: - self.model.embed_tokens = value - - def get_output_embeddings(self) -> torch.nn.Module: - return self.lm_head - - def set_output_embeddings(self, new_embeddings: torch.nn.Module) -> None: - self.lm_head = new_embeddings - - def set_decoder(self, decoder: QEffPlamoModel) -> None: - self.model = decoder - - def get_decoder(self) -> QEffPlamoModel: - return self.model - def forward( # type: ignore self, input_ids: Optional[torch.LongTensor] = None, @@ -684,25 +485,8 @@ def forward( # type: ignore logits = self.lm_head(hidden_states) logits = logits.float() - loss = None - if labels is not None: - # Shift so that tokens < n predict n - shift_logits = logits[..., :-1, :].contiguous() - shift_labels = labels[..., 1:].contiguous() - # Flatten the tokens - loss_fct = nn.CrossEntropyLoss() - shift_logits = shift_logits.view(-1, self.config.vocab_size) - shift_labels = shift_labels.view(-1) - # Enable model parallelism - shift_labels = shift_labels.to(shift_logits.device) - loss = loss_fct(shift_logits, shift_labels) - - if not return_dict: - output = (logits,) + outputs[1:] - return (loss,) + output if loss is not None else output - return CausalLMOutputWithPast( - loss=loss, + loss=None, logits=logits, past_key_values=outputs.past_key_values, hidden_states=outputs.hidden_states,