From ab6e66aaf328784e58d1fc0faf1dd92f6b66488e Mon Sep 17 00:00:00 2001 From: Rishin Raj Date: Thu, 8 May 2025 08:12:53 +0000 Subject: [PATCH 1/4] Added support for chunked attention Signed-off-by: Rishin --- .../transformers/models/llama4/modeling_llama4.py | 12 ++++++++---- 1 file changed, 8 insertions(+), 4 deletions(-) diff --git a/QEfficient/transformers/models/llama4/modeling_llama4.py b/QEfficient/transformers/models/llama4/modeling_llama4.py index fde1da358..d85e49de9 100644 --- a/QEfficient/transformers/models/llama4/modeling_llama4.py +++ b/QEfficient/transformers/models/llama4/modeling_llama4.py @@ -566,8 +566,13 @@ def forward( key_states = key_states.transpose(1, 2) if past_key_value is not None: + chunk_postion_ids = position_ids + + if self.use_rope: + chunk_postion_ids = chunk_postion_ids % self.config.attention_chunk_size + # sin and cos are specific to RoPE models; cache_position needed for the static cache - cache_kwargs = {"batch_index": batch_index, "position_ids": position_ids} + cache_kwargs = {"batch_index": batch_index, "position_ids": chunk_postion_ids} key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) attention_interface: Callable = eager_attention_forward @@ -714,9 +719,8 @@ def forward( causal_mask = _create_causal_mask(position_ids=position_ids, target_length=past_seen_tokens) - _, chunk_causal_mask = self._update_causal_mask( - attention_mask, inputs_embeds, cache_position, past_key_values, output_attentions - ) + chunked_position_ids = position_ids % self.config.attention_chunk_size + chunk_causal_mask = _create_causal_mask(position_ids=chunked_position_ids, target_length=past_seen_tokens) # embed positions hidden_states = inputs_embeds From d1665215a34fe3fa4515989c3896394af4d5e596 Mon Sep 17 00:00:00 2001 From: Rishin Raj Date: Thu, 8 May 2025 09:15:53 +0000 Subject: [PATCH 2/4] Added pkv input for switching between ctx len and chunked ctx len Signed-off-by: Rishin --- .../models/llama4/modeling_llama4.py | 16 +++++++++++++++- QEfficient/utils/constants.py | 3 +++ 2 files changed, 18 insertions(+), 1 deletion(-) diff --git a/QEfficient/transformers/models/llama4/modeling_llama4.py b/QEfficient/transformers/models/llama4/modeling_llama4.py index d85e49de9..05973ccff 100644 --- a/QEfficient/transformers/models/llama4/modeling_llama4.py +++ b/QEfficient/transformers/models/llama4/modeling_llama4.py @@ -907,6 +907,12 @@ def get_specializations( prefill_seq_len = prefill_seq_len if prefill_seq_len else 32 ctx_len = ctx_len if ctx_len else constants.INTERN_CTX_LEN + chunk_ctx_len = ( + self.config.text_config.attention_chunk_size + if hasattr(self, "config") + else constants.LLAMA4_ATTENTION_CHUNK_SIZE + ) + if img_size is None and hasattr(self.config.vision_config, "image_size"): img_size = getattr(self.config.vision_config, "image_size") elif img_size is None: @@ -930,6 +936,7 @@ def get_specializations( "batch_size_times_num_tiles": batch_size_times_num_tiles, "img_size": img_size, "chunk_length": prefill_seq_len, + "chunk_ctx_len": chunk_ctx_len, }, { "batch_size": batch_size, @@ -938,6 +945,7 @@ def get_specializations( "batch_size_times_num_tiles": batch_size_times_num_tiles, "img_size": img_size, "chunk_length": prefill_seq_len, + "chunk_ctx_len": chunk_ctx_len, }, ] @@ -960,8 +968,14 @@ def get_onnx_dynamic_axes(self, kv_offload: bool = False): vision_dynamic_axes["pixel_values"] = {0: "batch_size_times_num_tiles", 2: "img_size", 3: "img_size"} vision_dynamic_axes["input_ids"] = {0: "batch_size", 1: "seq_len"} - pkv_dynamic_axes = {0: "batch_size", 2: "ctx_len"} + pkv_dynamic_axes = {0: "batch_size"} for i in range(self.language_model.config.num_hidden_layers): + # switch between chunk_ctx_len and ctx_len for RoPE and NoPE layers. + if int((i + 1) % 4 != 0): + pkv_dynamic_axes[2] = "chunk_ctx_len" + else: + pkv_dynamic_axes[2] = "ctx_len" + for kv in ["key", "value"]: lang_dynamic_axes[f"past_{kv}.{i}"] = pkv_dynamic_axes diff --git a/QEfficient/utils/constants.py b/QEfficient/utils/constants.py index b1ff9701e..e7e5e88df 100644 --- a/QEfficient/utils/constants.py +++ b/QEfficient/utils/constants.py @@ -86,6 +86,9 @@ def get_models_dir(): GRANITEVISION_CTX_LEN = 6000 GRANITEVISION_NUM_CHANNELS = 3 +# Llama4 Constants +LLAMA4_ATTENTION_CHUNK_SIZE = 8192 + class Constants: # Export Constants. From 0e03e4c7e883dcf439d8195599c1d6409334949c Mon Sep 17 00:00:00 2001 From: Rishin Raj Date: Tue, 13 May 2025 08:07:38 +0000 Subject: [PATCH 3/4] Fix for chunk causal mask and target length Signed-off-by: Rishin --- QEfficient/transformers/models/llama4/modeling_llama4.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/QEfficient/transformers/models/llama4/modeling_llama4.py b/QEfficient/transformers/models/llama4/modeling_llama4.py index 05973ccff..5025953ca 100644 --- a/QEfficient/transformers/models/llama4/modeling_llama4.py +++ b/QEfficient/transformers/models/llama4/modeling_llama4.py @@ -718,9 +718,12 @@ def forward( position_ids = cache_position.unsqueeze(0) causal_mask = _create_causal_mask(position_ids=position_ids, target_length=past_seen_tokens) + chunk_causal_mask = None - chunked_position_ids = position_ids % self.config.attention_chunk_size - chunk_causal_mask = _create_causal_mask(position_ids=chunked_position_ids, target_length=past_seen_tokens) + if past_seen_tokens > self.config.attention_chunk_size: + chunked_position_ids = position_ids % self.config.attention_chunk_size + target_length = min(past_seen_tokens, torch.tensor(self.config.attention_chunk_size)) + chunk_causal_mask = _create_causal_mask(position_ids=chunked_position_ids, target_length=target_length) # embed positions hidden_states = inputs_embeds From 2b25e83c3ef6fe449513763ce9bc51f914c120c5 Mon Sep 17 00:00:00 2001 From: Rishin Raj Date: Wed, 14 May 2025 07:46:20 +0000 Subject: [PATCH 4/4] chunk ctx len fix Signed-off-by: Rishin --- .../transformers/models/llama4/modeling_llama4.py | 11 +++++++---- 1 file changed, 7 insertions(+), 4 deletions(-) diff --git a/QEfficient/transformers/models/llama4/modeling_llama4.py b/QEfficient/transformers/models/llama4/modeling_llama4.py index 5025953ca..ab656f9db 100644 --- a/QEfficient/transformers/models/llama4/modeling_llama4.py +++ b/QEfficient/transformers/models/llama4/modeling_llama4.py @@ -910,10 +910,13 @@ def get_specializations( prefill_seq_len = prefill_seq_len if prefill_seq_len else 32 ctx_len = ctx_len if ctx_len else constants.INTERN_CTX_LEN - chunk_ctx_len = ( - self.config.text_config.attention_chunk_size - if hasattr(self, "config") - else constants.LLAMA4_ATTENTION_CHUNK_SIZE + chunk_ctx_len = min( + ctx_len, + ( + self.config.text_config.attention_chunk_size + if hasattr(self, "config") + else constants.LLAMA4_ATTENTION_CHUNK_SIZE + ), ) if img_size is None and hasattr(self.config.vision_config, "image_size"):