Skip to content

Commit 8e392f0

Browse files
authored
[XPU] support prefix cache (#4423)
Co-authored-by: ddchenhao66 <dhaochen163.com>
1 parent 5bde20b commit 8e392f0

File tree

4 files changed

+112
-45
lines changed

4 files changed

+112
-45
lines changed

fastdeploy/cache_manager/cache_transfer_manager.py

Lines changed: 44 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -30,14 +30,25 @@
3030
from fastdeploy.cache_manager.cache_data import CacheStatus
3131
from fastdeploy.config import SpeculativeConfig
3232
from fastdeploy.inter_communicator import EngineCacheQueue, IPCSignal, KVCacheStatus
33-
from fastdeploy.model_executor.ops.gpu import (
34-
cuda_host_alloc,
35-
cuda_host_free,
36-
set_data_ipc,
37-
share_external_data,
38-
swap_cache_all_layers,
39-
unset_data_ipc,
40-
)
33+
from fastdeploy.platforms import current_platform
34+
35+
if current_platform.is_cuda():
36+
from fastdeploy.model_executor.ops.gpu import (
37+
cuda_host_alloc,
38+
cuda_host_free,
39+
set_data_ipc,
40+
share_external_data,
41+
swap_cache_all_layers,
42+
unset_data_ipc,
43+
)
44+
elif current_platform.is_xpu():
45+
from fastdeploy.model_executor.ops.xpu import (
46+
cuda_host_alloc,
47+
cuda_host_free,
48+
set_data_ipc,
49+
share_external_data,
50+
swap_cache_all_layers,
51+
)
4152
from fastdeploy.utils import get_logger
4253

4354

@@ -114,7 +125,6 @@ def __init__(self, args):
114125
"""
115126
初始化CacheTransferManager
116127
"""
117-
118128
device = args.device_id
119129
rank = args.rank
120130
self.gpu_cache_kvs = {}
@@ -173,8 +183,9 @@ def __init__(self, args):
173183
suffix=args.engine_pid,
174184
create=False,
175185
)
176-
177-
threading.Thread(target=self.clear_or_update_caches, args=[args], daemon=True).start()
186+
# TODO XPU support RL
187+
if not current_platform.is_xpu():
188+
threading.Thread(target=self.clear_or_update_caches, args=[args], daemon=True).start()
178189

179190
def _init_gpu_cache(self, args):
180191

@@ -185,7 +196,10 @@ def _init_gpu_cache(self, args):
185196
logger.info(f"[rank {self.rank}/{self.n_ranks}] OK! Stop waiting.")
186197

187198
logger.info(f"[rank {self.rank}/{self.n_ranks}] Initializing kv cache for all layers.")
188-
paddle.set_device(f"gpu:{self.device}")
199+
if current_platform.is_cuda():
200+
paddle.set_device(f"gpu:{self.device}")
201+
elif current_platform.is_xpu():
202+
paddle.set_device(f"xpu:{self.device}")
189203
for i in range(args.num_layers + self.num_extra_layers):
190204
num_gpu_blocks = args.num_gpu_blocks if i < args.num_layers else self.num_extra_layer_gpu_blocks
191205
cache_shape = [num_gpu_blocks, args.kv_num_head, args.block_size, args.head_dim]
@@ -202,8 +216,12 @@ def _init_gpu_cache(self, args):
202216
logger.info(f"[rank {self.rank}/{self.n_ranks}] ..attaching kv cache for layer {i}: {cache_shape}")
203217
key_cache = paddle.empty(shape=[], dtype=args.cache_dtype)
204218
val_cache = paddle.empty(shape=[], dtype=args.cache_dtype)
205-
key_cache = share_external_data(key_cache, key_name, cache_shape)
206-
val_cache = share_external_data(val_cache, val_name, cache_shape)
219+
if current_platform.is_xpu():
220+
key_cache = share_external_data(key_cache, key_name, cache_shape, True)
221+
val_cache = share_external_data(val_cache, val_name, cache_shape, True)
222+
else:
223+
key_cache = share_external_data(key_cache, key_name, cache_shape)
224+
val_cache = share_external_data(val_cache, val_name, cache_shape)
207225

208226
self.gpu_cache_kvs[key_name] = key_cache
209227
self.gpu_cache_kvs[val_name] = val_cache
@@ -217,9 +235,10 @@ def _init_gpu_cache(self, args):
217235
cache_kv_size_byte = sum([tmp.numel() * 1 for key, tmp in self.gpu_cache_kvs.items()])
218236
logger.info(f"[rank {self.rank}/{self.n_ranks}] device :{self.device}")
219237
logger.info(f"[rank {self.rank}/{self.n_ranks}] cache_kv_size_byte : {cache_kv_size_byte}")
220-
logger.info(
221-
f"[rank {self.rank}/{self.n_ranks}] done init cache (full) gmem alloc : {paddle.device.cuda.memory_allocated()}"
222-
)
238+
if current_platform.is_cuda():
239+
logger.info(
240+
f"[rank {self.rank}/{self.n_ranks}] done init cache (full) gmem alloc : {paddle.device.cuda.memory_allocated()}"
241+
)
223242

