From 582bb7294715b0bd6d7779568b7508183a0c8ed9 Mon Sep 17 00:00:00 2001 From: ltd0924 Date: Thu, 21 Aug 2025 01:40:05 +0800 Subject: [PATCH] [Feature] PD support chunked prefill --- fastdeploy/cache_manager/cache_messager.py | 200 ++++++++++-------- .../cache_manager/cache_transfer_manager.py | 11 + .../transfer_factory/ipc_cache_transfer.py | 1 + fastdeploy/engine/engine.py | 19 +- fastdeploy/output/token_processor.py | 7 +- fastdeploy/splitwise/splitwise_connector.py | 5 +- fastdeploy/worker/gpu_model_runner.py | 4 - 7 files changed, 137 insertions(+), 110 deletions(-) diff --git a/fastdeploy/cache_manager/cache_messager.py b/fastdeploy/cache_manager/cache_messager.py index 409941f7d8..b19c71c317 100644 --- a/fastdeploy/cache_manager/cache_messager.py +++ b/fastdeploy/cache_manager/cache_messager.py @@ -23,7 +23,8 @@ import paddle from fastdeploy.cache_manager.transfer_factory import IPCCommManager, RDMACommManager -from fastdeploy.inter_communicator import EngineWorkerQueue, IPCSignal +from fastdeploy.inter_communicator import EngineWorkerQueue +from fastdeploy.model_executor.ops.gpu import get_output_kv_signal from fastdeploy.utils import get_logger logger = get_logger("cache_messager", "cache_messager.log") @@ -46,6 +47,7 @@ def __init__( nranks, num_layers, gpu_id=0, + block_size=64, rdma_port=None, ): """ @@ -82,6 +84,7 @@ def __init__( client_id=self.rank, local_data_parallel_id=local_data_parallel_id, ) + self.block_size = block_size transfer_protocol = transfer_protocol.split(",") logger.info(f"splitwise role: {splitwise_role}, {transfer_protocol}" f"rank: {rank}") @@ -143,105 +146,78 @@ def __init__( self.gpu_id = gpu_id self.cache_info = dict() - self.dp_rank_id = self.rank + local_data_parallel_id * self.nranks - - layerwise_send_cache_thread = threading.Thread(target=self._prefill_layerwise_send_cache_thread) - layerwise_send_cache_thread.daemon = True - layerwise_send_cache_thread.start() + self.rank_id = self.rank + local_data_parallel_id * self.nranks + self.cache_tasks_list = [] # 支持每个元素是一个列表 + self.engine_cache_task_thread_lock = threading.Lock() + self.engine_cache_tasks = [dict() for _ in range(512)] logger.info(f"cache messager init finished, use {transfer_protocol}") - def _prefill_layerwise_send_cache_thread(self): + def prefill_layerwise_send_cache_thread(self): """ layerwise_send_cache_thread: send cache to other instance """ - try: - prefilled_step_idx_data = np.zeros(shape=[1], dtype=np.int32) - prefilled_layer_idx_data = np.zeros(shape=[1], dtype=np.int32) + while True: try: - step_shm_value = IPCSignal( - name=f"splitwise_complete_prefilled_step_{self.dp_rank_id}", - array=prefilled_step_idx_data, - dtype=np.int32, - suffix=self.gpu_id, - create=True, - ) - layer_shm_value = IPCSignal( - name=f"splitwise_complete_prefilled_layer_{self.dp_rank_id}", - array=prefilled_layer_idx_data, - dtype=np.int32, - suffix=self.gpu_id, - create=True, - ) - except: - step_shm_value = IPCSignal( - name=f"splitwise_complete_prefilled_step_{self.dp_rank_id}", - array=prefilled_step_idx_data, - dtype=np.int32, - suffix=self.gpu_id, - create=False, - ) - layer_shm_value = IPCSignal( - name=f"splitwise_complete_prefilled_layer_{self.dp_rank_id}", - array=prefilled_layer_idx_data, - dtype=np.int32, - suffix=self.gpu_id, - create=False, - ) - - step_shm_value.value[0] = -1 - layer_shm_value.value[0] = -1 - - self.last_step_idx = -1 - self.last_layer_idx = -1 # int32 - - while True: - cache_info = self.engine_worker_queue.get_cache_info() - if cache_info: - logger.debug(f"cache info {cache_info}") for info in cache_info: if info["request_id"] in self.cache_info: self.cache_info[info["request_id"]].update(info) current_info = self.cache_info[info["request_id"]] if "dest_block_ids" in current_info and "src_block_ids" in current_info: + decode_cached_block = len(current_info["src_block_ids"]) - len( + current_info["dest_block_ids"] + ) current_src_blocks = current_info["src_block_ids"][ -len(current_info["dest_block_ids"]) : ] + current_info["send_finished_tokens"] = decode_cached_block * self.block_size + current_info["current_tokens"] = current_info["send_finished_tokens"] current_info["src_block_ids"] = current_src_blocks - current_info["current_layer_ids"] = 0 + current_info["current_layer_ids"] = -1 + current_info["current_block_num"] = 0 current_info["status"] = "init" - logger.info(f"start cache_infos: {current_info}") - self.cache_info[info["request_id"]] = current_info - self.last_step_idx = min(self.last_step_idx, current_info["current_id"]) + logger.info(f"current info: {current_info}") + self.cache_info[info["request_id"]] = current_info else: self.cache_info[info["request_id"]] = info - prefilled_layer_idx = layer_shm_value.value[0] - prefilled_step_idx = step_shm_value.value[0] - if prefilled_layer_idx == self.num_layers - 1: - time.sleep(0.001) - prefilled_layer_idx = layer_shm_value.value[0] - prefilled_step_idx = step_shm_value.value[0] - - if prefilled_step_idx == -1: - time.sleep(0.001) - continue if not self.cache_info: - time.sleep(0.001) + time.sleep(0.005) continue - logger.debug(f"prefilled_layer_idx: {prefilled_layer_idx}, prefilled_step_idx: {prefilled_step_idx}") + for req_id, item in list(self.cache_info.items()): if "status" not in item: continue - if "layer_idx" not in item: - item["layer_idx"] = 0 - if item["status"] == "error": - del self.cache_info[req_id] + if "prefilled_token_num" not in self.engine_cache_tasks[item["current_id"]]: continue - if item["current_id"] > prefilled_step_idx: + if ( + self.engine_cache_tasks[item["current_id"]]["prefilled_token_num"] + <= item["send_finished_tokens"] + ): + time.sleep(0.005) continue + if ( + self.engine_cache_tasks[item["current_id"]]["prefilled_token_num"] == item["current_tokens"] + and self.engine_cache_tasks[item["current_id"]]["prefilled_layer_idx"] + == item["current_layer_ids"] + ): + continue + + prefill_layer = self.engine_cache_tasks[item["current_id"]]["prefilled_layer_idx"] + prefill_tokens = self.engine_cache_tasks[item["current_id"]]["prefilled_token_num"] + if prefill_tokens > item["current_tokens"] and item["current_block_num"] != 0: + prefill_tokens = item["current_tokens"] + prefill_layer = self.num_layers - 1 + current_block_num = item["current_block_num"] + else: + current_block_num = (prefill_tokens - item["send_finished_tokens"]) // self.block_size + if prefill_tokens == item["total_tokens"]: + current_block_num = len(item["src_block_ids"]) + + item["current_block_num"] = current_block_num + current_transfer_protocol = item["transfer_protocol"] if item["transfer_protocol"] == "rdma": target_ip = item["ip"] @@ -257,14 +233,20 @@ def _prefill_layerwise_send_cache_thread(self): elif item["transfer_protocol"] == "ipc": target_ip = "0.0.0.0" target_id = int(item["device_ids"][self.rank]) - src_block_ids = paddle.to_tensor(item["src_block_ids"], dtype="int32", place="cpu") - dest_block_ids = paddle.to_tensor(item["dest_block_ids"], dtype="int32", place="cpu") - if item["current_id"] < prefilled_step_idx: - current_layer_idx = self.num_layers - else: - current_layer_idx = prefilled_layer_idx + 1 - for layer_idx in range(item["layer_idx"], current_layer_idx): + src_block_ids = item["src_block_ids"][:current_block_num] + dest_block_ids = item["dest_block_ids"][:current_block_num] + src_block_ids = paddle.to_tensor(src_block_ids, dtype="int32", place="cpu") + dest_block_ids = paddle.to_tensor(dest_block_ids, dtype="int32", place="cpu") + logger.debug( + f"src_block_ids: {src_block_ids.shape}, dest_block_ids: {dest_block_ids.shape}" + f"req_id: {item['request_id']}, current_tokens: {item['current_tokens']}, prefill tokens {prefill_tokens}" + f"send_finished_tokens: {item['send_finished_tokens']}, " + f"current_layer_ids: {item['current_layer_ids']}, " + f"prefilled_layer_idx: {prefill_layer}" + ) + + for layer_idx in range(item["current_layer_ids"] + 1, prefill_layer + 1): tic = time.time() return_code = self.messager[current_transfer_protocol].write_cache( target_ip, @@ -282,8 +264,10 @@ def _prefill_layerwise_send_cache_thread(self): f"write cache failed, layer_idx: {layer_idx}, " f"req_id: {item['request_id']}, dest_ip: {target_ip}" ) - break + self.engine_cache_tasks[item["current_id"]] = dict() + del self.cache_info[req_id] + break tok = time.time() cost_time = tok - tic block_num = len(src_block_ids) @@ -295,19 +279,53 @@ def _prefill_layerwise_send_cache_thread(self): f"block_num: {block_num}, send_cache_speed(GB/s): {round(send_cache_speed, 5)}," f"avg_time per block(ms): {round(avg_time_per_block, 5)}" ) - item["layer_idx"] = current_layer_idx - if item["layer_idx"] == self.num_layers: + item["current_layer_ids"] = layer_idx + item["current_tokens"] = prefill_tokens + if item["current_layer_ids"] == self.num_layers - 1: if item["transfer_protocol"] == "ipc": self.messager["ipc"].write_block_by_sync(target_id) - logger.info(f"finish write cache {item['request_id']}") - self.engine_worker_queue.finish_request_barrier.wait() - if self.rank == 0: - self.engine_worker_queue.put_finished_req([(item["request_id"], "finished")]) - logger.info(f"put write cache {item['request_id']}") - del self.cache_info[req_id] + if prefill_tokens == item["total_tokens"]: + logger.info(f"finish write cache {item['request_id']}") + self.engine_worker_queue.finish_request_barrier.wait() + if self.rank == 0: + self.engine_worker_queue.put_finished_req([(item["request_id"], "finished")]) + logger.info(f"put write cache {item['request_id']}") + self.engine_cache_tasks[item["current_id"]] = dict() + del self.cache_info[req_id] - self.last_step_idx = prefilled_step_idx - self.last_layer_idx = prefilled_layer_idx + else: + item["current_layer_ids"] = -1 + item["src_block_ids"] = item["src_block_ids"][current_block_num:] + item["dest_block_ids"] = item["dest_block_ids"][current_block_num:] + item["send_finished_tokens"] = prefill_tokens + item["current_block_num"] = 0 + + except Exception as e: + logger.error(f"prefill layerwise send cache thread has exception: {e} {traceback.format_exc()!s}") + time.sleep(0.01) + + def consume_signals(self): + paddle.device.set_device("cpu") + kv_signal_data = paddle.full(shape=[512 * 3 + 2], fill_value=-1, dtype="int32") + while True: + try: + get_output_kv_signal(kv_signal_data, self.rank_id, 0) # wait_flag + if not self.cache_info: + time.sleep(0.01) + continue + tasks_count = kv_signal_data[0] + if tasks_count == -1: + time.sleep(0.001) + continue + layer_id = kv_signal_data[1].numpy().tolist() + if layer_id == self.num_layers - 1: + logger.info(f"tasks_count: {tasks_count}, layer_id: {layer_id}") - except Exception as e: - logger.error(f"prefill layerwise send cache thread has exception: {e}, {str(traceback.format_exc())}") + for bi in range(tasks_count): + engine_idx = kv_signal_data[3 * bi + 2].numpy().tolist() + chuck_token_offset = kv_signal_data[3 * bi + 3].numpy().tolist() + current_seq_len = kv_signal_data[3 * bi + 4].numpy().tolist() + self.engine_cache_tasks[engine_idx]["prefilled_layer_idx"] = layer_id + self.engine_cache_tasks[engine_idx]["prefilled_token_num"] = chuck_token_offset + current_seq_len + except Exception as e: + logger.error(f"Consume signals get exception: {e}") diff --git a/fastdeploy/cache_manager/cache_transfer_manager.py b/fastdeploy/cache_manager/cache_transfer_manager.py index 5078a513dd..e920888243 100644 --- a/fastdeploy/cache_manager/cache_transfer_manager.py +++ b/fastdeploy/cache_manager/cache_transfer_manager.py @@ -18,6 +18,7 @@ import concurrent.futures import json import queue +import threading import time import traceback @@ -217,9 +218,19 @@ def __init__(self, args): rank=self.rank, nranks=args.mp_num, num_layers=args.num_layers + self.num_extra_layers, + block_size=args.block_size, gpu_id=self.device, rdma_port=args.rdma_port, ) + + if args.splitwise_role == "prefill": + self.send_signal_thread = threading.Thread(target=self.cache_messager.consume_signals, daemon=True) + self.send_signal_thread.start() + self.layerwise_send_cache_thread = threading.Thread( + target=self.cache_messager.prefill_layerwise_send_cache_thread, daemon=True + ) + self.layerwise_send_cache_thread.start() + logger.info("successfully create cache messager") logger.info(f"done init CacheMessager gmem alloc : {paddle.device.cuda.memory_allocated()}") diff --git a/fastdeploy/cache_manager/transfer_factory/ipc_cache_transfer.py b/fastdeploy/cache_manager/transfer_factory/ipc_cache_transfer.py index 61a4fa10b0..d97ea88fb1 100644 --- a/fastdeploy/cache_manager/transfer_factory/ipc_cache_transfer.py +++ b/fastdeploy/cache_manager/transfer_factory/ipc_cache_transfer.py @@ -45,6 +45,7 @@ def __init__(self, rank_id_, remote_gpu_id_, layer_num, local_gpu_id_): self.local_gpu_id = int(local_gpu_id_) tmp = paddle.ones([1, 1]) logger.info(f"init ipc rank{self.rank_id} with remote {self.remote_gpu_id} {self.local_gpu_id}") + paddle.set_device(f"gpu:{self.local_gpu_id}") for layer_id in range(layer_num): key_unique_name = f"key_caches_{layer_id}_rank{self.rank_id}.device{self.remote_gpu_id}" value_unique_name = f"value_caches_{layer_id}_rank{self.rank_id}.device{self.remote_gpu_id}" diff --git a/fastdeploy/engine/engine.py b/fastdeploy/engine/engine.py index 3494186fa4..c3038aa945 100644 --- a/fastdeploy/engine/engine.py +++ b/fastdeploy/engine/engine.py @@ -145,8 +145,8 @@ def __init__(self, cfg): else: self.do_profile = 0 - self.partial_chunked_tokens = [0] * (self.cfg.max_num_partial_prefills + 1) - for idx in range(1, self.cfg.max_num_partial_prefills + 1): + self.partial_chunked_tokens = [0] * (self.cfg.max_prefill_batch + 1) + for idx in range(1, self.cfg.max_prefill_batch + 1): self.partial_chunked_tokens[idx] = ( (self.cfg.max_num_batched_tokens // idx) // self.cfg.cache_config.block_size @@ -654,6 +654,7 @@ def update_tokens(idx, chunk_size, update_chunk=False): for idx in range(len(requests)): requests[idx].set("prefill_chunk_info", requests_chunk[idx]) + requests[idx].min_tokens = len(requests_chunk[idx]) + 1 def update_mm_requests_chunk_size(self, requests): """ @@ -804,13 +805,15 @@ def insert_tasks(self, tasks, current_id=-1, allocated=False): self.split_connector.send_cache_infos(tasks, current_id) if not is_decode: llm_logger.info(f"Tasks are sent to engine, req_ids={req_ids}") + + if not self.cfg.enable_mm: + self.update_requests_chunk_size(tasks) + else: + self.update_mm_requests_chunk_size(tasks) for task in tasks: task.inference_start_time = time.time() - if not is_prefill: - if not self.cfg.enable_mm: - self.update_requests_chunk_size(tasks) - else: - self.update_mm_requests_chunk_size(tasks) + if is_prefill: + task.max_tokens = task.min_tokens self.engine_worker_queue.put_tasks((tasks, self.resource_manager.real_bsz)) if is_prefill and self.cfg.scheduler_config.name != "splitwise": self.engine_worker_queue.available_prefill_instances.put(1) @@ -1044,7 +1047,7 @@ def _setting_environ_variables(self): ) if self.cfg.splitwise_role != "mixed": - variables["FLAGS_use_pd_disaggregation"] = 1 + variables["FLAGS_use_pd_disaggregation_per_chunk"] = 1 # TODO dynamic load environment variable if self.cfg.splitwise_role == "prefill": variables["FLAGS_fmt_write_cache_completed_signal"] = 1 diff --git a/fastdeploy/output/token_processor.py b/fastdeploy/output/token_processor.py index 36ab0c362b..675dc8e004 100644 --- a/fastdeploy/output/token_processor.py +++ b/fastdeploy/output/token_processor.py @@ -229,14 +229,13 @@ def _recycle_resources(self, task_id, index, task, result=None, is_prefill=False llm_logger.info(f"finished_task_id: {finished_task_id}") self.prefill_result_status[finished_task_id[0]] = finished_task_id[1] if task_id in self.prefill_result_status: + if self.prefill_result_status[task_id] != "finished": + result.error_code = 400 + result.error_message = f"{task_id} failed to {self.prefill_result_status[task_id]}" self.split_connector.send_first_token(task.disaggregate_info, [result]) self.resource_manager.stop_flags[index] = True self.resource_manager.tasks_list[index] = None self.resource_manager._recycle_block_tables(task) - if self.prefill_result_status[task_id] != "finished": - result.error_code = 400 - result.error_message = f"{task_id} failed to {self.prefill_result_status[task_id]}" - del self.resource_manager.req_dict[task_id] break else: time.sleep(0.002) diff --git a/fastdeploy/splitwise/splitwise_connector.py b/fastdeploy/splitwise/splitwise_connector.py index 8924c00f56..dd2c004fba 100644 --- a/fastdeploy/splitwise/splitwise_connector.py +++ b/fastdeploy/splitwise/splitwise_connector.py @@ -373,12 +373,11 @@ def send_cache_infos(self, tasks, current_id): else: addr = "prefill" - if current_id == -1: - current_id = tasks[i].disaggregate_info["cache_info"]["ipc"]["current_id"] cache_info = { "request_id": tasks[i].request_id, "src_block_ids": tasks[i].block_tables, - "current_id": current_id, + "current_id": tasks[i].idx, + "total_tokens": tasks[i].prompt_token_ids_len + tasks[i].get("seq_lens_decoder", 0), } if addr not in temp_cache_info: temp_cache_info[addr] = [] diff --git a/fastdeploy/worker/gpu_model_runner.py b/fastdeploy/worker/gpu_model_runner.py index af567cba1e..a62aea096a 100644 --- a/fastdeploy/worker/gpu_model_runner.py +++ b/fastdeploy/worker/gpu_model_runner.py @@ -360,10 +360,6 @@ def insert_prefill_inputs(self, req_dicts: List[Request], num_running_requests: TODO(gongshaotian): Refactor this func """ - # NOTE(luotingdan): Set environment variable of prefill node - if req_dicts[-1].disaggregate_info is not None and req_dicts[-1].disaggregate_info["role"] == "prefill": - os.environ["PREFILL_NODE_ONE_STEP_STOP"] = "1" - req_len = len(req_dicts) for i in range(req_len): request = req_dicts[i]