Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 15 additions & 1 deletion python/sglang/srt/mem_cache/hiradix_cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,10 +21,15 @@
MatchPrefixParams,
MatchResult,
)
from sglang.srt.mem_cache.memory_pool import MHATokenToKVPool, MLATokenToKVPool
from sglang.srt.mem_cache.memory_pool import (
MHATokenToKVPool,
MLATokenToKVPool,
NSATokenToKVPool,
)
from sglang.srt.mem_cache.memory_pool_host import (
MHATokenToKVPoolHost,
MLATokenToKVPoolHost,
NSATokenToKVPoolHost,
)
from sglang.srt.mem_cache.radix_cache import (
RadixCache,
Expand Down Expand Up @@ -70,6 +75,15 @@ def __init__(self, params: CacheInitParams, server_args: ServerArgs):
server_args.hicache_mem_layout,
allocator_type=server_args.hicache_storage_backend,
)
elif isinstance(self.kv_cache, NSATokenToKVPool):
self.token_to_kv_pool_host = NSATokenToKVPoolHost(
self.kv_cache,
server_args.hicache_ratio,
server_args.hicache_size,
self.page_size,
server_args.hicache_mem_layout,
allocator_type=server_args.hicache_storage_backend,
)
elif isinstance(self.kv_cache, MLATokenToKVPool):
self.token_to_kv_pool_host = MLATokenToKVPoolHost(
self.kv_cache,
Expand Down
220 changes: 189 additions & 31 deletions python/sglang/srt/mem_cache/memory_pool_host.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,12 @@
from sglang.jit_kernel.hicache import (
transfer_hicache_one_layer as jit_transfer_hicache_one_layer,
)
from sglang.srt.mem_cache.memory_pool import KVCache, MHATokenToKVPool, MLATokenToKVPool
from sglang.srt.mem_cache.memory_pool import (
KVCache,
MHATokenToKVPool,
MLATokenToKVPool,
NSATokenToKVPool,
)
from sglang.srt.utils import is_cuda, is_npu, is_xpu

_is_cuda = is_cuda()
Expand Down Expand Up @@ -689,7 +694,9 @@ def __init__(
pin_memory: bool = True,
device: str = "cpu",
allocator_type: str = "default",
override_kv_cache_dim: Optional[int] = None,
):
self.override_kv_cache_dim = override_kv_cache_dim
super().__init__(
device_pool,
host_to_device_ratio,
Expand All @@ -711,13 +718,10 @@ def get_size_per_token(self):
self.kv_lora_rank = self.device_pool.kv_lora_rank
self.qk_rope_head_dim = self.device_pool.qk_rope_head_dim
self.layer_num = self.device_pool.layer_num

return (
(self.kv_lora_rank + self.qk_rope_head_dim)
* 1
* self.dtype.itemsize
* self.layer_num
self.kv_cache_dim = self.override_kv_cache_dim or (
self.kv_lora_rank + self.qk_rope_head_dim
)
return self.kv_cache_dim * self.dtype.itemsize * self.layer_num

def get_ksize_per_token(self):
return self.get_size_per_token()
Expand All @@ -728,22 +732,22 @@ def init_kv_buffer(self):
self.layer_num,
self.size,
1,
self.kv_lora_rank + self.qk_rope_head_dim,
self.kv_cache_dim,
)
elif self.layout == "page_first":
dims = (
self.size,
self.layer_num,
1,
self.kv_lora_rank + self.qk_rope_head_dim,
self.kv_cache_dim,
)
elif self.layout == "page_first_direct":
dims = (
self.page_num,
self.layer_num,
self.page_size,
1,
self.kv_lora_rank + self.qk_rope_head_dim,
self.kv_cache_dim,
)
# Ascend-specific: Aligns with NPUMLATokenToKVPool layout
# Separately allocate k_buffer and v_buffer for easier data transfer.
Expand Down Expand Up @@ -774,9 +778,7 @@ def init_kv_buffer(self):
return self.k_buffer
else:
raise ValueError(f"Unsupported layout: {self.layout}")
self.token_stride_size = (
self.kv_lora_rank + self.qk_rope_head_dim
) * self.dtype.itemsize
self.token_stride_size = self.kv_cache_dim * self.dtype.itemsize
self.layout_dim = self.token_stride_size * self.layer_num

alloc_func = ALLOC_MEMORY_FUNCS[self.device_pool.device]
Expand Down Expand Up @@ -933,7 +935,7 @@ def get_dummy_flat_data_page(self) -> torch.Tensor:
self.layer_num,
self.page_size,
1,
self.kv_lora_rank + self.qk_rope_head_dim,
self.kv_cache_dim,
),
dtype=self.dtype,
device=self.device,
Expand All @@ -946,14 +948,14 @@ def set_from_flat_data_page(self, index: int, data_page: torch.Tensor) -> None:
self.layer_num,
self.page_size,
1,
self.kv_lora_rank + self.qk_rope_head_dim,
self.kv_cache_dim,
)
elif self.layout == "page_first":
self.kv_buffer[index : index + self.page_size, :, :, :] = data_page.reshape(
self.page_size,
self.layer_num,
1,
self.kv_lora_rank + self.qk_rope_head_dim,
self.kv_cache_dim,
)
elif self.layout == "page_first_direct":
real_index = index // self.page_size
Expand All @@ -962,7 +964,7 @@ def set_from_flat_data_page(self, index: int, data_page: torch.Tensor) -> None:
self.layer_num,
self.page_size,
1,
self.kv_lora_rank + self.qk_rope_head_dim,
self.kv_cache_dim,
)
else:
raise ValueError(f"Unsupported layout: {self.layout}")
Expand All @@ -980,38 +982,194 @@ def get_page_buffer_meta(self, indices):
for layer_id in range(self.layer_num):
k_ptr = (
kv_buffer_data_ptr
+ indices[index]
* (self.kv_lora_rank + self.qk_rope_head_dim)
* self.dtype.itemsize
+ layer_id
* self.size
* (self.kv_lora_rank + self.qk_rope_head_dim)
* self.dtype.itemsize
+ indices[index] * self.kv_cache_dim * self.dtype.itemsize
+ layer_id * self.size * self.kv_cache_dim * self.dtype.itemsize
)
ptr_list.append(k_ptr)
element_size = (
self.dtype.itemsize
* self.page_size
* (self.kv_lora_rank + self.qk_rope_head_dim)
)
element_size = self.dtype.itemsize * self.page_size * self.kv_cache_dim
element_size_list = [element_size] * len(ptr_list)
elif self.layout in ["page_first", "page_first_direct"]:
for index in range(0, len(indices), self.page_size):
k_ptr = (
kv_buffer_data_ptr
+ indices[index]
* self.layer_num
* (self.kv_lora_rank + self.qk_rope_head_dim)
* self.kv_cache_dim
* self.dtype.itemsize
)
ptr_list.append(k_ptr)
element_size = (
self.layer_num
* self.dtype.itemsize
* self.page_size
* (self.kv_lora_rank + self.qk_rope_head_dim)
* self.kv_cache_dim
)
element_size_list = [element_size] * len(ptr_list)
else:
raise ValueError(f"Unsupported layout: {self.layout}")
return ptr_list, element_size_list


