Skip to content

Llama4 chunked attention support #395

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 3 commits into
base: add_llama4
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
31 changes: 26 additions & 5 deletions QEfficient/transformers/models/llama4/modeling_llama4.py
Original file line number Diff line number Diff line change
Expand Up @@ -566,8 +566,13 @@ def forward(
key_states = key_states.transpose(1, 2)

if past_key_value is not None:
chunk_postion_ids = position_ids

if self.use_rope:
chunk_postion_ids = chunk_postion_ids % self.config.attention_chunk_size

# sin and cos are specific to RoPE models; cache_position needed for the static cache
cache_kwargs = {"batch_index": batch_index, "position_ids": position_ids}
cache_kwargs = {"batch_index": batch_index, "position_ids": chunk_postion_ids}
key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)

attention_interface: Callable = eager_attention_forward
Expand Down Expand Up @@ -713,10 +718,12 @@ def forward(
position_ids = cache_position.unsqueeze(0)

causal_mask = _create_causal_mask(position_ids=position_ids, target_length=past_seen_tokens)
chunk_causal_mask = None

_, chunk_causal_mask = self._update_causal_mask(
attention_mask, inputs_embeds, cache_position, past_key_values, output_attentions
)
if past_seen_tokens > self.config.attention_chunk_size:
chunked_position_ids = position_ids % self.config.attention_chunk_size
target_length = min(past_seen_tokens, torch.tensor(self.config.attention_chunk_size))
chunk_causal_mask = _create_causal_mask(position_ids=chunked_position_ids, target_length=target_length)

# embed positions
hidden_states = inputs_embeds
Expand Down Expand Up @@ -903,6 +910,12 @@ def get_specializations(

prefill_seq_len = prefill_seq_len if prefill_seq_len else 32
ctx_len = ctx_len if ctx_len else constants.INTERN_CTX_LEN
chunk_ctx_len = (
self.config.text_config.attention_chunk_size
if hasattr(self, "config")
else constants.LLAMA4_ATTENTION_CHUNK_SIZE
)

if img_size is None and hasattr(self.config.vision_config, "image_size"):
img_size = getattr(self.config.vision_config, "image_size")
elif img_size is None:
Expand All @@ -926,6 +939,7 @@ def get_specializations(
"batch_size_times_num_tiles": batch_size_times_num_tiles,
"img_size": img_size,
"chunk_length": prefill_seq_len,
"chunk_ctx_len": chunk_ctx_len,
},
{
"batch_size": batch_size,
Expand All @@ -934,6 +948,7 @@ def get_specializations(
"batch_size_times_num_tiles": batch_size_times_num_tiles,
"img_size": img_size,
"chunk_length": prefill_seq_len,
"chunk_ctx_len": chunk_ctx_len,
},
]

Expand All @@ -956,8 +971,14 @@ def get_onnx_dynamic_axes(self, kv_offload: bool = False):
vision_dynamic_axes["pixel_values"] = {0: "batch_size_times_num_tiles", 2: "img_size", 3: "img_size"}
vision_dynamic_axes["input_ids"] = {0: "batch_size", 1: "seq_len"}

pkv_dynamic_axes = {0: "batch_size", 2: "ctx_len"}
pkv_dynamic_axes = {0: "batch_size"}
for i in range(self.language_model.config.num_hidden_layers):
# switch between chunk_ctx_len and ctx_len for RoPE and NoPE layers.
if int((i + 1) % 4 != 0):
pkv_dynamic_axes[2] = "chunk_ctx_len"
else:
pkv_dynamic_axes[2] = "ctx_len"

for kv in ["key", "value"]:
lang_dynamic_axes[f"past_{kv}.{i}"] = pkv_dynamic_axes

Expand Down
3 changes: 3 additions & 0 deletions QEfficient/utils/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,9 @@ def get_models_dir():
GRANITEVISION_CTX_LEN = 6000
GRANITEVISION_NUM_CHANNELS = 3

# Llama4 Constants
LLAMA4_ATTENTION_CHUNK_SIZE = 8192


class Constants:
# Export Constants.
Expand Down