@@ -3978,6 +3978,7 @@ def get_attn_backends_for_group(
39783978
39793979 def create_attn_groups (
39803980 attn_backends_map : dict [AttentionGroupKey , list [str ]],
3981+ kv_cache_group_id : int ,
39813982 ) -> list [AttentionGroup ]:
39823983 attn_groups : list [AttentionGroup ] = []
39833984 for (attn_backend , kv_cache_spec ), layer_names in attn_backends_map .items ():
@@ -3987,6 +3988,7 @@ def create_attn_groups(
39873988 kv_cache_spec ,
39883989 self .vllm_config ,
39893990 self .device ,
3991+ kv_cache_group_id ,
39903992 num_metadata_builders = 1
39913993 if not self .parallel_config .enable_dbo
39923994 else 2 ,
@@ -4005,8 +4007,8 @@ def create_attn_groups(
40054007 # Resolve cudagraph_mode before actually initialize metadata_builders
40064008 self ._check_and_update_cudagraph_mode (attention_backend_set )
40074009
4008- for attn_backends_map in attention_backend_maps :
4009- self .attn_groups .append (create_attn_groups (attn_backends_map ))
4010+ for i , attn_backend_map in enumerate ( attention_backend_maps ) :
4011+ self .attn_groups .append (create_attn_groups (attn_backend_map , i ))
40104012
40114013 # Calculate reorder batch threshold (if needed)
40124014 self .calculate_reorder_batch_threshold ()
@@ -4156,104 +4158,96 @@ def calculate_reorder_batch_threshold(self) -> None:
41564158 return
41574159 self .reorder_batch_threshold = reduce (min_none_high , reorder_batch_thresholds )
41584160
4159- def _find_compatible_block_sizes (
4160- self ,
4161- kv_manager_block_size : int ,
4162- backend_cls : type [AttentionBackend ],
4163- return_all : bool = False ,
4164- ) -> list [int ]:
4165- """
4166- Find compatible block sizes for a backend.
4167-
4168- Args:
4169- kv_manager_block_size: Physical block size of KV cache
4170- backend_cls: Attention backend class
4171- return_all: Return all compatible sizes if True, max size if False
4172-
4173- Returns:
4174- Compatible block size(s) based on return_all parameter
4175-
4176- Raises:
4177- ValueError: If no compatible block size found
4178- """
4179- supported_block_size = backend_cls .get_supported_kernel_block_size ()
4180- compatible_sizes = []
4181-
4182- for block_size in supported_block_size :
4183- if isinstance (block_size , int ):
4184- if kv_manager_block_size % block_size == 0 :
4185- compatible_sizes .append (block_size )
4186- elif (
4187- isinstance (block_size , MultipleOf )
4188- and kv_manager_block_size % block_size .base == 0
4189- ):
4190- compatible_sizes .append (kv_manager_block_size )
4191-
4192- if not compatible_sizes :
4193- raise ValueError (f"No compatible block size for { kv_manager_block_size } " )
4194-
4195- return compatible_sizes if return_all else [max (compatible_sizes )]
4196-
4197- def _select_common_block_size (
4198- self , kv_manager_block_size : int , attn_groups : list [AttentionGroup ]
4161+ @staticmethod
4162+ def select_common_block_size (
4163+ kv_manager_block_size : int , attn_groups : list [AttentionGroup ]
41994164 ) -> int :
42004165 """
4201- Select common block size for all backends.
4166+ Select a block size that is supported by all backends and is a factor of
4167+ kv_manager_block_size.
4168+
4169+ If kv_manager_block_size is supported by all backends, return it directly.
4170+ Otherwise, return the max supported size.
42024171
42034172 Args:
42044173 kv_manager_block_size: Block size of KV cache
42054174 attn_groups: List of attention groups
42064175
42074176 Returns:
4208- Block size supported by all backends,
4209- prioritizing cache_config.block_size
4177+ The selected block size
42104178
42114179 Raises:
4212- ValueError: If no common block size found
4180+ ValueError: If no valid block size found
42134181 """
4214- all_backend_supports = []
42154182
4216- for attn_group in attn_groups :
4217- compatible_sizes = self ._find_compatible_block_sizes (
4218- kv_manager_block_size , attn_group .backend , return_all = True
4219- )
4220- supported_sizes = sorted (list (set (compatible_sizes )), reverse = True )
4221- all_backend_supports .append (set (supported_sizes ))
4222-
4223- common_supported_sizes = set .intersection (* all_backend_supports )
4224-
4225- if not common_supported_sizes :
4226- error_msg = f"No common block size for { kv_manager_block_size } . "
4227- for i , attn_group in enumerate (attn_groups ):
4228- supported = all_backend_supports [i ]
4229- error_msg += (
4230- f"Backend { attn_group .backend } supports: { sorted (supported )} . "
4231- )
4232- raise ValueError (error_msg )
4233-
4234- if self .cache_config .block_size in common_supported_sizes :
4235- return self .cache_config .block_size
4183+ def block_size_is_supported (
4184+ backends : list [type [AttentionBackend ]], block_size : int
4185+ ) -> bool :
4186+ """
4187+ Check if the block size is supported by all backends.
4188+ """
4189+ for backend in backends :
4190+ is_supported = False
4191+ for supported_size in backend .get_supported_kernel_block_size ():
4192+ if isinstance (supported_size , int ):
4193+ if block_size == supported_size :
4194+ is_supported = True
4195+ elif isinstance (supported_size , MultipleOf ):
4196+ if block_size % supported_size .base == 0 :
4197+ is_supported = True
4198+ else :
4199+ raise ValueError (f"Unknown supported size: { supported_size } " )
4200+ if not is_supported :
4201+ return False
4202+ return True
4203+
4204+ backends = [group .backend for group in attn_groups ]
4205+
4206+ # Case 1: if the block_size of kv cache manager is supported by all backends,
4207+ # return it directly
4208+ if block_size_is_supported (backends , kv_manager_block_size ):
4209+ return kv_manager_block_size
4210+
4211+ # Case 2: otherwise, the block_size must be an `int`-format supported size of
4212+ # at least one backend. Iterate over all `int`-format supported sizes in
4213+ # descending order and return the first one that is supported by all backends.
4214+ # Simple proof:
4215+ # If the supported size b is in MultipleOf(x_i) format for all attention
4216+ # backends i, and b a factor of kv_manager_block_size, then
4217+ # kv_manager_block_size also satisfies MultipleOf(x_i) for all i. We will
4218+ # return kv_manager_block_size in case 1.
4219+ all_int_supported_sizes = set (
4220+ supported_size
4221+ for backend in backends
4222+ for supported_size in backend .get_supported_kernel_block_size ()
4223+ if isinstance (supported_size , int )
4224+ )
42364225
4237- return max (common_supported_sizes )
4226+ for supported_size in sorted (all_int_supported_sizes , reverse = True ):
4227+ if kv_manager_block_size % supported_size != 0 :
4228+ continue
4229+ if block_size_is_supported (backends , supported_size ):
4230+ return supported_size
4231+ raise ValueError (f"No common block size for { kv_manager_block_size } . " )
42384232
4239- def may_reinitialize_input_batch (self , kv_cache_config : KVCacheConfig ) -> None :
4233+ def may_reinitialize_input_batch (
4234+ self , kv_cache_config : KVCacheConfig , kernel_block_sizes : list [int ]
4235+ ) -> None :
42404236 """
42414237 Re-initialize the input batch if the block sizes are different from
42424238 `[self.cache_config.block_size]`. This usually happens when there
42434239 are multiple KV cache groups.
42444240
42454241 Args:
42464242 kv_cache_config: The KV cache configuration.
4243+ kernel_block_sizes: The kernel block sizes for each KV cache group.
42474244 """
42484245 block_sizes = [
42494246 kv_cache_group .kv_cache_spec .block_size
42504247 for kv_cache_group in kv_cache_config .kv_cache_groups
42514248 if not isinstance (kv_cache_group .kv_cache_spec , EncoderOnlyAttentionSpec )
42524249 ]
42534250
4254- # Generate kernel_block_sizes that matches each block_size
4255- kernel_block_sizes = self ._prepare_kernel_block_sizes (kv_cache_config )
4256-
42574251 if block_sizes != [self .cache_config .block_size ] or kernel_block_sizes != [
42584252 self .cache_config .block_size
42594253 ]:
@@ -4354,7 +4348,7 @@ def _prepare_kernel_block_sizes(self, kv_cache_config: KVCacheConfig) -> list[in
43544348 # all backends in the group.
43554349 attn_groups = self .attn_groups [kv_cache_group_id ]
43564350 kv_manager_block_size = kv_cache_group .kv_cache_spec .block_size
4357- selected_kernel_size = self ._select_common_block_size (
4351+ selected_kernel_size = self .select_common_block_size (
43584352 kv_manager_block_size , attn_groups
43594353 )
43604354 kernel_block_sizes .append (selected_kernel_size )
@@ -4372,6 +4366,7 @@ def _reshape_kv_cache_tensors(
43724366 self ,
43734367 kv_cache_config : KVCacheConfig ,
43744368 kv_cache_raw_tensors : dict [str , torch .Tensor ],
4369+ kernel_block_sizes : list [int ],
43754370 ) -> dict [str , torch .Tensor ]:
43764371 """
43774372 Reshape the KV cache tensors to the desired shape and dtype.
@@ -4380,6 +4375,7 @@ def _reshape_kv_cache_tensors(
43804375 kv_cache_config: The KV cache config
43814376 kv_cache_raw_tensors: The KV cache buffer of each layer, with
43824377 correct size but uninitialized shape.
4378+ kernel_block_sizes: The kernel block sizes for each KV cache group.
43834379 Returns:
43844380 Dict[str, torch.Tensor]: A map between layer names to their
43854381 corresponding memory buffer for KV cache.
@@ -4389,6 +4385,10 @@ def _reshape_kv_cache_tensors(
43894385 for group in self ._kv_cache_spec_attn_group_iterator ():
43904386 kv_cache_spec = group .kv_cache_spec
43914387 attn_backend = group .backend
4388+ if group .kv_cache_group_id == len (kernel_block_sizes ):
4389+ # There may be a last group for layers without kv cache.
4390+ continue
4391+ kernel_block_size = kernel_block_sizes [group .kv_cache_group_id ]
43924392 for layer_name in group .layer_names :
43934393 if layer_name in self .runner_only_attn_layers :
43944394 continue
@@ -4397,24 +4397,21 @@ def _reshape_kv_cache_tensors(
43974397 num_blocks = raw_tensor .numel () // kv_cache_spec .page_size_bytes
43984398 if isinstance (kv_cache_spec , AttentionSpec ):
43994399 has_attn = True
4400- kv_manager_block_size = kv_cache_spec .block_size
4401- kernel_size_list = self ._find_compatible_block_sizes (
4402- kv_manager_block_size , attn_backend , return_all = False
4400+ num_blocks_per_kv_block = (
4401+ kv_cache_spec .block_size // kernel_block_size
44034402 )
4404- kernel_size = kernel_size_list [0 ]
4405- num_blocks_per_kv_block = kv_manager_block_size // kernel_size
44064403 kernel_num_blocks = num_blocks * num_blocks_per_kv_block
44074404
44084405 kv_cache_shape = attn_backend .get_kv_cache_shape (
44094406 kernel_num_blocks ,
4410- kernel_size ,
4407+ kernel_block_size ,
44114408 kv_cache_spec .num_kv_heads ,
44124409 kv_cache_spec .head_size ,
44134410 cache_dtype_str = self .cache_config .cache_dtype ,
44144411 )
44154412 dtype = kv_cache_spec .dtype
44164413 try :
4417- kv_cache_stride_order = attn_backend .get_kv_cache_stride_order () # noqa: E501
4414+ kv_cache_stride_order = attn_backend .get_kv_cache_stride_order ()
44184415 assert len (kv_cache_stride_order ) == len (kv_cache_shape )
44194416 except (AttributeError , NotImplementedError ):
44204417 kv_cache_stride_order = tuple (range (len (kv_cache_shape )))
@@ -4497,13 +4494,15 @@ def _update_hybrid_attention_mamba_layout(
44974494 )
44984495
44994496 def initialize_kv_cache_tensors (
4500- self , kv_cache_config : KVCacheConfig
4497+ self , kv_cache_config : KVCacheConfig , kernel_block_sizes : list [ int ]
45014498 ) -> dict [str , torch .Tensor ]:
45024499 """
45034500 Initialize the memory buffer for KV cache.
45044501
45054502 Args:
45064503 kv_cache_config: The KV cache config
4504+ kernel_block_sizes: The kernel block sizes for each KV cache group.
4505+
45074506 Returns:
45084507 Dict[str, torch.Tensor]: A map between layer names to their
45094508 corresponding memory buffer for KV cache.
@@ -4512,7 +4511,7 @@ def initialize_kv_cache_tensors(
45124511 kv_cache_raw_tensors = self ._allocate_kv_cache_tensors (kv_cache_config )
45134512 # Change the memory buffer to the desired shape
45144513 kv_caches = self ._reshape_kv_cache_tensors (
4515- kv_cache_config , kv_cache_raw_tensors
4514+ kv_cache_config , kv_cache_raw_tensors , kernel_block_sizes
45164515 )
45174516
45184517 # Set up cross-layer KV cache sharing
@@ -4571,9 +4570,17 @@ def initialize_kv_cache(self, kv_cache_config: KVCacheConfig) -> None:
45714570 self .may_add_encoder_only_layers_to_kv_cache_config ()
45724571 self .maybe_add_kv_sharing_layers_to_kv_cache_groups (kv_cache_config )
45734572 self .initialize_attn_backend (kv_cache_config )
4573+ # The kernel block size for all KV cache groups. For example, if
4574+ # kv_cache_manager uses block_size 256 for a given group, but the attention
4575+ # backends for that group only supports block_size 64, we will return
4576+ # kernel_block_size 64 and split the 256-token-block to 4 blocks with 64
4577+ # tokens each.
4578+ kernel_block_sizes = self ._prepare_kernel_block_sizes (kv_cache_config )
45744579 # Reinitialize need to after initialize_attn_backend
4575- self .may_reinitialize_input_batch (kv_cache_config )
4576- kv_caches = self .initialize_kv_cache_tensors (kv_cache_config )
4580+ self .may_reinitialize_input_batch (kv_cache_config , kernel_block_sizes )
4581+ kv_caches = self .initialize_kv_cache_tensors (
4582+ kv_cache_config , kernel_block_sizes
4583+ )
45774584
45784585 if self .speculative_config and self .speculative_config .use_eagle ():
45794586 assert isinstance (self .drafter , EagleProposer )
0 commit comments