224243
def _init_cpu_cache(self, args):
225244
if args.num_cpu_blocks == 0:
@@ -473,7 +492,10 @@ def clear_or_update_caches(self, args):
473492
time.sleep(0.1)
474493

475494
# clear gpu caches
476-
paddle.set_device(f"gpu:{self.device}")
495+
if current_platform.is_cuda():
496+
paddle.set_device(f"gpu:{self.device}")
497+
elif current_platform.is_xpu():
498+
paddle.set_device(f"xpu:{self.device}")
477499
for name, tensor in self.gpu_cache_kvs.items():
478500
unset_data_ipc(tensor, name, True, False)
479501
self.gpu_cache_kvs.clear()
@@ -543,5 +565,8 @@ def main():
543565
args = parse_args()
544566
rank_id = args.rank + args.local_data_parallel_id * args.mp_num
545567
logger = get_logger("cache_transfer_manager", f"cache_transfer_manager_rank{rank_id}.log")
546-
paddle.set_device(f"gpu:{args.device_id}")
568+
if current_platform.is_cuda():
569+
paddle.set_device(f"gpu:{args.device_id}")
570+
elif current_platform.is_xpu():
571+
paddle.set_device(f"xpu:{args.device_id}")
547572
main()

fastdeploy/engine/args_utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -410,7 +410,7 @@ def __post_init__(self):
410410
self.enable_prefix_caching = False
411411
if self.speculative_config is not None:
412412
self.enable_prefix_caching = False
413-
if not current_platform.is_cuda():
413+
if not current_platform.is_cuda() and not current_platform.is_xpu():
414414
self.enable_prefix_caching = False
415415
# if self.dynamic_load_weight:
416416
# self.enable_prefix_caching = False

fastdeploy/worker/xpu_model_runner.py

