Skip to content

Support of Grok1ModelForCausalLM #360

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 16 commits into from
Closed
4 changes: 4 additions & 0 deletions QEfficient/base/pytorch_transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,6 +107,10 @@ def apply(cls, model: nn.Module) -> Tuple[nn.Module, bool]:
):
for orig_method_name, mapped_method in repl_method_map.items():
setattr(module, orig_method_name, MethodType(mapped_method, module))

if hasattr(module, "__qeff_init__"):
module.__qeff_init__()

transformed = True

return model, transformed
7 changes: 7 additions & 0 deletions QEfficient/transformers/models/grok_1/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
# -----------------------------------------------------------------------------
#
# Copyright (c) 2024 Qualcomm Innovation Center, Inc. All rights reserved.
# SPDX-License-Identifier: BSD-3-Clause
#
# -----------------------------------------------------------------------------

359 changes: 359 additions & 0 deletions QEfficient/transformers/models/grok_1/modeling_grok1.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,359 @@
# -----------------------------------------------------------------------------
#
# Copyright (c) 2024 Qualcomm Innovation Center, Inc. All rights reserved.
# SPDX-License-Identifier: BSD-3-Clause
#
# -----------------------------------------------------------------------------
from typing import List, Optional, Tuple, Union

import torch
import torch.nn as nn
import torch.nn.functional as F
from transformers.modeling_outputs import (
MoeCausalLMOutputWithPast,
MoeModelOutputWithPast,
)
from transformers.models.llama.modeling_llama import repeat_kv, rotate_half

from QEfficient.transformers.cache_utils import QEffDynamicCache
from QEfficient.transformers.modeling_attn_mask_utils import _create_causal_mask


# Copied from transformers.models.llama.modeling_llama.apply_rotary_pos_emb
def qeff_apply_rotary_pos_emb(q, k, cos, sin, position_ids, unsqueeze_dim=1):
"""Applies Rotary Position Embedding to the query and key tensors.

Args:
q (`torch.Tensor`): The query tensor.
k (`torch.Tensor`): The key tensor.
cos (`torch.Tensor`): The cosine part of the rotary embedding.
sin (`torch.Tensor`): The sine part of the rotary embedding.
position_ids (`torch.Tensor`):
The position indices of the tokens corresponding to the query and key tensors. For example, this can be
used to pass offsetted position ids when working with a KV-cache.
unsqueeze_dim (`int`, *optional*, defaults to 1):
The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and
sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note
that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and
k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes
cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have
the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2.
Returns:
`tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding.
"""
cos = cos[position_ids].unsqueeze(unsqueeze_dim)
sin = sin[position_ids].unsqueeze(unsqueeze_dim)
q_embed = (q * cos) + (rotate_half(q) * sin)
k_embed = (k * cos) + (rotate_half(k) * sin)
return q_embed, k_embed


class QEffGrok1MultiHeadAttention(nn.Module):
def forward(
self,
hidden_states: torch.Tensor,
layer_idx: int,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_value: Optional[Tuple[torch.Tensor]] = None,
batch_index: Optional[torch.LongTensor] = None,
output_attentions: bool = False,
use_cache: bool = False,
**kwargs,
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
bsz, q_len, _ = hidden_states.size()

query_states = self.q_proj(hidden_states)
key_states = self.k_proj(hidden_states)
value_states = self.v_proj(hidden_states)

query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)

kv_seq_len = key_states.shape[-2]
if past_key_value is not None:
kv_seq_len = past_key_value.get_usable_length(kv_seq_len, layer_idx)

cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
query_states, key_states = qeff_apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids)

if past_key_value is not None:
cache_kwargs = {
"sin": sin,
"cos": cos,
"batch_index": batch_index,
"position_ids": position_ids,
} # Specific to RoPE models
key_states, value_states = past_key_value.update(key_states, value_states, layer_idx, cache_kwargs)

# repeat k/v heads if n_kv_heads < n_heads
key_states = repeat_kv(key_states, self.num_key_value_groups)
value_states = repeat_kv(value_states, self.num_key_value_groups)

attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)).to(torch.float)
attn_weights = attn_weights * self.attn_output_multiplier
attn_weights = self.max_attn_val * F.tanh(attn_weights / self.max_attn_val)

