22import os
33import weakref
44from dataclasses import dataclass , field
5- from typing import TYPE_CHECKING , Optional , Tuple , Union
5+ from typing import TYPE_CHECKING , List , Optional , Tuple , Union
66
77import torch
88
@@ -83,6 +83,8 @@ class TrtllmAttentionWrapper:
8383 spec_decoding_bl_tree_mask_offset : Optional [torch .Tensor ]
8484 spec_decoding_bl_tree_mask : Optional [torch .Tensor ]
8585 spec_bl_tree_first_sparse_mask_offset_kv : Optional [torch .Tensor ]
86+ helix_position_offsets : Optional [torch .Tensor ]
87+ helix_is_inactive_rank : Optional [torch .Tensor ]
8688 kwargs : dict
8789
8890 def __init__ (
@@ -298,10 +300,6 @@ def plan(
298300 self .sparse_mla_topk = sparse_mla_topk
299301 self .helix_position_offsets = helix_position_offsets
300302 self .helix_is_inactive_rank = helix_is_inactive_rank
301- if self .helix_is_inactive_rank is not None and not isinstance (
302- self .helix_is_inactive_rank , torch .Tensor ):
303- self .helix_is_inactive_rank = torch .tensor (
304- self .helix_is_inactive_rank , dtype = torch .bool , pin_memory = True )
305303
306304 if max_sequence_length > self .rope_params .max_positions :
307305 self .rope_params .max_positions = max_sequence_length
@@ -646,6 +644,14 @@ class TrtllmAttentionMetadata(AttentionMetadata):
646644 spec_decoding_bl_tree_mask : Optional [torch .Tensor ] = None
647645 spec_bl_tree_first_sparse_mask_offset_kv : Optional [torch .Tensor ] = None
648646
647+ # Flag to enable helix parallelism.
648+ enable_helix : bool = False
649+
650+ # Global position ids of tokens for each sequence in the batch. Given
651+ # each helix rank has only a subset of tokens for a sequence, we compute
652+ # a global position id for each token here.
653+ helix_position_offsets : Optional [torch .Tensor ] = None
654+
649655 # Whether the current rank is inactive for helix parallelism.
650656 # In helix parallelism, only the active rank appends KV cache for the query token
651657 # and attends to the previously cached tokens as well as the query token. Inactive
@@ -824,11 +830,64 @@ def _post_init_with_buffers(self, buffers) -> None:
824830 pin_memory = True ,
825831 )
826832
833+ # Allocate static buffers for helix parallelism support
834+ if self .enable_helix :
835+ self .helix_position_offsets = self .get_empty (
836+ buffers ,
837+ (self .max_num_tokens , ),
838+ cache_name = "helix_position_offsets" ,
839+ dtype = torch .int ,
840+ capture_graph = capture_graph ,
841+ )
842+ self .helix_position_offsets_cpu = torch .empty_like (
843+ self .helix_position_offsets ,
844+ device = 'cpu' ,
845+ pin_memory = True ,
846+ )
847+ self .helix_is_inactive_rank = self .get_empty (
848+ buffers ,
849+ (self .max_num_sequences , ),
850+ cache_name = "helix_is_inactive_rank" ,
851+ dtype = torch .bool ,
852+ capture_graph = capture_graph ,
853+ )
854+ self .helix_is_inactive_rank_cpu = torch .empty_like (
855+ self .helix_is_inactive_rank ,
856+ device = 'cpu' ,
857+ pin_memory = True ,
858+ )
859+
827860 def on_update_kv_lens (self ):
828861 # After changing the kv_lens/kv_lens_cuda, we may need to update other metadata.
829862 # Especially for the changes in the _preprocess_inputs() of model_engine.py.
830863 pass
831864
865+ def update_helix_param (
866+ self ,
867+ helix_position_offsets : List [int ],
868+ helix_is_inactive_rank : List [bool ],
869+ ) -> None :
870+ """
871+ Update helix parameters by copying into static buffers for CUDA graph compatibility.
872+
873+ Args:
874+ helix_position_offsets: Position offsets for helix parallelism with shape (num_tokens,).
875+ helix_is_inactive_rank: Whether the current rank is inactive with shape (batch_size,).
876+ """
877+ if helix_position_offsets is not None and self .helix_position_offsets is not None :
878+ num_tokens = len (helix_position_offsets )
879+ self .helix_position_offsets_cpu [:num_tokens ].copy_ (
880+ torch .tensor (helix_position_offsets , dtype = torch .int ))
881+ self .helix_position_offsets [:num_tokens ].copy_ (
882+ self .helix_position_offsets_cpu [:num_tokens ], non_blocking = True )
883+
884+ if helix_is_inactive_rank is not None and self .helix_is_inactive_rank is not None :
885+ batch_size = len (helix_is_inactive_rank )
886+ self .helix_is_inactive_rank_cpu [:batch_size ].copy_ (
887+ torch .tensor (helix_is_inactive_rank , dtype = torch .bool ))
888+ self .helix_is_inactive_rank [:batch_size ].copy_ (
889+ self .helix_is_inactive_rank_cpu [:batch_size ], non_blocking = True )
890+
832891 def prepare (self ) -> None :
833892 extra_attrs = get_model_extra_attrs ()
834893 # If model extra attrs is set, attention_metadata is setup in executor.
@@ -868,18 +927,13 @@ def prepare(self) -> None:
868927
869928 if self .enable_flash_mla :
870929 self .prepare_flash_mla ()
871- # number of tokens needed in the kv cache for each sequence after the next pass
872- if self . helix_is_inactive_rank is not None and len (
873- self .helix_is_inactive_rank ) :
930+
931+ # number of tokens needed in the kv cache for each sequence after the next pass.
932+ if self .enable_helix :
874933 # If helix is inactive, attend to the previously cached tokens only.
875934 assert cached_token_lens is not None , "cached_token_lens should be set for helix"
935+ active_rank = ~ self .helix_is_inactive_rank_cpu [:self .num_seqs ]
876936 kv_lens = cached_token_lens .clone ()
877- helix_is_inactive_rank_cpu = torch .tensor (
878- self .helix_is_inactive_rank ,
879- dtype = torch .bool ,
880- device = 'cpu' ,
881- )
882- active_rank = ~ helix_is_inactive_rank_cpu
883937 kv_lens [active_rank ] += self .seq_lens_kv [active_rank ]
884938 else :
885939 kv_lens = cached_token_lens + self .seq_lens_kv if cached_token_lens is not None else self .seq_lens_kv
@@ -1485,7 +1539,6 @@ def forward(
14851539 mrope_config : Optional [dict ] = None ,
14861540 attention_window_size : Optional [int ] = None ,
14871541 softmax_stats_tensor : Optional [torch .Tensor ] = None ,
1488- helix_position_offsets : Optional [torch .Tensor ] = None ,
14891542 enable_attn_nvfp4_output : bool = True ,
14901543 output : Optional [torch .Tensor ] = None ,
14911544 output_sf : Optional [torch .Tensor ] = None ,
@@ -1596,7 +1649,7 @@ def forward(
15961649 sparse_attn_indices_block_size = sparse_attn_indices_block_size ,
15971650 sparse_mla_topk = metadata .sparse_mla_topk if hasattr (
15981651 metadata , 'sparse_mla_topk' ) else 0 ,
1599- helix_position_offsets = helix_position_offsets ,
1652+ helix_position_offsets = metadata . helix_position_offsets ,
16001653 helix_is_inactive_rank = metadata .helix_is_inactive_rank ,
16011654 )
16021655 out_dtype = None
@@ -1856,8 +1909,6 @@ def mla_rope_generation(
18561909 mla_bmm1_scale : torch .Tensor ,
18571910 mla_bmm2_scale : torch .Tensor ,
18581911 quant_q_buffer : torch .Tensor ,
1859- helix_position_offsets : Optional [torch .Tensor ] = None ,
1860- helix_is_inactive_rank : Optional [torch .Tensor ] = None ,
18611912 out_scale : Optional [torch .Tensor ] = None ,
18621913 ) -> None :
18631914 """
@@ -1878,13 +1929,9 @@ def mla_rope_generation(
18781929 assert metadata .kv_cache_manager is not None
18791930 sink_token_length = 0
18801931
1881- # Ensure helix_is_inactive_rank and position_ids are on the same device.
1882- if helix_is_inactive_rank is not None :
1883- assert helix_is_inactive_rank .device == helix_position_offsets .device , \
1884- f"helix_is_inactive_rank must be on the same device as helix_position_offsets, " \
1885- f"got { helix_is_inactive_rank .device } vs { helix_position_offsets .device } "
1886-
1887- mla_tensor_params = [helix_position_offsets , helix_is_inactive_rank ]
1932+ mla_tensor_params = [
1933+ metadata .helix_position_offsets , metadata .helix_is_inactive_rank
1934+ ]
18881935
18891936 torch .ops .trtllm .mla_rope_generation (
18901937 fused_q ,
0 commit comments