Skip to content

[Feature] PD support chunked prefill #3503

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 3 commits into
base: develop
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
200 changes: 109 additions & 91 deletions fastdeploy/cache_manager/cache_messager.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,8 @@
import paddle

from fastdeploy.cache_manager.transfer_factory import IPCCommManager, RDMACommManager
from fastdeploy.inter_communicator import EngineWorkerQueue, IPCSignal
from fastdeploy.inter_communicator import EngineWorkerQueue
from fastdeploy.model_executor.ops.gpu import get_output_kv_signal
from fastdeploy.utils import get_logger

logger = get_logger("cache_messager", "cache_messager.log")
Expand All @@ -46,6 +47,7 @@ def __init__(
nranks,
num_layers,
gpu_id=0,
block_size=64,
rdma_port=None,
):
"""
Expand Down Expand Up @@ -82,6 +84,7 @@ def __init__(
client_id=self.rank,
local_data_parallel_id=local_data_parallel_id,
)
self.block_size = block_size
transfer_protocol = transfer_protocol.split(",")

logger.info(f"splitwise role: {splitwise_role}, {transfer_protocol}" f"rank: {rank}")
Expand Down Expand Up @@ -143,105 +146,78 @@ def __init__(

self.gpu_id = gpu_id
self.cache_info = dict()
self.dp_rank_id = self.rank + local_data_parallel_id * self.nranks

layerwise_send_cache_thread = threading.Thread(target=self._prefill_layerwise_send_cache_thread)
layerwise_send_cache_thread.daemon = True
layerwise_send_cache_thread.start()
self.rank_id = self.rank + local_data_parallel_id * self.nranks
self.cache_tasks_list = [] # 支持每个元素是一个列表
self.engine_cache_task_thread_lock = threading.Lock()
self.engine_cache_tasks = [dict() for _ in range(512)]

logger.info(f"cache messager init finished, use {transfer_protocol}")

def _prefill_layerwise_send_cache_thread(self):
def prefill_layerwise_send_cache_thread(self):
"""
layerwise_send_cache_thread:
send cache to other instance
"""
try:
prefilled_step_idx_data = np.zeros(shape=[1], dtype=np.int32)
prefilled_layer_idx_data = np.zeros(shape=[1], dtype=np.int32)
while True:
try:
step_shm_value = IPCSignal(
name=f"splitwise_complete_prefilled_step_{self.dp_rank_id}",
array=prefilled_step_idx_data,
dtype=np.int32,
suffix=self.gpu_id,
create=True,
)
layer_shm_value = IPCSignal(
name=f"splitwise_complete_prefilled_layer_{self.dp_rank_id}",
array=prefilled_layer_idx_data,
dtype=np.int32,
suffix=self.gpu_id,
create=True,
)
except:
step_shm_value = IPCSignal(
name=f"splitwise_complete_prefilled_step_{self.dp_rank_id}",
array=prefilled_step_idx_data,
dtype=np.int32,
suffix=self.gpu_id,
create=False,
)
layer_shm_value = IPCSignal(
name=f"splitwise_complete_prefilled_layer_{self.dp_rank_id}",
array=prefilled_layer_idx_data,
dtype=np.int32,
suffix=self.gpu_id,
create=False,
)

step_shm_value.value[0] = -1
layer_shm_value.value[0] = -1

self.last_step_idx = -1
self.last_layer_idx = -1 # int32

while True:

cache_info = self.engine_worker_queue.get_cache_info()

if cache_info:
logger.debug(f"cache info {cache_info}")
for info in cache_info:
if info["request_id"] in self.cache_info:
self.cache_info[info["request_id"]].update(info)
current_info = self.cache_info[info["request_id"]]
if "dest_block_ids" in current_info and "src_block_ids" in current_info:
decode_cached_block = len(current_info["src_block_ids"]) - len(
current_info["dest_block_ids"]
)
current_src_blocks = current_info["src_block_ids"][
-len(current_info["dest_block_ids"]) :
]
current_info["send_finished_tokens"] = decode_cached_block * self.block_size
current_info["current_tokens"] = current_info["send_finished_tokens"]
current_info["src_block_ids"] = current_src_blocks
current_info["current_layer_ids"] = 0
current_info["current_layer_ids"] = -1
current_info["current_block_num"] = 0
current_info["status"] = "init"
logger.info(f"start cache_infos: {current_info}")
self.cache_info[info["request_id"]] = current_info
self.last_step_idx = min(self.last_step_idx, current_info["current_id"])
logger.info(f"current info: {current_info}")
self.cache_info[info["request_id"]] = current_info
else:
self.cache_info[info["request_id"]] = info
prefilled_layer_idx = layer_shm_value.value[0]
prefilled_step_idx = step_shm_value.value[0]
if prefilled_layer_idx == self.num_layers - 1:
time.sleep(0.001)
prefilled_layer_idx = layer_shm_value.value[0]
prefilled_step_idx = step_shm_value.value[0]

if prefilled_step_idx == -1:
time.sleep(0.001)
continue
if not self.cache_info:
time.sleep(0.001)
time.sleep(0.005)
continue
logger.debug(f"prefilled_layer_idx: {prefilled_layer_idx}, prefilled_step_idx: {prefilled_step_idx}")

for req_id, item in list(self.cache_info.items()):
if "status" not in item:
continue
if "layer_idx" not in item:
item["layer_idx"] = 0
if item["status"] == "error":
del self.cache_info[req_id]
if "prefilled_token_num" not in self.engine_cache_tasks[item["current_id"]]:
continue
if item["current_id"] > prefilled_step_idx:
if (
self.engine_cache_tasks[item["current_id"]]["prefilled_token_num"]
<= item["send_finished_tokens"]
):
time.sleep(0.005)
continue
if (
self.engine_cache_tasks[item["current_id"]]["prefilled_token_num"] == item["current_tokens"]
and self.engine_cache_tasks[item["current_id"]]["prefilled_layer_idx"]
== item["current_layer_ids"]
):
continue

prefill_layer = self.engine_cache_tasks[item["current_id"]]["prefilled_layer_idx"]
prefill_tokens = self.engine_cache_tasks[item["current_id"]]["prefilled_token_num"]
if prefill_tokens > item["current_tokens"] and item["current_block_num"] != 0:
prefill_tokens = item["current_tokens"]
prefill_layer = self.num_layers - 1
current_block_num = item["current_block_num"]
else:
current_block_num = (prefill_tokens - item["send_finished_tokens"]) // self.block_size
if prefill_tokens == item["total_tokens"]:
current_block_num = len(item["src_block_ids"])

item["current_block_num"] = current_block_num

current_transfer_protocol = item["transfer_protocol"]
if item["transfer_protocol"] == "rdma":
target_ip = item["ip"]
Expand All @@ -257,14 +233,20 @@ def _prefill_layerwise_send_cache_thread(self):
elif item["transfer_protocol"] == "ipc":
target_ip = "0.0.0.0"
target_id = int(item["device_ids"][self.rank])
src_block_ids = paddle.to_tensor(item["src_block_ids"], dtype="int32", place="cpu")
dest_block_ids = paddle.to_tensor(item["dest_block_ids"], dtype="int32", place="cpu")
if item["current_id"] < prefilled_step_idx:
current_layer_idx = self.num_layers
else:
current_layer_idx = prefilled_layer_idx + 1

for layer_idx in range(item["layer_idx"], current_layer_idx):
src_block_ids = item["src_block_ids"][:current_block_num]
dest_block_ids = item["dest_block_ids"][:current_block_num]
src_block_ids = paddle.to_tensor(src_block_ids, dtype="int32", place="cpu")
dest_block_ids = paddle.to_tensor(dest_block_ids, dtype="int32", place="cpu")
logger.debug(
f"src_block_ids: {src_block_ids.shape}, dest_block_ids: {dest_block_ids.shape}"
f"req_id: {item['request_id']}, current_tokens: {item['current_tokens']}, prefill tokens {prefill_tokens}"
f"send_finished_tokens: {item['send_finished_tokens']}, "
f"current_layer_ids: {item['current_layer_ids']}, "
f"prefilled_layer_idx: {prefill_layer}"
)

for layer_idx in range(item["current_layer_ids"] + 1, prefill_layer + 1):
tic = time.time()
return_code = self.messager[current_transfer_protocol].write_cache(
target_ip,
Expand All @@ -282,8 +264,10 @@ def _prefill_layerwise_send_cache_thread(self):
f"write cache failed, layer_idx: {layer_idx}, "
f"req_id: {item['request_id']}, dest_ip: {target_ip}"
)
break
self.engine_cache_tasks[item["current_id"]] = dict()
del self.cache_info[req_id]

break
tok = time.time()
cost_time = tok - tic
block_num = len(src_block_ids)
Expand All @@ -295,19 +279,53 @@ def _prefill_layerwise_send_cache_thread(self):
f"block_num: {block_num}, send_cache_speed(GB/s): {round(send_cache_speed, 5)},"
f"avg_time per block(ms): {round(avg_time_per_block, 5)}"
)
item["layer_idx"] = current_layer_idx
if item["layer_idx"] == self.num_layers:
item["current_layer_ids"] = layer_idx
item["current_tokens"] = prefill_tokens
if item["current_layer_ids"] == self.num_layers - 1:
if item["transfer_protocol"] == "ipc":
self.messager["ipc"].write_block_by_sync(target_id)
logger.info(f"finish write cache {item['request_id']}")
self.engine_worker_queue.finish_request_barrier.wait()
if self.rank == 0:
self.engine_worker_queue.put_finished_req([(item["request_id"], "finished")])
logger.info(f"put write cache {item['request_id']}")
del self.cache_info[req_id]
if prefill_tokens == item["total_tokens"]:
logger.info(f"finish write cache {item['request_id']}")
self.engine_worker_queue.finish_request_barrier.wait()
if self.rank == 0:
self.engine_worker_queue.put_finished_req([(item["request_id"], "finished")])
logger.info(f"put write cache {item['request_id']}")
self.engine_cache_tasks[item["current_id"]] = dict()
del self.cache_info[req_id]

self.last_step_idx = prefilled_step_idx
self.last_layer_idx = prefilled_layer_idx
else:
item["current_layer_ids"] = -1
item["src_block_ids"] = item["src_block_ids"][current_block_num:]
item["dest_block_ids"] = item["dest_block_ids"][current_block_num:]
item["send_finished_tokens"] = prefill_tokens
item["current_block_num"] = 0

except Exception as e:
logger.error(f"prefill layerwise send cache thread has exception: {e} {traceback.format_exc()!s}")
time.sleep(0.01)

def consume_signals(self):
paddle.device.set_device("cpu")
kv_signal_data = paddle.full(shape=[512 * 3 + 2], fill_value=-1, dtype="int32")
while True:
try:
get_output_kv_signal(kv_signal_data, self.rank_id, 0) # wait_flag
if not self.cache_info:
time.sleep(0.01)
continue
tasks_count = kv_signal_data[0]
if tasks_count == -1:
time.sleep(0.001)
continue
layer_id = kv_signal_data[1].numpy().tolist()
if layer_id == self.num_layers - 1:
logger.info(f"tasks_count: {tasks_count}, layer_id: {layer_id}")

except Exception as e:
logger.error(f"prefill layerwise send cache thread has exception: {e}, {str(traceback.format_exc())}")
for bi in range(tasks_count):
engine_idx = kv_signal_data[3 * bi + 2].numpy().tolist()
chuck_token_offset = kv_signal_data[3 * bi + 3].numpy().tolist()
current_seq_len = kv_signal_data[3 * bi + 4].numpy().tolist()
self.engine_cache_tasks[engine_idx]["prefilled_layer_idx"] = layer_id
self.engine_cache_tasks[engine_idx]["prefilled_token_num"] = chuck_token_offset + current_seq_len
except Exception as e:
logger.error(f"Consume signals get exception: {e}")
11 changes: 11 additions & 0 deletions fastdeploy/cache_manager/cache_transfer_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
import concurrent.futures
import json
import queue
import threading
import time
import traceback

Expand Down Expand Up @@ -217,9 +218,19 @@ def __init__(self, args):
rank=self.rank,
nranks=args.mp_num,
num_layers=args.num_layers + self.num_extra_layers,
block_size=args.block_size,
gpu_id=self.device,
rdma_port=args.rdma_port,
)

if args.splitwise_role == "prefill":
self.send_signal_thread = threading.Thread(target=self.cache_messager.consume_signals, daemon=True)
self.send_signal_thread.start()
self.layerwise_send_cache_thread = threading.Thread(
target=self.cache_messager.prefill_layerwise_send_cache_thread, daemon=True
)
self.layerwise_send_cache_thread.start()

logger.info("successfully create cache messager")
logger.info(f"done init CacheMessager gmem alloc : {paddle.device.cuda.memory_allocated()}")

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@ def __init__(self, rank_id_, remote_gpu_id_, layer_num, local_gpu_id_):
self.local_gpu_id = int(local_gpu_id_)
tmp = paddle.ones([1, 1])
logger.info(f"init ipc rank{self.rank_id} with remote {self.remote_gpu_id} {self.local_gpu_id}")
paddle.set_device(f"gpu:{self.local_gpu_id}")
for layer_id in range(layer_num):
key_unique_name = f"key_caches_{layer_id}_rank{self.rank_id}.device{self.remote_gpu_id}"
value_unique_name = f"value_caches_{layer_id}_rank{self.rank_id}.device{self.remote_gpu_id}"
Expand Down
19 changes: 11 additions & 8 deletions fastdeploy/engine/engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -145,8 +145,8 @@ def __init__(self, cfg):
else:
self.do_profile = 0

self.partial_chunked_tokens = [0] * (self.cfg.max_num_partial_prefills + 1)
for idx in range(1, self.cfg.max_num_partial_prefills + 1):
self.partial_chunked_tokens = [0] * (self.cfg.max_prefill_batch + 1)
for idx in range(1, self.cfg.max_prefill_batch + 1):
self.partial_chunked_tokens[idx] = (
(self.cfg.max_num_batched_tokens // idx)
// self.cfg.cache_config.block_size
Expand Down Expand Up @@ -654,6 +654,7 @@ def update_tokens(idx, chunk_size, update_chunk=False):

for idx in range(len(requests)):
requests[idx].set("prefill_chunk_info", requests_chunk[idx])
requests[idx].min_tokens = len(requests_chunk[idx]) + 1

def update_mm_requests_chunk_size(self, requests):
"""
Expand Down Expand Up @@ -804,13 +805,15 @@ def insert_tasks(self, tasks, current_id=-1, allocated=False):
self.split_connector.send_cache_infos(tasks, current_id)
if not is_decode:
llm_logger.info(f"Tasks are sent to engine, req_ids={req_ids}")