if attention_mask is not None:
attn_weights = torch.where(attention_mask, torch.tensor(-10000.0, dtype=torch.float32), attn_weights)

attn_weights = F.softmax(attn_weights, dim=-1).to(query_states.dtype)
attn_output = torch.matmul(attn_weights, value_states)

if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim):
raise ValueError(
f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is"
f" {attn_output.size()}"
)

attn_output = attn_output.transpose(1, 2).contiguous()
attn_output = attn_output.reshape(bsz, q_len, -1)

attn_output = self.o_proj(attn_output)

if not output_attentions:
attn_weights = None

return attn_output, attn_weights, past_key_value


class QEffGrok1MoeBlock(nn.Module):
def forward(self, hidden_states: torch.Tensor) -> Tuple[torch.Tensor]:
batch_size, sequence_length, hidden_dim = hidden_states.shape
hidden_states = hidden_states.view(-1, hidden_dim)
# router_logits: (batch * sequence_length, n_experts)
router_logits = self.gate(hidden_states)

routing_weights = F.softmax(router_logits, dim=1, dtype=torch.float)
routing_weights, selected_experts = torch.topk(routing_weights, self.top_k, dim=-1)
# we cast back to the input dtype
routing_weights = routing_weights.to(hidden_states.dtype)

final_hidden_states = torch.zeros(
(batch_size * sequence_length, hidden_dim),
dtype=hidden_states.dtype,
device=hidden_states.device,
)
# One hot encode the selected experts to create an expert mask
# this will be used to easily index which expert is going to be sollicitated
expert_mask = torch.nn.functional.one_hot(selected_experts, num_classes=self.num_experts).permute(2, 1, 0)

# Loop over all available experts in the model and perform the computation on each expert
for expert_idx in range(self.num_experts):
expert_layer = self.experts[expert_idx]
expert_mask_tr = expert_mask[expert_idx].transpose(0, 1)
current_hidden_states = expert_layer(hidden_states) * (((routing_weights * expert_mask_tr).sum(1))[:, None])
current_hidden_states = torch.where(
(routing_weights * expert_mask_tr).sum(1).to(torch.bool)[:, None],
current_hidden_states,
torch.tensor(0.0),
)
final_hidden_states = final_hidden_states + current_hidden_states
final_hidden_states = final_hidden_states.reshape(batch_size, sequence_length, hidden_dim)
return final_hidden_states, router_logits


class QEffGrok1DecoderLayer(nn.Module):
def forward(
self,
hidden_states: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_value: Optional[Tuple[torch.Tensor]] = None,
batch_index: Optional[torch.LongTensor] = None,
output_attentions: Optional[bool] = False,
output_router_logits: Optional[bool] = False,
use_cache: Optional[bool] = False,
**kwargs,
) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
residual = hidden_states
hidden_states = self.pre_attn_norm(hidden_states)
hidden_states, attention_weights, present_key_value = self.attn(
hidden_states,
layer_idx=self.layer_idx,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_value=past_key_value,
batch_index=batch_index,
output_attentions=output_attentions,
use_cache=use_cache,
)
hidden_states = self.post_attn_norm(hidden_states)
hidden_states = residual + hidden_states

residual = hidden_states
hidden_states = self.pre_moe_norm(hidden_states)
hidden_states, router_logits = self.moe_block(hidden_states)
hidden_states = self.post_moe_norm(hidden_states)
hidden_states = residual + hidden_states

outputs = (hidden_states,)
if output_attentions:
outputs += (attention_weights,)
if use_cache:
outputs += (present_key_value,)
if output_router_logits:
outputs += (router_logits,)
return outputs


class QEffGrok1Model(nn.Module):
def __qeff_init__(self):
for idx, layer in enumerate(self.layers):
layer.layer_idx = idx

def forward(
self,
input_ids: torch.LongTensor = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[List[torch.FloatTensor]] = None,
batch_index: Optional[torch.LongTensor] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
output_router_logits: Optional[bool] = None,
return_dict: Optional[bool] = None,
) -> Union[Tuple, MoeModelOutputWithPast]:
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
)
use_cache = use_cache if use_cache is not None else self.config.use_cache

