Skip to content
Open
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 9 additions & 5 deletions vllm_gaudi/extension/defragmentation.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,22 +22,26 @@ class CacheSwapUtils(torch.nn.Module):
def __init__(self, kv_caches: tuple[tuple[torch.tensor, torch.tensor]], block_size: int):
super().__init__()
self.block_size = block_size
config = get_config()
self.apc = config.prefix_caching
self.kv_caches = tuple(kv_caches)
self.block_slots = torch.arange(0, self.block_size, dtype=torch.long, device=kv_caches[0][0].device)
self.is_mla = all([cache[1] is None for cache in self.kv_caches])

def forward(self, srcs: torch.tensor, dsts: torch.tensor, caches: list[torch.tensor]):
""" Internal method wrapped in HPU/t.compile graphs"""
htorch.core.mark_step()
srcs = ((srcs * self.block_size).unsqueeze(-1) + self.block_slots).flatten()
dsts = ((dsts * self.block_size).unsqueeze(-1) + self.block_slots).flatten()
srcs = ((srcs * self.block_size).unsqueeze(-1) + self.block_slots).flatten() # used
dsts = ((dsts * self.block_size).unsqueeze(-1) + self.block_slots).flatten() # free
for cache in caches:
prev_srcs = cache.index_select(0, srcs)
prev_dsts = cache.index_select(0, dsts)
# using apc we need to swap free blocks back as they can contain cached data
if self.apc:
prev_dsts = cache.index_select(0, dsts)
cache.index_copy_(0, srcs, prev_dsts)
prev_dsts = None
cache.index_copy_(0, dsts, prev_srcs)
cache.index_copy_(0, srcs, prev_dsts)
prev_srcs = None
prev_dsts = None
srcs = None
dsts = None
htorch.core.mark_step()
Expand Down
Loading