Lines changed: 63 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@
2626
from fastdeploy.config import FDConfig
2727
from fastdeploy.engine.request import Request, RequestType
2828
from fastdeploy.input.ernie4_5_vl_processor import DataProcessor
29+
from fastdeploy.inter_communicator import IPCSignal
2930
from fastdeploy.model_executor.forward_meta import ForwardMeta, XPUForwardMeta
3031
from fastdeploy.model_executor.graph_optimization.utils import (
3132
profile_run_guard,
@@ -45,6 +46,8 @@
4546
get_infer_param,
4647
get_padding_offset,
4748
recover_decode_task,
49+
set_data_ipc,
50+
share_external_data,
4851
update_inputs_v1,
4952
)
5053
from fastdeploy.utils import get_logger
@@ -335,11 +338,19 @@ def step_paddle(
335338
class XPUModelRunner(ModelRunnerBase):
336339
""" """
337340

338-
def __init__(self, fd_config: FDConfig, device: str, rank: int, local_rank: int):
341+
def __init__(
342+
self,
343+
fd_config: FDConfig,
344+
device: str, # logic device
345+
device_id: int, # physical device id
346+
rank: int,
347+
local_rank: int,
348+
):
339349
super().__init__(fd_config=fd_config, device=device)
340350
self.enable_mm = self.model_config.enable_mm
341351
self.rank = rank
342352
self.local_rank = local_rank
353+
self.device_id = device_id
343354
self.enable_early_stop = self.fd_config.early_stop_config.enable_early_stop
344355

345356
# VL model config:
@@ -895,11 +906,11 @@ def initialize_attention_backend(self):
895906
for attn_backend in self.attn_backends:
896907
attn_backend.init_attention_metadata(self.forward_meta)
897908

898-
def initialize_kv_cache(self) -> None:
909+
def initialize_kv_cache(self, profile: bool = False) -> None:
899910
"""
900911
Initialize kv cache
901912
"""
902-
cache_kvs = {}
913+
# cache_kvs = {}
903914
max_block_num = self.num_gpu_blocks
904915

905916
# Get kv cache dtype
@@ -914,21 +925,56 @@ def initialize_kv_cache(self) -> None:
914925

915926
# Get kv cache shape
916927
kv_cache_shape = self.attn_backends[0].get_kv_cache_shape(max_num_blocks=max_block_num)
928+
local_rank = self.local_rank % self.parallel_config.tensor_parallel_size
929+
930+
cache_ready_signal_data = np.zeros(shape=[self.parallel_config.tensor_parallel_size], dtype=np.int32)
931+
cache_ready_signal = IPCSignal(
932+
name="cache_ready_signal",
933+
array=cache_ready_signal_data,
934+
dtype=np.int32,
935+
suffix=self.parallel_config.engine_worker_queue_port,
936+
create=False,
937+
)
938+
939+
# Check if gpu runner needs to create kv cache
940+
# 1. During profiling, it creates its own kv cache.
941+
# 2. GPU runner creates kv cache tensor unless p/d disaggregation is enabled.
942+
create_cache_tensor = profile or self.scheduler_config.splitwise_role == "mixed"
943+
if not create_cache_tensor:
944+
logger.info(f"Waiting for cache managers to create kv cache.. {cache_ready_signal.value}")
945+
while cache_ready_signal.value[local_rank] != 1:
946+
time.sleep(1)
947+
logger.info(f"OK! Stop waiting. {cache_ready_signal.value}")
948+
949+
logger.info(f"Initializing kv cache for all layers. {cache_ready_signal.value}")
950+
cache_kvs_list = []
917951

918952
for i in range(self.model_config.num_hidden_layers):
919-
cache_kvs[f"key_caches_{i}"] = paddle.full(
920-
shape=kv_cache_shape,
921-
fill_value=0,
922-
dtype=cache_type,
923-
)
924-
cache_kvs[f"value_caches_{i}"] = paddle.full(
925-
shape=kv_cache_shape,
926-
fill_value=0,
927-
dtype=cache_type,
928-
)
929-
self.share_inputs["caches"] = list(cache_kvs.values())
930-
for value in cache_kvs.values():
931-
del value
953+
key_cache_name = f"key_caches_{i}_rank{local_rank}.device{self.device_id}"
954+
val_cache_name = f"value_caches_{i}_rank{local_rank}.device{self.device_id}"
955+
956+
if create_cache_tensor:
957+
logger.info(f"..creating kv cache for layer {i}: {kv_cache_shape}")
958+
key_cache = paddle.full(shape=kv_cache_shape, fill_value=0, dtype=cache_type)
959+
set_data_ipc(key_cache, key_cache_name)
960+
val_cache = paddle.full(shape=kv_cache_shape, fill_value=0, dtype=cache_type)
961+
set_data_ipc(val_cache, val_cache_name)
962+
cache_kvs_list.extend([key_cache, val_cache])
963+
964+
else:
965+
logger.info(f"..attaching kv cache for layer {i}: {kv_cache_shape}")
966+
key_cache = paddle.empty(shape=[], dtype=cache_type)
967+
key_cache = share_external_data(key_cache, key_cache_name, kv_cache_shape, False)
968+
val_cache = paddle.empty(shape=[], dtype=cache_type)
969+
val_cache = share_external_data(val_cache, val_cache_name, kv_cache_shape, False)
970+
cache_kvs_list.extend([key_cache, val_cache])
971+
972+
self.share_inputs["caches"] = cache_kvs_list
973+
974+
if not profile and create_cache_tensor:
975+
cache_ready_signal.value[local_rank] = 1
976+
logger.info(f"✅ kv cache is ready! {cache_ready_signal.value}")
977+
932978
paddle.device.xpu.empty_cache()
933979

934980
def initialize_attn_backend(self) -> None:
@@ -1138,18 +1184,12 @@ class at the server level, which is too granular for ModelRunner.
11381184

11391185
return None
11401186

1141-
def prepare_profile(self) -> None:
1142-
"""Prepare the profile run by setting the block number and initializing the KV cache."""
1143-
paddle.device.xpu.empty_cache()
1144-
self.num_gpu_blocks = self.parallel_config.total_block_num
1145-
self.initialize_kv_cache()
1146-
11471187
@profile_run_guard(True)
11481188
def profile_run(self) -> None:
11491189
"""Execute a forward pass with dummy inputs to profile the memory usage of the model"""
11501190

11511191
self.num_gpu_blocks = self.parallel_config.total_block_num
1152-
self.initialize_kv_cache()
1192+
self.initialize_kv_cache(profile=True)
11531193

11541194
self._dummy_run(
11551195
num_tokens=int(self.scheduler_config.max_num_batched_tokens),

fastdeploy/worker/xpu_worker.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323
from fastdeploy import envs
2424
from fastdeploy.config import FDConfig
2525
from fastdeploy.engine.request import Request
26+
from fastdeploy.platforms import current_platform
2627
from fastdeploy.utils import get_logger, set_random_seed
2728
from fastdeploy.worker.output import ModelRunnerOutput
2829
from fastdeploy.worker.worker_base import WorkerBase
@@ -49,10 +50,11 @@ def __init__(
4950

5051
def init_device(self):
5152
"""Initialize device and Construct model runner"""
53+
self.max_chips_per_node = 16 if current_platform.is_iluvatar() else 8
5254
if paddle.is_compiled_with_xpu():
5355
# Set environment variable
5456
self.device_ids = self.parallel_config.device_ids.split(",")
55-
self.device = f"xpu:{self.local_rank}"
57+
self.device = f"xpu:{self.local_rank % self.max_chips_per_node}"
5658
paddle.device.set_device(self.device)
5759
paddle.set_default_dtype(self.parallel_config.dtype)
5860

@@ -67,6 +69,7 @@ def init_device(self):
6769
fd_config=self.fd_config,
6870
device=self.device,
6971
rank=self.rank,
72+
device_id=int(self.device_ids[self.local_rank % self.max_chips_per_node]),
7073
local_rank=self.local_rank,
7174
)
7275

@@ -109,7 +112,6 @@ def determine_available_memory(self) -> int:
109112
used_memory: {used_memory}, free_memory: {free_memory}"
110113
)
111114

112-
self.model_runner.prepare_profile()
113115
if self.parallel_config.use_ep:
114116
logger.warning("EP mode does not support profile run.")
115117
else:

0 commit comments

Comments
 (0)