Skip to content

Commit df33486

Browse files
authored
[Hybrid] A simpler algorithm to find kernel_block_size (vllm-project#26476)
Signed-off-by: Chen Zhang <[email protected]>
1 parent 0e0a638 commit df33486

File tree

3 files changed

+149
-85
lines changed

3 files changed

+149
-85
lines changed

tests/v1/worker/test_gpu_model_runner.py

Lines changed: 53 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
import torch
77

88
from vllm.attention import Attention
9+
from vllm.attention.backends.abstract import MultipleOf
910
from vllm.config import (
1011
CacheConfig,
1112
ModelConfig,
@@ -34,6 +35,7 @@
3435
from vllm.v1.sample.metadata import SamplingMetadata
3536
from vllm.v1.worker.gpu_input_batch import InputBatch
3637
from vllm.v1.worker.gpu_model_runner import GPUModelRunner
38+
from vllm.v1.worker.utils import AttentionGroup
3739

3840
BLOCK_SIZE = 16
3941
NUM_BLOCKS = 10
@@ -181,6 +183,57 @@ def _is_req_state_block_table_match(model_runner, req_id: str) -> bool:
181183
).all()
182184

183185

186+
def _make_mock_backend_for_kernel_block_size(
187+
supported_sizes: list[int | MultipleOf],
188+
):
189+
class _MockBackend:
190+
@staticmethod
191+
def get_supported_kernel_block_size():
192+
return supported_sizes
193+
194+
return _MockBackend()
195+
196+
197+
def _make_kv_cache_spec() -> FullAttentionSpec:
198+
return FullAttentionSpec(block_size=1, num_kv_heads=1, head_size=1, dtype="float16")
199+
200+
201+
def test_select_common_block_size_prefers_manager_block_size():
202+
backend_a = _make_mock_backend_for_kernel_block_size([MultipleOf(32)])
203+
backend_b = _make_mock_backend_for_kernel_block_size([64, MultipleOf(16)])
204+
attn_groups = [
205+
AttentionGroup(backend_a, [], [], _make_kv_cache_spec(), 0),
206+
AttentionGroup(backend_b, [], [], _make_kv_cache_spec(), 0),
207+
]
208+
209+
selected_size = GPUModelRunner.select_common_block_size(128, attn_groups)
210+
assert selected_size == 128
211+
212+
213+
def test_select_common_block_size_uses_largest_shared_int():
214+
backend_a = _make_mock_backend_for_kernel_block_size([128, 64])
215+
backend_b = _make_mock_backend_for_kernel_block_size([64, 32])
216+
attn_groups = [
217+
AttentionGroup(backend_a, [], [], _make_kv_cache_spec(), 0),
218+
AttentionGroup(backend_b, [], [], _make_kv_cache_spec(), 0),
219+
]
220+
221+
selected_size = GPUModelRunner.select_common_block_size(256, attn_groups)
222+
assert selected_size == 64
223+
224+
225+
def test_select_common_block_size_no_valid_option():
226+
backend_a = _make_mock_backend_for_kernel_block_size([64])
227+
backend_b = _make_mock_backend_for_kernel_block_size([MultipleOf(16)])
228+
attn_groups = [
229+
AttentionGroup(backend_a, [], [], _make_kv_cache_spec(), 0),
230+
AttentionGroup(backend_b, [], [], _make_kv_cache_spec(), 0),
231+
]
232+
233+
with pytest.raises(ValueError):
234+
GPUModelRunner.select_common_block_size(48, attn_groups)
235+
236+
184237
def test_update_states_new_request(model_runner, dist_init):
185238
req_id = "req_0"
186239

vllm/v1/worker/gpu_model_runner.py

Lines changed: 91 additions & 84 deletions
Original file line numberDiff line numberDiff line change
@@ -3978,6 +3978,7 @@ def get_attn_backends_for_group(
39783978

39793979
def create_attn_groups(
39803980
attn_backends_map: dict[AttentionGroupKey, list[str]],
3981+
kv_cache_group_id: int,
39813982
) -> list[AttentionGroup]:
39823983
attn_groups: list[AttentionGroup] = []
39833984
for (attn_backend, kv_cache_spec), layer_names in attn_backends_map.items():
@@ -3987,6 +3988,7 @@ def create_attn_groups(
39873988
kv_cache_spec,
39883989
self.vllm_config,
39893990
self.device,
3991+
kv_cache_group_id,
39903992
num_metadata_builders=1
39913993
if not self.parallel_config.enable_dbo
39923994
else 2,
@@ -4005,8 +4007,8 @@ def create_attn_groups(
40054007
# Resolve cudagraph_mode before actually initialize metadata_builders
40064008
self._check_and_update_cudagraph_mode(attention_backend_set)
40074009

4008-
for attn_backends_map in attention_backend_maps:
4009-
self.attn_groups.append(create_attn_groups(attn_backends_map))
4010+
for i, attn_backend_map in enumerate(attention_backend_maps):
4011+
self.attn_groups.append(create_attn_groups(attn_backend_map, i))
40104012

40114013
# Calculate reorder batch threshold (if needed)
40124014
self.calculate_reorder_batch_threshold()
@@ -4156,104 +4158,96 @@ def calculate_reorder_batch_threshold(self) -> None:
41564158
return
41574159
self.reorder_batch_threshold = reduce(min_none_high, reorder_batch_thresholds)
41584160

4159-
def _find_compatible_block_sizes(
4160-
self,
4161-
kv_manager_block_size: int,
4162-
backend_cls: type[AttentionBackend],
4163-
return_all: bool = False,
4164-
) -> list[int]:
4165-
"""
4166-
Find compatible block sizes for a backend.
4167-
4168-
Args:
4169-
kv_manager_block_size: Physical block size of KV cache
4170-
backend_cls: Attention backend class
4171-
return_all: Return all compatible sizes if True, max size if False
4172-
4173-
Returns:
4174-
Compatible block size(s) based on return_all parameter
4175-
4176-
Raises:
4177-
ValueError: If no compatible block size found
4178-
"""
4179-
supported_block_size = backend_cls.get_supported_kernel_block_size()
4180-
compatible_sizes = []
4181-
4182-
for block_size in supported_block_size:
4183-
if isinstance(block_size, int):
4184-
if kv_manager_block_size % block_size == 0:
4185-
compatible_sizes.append(block_size)
4186-
elif (
4187-
isinstance(block_size, MultipleOf)
4188-
and kv_manager_block_size % block_size.base == 0
4189-
):
4190-
compatible_sizes.append(kv_manager_block_size)
4191-
4192-
if not compatible_sizes:
4193-
raise ValueError(f"No compatible block size for {kv_manager_block_size}")
4194-
4195-
return compatible_sizes if return_all else [max(compatible_sizes)]
4196-
4197-
def _select_common_block_size(
4198-
self, kv_manager_block_size: int, attn_groups: list[AttentionGroup]
4161+
@staticmethod
4162+
def select_common_block_size(
4163+
kv_manager_block_size: int, attn_groups: list[AttentionGroup]
41994164
) -> int:
42004165
"""
4201-
Select common block size for all backends.
4166+
Select a block size that is supported by all backends and is a factor of
4167+
kv_manager_block_size.
4168+
4169+
If kv_manager_block_size is supported by all backends, return it directly.
4170+
Otherwise, return the max supported size.
42024171
42034172
Args:
42044173
kv_manager_block_size: Block size of KV cache
42054174
attn_groups: List of attention groups
42064175
42074176
Returns:
4208-
Block size supported by all backends,
4209-
prioritizing cache_config.block_size
4177+
The selected block size
42104178
42114179
Raises:
4212-
ValueError: If no common block size found
4180+
ValueError: If no valid block size found
42134181
"""
4214-
all_backend_supports = []
42154182

4216-
for attn_group in attn_groups:
4217-
compatible_sizes = self._find_compatible_block_sizes(
4218-
kv_manager_block_size, attn_group.backend, return_all=True
4219-
)
4220-
supported_sizes = sorted(list(set(compatible_sizes)), reverse=True)
4221-
all_backend_supports.append(set(supported_sizes))
4222-
4223-
common_supported_sizes = set.intersection(*all_backend_supports)
4224-
4225-
if not common_supported_sizes:
4226-
error_msg = f"No common block size for {kv_manager_block_size}. "
4227-
for i, attn_group in enumerate(attn_groups):
4228-
supported = all_backend_supports[i]
4229-
error_msg += (
4230-
f"Backend {attn_group.backend} supports: {sorted(supported)}. "
4231-
)
4232-
raise ValueError(error_msg)
4233-
4234-
if self.cache_config.block_size in common_supported_sizes:
4235-
return self.cache_config.block_size
4183+
def block_size_is_supported(
4184+
backends: list[type[AttentionBackend]], block_size: int
4185+
) -> bool:
4186+
"""
4187+
Check if the block size is supported by all backends.
4188+
"""
4189+
for backend in backends:
4190+
is_supported = False
4191+
for supported_size in backend.get_supported_kernel_block_size():
4192+
if isinstance(supported_size, int):
4193+
if block_size == supported_size:
4194+
is_supported = True
4195+
elif isinstance(supported_size, MultipleOf):
4196+
if block_size % supported_size.base == 0:
4197+
is_supported = True
4198+
else:
4199+
raise ValueError(f"Unknown supported size: {supported_size}")
4200+
if not is_supported:
4201+
return False
4202+
return True
4203+
4204+
backends = [group.backend for group in attn_groups]
4205+
4206+
# Case 1: if the block_size of kv cache manager is supported by all backends,
4207+
# return it directly
4208+
if block_size_is_supported(backends, kv_manager_block_size):
4209+
return kv_manager_block_size
4210+
4211+
# Case 2: otherwise, the block_size must be an `int`-format supported size of
4212+
# at least one backend. Iterate over all `int`-format supported sizes in
4213+
# descending order and return the first one that is supported by all backends.
4214+
# Simple proof:
4215+
# If the supported size b is in MultipleOf(x_i) format for all attention
4216+
# backends i, and b a factor of kv_manager_block_size, then
4217+
# kv_manager_block_size also satisfies MultipleOf(x_i) for all i. We will
4218+
# return kv_manager_block_size in case 1.
4219+
all_int_supported_sizes = set(
4220+
supported_size
4221+
for backend in backends
4222+
for supported_size in backend.get_supported_kernel_block_size()
4223+
if isinstance(supported_size, int)
4224+
)
42364225

4237-
return max(common_supported_sizes)
4226+
for supported_size in sorted(all_int_supported_sizes, reverse=True):
4227+
if kv_manager_block_size % supported_size != 0:
4228+
continue
4229+
if block_size_is_supported(backends, supported_size):
4230+
return supported_size
4231+
raise ValueError(f"No common block size for {kv_manager_block_size}. ")
42384232

4239-
def may_reinitialize_input_batch(self, kv_cache_config: KVCacheConfig) -> None:
4233+
def may_reinitialize_input_batch(
4234+
self, kv_cache_config: KVCacheConfig, kernel_block_sizes: list[int]
4235+
) -> None:
42404236
"""
42414237
Re-initialize the input batch if the block sizes are different from
42424238
`[self.cache_config.block_size]`. This usually happens when there
42434239
are multiple KV cache groups.
42444240
42454241
Args:
42464242
kv_cache_config: The KV cache configuration.
4243+
kernel_block_sizes: The kernel block sizes for each KV cache group.
42474244
"""
42484245
block_sizes = [
42494246
kv_cache_group.kv_cache_spec.block_size
42504247
for kv_cache_group in kv_cache_config.kv_cache_groups
42514248
if not isinstance(kv_cache_group.kv_cache_spec, EncoderOnlyAttentionSpec)
42524249
]
42534250

4254-
# Generate kernel_block_sizes that matches each block_size
4255-
kernel_block_sizes = self._prepare_kernel_block_sizes(kv_cache_config)
4256-
42574251
if block_sizes != [self.cache_config.block_size] or kernel_block_sizes != [
42584252
self.cache_config.block_size
42594253
]:
@@ -4354,7 +4348,7 @@ def _prepare_kernel_block_sizes(self, kv_cache_config: KVCacheConfig) -> list[in
43544348
# all backends in the group.
43554349
attn_groups = self.attn_groups[kv_cache_group_id]
43564350
kv_manager_block_size = kv_cache_group.kv_cache_spec.block_size
4357-
selected_kernel_size = self._select_common_block_size(
4351+
selected_kernel_size = self.select_common_block_size(
43584352
kv_manager_block_size, attn_groups
43594353
)
43604354
kernel_block_sizes.append(selected_kernel_size)
@@ -4372,6 +4366,7 @@ def _reshape_kv_cache_tensors(
43724366
self,
43734367
kv_cache_config: KVCacheConfig,
43744368
kv_cache_raw_tensors: dict[str, torch.Tensor],
4369+
kernel_block_sizes: list[int],
43754370
) -> dict[str, torch.Tensor]:
43764371
"""
43774372
Reshape the KV cache tensors to the desired shape and dtype.
@@ -4380,6 +4375,7 @@ def _reshape_kv_cache_tensors(
43804375
kv_cache_config: The KV cache config
43814376
kv_cache_raw_tensors: The KV cache buffer of each layer, with
43824377
correct size but uninitialized shape.
4378+
kernel_block_sizes: The kernel block sizes for each KV cache group.
43834379
Returns:
43844380
Dict[str, torch.Tensor]: A map between layer names to their
43854381
corresponding memory buffer for KV cache.
@@ -4389,6 +4385,10 @@ def _reshape_kv_cache_tensors(
43894385
for group in self._kv_cache_spec_attn_group_iterator():
43904386
kv_cache_spec = group.kv_cache_spec
43914387
attn_backend = group.backend
4388+
if group.kv_cache_group_id == len(kernel_block_sizes):
4389+
# There may be a last group for layers without kv cache.
4390+
continue
4391+
kernel_block_size = kernel_block_sizes[group.kv_cache_group_id]
43924392
for layer_name in group.layer_names:
43934393
if layer_name in self.runner_only_attn_layers:
43944394
continue
@@ -4397,24 +4397,21 @@ def _reshape_kv_cache_tensors(
43974397
num_blocks = raw_tensor.numel() // kv_cache_spec.page_size_bytes
43984398
if isinstance(kv_cache_spec, AttentionSpec):
43994399
has_attn = True
4400-
kv_manager_block_size = kv_cache_spec.block_size
4401-
kernel_size_list = self._find_compatible_block_sizes(
4402-
kv_manager_block_size, attn_backend, return_all=False
4400+
num_blocks_per_kv_block = (
4401+
kv_cache_spec.block_size // kernel_block_size
44034402
)
4404-
kernel_size = kernel_size_list[0]
4405-
num_blocks_per_kv_block = kv_manager_block_size // kernel_size
44064403
kernel_num_blocks = num_blocks * num_blocks_per_kv_block
44074404

44084405
kv_cache_shape = attn_backend.get_kv_cache_shape(
44094406
kernel_num_blocks,
4410-
kernel_size,
4407+
kernel_block_size,
44114408
kv_cache_spec.num_kv_heads,
44124409
kv_cache_spec.head_size,
44134410
cache_dtype_str=self.cache_config.cache_dtype,
44144411
)
44154412
dtype = kv_cache_spec.dtype
44164413
try:
4417-
kv_cache_stride_order = attn_backend.get_kv_cache_stride_order() # noqa: E501
4414+
kv_cache_stride_order = attn_backend.get_kv_cache_stride_order()
44184415
assert len(kv_cache_stride_order) == len(kv_cache_shape)
44194416
except (AttributeError, NotImplementedError):
44204417
kv_cache_stride_order = tuple(range(len(kv_cache_shape)))
@@ -4497,13 +4494,15 @@ def _update_hybrid_attention_mamba_layout(
44974494
)
44984495

44994496
def initialize_kv_cache_tensors(
4500-
self, kv_cache_config: KVCacheConfig
4497+
self, kv_cache_config: KVCacheConfig, kernel_block_sizes: list[int]
45014498
) -> dict[str, torch.Tensor]:
45024499
"""
45034500
Initialize the memory buffer for KV cache.
45044501
45054502
Args:
45064503
kv_cache_config: The KV cache config
4504+
kernel_block_sizes: The kernel block sizes for each KV cache group.
4505+
45074506
Returns:
45084507
Dict[str, torch.Tensor]: A map between layer names to their
45094508
corresponding memory buffer for KV cache.
@@ -4512,7 +4511,7 @@ def initialize_kv_cache_tensors(
45124511
kv_cache_raw_tensors = self._allocate_kv_cache_tensors(kv_cache_config)
45134512
# Change the memory buffer to the desired shape
45144513
kv_caches = self._reshape_kv_cache_tensors(
4515-
kv_cache_config, kv_cache_raw_tensors
4514+
kv_cache_config, kv_cache_raw_tensors, kernel_block_sizes
45164515
)
45174516

45184517
# Set up cross-layer KV cache sharing
@@ -4571,9 +4570,17 @@ def initialize_kv_cache(self, kv_cache_config: KVCacheConfig) -> None:
45714570
self.may_add_encoder_only_layers_to_kv_cache_config()
45724571
self.maybe_add_kv_sharing_layers_to_kv_cache_groups(kv_cache_config)
45734572
self.initialize_attn_backend(kv_cache_config)
4573+
# The kernel block size for all KV cache groups. For example, if
4574+
# kv_cache_manager uses block_size 256 for a given group, but the attention
4575+
# backends for that group only supports block_size 64, we will return
4576+
# kernel_block_size 64 and split the 256-token-block to 4 blocks with 64
4577+
# tokens each.
4578+
kernel_block_sizes = self._prepare_kernel_block_sizes(kv_cache_config)
45744579
# Reinitialize need to after initialize_attn_backend
4575-
self.may_reinitialize_input_batch(kv_cache_config)
4576-
kv_caches = self.initialize_kv_cache_tensors(kv_cache_config)
4580+
self.may_reinitialize_input_batch(kv_cache_config, kernel_block_sizes)
4581+
kv_caches = self.initialize_kv_cache_tensors(
4582+
kv_cache_config, kernel_block_sizes
4583+
)
45774584

45784585
if self.speculative_config and self.speculative_config.use_eagle():
45794586
assert isinstance(self.drafter, EagleProposer)

0 commit comments

Comments
 (0)