class NSATokenToKVPoolHost(MLATokenToKVPoolHost):
device_pool: NSATokenToKVPool

def __init__(
self,
device_pool: NSATokenToKVPool,
host_to_device_ratio: float,
host_size: int,
page_size: int,
layout: str,
pin_memory: bool = True,
device: str = "cpu",
allocator_type: str = "default",
):
# Initialize indexer metadata before HostKVCache.__init__ calls get_size_per_token.
self.index_head_dim = device_pool.index_head_dim
self.indexer_quant_block_size = device_pool.quant_block_size
self.indexer_dtype = NSATokenToKVPool.index_k_with_scale_buffer_dtype
self.indexer_size_per_token = (
self.index_head_dim
+ self.index_head_dim // self.indexer_quant_block_size * 4
)
super().__init__(
device_pool,
host_to_device_ratio,
host_size,
page_size,
layout,
pin_memory,
device,
allocator_type,
override_kv_cache_dim=device_pool.kv_cache_dim,
)
self.indexer_page_stride_size = (
self.indexer_size_per_token * self.page_size * self.indexer_dtype.itemsize
)
self.indexer_page_num = (self.size + self.page_size + 1) // self.page_size
self._init_indexer_buffers()
logger.info(
f"NSATokenToKVPoolHost initialized with indexer page stride size: {self.indexer_page_stride_size}, page num: {self.indexer_page_num}"
)

