Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions python/sglang/srt/disaggregation/kv_events.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,16 +56,24 @@ class KVCacheEvent(
"""Base class for all KV cache-related events"""


# Medium values for storage tiers (compatible with vLLM)
MEDIUM_GPU = "GPU"
MEDIUM_CPU_TIER1 = "CPU_TIER1"
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does CPU_TIER1 means HiRadixCache? And CPU_TIER2 means remote cpu memory pool?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's a bit confusing actually.

  • CPU_TIER1 is actually L2 storage. Its the host_to_device and device_to_host buffers in HiRadixCache. Pinned memory
  • CPU_TIER2 is actually L3 storage. This is the remote storage backend

I'm not sure why VLLM does it this way. I am ok with changing this btw

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How about MEDIUM_CPU = "CPU_PINNED"? As discussed in slack.
And let's the backend storage to emit the kvevents, then, no need MEDIUM_CPU_TIER2 here?

MEDIUM_CPU_TIER2 = "CPU_TIER2"


class BlockStored(KVCacheEvent):
block_hashes: list[int]
parent_block_hash: Optional[int]
token_ids: list[int]
block_size: int
lora_id: Optional[int]
medium: Optional[str] = None


class BlockRemoved(KVCacheEvent):
block_hashes: list[int]
medium: Optional[str] = None


class AllBlocksCleared(KVCacheEvent):
Expand Down
6 changes: 5 additions & 1 deletion python/sglang/srt/mem_cache/radix_cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@
logger = logging.getLogger(__name__)

from sglang.srt.disaggregation.kv_events import (
MEDIUM_GPU,
AllBlocksCleared,
BlockRemoved,
BlockStored,
Expand Down Expand Up @@ -822,6 +823,7 @@ def _record_store_event(self, node: TreeNode):
token_ids=page_tokens,
block_size=len(page_tokens),
lora_id=None,
medium=MEDIUM_GPU,
)
)

Expand All @@ -843,7 +845,9 @@ def _record_remove_event(self, node: TreeNode):

block_hash = hash_str_to_int64(node.hash_value[page_index])

self.kv_event_queue.append(BlockRemoved(block_hashes=[block_hash]))
self.kv_event_queue.append(
BlockRemoved(block_hashes=[block_hash], medium=MEDIUM_GPU)
)

page_index += 1

Expand Down
Loading