if not self.cfg.enable_mm:
self.update_requests_chunk_size(tasks)
else:
self.update_mm_requests_chunk_size(tasks)
for task in tasks:
task.inference_start_time = time.time()
if not is_prefill:
if not self.cfg.enable_mm:
self.update_requests_chunk_size(tasks)
else:
self.update_mm_requests_chunk_size(tasks)
if is_prefill:
task.max_tokens = task.min_tokens
self.engine_worker_queue.put_tasks((tasks, self.resource_manager.real_bsz))
if is_prefill and self.cfg.scheduler_config.name != "splitwise":
self.engine_worker_queue.available_prefill_instances.put(1)
Expand Down Expand Up @@ -1044,7 +1047,7 @@ def _setting_environ_variables(self):
)

if self.cfg.splitwise_role != "mixed":
variables["FLAGS_use_pd_disaggregation"] = 1
variables["FLAGS_use_pd_disaggregation_per_chunk"] = 1
# TODO dynamic load environment variable
if self.cfg.splitwise_role == "prefill":
variables["FLAGS_fmt_write_cache_completed_signal"] = 1
Expand Down
7 changes: 3 additions & 4 deletions fastdeploy/output/token_processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -229,14 +229,13 @@ def _recycle_resources(self, task_id, index, task, result=None, is_prefill=False
llm_logger.info(f"finished_task_id: {finished_task_id}")
self.prefill_result_status[finished_task_id[0]] = finished_task_id[1]
if task_id in self.prefill_result_status:
if self.prefill_result_status[task_id] != "finished":
result.error_code = 400
result.error_message = f"{task_id} failed to {self.prefill_result_status[task_id]}"
self.split_connector.send_first_token(task.disaggregate_info, [result])
self.resource_manager.stop_flags[index] = True
self.resource_manager.tasks_list[index] = None
self.resource_manager._recycle_block_tables(task)
if self.prefill_result_status[task_id] != "finished":
result.error_code = 400
result.error_message = f"{task_id} failed to {self.prefill_result_status[task_id]}"
del self.resource_manager.req_dict[task_id]
break
else:
time.sleep(0.002)
Expand Down
Loading
Loading