def get_size_per_token(self):
base = super().get_size_per_token()
return (
base
+ self.indexer_size_per_token * self.layer_num * self.indexer_dtype.itemsize
)

def _init_indexer_buffers(self):
alloc_func = ALLOC_MEMORY_FUNCS[self.device_pool.device]
self.index_k_with_scale_buffer = [
alloc_func(
(self.indexer_page_num, self.indexer_page_stride_size),
dtype=self.indexer_dtype,
device=self.device,
pin_memory=self.pin_memory,
allocator=self.allocator,
)
for _ in range(self.layer_num)
]
self.index_k_data_refs = [
self.index_k_with_scale_buffer[i] for i in range(self.layer_num)
]
self.index_k_data_ptrs = torch.tensor(
[x.data_ptr() for x in self.index_k_data_refs],
dtype=torch.uint64,
device=self.device_pool.device,
)
self.index_k_device_ptrs = torch.tensor(
[x.data_ptr() for x in self.device_pool.index_k_with_scale_buffer],
dtype=torch.uint64,
device=self.device_pool.device,
)

def _get_indexer_page_indices(self, host_indices, device_indices):
if host_indices.numel() == 0:
return host_indices, device_indices
if host_indices.numel() % self.page_size != 0:
raise ValueError(
"Index buffer transfer expects page-aligned indices for NSA."
)
host_page_indices = (
host_indices.reshape(-1, self.page_size)[:, 0] // self.page_size
)
device_page_indices = (
device_indices.reshape(-1, self.page_size)[:, 0] // self.page_size
)
return host_page_indices, device_page_indices

def _load_indexer_to_device_per_layer(
self, device_pool, host_indices, device_indices, layer_id, io_backend
):
host_page_indices, device_page_indices = self._get_indexer_page_indices(
host_indices, device_indices
)
use_kernel = io_backend == "kernel" and self.indexer_page_stride_size % 8 == 0
if use_kernel:
transfer_kv_per_layer_mla(
src=self.index_k_with_scale_buffer[layer_id],
dst=device_pool.index_k_with_scale_buffer[layer_id],
src_indices=host_page_indices,
dst_indices=device_page_indices,
item_size=self.indexer_page_stride_size,
)
else:
transfer_kv_direct(
src_layers=[self.index_k_with_scale_buffer[layer_id]],
dst_layers=[device_pool.index_k_with_scale_buffer[layer_id]],
src_indices=host_page_indices,
dst_indices=device_page_indices,
page_size=1,
)

def _backup_indexer_from_device_all_layer(
self, device_pool, host_indices, device_indices, io_backend
):
host_page_indices, device_page_indices = self._get_indexer_page_indices(
host_indices, device_indices
)
use_kernel = io_backend == "kernel" and self.indexer_page_stride_size % 8 == 0
if use_kernel:
transfer_kv_all_layer_mla(
src_layers=self.index_k_device_ptrs,
dst_layers=self.index_k_data_ptrs,
src_indices=device_page_indices,
dst_indices=host_page_indices,
item_size=self.indexer_page_stride_size,
num_layers=self.layer_num,
)
else:
transfer_kv_direct(
src_layers=device_pool.index_k_with_scale_buffer,
dst_layers=self.index_k_with_scale_buffer,
src_indices=device_page_indices,
dst_indices=host_page_indices,
page_size=1,
)

def load_to_device_per_layer(
self,
device_pool,
host_indices,
device_indices,
layer_id,
io_backend,
):
super().load_to_device_per_layer(
device_pool, host_indices, device_indices, layer_id, io_backend
)
self._load_indexer_to_device_per_layer(
device_pool, host_indices, device_indices, layer_id, io_backend
)

def backup_from_device_all_layer(
self, device_pool, host_indices, device_indices, io_backend
):
super().backup_from_device_all_layer(
device_pool, host_indices, device_indices, io_backend
)
self._backup_indexer_from_device_all_layer(
device_pool, host_indices, device_indices, io_backend
)
Loading
Loading