22import os
33import weakref
44from dataclasses import dataclass , field
5- from typing import TYPE_CHECKING , Optional , Tuple , Union
5+ from typing import TYPE_CHECKING , List , Optional , Tuple , Union
66
77import torch
88
@@ -83,6 +83,8 @@ class TrtllmAttentionWrapper:
8383 spec_decoding_bl_tree_mask_offset : Optional [torch .Tensor ]
8484 spec_decoding_bl_tree_mask : Optional [torch .Tensor ]
8585 spec_bl_tree_first_sparse_mask_offset_kv : Optional [torch .Tensor ]
86+ helix_position_offsets : Optional [torch .Tensor ]
87+ helix_is_inactive_rank : Optional [torch .Tensor ]
8688 kwargs : dict
8789
8890 def __init__ (
@@ -298,10 +300,6 @@ def plan(
298300 self .sparse_mla_topk = sparse_mla_topk
299301 self .helix_position_offsets = helix_position_offsets
300302 self .helix_is_inactive_rank = helix_is_inactive_rank
301- if self .helix_is_inactive_rank is not None and not isinstance (
302- self .helix_is_inactive_rank , torch .Tensor ):
303- self .helix_is_inactive_rank = torch .tensor (
304- self .helix_is_inactive_rank , dtype = torch .bool , pin_memory = True )
305303
306304 if max_sequence_length > self .rope_params .max_positions :
307305 self .rope_params .max_positions = max_sequence_length
@@ -646,12 +644,22 @@ class TrtllmAttentionMetadata(AttentionMetadata):
646644 spec_decoding_bl_tree_mask : Optional [torch .Tensor ] = None
647645 spec_bl_tree_first_sparse_mask_offset_kv : Optional [torch .Tensor ] = None
648646
647+ # Flag to enable helix parallelism.
648+ enable_helix : bool = False
649+
650+ # Global position ids of tokens for each sequence in the batch. Given
651+ # each helix rank has only a subset of tokens for a sequence, we compute
652+ # a global position id for each token here.
653+ helix_position_offsets : Optional [torch .Tensor ] = None
654+ helix_position_offsets_cpu : Optional [torch .Tensor ] = None
655+
649656 # Whether the current rank is inactive for helix parallelism.
650657 # In helix parallelism, only the active rank appends KV cache for the query token
651658 # and attends to the previously cached tokens as well as the query token. Inactive
652659 # ranks do not append KV cache for the query token and attend to the previously
653660 # cached tokens only.
654661 helix_is_inactive_rank : Optional [torch .Tensor ] = None
662+ helix_is_inactive_rank_cpu : Optional [torch .Tensor ] = None
655663
656664 @property
657665 def max_seq_len (self ) -> int :
@@ -696,6 +704,8 @@ def host_kv_cache_pool_mapping(self) -> Optional[torch.Tensor]:
696704
697705 def __post_init__ (self ) -> None :
698706 super ().__post_init__ ()
707+ self .enable_helix = self .mapping .has_cp_helix (
708+ ) if self .mapping is not None else False
699709 self ._post_init_with_buffers (self .cuda_graph_buffers )
700710
701711 def _post_init_with_buffers (self , buffers ) -> None :
@@ -824,11 +834,64 @@ def _post_init_with_buffers(self, buffers) -> None:
824834 pin_memory = True ,
825835 )
826836
837+ # Allocate static buffers for helix parallelism support
838+ if self .enable_helix :
839+ self .helix_position_offsets = self .get_empty (
840+ buffers ,
841+ (self .max_num_tokens , ),
842+ cache_name = "helix_position_offsets" ,
843+ dtype = torch .int ,
844+ capture_graph = capture_graph ,
845+ )
846+ self .helix_position_offsets_cpu = torch .empty_like (
847+ self .helix_position_offsets ,
848+ device = 'cpu' ,
849+ pin_memory = True ,
850+ )
851+ self .helix_is_inactive_rank = self .get_empty (
852+ buffers ,
853+ (self .max_num_sequences , ),
854+ cache_name = "helix_is_inactive_rank" ,
855+ dtype = torch .bool ,
856+ capture_graph = capture_graph ,
857+ )
858+ self .helix_is_inactive_rank_cpu = torch .empty_like (
859+ self .helix_is_inactive_rank ,
860+ device = 'cpu' ,
861+ pin_memory = True ,
862+ )
863+
827864 def on_update_kv_lens (self ):
828865 # After changing the kv_lens/kv_lens_cuda, we may need to update other metadata.
829866 # Especially for the changes in the _preprocess_inputs() of model_engine.py.
830867 pass
831868
869+ def update_helix_param (
870+ self ,
871+ helix_position_offsets : List [int ],
872+ helix_is_inactive_rank : List [bool ],
873+ ) -> None :
874+ """
875+ Update helix parameters by copying into static buffers for CUDA graph compatibility.
876+
877+ Args:
878+ helix_position_offsets: Position offsets for helix parallelism with shape (num_tokens,).
879+ helix_is_inactive_rank: Whether the current rank is inactive with shape (batch_size,).
880+ """
881+ if helix_position_offsets is not None and self .helix_position_offsets is not None :
882+ num_tokens = len (helix_position_offsets )
883+ self .helix_position_offsets_cpu [:num_tokens ].copy_ (
884+ torch .tensor (helix_position_offsets , dtype = torch .int ))
885+ self .helix_position_offsets [:num_tokens ].copy_ (
886+ self .helix_position_offsets_cpu [:num_tokens ], non_blocking = True )
887+
888+ if helix_is_inactive_rank is not None and self .helix_is_inactive_rank is not None :
889+ batch_size = len (helix_is_inactive_rank )
890+ self .helix_is_inactive_rank_cpu [:batch_size ].copy_ (
891+ torch .tensor (helix_is_inactive_rank , dtype = torch .bool ))
892+ self .helix_is_inactive_rank [:batch_size ].copy_ (
893+ self .helix_is_inactive_rank_cpu [:batch_size ], non_blocking = True )
894+
832895 def prepare (self ) -> None :
833896 extra_attrs = get_model_extra_attrs ()
834897 # If model extra attrs is set, attention_metadata is setup in executor.
@@ -868,18 +931,13 @@ def prepare(self) -> None:
868931
869932 if self .enable_flash_mla :
870933 self .prepare_flash_mla ()
871- # number of tokens needed in the kv cache for each sequence after the next pass
872- if self . helix_is_inactive_rank is not None and len (
873- self .helix_is_inactive_rank ) :
934+
935+ # number of tokens needed in the kv cache for each sequence after the next pass.
936+ if self .enable_helix :
874937 # If helix is inactive, attend to the previously cached tokens only.
875938 assert cached_token_lens is not None , "cached_token_lens should be set for helix"
939+ active_rank = ~ self .helix_is_inactive_rank_cpu [:self .num_seqs ]
876940 kv_lens = cached_token_lens .clone ()
877- helix_is_inactive_rank_cpu = torch .tensor (
878- self .helix_is_inactive_rank ,
879- dtype = torch .bool ,
880- device = 'cpu' ,
881- )
882- active_rank = ~ helix_is_inactive_rank_cpu
883941 kv_lens [active_rank ] += self .seq_lens_kv [active_rank ]
884942 else :
885943 kv_lens = cached_token_lens + self .seq_lens_kv if cached_token_lens is not None else self .seq_lens_kv
@@ -1485,7 +1543,6 @@ def forward(
14851543 mrope_config : Optional [dict ] = None ,
14861544 attention_window_size : Optional [int ] = None ,
14871545 softmax_stats_tensor : Optional [torch .Tensor ] = None ,
1488- helix_position_offsets : Optional [torch .Tensor ] = None ,
14891546 enable_attn_nvfp4_output : bool = True ,
14901547 output : Optional [torch .Tensor ] = None ,
14911548 output_sf : Optional [torch .Tensor ] = None ,
@@ -1596,7 +1653,7 @@ def forward(
15961653 sparse_attn_indices_block_size = sparse_attn_indices_block_size ,
15971654 sparse_mla_topk = metadata .sparse_mla_topk if hasattr (
15981655 metadata , 'sparse_mla_topk' ) else 0 ,
1599- helix_position_offsets = helix_position_offsets ,
1656+ helix_position_offsets = metadata . helix_position_offsets ,
16001657 helix_is_inactive_rank = metadata .helix_is_inactive_rank ,
16011658 )
16021659 out_dtype = None
@@ -1856,8 +1913,6 @@ def mla_rope_generation(
18561913 mla_bmm1_scale : torch .Tensor ,
18571914 mla_bmm2_scale : torch .Tensor ,
18581915 quant_q_buffer : torch .Tensor ,
1859- helix_position_offsets : Optional [torch .Tensor ] = None ,
1860- helix_is_inactive_rank : Optional [torch .Tensor ] = None ,
18611916 out_scale : Optional [torch .Tensor ] = None ,
18621917 ) -> None :
18631918 """
@@ -1878,13 +1933,9 @@ def mla_rope_generation(
18781933 assert metadata .kv_cache_manager is not None
18791934 sink_token_length = 0
18801935
1881- # Ensure helix_is_inactive_rank and position_ids are on the same device.
1882- if helix_is_inactive_rank is not None :
1883- assert helix_is_inactive_rank .device == helix_position_offsets .device , \
1884- f"helix_is_inactive_rank must be on the same device as helix_position_offsets, " \
1885- f"got { helix_is_inactive_rank .device } vs { helix_position_offsets .device } "
1886-
1887- mla_tensor_params = [helix_position_offsets , helix_is_inactive_rank ]
1936+ mla_tensor_params = [
1937+ metadata .helix_position_offsets , metadata .helix_is_inactive_rank
1938+ ]
18881939
18891940 torch .ops .trtllm .mla_rope_generation (
18901941 fused_q ,
0 commit comments