Skip to content

Commit 518086c

Browse files
committed
correct mamba cache dtype extraction
Signed-off-by: Lucas Liebenwein <[email protected]>
1 parent b7798bf commit 518086c

File tree

3 files changed

+10
-82
lines changed

3 files changed

+10
-82
lines changed

tensorrt_llm/_torch/auto_deploy/custom_ops/attention_interface.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@ class CacheConfig:
2929
"""A dataclass to hold information how to configure the cache."""
3030

3131
dtype: Optional[torch.dtype] = None
32+
mamba_dtype: Optional[torch.dtype] = None
3233

3334

3435
class SequenceInfo:

tensorrt_llm/_torch/auto_deploy/custom_ops/mamba/torch_backend_mamba.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -325,7 +325,7 @@ def get_cached_attention_op(cls) -> MHACallable:
325325

326326
@classmethod
327327
def get_prepare_metadata_op(cls) -> Tuple[PrepareMetadataCallable, int]:
328-
# Returns (seq_len, seq_start, slot_idx)
328+
# Returns (seq_len, seq_start, slot_idx, use_initial_states)
329329
return torch.ops.auto_deploy.torch_ssm_prepare_metadata, 4
330330

331331
@classmethod
@@ -339,6 +339,9 @@ def get_cache_initializers(
339339
num_heads = hs_fake.shape[-2]
340340
head_dim = hs_fake.shape[-1]
341341

342+
# dtype from node itself
343+
dtype = source_attn_node.meta["val"].dtype
344+
342345
# Infer state size by assuming B has shape [b, s, n_groups * ssm_state_size]
343346
# During runtime we pass [b, s, n_groups, ssm_state_size]; both give the same last dim product.
344347
if B_fake.ndim >= 4:
@@ -354,7 +357,7 @@ def _get_ssm_cache(si: SequenceInfo):
354357
head_dim,
355358
ssm_state_size,
356359
device=si.device,
357-
dtype=cache_config.dtype or hs_fake.dtype,
360+
dtype=cache_config.mamba_dtype or dtype,
358361
)
359362

360363
return {"ssm_state_cache": _get_ssm_cache}

tensorrt_llm/_torch/auto_deploy/custom_ops/mamba/triton_backend_mamba.py

Lines changed: 4 additions & 80 deletions
Original file line numberDiff line numberDiff line change
@@ -1,27 +1,14 @@
1-
from typing import List, Tuple
1+
from typing import List
22

33
import torch
4-
from torch._ops import OpOverloadPacket
5-
from torch.fx import Node
64

75
# Triton kernels
86
from tensorrt_llm._torch.modules.mamba.mamba2_metadata import cu_seqlens_to_chunk_indices_offsets
97
from tensorrt_llm._torch.modules.mamba.selective_state_update import selective_state_update
108
from tensorrt_llm._torch.modules.mamba.ssd_combined import mamba_chunk_scan_combined
119

12-
from ...utils.node_utils import extract_op_args
13-
from ..attention_interface import (
14-
AttentionDescriptor,
15-
AttentionLayout,
16-
AttentionRegistry,
17-
BufferInitializerDict,
18-
CacheConfig,
19-
CacheInitializerDict,
20-
Constant,
21-
MHACallable,
22-
PrepareMetadataCallable,
23-
SequenceInfo,
24-
)
10+
from ..attention_interface import AttentionRegistry, MHACallable
11+
from .torch_backend_mamba import TorchBackendSSM
2512

2613

2714
@torch.library.custom_op("auto_deploy::triton_cached_ssm", mutates_args={})
@@ -202,70 +189,7 @@ def _triton_cached_ssm_fake(
202189

203190

204191
@AttentionRegistry.register("triton_ssm")
205-
class TritonBackendSSM(AttentionDescriptor):
206-
@classmethod
207-
def is_paged(cls) -> bool:
208-
return True
209-
210-
@classmethod
211-
def get_attention_layout(cls) -> AttentionLayout:
212-
# Hidden states follow [b, s, n, d]
213-
return "bsnd"
214-
215-
@classmethod
216-
def get_num_qkv_args(cls) -> int:
217-
# torch_ssm_transform signature has 7 node/state arguments
218-
return 7
219-
220-
@classmethod
221-
def get_source_attention_op(cls) -> OpOverloadPacket:
222-
# Keep source op unchanged (used for uncached pre-export)
223-
return torch.ops.auto_deploy.torch_ssm
224-
192+
class TritonBackendSSM(TorchBackendSSM):
225193
@classmethod
226194
def get_cached_attention_op(cls) -> MHACallable:
227195
return torch.ops.auto_deploy.triton_cached_ssm
228-
229-
@classmethod
230-
def get_prepare_metadata_op(cls) -> Tuple[PrepareMetadataCallable, int]:
231-
# Returns (seq_len, seq_start, slot_idx, use_initial_states)
232-
return torch.ops.auto_deploy.torch_ssm_prepare_metadata, 4
233-
234-
@classmethod
235-
def get_cache_initializers(
236-
cls, source_attn_node: Node, cache_config: CacheConfig
237-
) -> CacheInitializerDict:
238-
# Shapes from fake tensors
239-
hs_fake: torch.Tensor = source_attn_node.args[0].meta["val"]
240-
B_fake: torch.Tensor = source_attn_node.args[2].meta["val"]
241-
242-
num_heads = hs_fake.shape[-2]
243-
head_dim = hs_fake.shape[-1]
244-
245-
if B_fake.ndim >= 4:
246-
ssm_state_size = B_fake.shape[-1]
247-
else:
248-
ssm_state_size = max(1, B_fake.shape[-1])
249-
250-
def _get_ssm_cache(si: SequenceInfo):
251-
return torch.empty(
252-
si.max_batch_size,
253-
num_heads,
254-
head_dim,
255-
ssm_state_size,
256-
device=si.device,
257-
dtype=cache_config.dtype or hs_fake.dtype,
258-
)
259-
260-
return {"ssm_state_cache": _get_ssm_cache}
261-
262-
@classmethod
263-
def get_global_buffer_initializers(cls, source_attn_node: Node) -> BufferInitializerDict:
264-
return {}
265-
266-
@classmethod
267-
def get_constants(cls, source_attn_node: Node) -> List[Constant]:
268-
time_step_limit, chunk_size = extract_op_args(
269-
source_attn_node, "time_step_limit", "chunk_size"
270-
)
271-
return [time_step_limit, chunk_size]

0 commit comments

Comments
 (0)