Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions examples/auto_deploy/.gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -4,3 +4,4 @@ benchmark_results.json
*.png
# ignore config files that users might put here for debugging
*.yaml
!nano_v3.yaml
23 changes: 23 additions & 0 deletions examples/auto_deploy/nano_v3.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
runtime: trtllm
compile_backend: torch-cudagraph
max_batch_size: 384
max_seq_len: 65536 # tunable
enable_chunked_prefill: true
attn_backend: flashinfer
model_factory: AutoModelForCausalLM
skip_loading_weights: false
free_mem_ratio: 0.9
cuda_graph_batch_sizes: [1, 2, 4, 8, 16, 24, 32, 64, 128, 256, 320, 384]
kv_cache_config:
# disable kv_cache reuse since not supported for hybrid/ssm models
enable_block_reuse: false
transforms:
detect_sharding:
sharding_source: ['factory', 'heuristic']
sharding_dims: ['ep', 'bmm']
# tunable mamba cache dtype
# --> use float32 for accuracy and default (null) for speed
insert_cached_ssm_attention:
cache_config:
# mamba_dtype: float32
mamba_dtype: null
Original file line number Diff line number Diff line change
Expand Up @@ -10,10 +10,10 @@
"""

from abc import ABC, abstractmethod
from dataclasses import dataclass
from typing import Dict, List, Literal, Optional, Protocol, Sequence, Set, Tuple, Type, Union

import torch
from pydantic import BaseModel, ConfigDict, Field, field_validator
from torch._ops import OpOverloadPacket
from torch.fx import Node
from torch.types import Number
Expand All @@ -24,11 +24,39 @@
Constant = Union[int, float, str, None]


@dataclass
class CacheConfig:
"""A dataclass to hold information how to configure the cache."""
class CacheConfig(BaseModel):
"""Cache configuration for attention-related dtypes."""

dtype: Optional[torch.dtype] = None
model_config = ConfigDict(
arbitrary_types_allowed=True,
extra="forbid",
)

dtype: Optional[torch.dtype] = Field(default=None, description="KV cache dtype.")
mamba_dtype: Optional[torch.dtype] = Field(default=None, description="Mamba cache dtype.")

@field_validator("dtype", "mamba_dtype", mode="before")
@classmethod
def _coerce_dtype(cls, value):
if value is None or isinstance(value, torch.dtype):
return value
if isinstance(value, str):
dtype = getattr(torch, value, None)
assert isinstance(dtype, torch.dtype), f"Invalid {dtype=}"
return dtype
return value

def __or__(self, other: "CacheConfig") -> "CacheConfig":
"""Combine two CacheConfig objects field-wise using Python's `or` semantics.

For each field, selects the first non-None value between `self` and `other`.
"""
if not isinstance(other, CacheConfig):
raise NotImplementedError(f"Cannot combine CacheConfig with {type(other)}")
merged_kwargs = {}
for field_name in type(self).model_fields.keys():
merged_kwargs[field_name] = getattr(self, field_name) or getattr(other, field_name)
return CacheConfig(**merged_kwargs)


class SequenceInfo:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -347,14 +347,17 @@ def get_cache_initializers(
# Fallback: assume last dim is n_groups * state_size and choose a minimal positive size
ssm_state_size = max(1, B_fake.shape[-1])

# extract ssm_state_dtype from cache_config or hs_fake
ssm_state_dtype = cache_config.mamba_dtype or hs_fake.dtype

def _get_ssm_cache(si: SequenceInfo):
return torch.empty(
si.max_batch_size,
num_heads,
head_dim,
ssm_state_size,
device=si.device,
dtype=cache_config.dtype or hs_fake.dtype,
dtype=ssm_state_dtype,
)

return {"ssm_state_cache": _get_ssm_cache}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -125,6 +125,7 @@ def _triton_cached_ssm(
dt_limit=(time_step_limit[0], time_step_limit[1]),
return_final_states=False,
return_varlen_states=True,
mamba_ssm_cache_dtype=ssm_state_cache.dtype,
)

y_flat[:total_prefill_tokens] = y_prefill[0].to(y_flat.dtype)
Expand Down Expand Up @@ -198,9 +199,7 @@ def _triton_cached_ssm_fake(
)


## Note: we reuse the existing metadata custom op and its registered fake from torch backend.


# TODO: consider inheriting from TorchBackendSSM instead of redefining everything
@AttentionRegistry.register("triton_ssm")
class TritonBackendSSM(AttentionDescriptor):
@classmethod
Expand Down Expand Up @@ -247,14 +246,17 @@ def get_cache_initializers(
else:
ssm_state_size = max(1, B_fake.shape[-1])

# extract ssm_state_dtype from cache_config or hs_fake
ssm_state_dtype = cache_config.mamba_dtype or hs_fake.dtype

def _get_ssm_cache(si: SequenceInfo):
return torch.empty(
si.max_batch_size,
num_heads,
head_dim,
ssm_state_size,
device=si.device,
dtype=cache_config.dtype or hs_fake.dtype,
dtype=ssm_state_dtype,
)

return {"ssm_state_cache": _get_ssm_cache}
Expand Down
14 changes: 12 additions & 2 deletions tensorrt_llm/_torch/auto_deploy/transform/library/kvcache.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,12 @@
from pydantic import Field
from torch.fx import GraphModule, Node

from ...custom_ops.attention_interface import AttentionDescriptor, AttentionRegistry, Constant
from ...custom_ops.attention_interface import (
AttentionDescriptor,
AttentionRegistry,
CacheConfig,
Constant,
)
from ...distributed.common import all_gather_object, get_world_size
from ...distributed.common import is_initialized as is_distributed_initialized
from ...models.factory import ModelFactory
Expand Down Expand Up @@ -66,6 +71,9 @@ class InsertCachedAttentionConfig(TransformConfig):
"""Configuration for the insert cached attention transform."""

backend: Optional[str] = Field(default=None, description="The attention backend to use.")
cache_config: CacheConfig = Field(
default_factory=CacheConfig, description="The custom cache configuration to use."
)


@TransformRegistry.register("insert_cached_attention")
Expand Down Expand Up @@ -137,7 +145,9 @@ def _apply(
"""Replace uncached source attention node with corresponding cached attn node."""
attn_descriptor = self.attn_descriptor

cache_config = factory.get_cache_config()
# run field-wise or to combine the cache config from the transform and the factory
# the transform config takes precedence over the factory config
cache_config = self.config.cache_config | factory.get_cache_config()

# Get all attention nodes and their info objects
source_op = attn_descriptor.get_source_attention_op()
Expand Down