2626from fastdeploy .config import FDConfig
2727from fastdeploy .engine .request import Request , RequestType
2828from fastdeploy .input .ernie4_5_vl_processor import DataProcessor
29+ from fastdeploy .inter_communicator import IPCSignal
2930from fastdeploy .model_executor .forward_meta import ForwardMeta , XPUForwardMeta
3031from fastdeploy .model_executor .graph_optimization .utils import (
3132 profile_run_guard ,
4546 get_infer_param ,
4647 get_padding_offset ,
4748 recover_decode_task ,
49+ set_data_ipc ,
50+ share_external_data ,
4851 update_inputs_v1 ,
4952)
5053from fastdeploy .utils import get_logger
@@ -335,11 +338,19 @@ def step_paddle(
335338class XPUModelRunner (ModelRunnerBase ):
336339 """ """
337340
338- def __init__ (self , fd_config : FDConfig , device : str , rank : int , local_rank : int ):
341+ def __init__ (
342+ self ,
343+ fd_config : FDConfig ,
344+ device : str , # logic device
345+ device_id : int , # physical device id
346+ rank : int ,
347+ local_rank : int ,
348+ ):
339349 super ().__init__ (fd_config = fd_config , device = device )
340350 self .enable_mm = self .model_config .enable_mm
341351 self .rank = rank
342352 self .local_rank = local_rank
353+ self .device_id = device_id
343354 self .enable_early_stop = self .fd_config .early_stop_config .enable_early_stop
344355
345356 # VL model config:
@@ -895,11 +906,11 @@ def initialize_attention_backend(self):
895906 for attn_backend in self .attn_backends :
896907 attn_backend .init_attention_metadata (self .forward_meta )
897908
898- def initialize_kv_cache (self ) -> None :
909+ def initialize_kv_cache (self , profile : bool = False ) -> None :
899910 """
900911 Initialize kv cache
901912 """
902- cache_kvs = {}
913+ # cache_kvs = {}
903914 max_block_num = self .num_gpu_blocks
904915
905916 # Get kv cache dtype
@@ -914,21 +925,56 @@ def initialize_kv_cache(self) -> None:
914925
915926 # Get kv cache shape
916927 kv_cache_shape = self .attn_backends [0 ].get_kv_cache_shape (max_num_blocks = max_block_num )
928+ local_rank = self .local_rank % self .parallel_config .tensor_parallel_size
929+
930+ cache_ready_signal_data = np .zeros (shape = [self .parallel_config .tensor_parallel_size ], dtype = np .int32 )
931+ cache_ready_signal = IPCSignal (
932+ name = "cache_ready_signal" ,
933+ array = cache_ready_signal_data ,
934+ dtype = np .int32 ,
935+ suffix = self .parallel_config .engine_worker_queue_port ,
936+ create = False ,
937+ )
938+
939+ # Check if gpu runner needs to create kv cache
940+ # 1. During profiling, it creates its own kv cache.
941+ # 2. GPU runner creates kv cache tensor unless p/d disaggregation is enabled.
942+ create_cache_tensor = profile or self .scheduler_config .splitwise_role == "mixed"
943+ if not create_cache_tensor :
944+ logger .info (f"Waiting for cache managers to create kv cache.. { cache_ready_signal .value } " )
945+ while cache_ready_signal .value [local_rank ] != 1 :
946+ time .sleep (1 )
947+ logger .info (f"OK! Stop waiting. { cache_ready_signal .value } " )
948+
949+ logger .info (f"Initializing kv cache for all layers. { cache_ready_signal .value } " )
950+ cache_kvs_list = []
917951
918952 for i in range (self .model_config .num_hidden_layers ):
919- cache_kvs [f"key_caches_{ i } " ] = paddle .full (
920- shape = kv_cache_shape ,
921- fill_value = 0 ,
922- dtype = cache_type ,
923- )
924- cache_kvs [f"value_caches_{ i } " ] = paddle .full (
925- shape = kv_cache_shape ,
926- fill_value = 0 ,
927- dtype = cache_type ,
928- )
929- self .share_inputs ["caches" ] = list (cache_kvs .values ())
930- for value in cache_kvs .values ():
931- del value
953+ key_cache_name = f"key_caches_{ i } _rank{ local_rank } .device{ self .device_id } "
954+ val_cache_name = f"value_caches_{ i } _rank{ local_rank } .device{ self .device_id } "
955+
956+ if create_cache_tensor :
957+ logger .info (f"..creating kv cache for layer { i } : { kv_cache_shape } " )
958+ key_cache = paddle .full (shape = kv_cache_shape , fill_value = 0 , dtype = cache_type )
959+ set_data_ipc (key_cache , key_cache_name )
960+ val_cache = paddle .full (shape = kv_cache_shape , fill_value = 0 , dtype = cache_type )
961+ set_data_ipc (val_cache , val_cache_name )
962+ cache_kvs_list .extend ([key_cache , val_cache ])
963+
964+ else :
965+ logger .info (f"..attaching kv cache for layer { i } : { kv_cache_shape } " )
966+ key_cache = paddle .empty (shape = [], dtype = cache_type )
967+ key_cache = share_external_data (key_cache , key_cache_name , kv_cache_shape , False )
968+ val_cache = paddle .empty (shape = [], dtype = cache_type )
969+ val_cache = share_external_data (val_cache , val_cache_name , kv_cache_shape , False )
970+ cache_kvs_list .extend ([key_cache , val_cache ])
971+
972+ self .share_inputs ["caches" ] = cache_kvs_list
973+
974+ if not profile and create_cache_tensor :
975+ cache_ready_signal .value [local_rank ] = 1
976+ logger .info (f"✅ kv cache is ready! { cache_ready_signal .value } " )
977+
932978 paddle .device .xpu .empty_cache ()
933979
934980 def initialize_attn_backend (self ) -> None :
@@ -1138,18 +1184,12 @@ class at the server level, which is too granular for ModelRunner.
11381184
11391185 return None
11401186
1141- def prepare_profile (self ) -> None :
1142- """Prepare the profile run by setting the block number and initializing the KV cache."""
1143- paddle .device .xpu .empty_cache ()
1144- self .num_gpu_blocks = self .parallel_config .total_block_num
1145- self .initialize_kv_cache ()
1146-
11471187 @profile_run_guard (True )
11481188 def profile_run (self ) -> None :
11491189 """Execute a forward pass with dummy inputs to profile the memory usage of the model"""
11501190
11511191 self .num_gpu_blocks = self .parallel_config .total_block_num
1152- self .initialize_kv_cache ()
1192+ self .initialize_kv_cache (profile = True )
11531193
11541194 self ._dummy_run (
11551195 num_tokens = int (self .scheduler_config .max_num_batched_tokens ),
0 commit comments