return_dict = return_dict if return_dict is not None else self.config.use_return_dict

# retrieve input_ids and inputs_embeds
if input_ids is not None and inputs_embeds is not None:
raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
elif input_ids is not None:
batch_size, seq_length = input_ids.shape[:2]
elif inputs_embeds is not None:
batch_size, seq_length = inputs_embeds.shape[:2]
else:
raise ValueError("You have to specify either input_ids or inputs_embeds")

seq_length_with_past = seq_length
past_key_values_length = 0
if past_key_values is not None:
past_key_values_length = past_key_values[0][0].shape[2]
seq_length_with_past = seq_length_with_past + past_key_values_length

past_key_values = QEffDynamicCache.from_legacy_cache(past_key_values)

if inputs_embeds is None:
inputs_embeds = self.embed_tokens(input_ids)
inputs_embeds = inputs_embeds * self.embedding_multiplier_scale

attention_mask = _create_causal_mask(position_ids=position_ids, target_length=past_key_values_length)

# embed positions
hidden_states = inputs_embeds

# decoder layers
all_hidden_states = () if output_hidden_states else None
all_self_attns = () if output_attentions else None
all_router_logits = () if output_router_logits else None
next_decoder_cache = () if use_cache else None

for idx, decoder_layer in enumerate(self.layers):
if output_hidden_states:
all_hidden_states += (hidden_states,)

layer_outputs = decoder_layer(
hidden_states,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_value=past_key_values,
batch_index=batch_index,
output_attentions=output_attentions,
use_cache=use_cache,
)

hidden_states = layer_outputs[0]

if use_cache:
next_decoder_cache += (layer_outputs[2 if output_attentions else 1],)

if output_attentions:
all_self_attns += (layer_outputs[1],)

if output_router_logits:
all_router_logits += (layer_outputs[-1],)

hidden_states = self.norm(hidden_states)

# add hidden states from the last decoder layer
if output_hidden_states:
all_hidden_states += (hidden_states,)

past_key_values = past_key_values.to_legacy_cache()

return MoeModelOutputWithPast(
last_hidden_state=hidden_states,
past_key_values=past_key_values,
hidden_states=all_hidden_states,
attentions=all_self_attns,
router_logits=all_router_logits,
)


class QEffGrok1ModelForCausalLM(nn.Module):
def forward(
self,
input_ids: torch.LongTensor = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[List[torch.FloatTensor]] = None,
batch_index: Optional[torch.LongTensor] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
labels: Optional[torch.LongTensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
output_router_logits: Optional[bool] = None,
return_dict: Optional[bool] = None,
**kwargs,
):
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_router_logits = (
output_router_logits if output_router_logits is not None else self.config.output_router_logits
)

output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
)
return_dict = return_dict if return_dict is not None else self.config.use_return_dict

# decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
outputs = self.model(
input_ids=input_ids,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_values=past_key_values,
batch_index=batch_index,
inputs_embeds=inputs_embeds,
use_cache=use_cache,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
output_router_logits=output_router_logits,
return_dict=return_dict,
**kwargs,
)

# Cast to int32 to avoid ONNXRT issue
logit_idx = position_ids.to(torch.int32).argmax(1, keepdim=True)
hidden_states = outputs[0][torch.arange(position_ids.shape[0]).view(-1, 1), logit_idx]
logits = self.lm_head(hidden_states)
logits = logits * self.output_multiplier_scale
logits = logits.float()

return MoeCausalLMOutputWithPast(
logits=logits,
past_key_values=outputs.past_key_values,
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
router_logits=outputs.router_logits,
)
1 change: 1 addition & 0 deletions QEfficient/transformers/models/modeling_auto.py
Original file line number Diff line number Diff line change
Expand Up @@ -1295,6 +1295,7 @@ class QEFFAutoModelForCausalLM(QEFFBaseModel):
FP8DeQuantLinearToLinearTransform,
CustomOpsTransform,
KVCacheTransform,
KVCacheModuleMethodMapperTransform,
]
_onnx_transforms = [FP16ClipTransform, SplitTensorsTransform]

Expand Down
Loading
Loading