Skip to content

Commit 2933e93

Browse files
committed
move enable_helix
1 parent 718bdc8 commit 2933e93

File tree

2 files changed

+2
-2
lines changed

2 files changed

+2
-2
lines changed

tensorrt_llm/_torch/attention_backend/trtllm.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -702,6 +702,8 @@ def host_kv_cache_pool_mapping(self) -> Optional[torch.Tensor]:
702702

703703
def __post_init__(self) -> None:
704704
super().__post_init__()
705+
self.enable_helix = self.mapping.has_cp_helix(
706+
) if self.mapping is not None else False
705707
self._post_init_with_buffers(self.cuda_graph_buffers)
706708

707709
def _post_init_with_buffers(self, buffers) -> None:

tensorrt_llm/_torch/pyexecutor/model_engine.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1037,8 +1037,6 @@ def _set_up_attn_metadata(self, kv_cache_manager: KVCacheManager):
10371037
cache_indirection=cache_indirection,
10381038
sparse_attention_config=self.sparse_attention_config,
10391039
num_heads_per_kv=num_heads_per_kv,
1040-
enable_helix=self.mapping.has_cp_helix()
1041-
if self.mapping is not None else False,
10421040
)
10431041

10441042
return self.attn_metadata

0 commit comments

Comments
 (0)