|
1 | | -from typing import List, Tuple |
| 1 | +from typing import List |
2 | 2 |
|
3 | 3 | import torch |
4 | | -from torch._ops import OpOverloadPacket |
5 | | -from torch.fx import Node |
6 | 4 |
|
7 | 5 | # Triton kernels |
8 | 6 | from tensorrt_llm._torch.modules.mamba.mamba2_metadata import cu_seqlens_to_chunk_indices_offsets |
9 | 7 | from tensorrt_llm._torch.modules.mamba.selective_state_update import selective_state_update |
10 | 8 | from tensorrt_llm._torch.modules.mamba.ssd_combined import mamba_chunk_scan_combined |
11 | 9 |
|
12 | | -from ...utils.node_utils import extract_op_args |
13 | | -from ..attention_interface import ( |
14 | | - AttentionDescriptor, |
15 | | - AttentionLayout, |
16 | | - AttentionRegistry, |
17 | | - BufferInitializerDict, |
18 | | - CacheConfig, |
19 | | - CacheInitializerDict, |
20 | | - Constant, |
21 | | - MHACallable, |
22 | | - PrepareMetadataCallable, |
23 | | - SequenceInfo, |
24 | | -) |
| 10 | +from ..attention_interface import AttentionRegistry, MHACallable |
| 11 | +from .torch_backend_mamba import TorchBackendSSM |
25 | 12 |
|
26 | 13 |
|
27 | 14 | @torch.library.custom_op("auto_deploy::triton_cached_ssm", mutates_args={}) |
@@ -202,70 +189,7 @@ def _triton_cached_ssm_fake( |
202 | 189 |
|
203 | 190 |
|
204 | 191 | @AttentionRegistry.register("triton_ssm") |
205 | | -class TritonBackendSSM(AttentionDescriptor): |
206 | | - @classmethod |
207 | | - def is_paged(cls) -> bool: |
208 | | - return True |
209 | | - |
210 | | - @classmethod |
211 | | - def get_attention_layout(cls) -> AttentionLayout: |
212 | | - # Hidden states follow [b, s, n, d] |
213 | | - return "bsnd" |
214 | | - |
215 | | - @classmethod |
216 | | - def get_num_qkv_args(cls) -> int: |
217 | | - # torch_ssm_transform signature has 7 node/state arguments |
218 | | - return 7 |
219 | | - |
220 | | - @classmethod |
221 | | - def get_source_attention_op(cls) -> OpOverloadPacket: |
222 | | - # Keep source op unchanged (used for uncached pre-export) |
223 | | - return torch.ops.auto_deploy.torch_ssm |
224 | | - |
| 192 | +class TritonBackendSSM(TorchBackendSSM): |
225 | 193 | @classmethod |
226 | 194 | def get_cached_attention_op(cls) -> MHACallable: |
227 | 195 | return torch.ops.auto_deploy.triton_cached_ssm |
228 | | - |
229 | | - @classmethod |
230 | | - def get_prepare_metadata_op(cls) -> Tuple[PrepareMetadataCallable, int]: |
231 | | - # Returns (seq_len, seq_start, slot_idx, use_initial_states) |
232 | | - return torch.ops.auto_deploy.torch_ssm_prepare_metadata, 4 |
233 | | - |
234 | | - @classmethod |
235 | | - def get_cache_initializers( |
236 | | - cls, source_attn_node: Node, cache_config: CacheConfig |
237 | | - ) -> CacheInitializerDict: |
238 | | - # Shapes from fake tensors |
239 | | - hs_fake: torch.Tensor = source_attn_node.args[0].meta["val"] |
240 | | - B_fake: torch.Tensor = source_attn_node.args[2].meta["val"] |
241 | | - |
242 | | - num_heads = hs_fake.shape[-2] |
243 | | - head_dim = hs_fake.shape[-1] |
244 | | - |
245 | | - if B_fake.ndim >= 4: |
246 | | - ssm_state_size = B_fake.shape[-1] |
247 | | - else: |
248 | | - ssm_state_size = max(1, B_fake.shape[-1]) |
249 | | - |
250 | | - def _get_ssm_cache(si: SequenceInfo): |
251 | | - return torch.empty( |
252 | | - si.max_batch_size, |
253 | | - num_heads, |
254 | | - head_dim, |
255 | | - ssm_state_size, |
256 | | - device=si.device, |
257 | | - dtype=cache_config.dtype or hs_fake.dtype, |
258 | | - ) |
259 | | - |
260 | | - return {"ssm_state_cache": _get_ssm_cache} |
261 | | - |
262 | | - @classmethod |
263 | | - def get_global_buffer_initializers(cls, source_attn_node: Node) -> BufferInitializerDict: |
264 | | - return {} |
265 | | - |
266 | | - @classmethod |
267 | | - def get_constants(cls, source_attn_node: Node) -> List[Constant]: |
268 | | - time_step_limit, chunk_size = extract_op_args( |
269 | | - source_attn_node, "time_step_limit", "chunk_size" |
270 | | - ) |
271 | | - return [time_step_limit, chunk_size] |
0 commit comments