Skip to content

Commit 1e7acc6

Browse files
committed
[None][feat] Cudagraph updates for helix parallelism
Signed-off-by: Balaram Buddharaju <[email protected]>
1 parent 799a2ae commit 1e7acc6

File tree

9 files changed

+167
-72
lines changed

9 files changed

+167
-72
lines changed

tensorrt_llm/_torch/attention_backend/trtllm.py

Lines changed: 76 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
import os
33
import weakref
44
from dataclasses import dataclass, field
5-
from typing import TYPE_CHECKING, Optional, Tuple, Union
5+
from typing import TYPE_CHECKING, List, Optional, Tuple, Union
66

77
import torch
88

@@ -83,6 +83,8 @@ class TrtllmAttentionWrapper:
8383
spec_decoding_bl_tree_mask_offset: Optional[torch.Tensor]
8484
spec_decoding_bl_tree_mask: Optional[torch.Tensor]
8585
spec_bl_tree_first_sparse_mask_offset_kv: Optional[torch.Tensor]
86+
helix_position_offsets: Optional[torch.Tensor]
87+
helix_is_inactive_rank: Optional[torch.Tensor]
8688
kwargs: dict
8789

8890
def __init__(
@@ -298,10 +300,6 @@ def plan(
298300
self.sparse_mla_topk = sparse_mla_topk
299301
self.helix_position_offsets = helix_position_offsets
300302
self.helix_is_inactive_rank = helix_is_inactive_rank
301-
if self.helix_is_inactive_rank is not None and not isinstance(
302-
self.helix_is_inactive_rank, torch.Tensor):
303-
self.helix_is_inactive_rank = torch.tensor(
304-
self.helix_is_inactive_rank, dtype=torch.bool, pin_memory=True)
305303

306304
if max_sequence_length > self.rope_params.max_positions:
307305
self.rope_params.max_positions = max_sequence_length
@@ -646,12 +644,22 @@ class TrtllmAttentionMetadata(AttentionMetadata):
646644
spec_decoding_bl_tree_mask: Optional[torch.Tensor] = None
647645
spec_bl_tree_first_sparse_mask_offset_kv: Optional[torch.Tensor] = None
648646

647+
# Flag to enable helix parallelism.
648+
enable_helix: bool = False
649+
650+
# Global position ids of tokens for each sequence in the batch. Given
651+
# each helix rank has only a subset of tokens for a sequence, we compute
652+
# a global position id for each token here.
653+
helix_position_offsets: Optional[torch.Tensor] = None
654+
helix_position_offsets_cpu: Optional[torch.Tensor] = None
655+
649656
# Whether the current rank is inactive for helix parallelism.
650657
# In helix parallelism, only the active rank appends KV cache for the query token
651658
# and attends to the previously cached tokens as well as the query token. Inactive
652659
# ranks do not append KV cache for the query token and attend to the previously
653660
# cached tokens only.
654661
helix_is_inactive_rank: Optional[torch.Tensor] = None
662+
helix_is_inactive_rank_cpu: Optional[torch.Tensor] = None
655663

656664
@property
657665
def max_seq_len(self) -> int:
@@ -696,6 +704,8 @@ def host_kv_cache_pool_mapping(self) -> Optional[torch.Tensor]:
696704

697705
def __post_init__(self) -> None:
698706
super().__post_init__()
707+
self.enable_helix = self.mapping.has_cp_helix(
708+
) if self.mapping is not None else False
699709
self._post_init_with_buffers(self.cuda_graph_buffers)
700710

701711
def _post_init_with_buffers(self, buffers) -> None:
@@ -824,11 +834,64 @@ def _post_init_with_buffers(self, buffers) -> None:
824834
pin_memory=True,
825835
)
826836

837+
# Allocate static buffers for helix parallelism support
838+
if self.enable_helix:
839+
self.helix_position_offsets = self.get_empty(
840+
buffers,
841+
(self.max_num_tokens, ),
842+
cache_name="helix_position_offsets",
843+
dtype=torch.int,
844+
capture_graph=capture_graph,
845+
)
846+
self.helix_position_offsets_cpu = torch.empty_like(
847+
self.helix_position_offsets,
848+
device='cpu',
849+
pin_memory=True,
850+
)
851+
self.helix_is_inactive_rank = self.get_empty(
852+
buffers,
853+
(self.max_num_sequences, ),
854+
cache_name="helix_is_inactive_rank",
855+
dtype=torch.bool,
856+
capture_graph=capture_graph,
857+
)
858+
self.helix_is_inactive_rank_cpu = torch.empty_like(
859+
self.helix_is_inactive_rank,
860+
device='cpu',
861+
pin_memory=True,
862+
)
863+
827864
def on_update_kv_lens(self):
828865
# After changing the kv_lens/kv_lens_cuda, we may need to update other metadata.
829866
# Especially for the changes in the _preprocess_inputs() of model_engine.py.
830867
pass
831868

869+
def update_helix_param(
870+
self,
871+
helix_position_offsets: List[int],
872+
helix_is_inactive_rank: List[bool],
873+
) -> None:
874+
"""
875+
Update helix parameters by copying into static buffers for CUDA graph compatibility.
876+
877+
Args:
878+
helix_position_offsets: Position offsets for helix parallelism with shape (num_tokens,).
879+
helix_is_inactive_rank: Whether the current rank is inactive with shape (batch_size,).
880+
"""
881+
if helix_position_offsets is not None and self.helix_position_offsets is not None:
882+
num_tokens = len(helix_position_offsets)
883+
self.helix_position_offsets_cpu[:num_tokens].copy_(
884+
torch.tensor(helix_position_offsets, dtype=torch.int))
885+
self.helix_position_offsets[:num_tokens].copy_(
886+
self.helix_position_offsets_cpu[:num_tokens], non_blocking=True)
887+
888+
if helix_is_inactive_rank is not None and self.helix_is_inactive_rank is not None:
889+
batch_size = len(helix_is_inactive_rank)
890+
self.helix_is_inactive_rank_cpu[:batch_size].copy_(
891+
torch.tensor(helix_is_inactive_rank, dtype=torch.bool))
892+
self.helix_is_inactive_rank[:batch_size].copy_(
893+
self.helix_is_inactive_rank_cpu[:batch_size], non_blocking=True)
894+
832895
def prepare(self) -> None:
833896
extra_attrs = get_model_extra_attrs()
834897
# If model extra attrs is set, attention_metadata is setup in executor.
@@ -868,18 +931,13 @@ def prepare(self) -> None:
868931

869932
if self.enable_flash_mla:
870933
self.prepare_flash_mla()
871-
# number of tokens needed in the kv cache for each sequence after the next pass
872-
if self.helix_is_inactive_rank is not None and len(
873-
self.helix_is_inactive_rank):
934+
935+
# number of tokens needed in the kv cache for each sequence after the next pass.
936+
if self.enable_helix:
874937
# If helix is inactive, attend to the previously cached tokens only.
875938
assert cached_token_lens is not None, "cached_token_lens should be set for helix"
939+
active_rank = ~self.helix_is_inactive_rank_cpu[:self.num_seqs]
876940
kv_lens = cached_token_lens.clone()
877-
helix_is_inactive_rank_cpu = torch.tensor(
878-
self.helix_is_inactive_rank,
879-
dtype=torch.bool,
880-
device='cpu',
881-
)
882-
active_rank = ~helix_is_inactive_rank_cpu
883941
kv_lens[active_rank] += self.seq_lens_kv[active_rank]
884942
else:
885943
kv_lens = cached_token_lens + self.seq_lens_kv if cached_token_lens is not None else self.seq_lens_kv
@@ -1485,7 +1543,6 @@ def forward(
14851543
mrope_config: Optional[dict] = None,
14861544
attention_window_size: Optional[int] = None,
14871545
softmax_stats_tensor: Optional[torch.Tensor] = None,
1488-
helix_position_offsets: Optional[torch.Tensor] = None,
14891546
enable_attn_nvfp4_output: bool = True,
14901547
output: Optional[torch.Tensor] = None,
14911548
output_sf: Optional[torch.Tensor] = None,
@@ -1596,7 +1653,7 @@ def forward(
15961653
sparse_attn_indices_block_size=sparse_attn_indices_block_size,
15971654
sparse_mla_topk=metadata.sparse_mla_topk if hasattr(
15981655
metadata, 'sparse_mla_topk') else 0,
1599-
helix_position_offsets=helix_position_offsets,
1656+
helix_position_offsets=metadata.helix_position_offsets,
16001657
helix_is_inactive_rank=metadata.helix_is_inactive_rank,
16011658
)
16021659
out_dtype = None
@@ -1856,8 +1913,6 @@ def mla_rope_generation(
18561913
mla_bmm1_scale: torch.Tensor,
18571914
mla_bmm2_scale: torch.Tensor,
18581915
quant_q_buffer: torch.Tensor,
1859-
helix_position_offsets: Optional[torch.Tensor] = None,
1860-
helix_is_inactive_rank: Optional[torch.Tensor] = None,
18611916
out_scale: Optional[torch.Tensor] = None,
18621917
) -> None:
18631918
"""
@@ -1878,13 +1933,9 @@ def mla_rope_generation(
18781933
assert metadata.kv_cache_manager is not None
18791934
sink_token_length = 0
18801935

1881-
# Ensure helix_is_inactive_rank and position_ids are on the same device.
1882-
if helix_is_inactive_rank is not None:
1883-
assert helix_is_inactive_rank.device == helix_position_offsets.device, \
1884-
f"helix_is_inactive_rank must be on the same device as helix_position_offsets, " \
1885-
f"got {helix_is_inactive_rank.device} vs {helix_position_offsets.device}"
1886-
1887-
mla_tensor_params = [helix_position_offsets, helix_is_inactive_rank]
1936+
mla_tensor_params = [
1937+
metadata.helix_position_offsets, metadata.helix_is_inactive_rank
1938+
]
18881939

18891940
torch.ops.trtllm.mla_rope_generation(
18901941
fused_q,

tensorrt_llm/_torch/modules/attention.py

Lines changed: 16 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -708,7 +708,7 @@ def __init__(
708708
dtype: torch.dtype = None,
709709
dense_bias: Optional[bool] = None,
710710
config: Optional[ModelConfig] = None,
711-
enable_unit_test: bool = False,
711+
enable_helix_test: bool = False,
712712
mapping_with_cp: Optional[Mapping] = None,
713713
reduce_output: bool = True,
714714
):
@@ -733,7 +733,7 @@ def __init__(
733733
dtype (torch.dtype): The data type.
734734
dense_bias (bool): Whether to use bias in the output projection layer.
735735
config (ModelConfig): The model configuration.
736-
enable_unit_test (bool): Whether to enable unit test.
736+
enable_helix_test (bool): Whether to enable helix unit test.
737737
"""
738738
super().__init__()
739739
self.layer_idx = layer_idx
@@ -754,7 +754,7 @@ def __init__(
754754
self.max_position_embeddings = max_position_embeddings
755755
self.pos_embd_params = pos_embd_params
756756
self.dense_bias = dense_bias
757-
self.enable_unit_test = enable_unit_test
757+
self.enable_helix_test = enable_helix_test
758758
if dense_bias is None:
759759
self.dense_bias = bias
760760

@@ -816,7 +816,7 @@ def __init__(
816816
self.num_key_value_heads_tp = (self.num_key_value_heads + tp_size -
817817
1) // tp_size
818818

819-
if self.enable_unit_test:
819+
if self.enable_helix_test:
820820
rms_norm_eps = getattr(config.pretrained_config, "rms_norm_eps",
821821
1e-6)
822822
else:
@@ -1108,8 +1108,8 @@ def _attn_forward_gen(self, attn_backend: AttentionBackend, q: torch.Tensor,
11081108
v,
11091109
attn_metadata,
11101110
softmax_stats_tensor=softmax_stats,
1111-
helix_position_offsets=position_ids,
1112-
**kwargs)
1111+
**kwargs,
1112+
)
11131113
# this is the post-processing of helix parallel attention,
11141114
# similar to the post-processing of ring attention
11151115
kv_lora_rank = partial_o.shape[-1] // self.num_heads_tp
@@ -1135,7 +1135,7 @@ def _attn_forward_gen(self, attn_backend: AttentionBackend, q: torch.Tensor,
11351135
def create_output(self, hidden_states: torch.Tensor, num_contexts: int):
11361136
num_tokens = hidden_states.shape[0]
11371137
hidden_size = self.o_proj.in_features
1138-
if self.enable_unit_test and num_contexts > 0:
1138+
if self.enable_helix_test and num_contexts > 0:
11391139
# note: for testing Helix parallelism, we ensure that the output is
11401140
# large enough for the context phase, but we then cut it again in
11411141
# `forward_context`
@@ -1379,6 +1379,12 @@ def forward_context_default(
13791379
-1,
13801380
)
13811381

1382+
if self.enable_helix_test:
1383+
# While helix parallelism is mainly meant for generation, we set the
1384+
# helix position offsets for the context phase to get the math right
1385+
# in test_mla_helix.py.
1386+
attn_metadata.helix_position_offsets = position_ids
1387+
13821388
k = torch.empty_like(q).view(-1, self.num_heads_tp, self.qk_head_dim)
13831389
maybe_compiled_copy_(
13841390
k[..., :self.qk_nope_head_dim],
@@ -1388,17 +1394,13 @@ def forward_context_default(
13881394
self.qk_rope_head_dim)
13891395
k = k.view(-1, self.num_heads_tp * self.qk_head_dim)
13901396

1391-
helix_position_offsets = position_ids if self.mapping.has_cp_helix(
1392-
) else None
1393-
13941397
attn_output = self.mha.forward(
13951398
q,
13961399
k,
13971400
v,
13981401
attn_metadata,
13991402
attention_input_type=AttentionInputType.context_only,
14001403
latent_cache=latent_cache,
1401-
helix_position_offsets=helix_position_offsets,
14021404
out_scale=self.out_scale,
14031405
output=output,
14041406
)
@@ -1769,12 +1771,6 @@ def forward_absorption_generation(
17691771
device=q.device,
17701772
)
17711773

1772-
helix_position_offsets, helix_is_inactive_rank = None, None
1773-
if self.mapping.has_cp_helix():
1774-
helix_position_offsets = position_ids
1775-
helix_is_inactive_rank = attn_metadata.helix_is_inactive_rank
1776-
assert helix_position_offsets is not None and helix_is_inactive_rank is not None, "helix_position_offsets and helix_is_inactive_rank must be provided for helix parallelism."
1777-
17781774
rope_stream = self.aux_stream if not has_fp8_kv_cache else None
17791775
if self.k_b_proj_trans.dtype == torch.bfloat16:
17801776
# [num_heads, num_tokens, self.qk_nope_head_dim]
@@ -1799,8 +1795,7 @@ def forward_absorption_generation(
17991795
mla_bmm1_scale,
18001796
mla_bmm2_scale,
18011797
quant_q_buffer,
1802-
helix_position_offsets=helix_position_offsets,
1803-
helix_is_inactive_rank=helix_is_inactive_rank),
1798+
),
18041799
self.ln_events[0],
18051800
self.ln_events[1],
18061801
rope_stream,
@@ -1829,8 +1824,7 @@ def forward_absorption_generation(
18291824
mla_bmm1_scale,
18301825
mla_bmm2_scale,
18311826
quant_q_buffer,
1832-
helix_position_offsets=helix_position_offsets,
1833-
helix_is_inactive_rank=helix_is_inactive_rank),
1827+
),
18341828
self.ln_events[0],
18351829
self.ln_events[1],
18361830
rope_stream,
@@ -2182,7 +2176,7 @@ def forward(
21822176
output=attn_output,
21832177
latent_cache_gen=latent_cache_gen)
21842178

2185-
if self.enable_unit_test and self.mapping.has_cp_helix():
2179+
if self.enable_helix_test and self.mapping.has_cp_helix():
21862180
# note: for allowing testing Helix parallelism, we ensure that
21872181
# the output is compatible with o_proj even in the context phase,
21882182
# thus we cut it to num_heads_tp_cp * v_head_dim

tensorrt_llm/_torch/pyexecutor/llm_request.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -488,9 +488,10 @@ def __init__(
488488
self.py_orig_prompt_len = self.orig_prompt_len
489489
self.py_max_new_tokens = self.max_new_tokens
490490
self.py_min_length = self.sampling_config.min_length
491+
# `seqlen_this_rank_cp`, `total_input_len_cp`, and `py_helix_is_inactive_rank` are relevant to helix parallelism.
492+
self.seqlen_this_rank_cp = self.prompt_len
493+
self.total_input_len_cp = self.prompt_len
491494
self.py_helix_is_inactive_rank = False
492-
self.seqlen_this_rank_cp = 0
493-
self.total_input_len_cp = 0
494495
self.py_batch_idx = None
495496
self.py_draft_pages_allocated = 0
496497
self.py_rewind_len = 0

0 commit comments

Comments
 (0)