From 38eb71c3b62fc36e9afbb542a1d7c220a2d7fbbf Mon Sep 17 00:00:00 2001 From: Kamil Kaczor Date: Thu, 4 Dec 2025 15:11:12 +0200 Subject: [PATCH 1/3] Reduce defrag operations in non-apc runs Signed-off-by: Kamil Kaczor --- vllm_gaudi/extension/defragmentation.py | 14 +++++++++----- 1 file changed, 9 insertions(+), 5 deletions(-) diff --git a/vllm_gaudi/extension/defragmentation.py b/vllm_gaudi/extension/defragmentation.py index a3efbfb97..47071950e 100644 --- a/vllm_gaudi/extension/defragmentation.py +++ b/vllm_gaudi/extension/defragmentation.py @@ -22,6 +22,8 @@ class CacheSwapUtils(torch.nn.Module): def __init__(self, kv_caches: tuple[tuple[torch.tensor, torch.tensor]], block_size: int): super().__init__() self.block_size = block_size + config = get_config() + self.apc = config.prefix_caching self.kv_caches = tuple(kv_caches) self.block_slots = torch.arange(0, self.block_size, dtype=torch.long, device=kv_caches[0][0].device) self.is_mla = all([cache[1] is None for cache in self.kv_caches]) @@ -29,15 +31,17 @@ def __init__(self, kv_caches: tuple[tuple[torch.tensor, torch.tensor]], block_si def forward(self, srcs: torch.tensor, dsts: torch.tensor, caches: list[torch.tensor]): """ Internal method wrapped in HPU/t.compile graphs""" htorch.core.mark_step() - srcs = ((srcs * self.block_size).unsqueeze(-1) + self.block_slots).flatten() - dsts = ((dsts * self.block_size).unsqueeze(-1) + self.block_slots).flatten() + srcs = ((srcs * self.block_size).unsqueeze(-1) + self.block_slots).flatten() # used + dsts = ((dsts * self.block_size).unsqueeze(-1) + self.block_slots).flatten() # free for cache in caches: prev_srcs = cache.index_select(0, srcs) - prev_dsts = cache.index_select(0, dsts) + # using apc we need to swap free blocks back as they can contain cached data + if self.apc: + prev_dsts = cache.index_select(0, dsts) + cache.index_copy_(0, srcs, prev_dsts) + prev_dsts = None cache.index_copy_(0, dsts, prev_srcs) - cache.index_copy_(0, srcs, prev_dsts) prev_srcs = None - prev_dsts = None srcs = None dsts = None htorch.core.mark_step() From a6458f0c5dc5e3839ee24f3db7b0cca0d1a029d6 Mon Sep 17 00:00:00 2001 From: Kamil Kaczor Date: Mon, 8 Dec 2025 07:42:17 +0100 Subject: [PATCH 2/3] Fix precommit, rename condition --- vllm_gaudi/extension/defragmentation.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/vllm_gaudi/extension/defragmentation.py b/vllm_gaudi/extension/defragmentation.py index 47071950e..3e46aef7a 100644 --- a/vllm_gaudi/extension/defragmentation.py +++ b/vllm_gaudi/extension/defragmentation.py @@ -23,7 +23,7 @@ def __init__(self, kv_caches: tuple[tuple[torch.tensor, torch.tensor]], block_si super().__init__() self.block_size = block_size config = get_config() - self.apc = config.prefix_caching + self.enable_prefix_caching = config.prefix_caching self.kv_caches = tuple(kv_caches) self.block_slots = torch.arange(0, self.block_size, dtype=torch.long, device=kv_caches[0][0].device) self.is_mla = all([cache[1] is None for cache in self.kv_caches]) @@ -31,12 +31,12 @@ def __init__(self, kv_caches: tuple[tuple[torch.tensor, torch.tensor]], block_si def forward(self, srcs: torch.tensor, dsts: torch.tensor, caches: list[torch.tensor]): """ Internal method wrapped in HPU/t.compile graphs""" htorch.core.mark_step() - srcs = ((srcs * self.block_size).unsqueeze(-1) + self.block_slots).flatten() # used - dsts = ((dsts * self.block_size).unsqueeze(-1) + self.block_slots).flatten() # free + srcs = ((srcs * self.block_size).unsqueeze(-1) + self.block_slots).flatten() # used + dsts = ((dsts * self.block_size).unsqueeze(-1) + self.block_slots).flatten() # free for cache in caches: prev_srcs = cache.index_select(0, srcs) # using apc we need to swap free blocks back as they can contain cached data - if self.apc: + if self.enable_prefix_caching: prev_dsts = cache.index_select(0, dsts) cache.index_copy_(0, srcs, prev_dsts) prev_dsts = None From d8bd751db00bfd86f356e8a44499283edf021a12 Mon Sep 17 00:00:00 2001 From: Kamil Kaczor Date: Tue, 9 Dec 2025 14:41:19 +0100 Subject: [PATCH 3/3] Skip creating get_config var --- vllm_gaudi/extension/defragmentation.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/vllm_gaudi/extension/defragmentation.py b/vllm_gaudi/extension/defragmentation.py index 3e46aef7a..a5dc4f643 100644 --- a/vllm_gaudi/extension/defragmentation.py +++ b/vllm_gaudi/extension/defragmentation.py @@ -22,8 +22,7 @@ class CacheSwapUtils(torch.nn.Module): def __init__(self, kv_caches: tuple[tuple[torch.tensor, torch.tensor]], block_size: int): super().__init__() self.block_size = block_size - config = get_config() - self.enable_prefix_caching = config.prefix_caching + self.enable_prefix_caching = get_config().prefix_caching self.kv_caches = tuple(kv_caches) self.block_slots = torch.arange(0, self.block_size, dtype=torch.long, device=kv_caches[0][0].device) self.is_mla = all([cache[1] is None for cache in self.kv_caches])