Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions docs/online_serving/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -577,3 +577,4 @@ DeltaFunctionCall:
- `/v1/pause` - Pause generation (causes denial of service). Inflight requests are aborted and cache is reset.
- `/v1/resume` - Resume generation.
- `/v1/is_paused` - Check if generation is paused.
- `/v1/abort_requests` - Abort inference requests to release GPU memory (KV Cache blocks) and compute resources. Accepts `req_ids` (list of request IDs) or `abort_all=true` (abort all requests). Returns the list of aborted requests with their generated token counts.
1 change: 1 addition & 0 deletions docs/online_serving/router.md
Original file line number Diff line number Diff line change
Expand Up @@ -151,6 +151,7 @@ The Router exposes a set of HTTP services to provide unified request scheduling,
|----------|------|------|
| POST | `/v1/chat/completions` | Provide scheduling services for inference requests based on the Chat Completions API |
| POST | `/v1/completions` | Provide scheduling services for general text completion inference requests |
| POST | `/v1/abort_requests` | Abort inference requests to release GPU memory and compute resources. Accepts `req_ids` or `abort_all=true`. Returns aborted requests with their generated token counts |
| POST | `/register` | Allow inference instances to register their metadata with the Router for scheduling |
| GET | `/registered` | Query the list of currently registered inference instances |
| GET | `/registered_number` | Query the number of currently registered inference instances |
Expand Down
1 change: 1 addition & 0 deletions docs/zh/online_serving/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -563,3 +563,4 @@ DeltaFunctionCall:
/v1/pause - 暂停推理生成(会导致服务拒绝推理请求)。正在进行中的请求会被中止,缓存会被重置。
/v1/resume - 恢复推理生成。
/v1/is_paused - 检查推理生成是否已暂停。
/v1/abort_requests - 中断推理请求,释放 GPU 显存(KV Cache blocks)和计算资源。支持传入 `req_ids`(请求 ID 列表)或 `abort_all=true`(中断所有请求)。返回已中断请求列表及其已生成的 token 数。
1 change: 1 addition & 0 deletions docs/zh/online_serving/router.md
Original file line number Diff line number Diff line change
Expand Up @@ -152,6 +152,7 @@ Router 通过 HTTP 接口对外提供统一的调度服务,同时支持运行
|----------|------|------|
| POST | `/v1/chat/completions` | 对外提供基于 Chat 接口的推理请求调度服务 |
| POST | `/v1/completions` | 对外提供通用文本补全请求的调度服务 |
| POST | `/v1/abort_requests` | 中断推理请求,释放 GPU 显存和计算资源。支持传入 `req_ids` 或 `abort_all=true`,返回已中断请求列表及其已生成的 token 数 |
| POST | `/register` | 推理实例向 Router 注册自身信息,用于参与调度 |
| GET | `/registered` | 查询当前已注册的推理实例列表 |
| GET | `/registered_number` | 查询当前已注册的推理实例数量 |
Expand Down
135 changes: 135 additions & 0 deletions fastdeploy/engine/common_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,9 +43,11 @@
from fastdeploy.config import FDConfig
from fastdeploy.engine.register_manager import RegisterManager
from fastdeploy.engine.request import (
CompletionOutput,
ControlRequest,
ControlResponse,
Request,
RequestMetrics,
RequestOutput,
RequestStatus,
RequestType,
Expand Down Expand Up @@ -1481,6 +1483,139 @@ def _control_update_weights(self, control_request: ControlRequest) -> Optional[d

return responses

def _control_abort_requests(self, control_req: ControlRequest):
if not envs.ENABLE_V1_KVCACHE_SCHEDULER:
raise Exception("abort_requests only supported in ENABLE_V1_KVCACHE_SCHEDULER")
args = control_req.get_args()
abort_all = args.get("abort_all", False)
req_ids = args.get("req_ids", [])
matched_input_ids = set()
now_reqs = list(set(self.resource_manager.requests.keys()) | set(self.scheduler.requests.keys()))

# Step 1: Determine target request list
if abort_all:
# all requests in running + waiting
target_req_ids = now_reqs
else:
# filter out requests that actually exist
target_req_ids = []
for rid in req_ids:
if rid in now_reqs:
target_req_ids.append(rid)
matched_input_ids.add(rid)
elif f"{rid}_0" in now_reqs:
target_req_ids.append(f"{rid}_0")
matched_input_ids.add(rid)

if not target_req_ids:
return {"aborted": [], "not_found": req_ids if not abort_all else []}

# Step 2: Collect partial results
aborted_info = []
results = []
for req_id in target_req_ids:
request = self.resource_manager.requests.get(req_id)
if request is None:
scheduled_req = self.scheduler.requests.get(req_id)
if scheduled_req is None:
continue
request = scheduled_req.raw

partial_token_ids = list(request.output_token_ids)

# Construct finished response with partial results
now = time.time()
abort_metrics = RequestMetrics(
arrival_time=request.metrics.arrival_time if request.metrics else now,
inference_start_time=request.metrics.inference_start_time if request.metrics else now,
engine_recv_latest_token_time=now,
engine_recv_first_token_time=request.metrics.engine_recv_first_token_time if request.metrics else now,
request_start_time=request.metrics.arrival_time if request.metrics else now,
)
result = RequestOutput(
request_id=req_id,
finished=True,
outputs=CompletionOutput(
index=0,
send_idx=len(partial_token_ids),
token_ids=[self.data_processor.eos_token_ids[0]],
),
metrics=abort_metrics,
error_code=200,
error_msg="Aborted",
)
results.append(result)
aborted_info.append(
{
"request_id": req_id,
"output_token_count": len(partial_token_ids),
}
)

# Step 3: Execute abort — add all requests to waiting_abort_req_id_set
if envs.ENABLE_V1_KVCACHE_SCHEDULER:
for req_id in target_req_ids:
self.resource_manager.add_abort_req_ids(req_id)
time.sleep(0.0001)
if self.cfg.scheduler_config.splitwise_role != "prefill":
self._wait_abort_complete(target_req_ids)

# Add results to scheduler, engine will have a thread calling get_results,
# then cleanup and call send_response to send to client.
# When client disconnects, send_response will automatically ignore
if self.cfg.scheduler_config.splitwise_role != "prefill":
try:
# self.send_response_server.send_response(req_id, [result])
self.scheduler.put_results(results)
except Exception:
pass # client may have disconnected

not_found = [rid for rid in req_ids if rid not in matched_input_ids] if not abort_all else []

return {"aborted": aborted_info, "not_found": not_found}

def _wait_abort_complete(self, target_req_ids, stall_timeout=1):
"""
Wait for all abort requests to complete.
- Keep monitoring as long as remaining is not empty, which means cleanup is not done yet
- If no progress within stall_timeout seconds, force cleanup requests stuck in to_be_aborted_req_id_set,
reset progress state if any, then continue monitoring
"""
target_set = set(target_req_ids)
prev_remaining_count = len(target_set)
last_progress_time = time.time()
remaining = target_set & self.resource_manager.get_reqs_in_aborting()
while remaining:
remaining = target_set & self.resource_manager.get_reqs_in_aborting()
if not remaining:
self.llm_logger.info(f"all {len(target_set)} abort reqs cleaned")
return

current_count = len(remaining)
if current_count < prev_remaining_count:
# progress made: recycle_abort_task was called
self.llm_logger.info(f"abort progress: {prev_remaining_count} -> {current_count}")
last_progress_time = time.time()
prev_remaining_count = current_count

if time.time() - last_progress_time > stall_timeout:
# no progress timeout: only cleanup requests stuck in to_be_aborted (worker hasn't returned -9)
stuck = remaining & self.resource_manager.to_be_aborted_req_id_set
if stuck:
self.llm_logger.warning(
f"no abort progress for {stall_timeout}s, "
f"force cleanup {len(stuck)} stuck requests (in to_be_aborted)"
)
for req_id in list(stuck):
self.llm_logger.warning(f"force cleanup stuck req_id:{req_id}")
self.resource_manager.recycle_abort_task(req_id)
# reset progress state
last_progress_time = time.time()
prev_remaining_count = current_count - len(stuck)
# else: remaining are all in waiting_abort_req_id_set, waiting for natural flow

time.sleep(0.005)

def _parse_tags(self, control_request: ControlRequest):
"""
Parse tags from control request.
Expand Down
4 changes: 4 additions & 0 deletions fastdeploy/engine/sched/resource_manager_v1.py
Original file line number Diff line number Diff line change
Expand Up @@ -282,6 +282,7 @@ def recycle_abort_task(self, request_id):
del self.requests[request_id]
del self.req_dict[request_id]
self.to_be_aborted_req_id_set.remove(request_id)
self.update_metrics()

def _trigger_abort(self, request_id, scheduled_reqs):
if request_id in self.requests:
Expand Down Expand Up @@ -1207,6 +1208,9 @@ def download_bos_features(bos_client, features_urls):
return None
inputs["audio_features"] = result

def get_reqs_in_aborting(self):
return self.waiting_abort_req_id_set | self.to_be_aborted_req_id_set

def get_available_position(self) -> int:
position = 0
while position < self.max_num_seqs:
Expand Down
19 changes: 19 additions & 0 deletions fastdeploy/entrypoints/openai/api_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -480,6 +480,25 @@ async def update_weights(request: Request) -> Response:
return control_response.to_api_json_response()


@app.post("/v1/abort_requests")
async def abort_requests(request: Request):
body = await request.json()
abort_all = body.get("abort_all", False)
req_ids = body.get("req_ids", None)

# 参数校验
if not abort_all and not req_ids:
return JSONResponse(status_code=400, content={"error": "must provide abort_all=true or req_ids"})

control_request = ControlRequest(
request_id=f"control-{uuid.uuid4()}",
method="abort_requests",
args={"abort_all": abort_all, "req_ids": req_ids or []},
)
control_response = await app.state.engine_client.run_control_method(control_request)
return control_response.to_api_json_response()


def wrap_streaming_generator(original_generator: AsyncGenerator):
"""
Wrap an async generator to release the connection semaphore when the generator is finished.
Expand Down
5 changes: 5 additions & 0 deletions fastdeploy/entrypoints/openai/serving_chat.py
Original file line number Diff line number Diff line change
Expand Up @@ -469,6 +469,9 @@ async def chat_completion_stream_generator(
if res.get("error_msg") is not None and "Recover" in res["error_msg"]:
choice.finish_reason = "recover_stop"

if res.get("error_msg") is not None and "Aborted" in res["error_msg"]:
choice.finish_reason = "abort"

inference_start_time[idx] = 0

if request.collect_metrics:
Expand Down Expand Up @@ -802,6 +805,8 @@ async def _create_chat_completion_choice(
if data.get("error_msg", None) is not None and "Recover" in data["error_msg"]:
finish_reason = "recover_stop"

if data.get("error_msg", None) is not None and "Aborted" in data["error_msg"]:
finish_reason = "abort"
return ChatCompletionResponseChoice(
index=idx,
message=message,
Expand Down
4 changes: 4 additions & 0 deletions fastdeploy/entrypoints/openai/serving_completion.py
Original file line number Diff line number Diff line change
Expand Up @@ -586,6 +586,8 @@ async def completion_stream_generator(
output,
tool_called[idx],
)
if res.get("error_msg") is not None and "Aborted" in res["error_msg"]:
choices[-1].finish_reason = "abort"
inference_start_time[idx] = 0

send_idx = output.get("send_idx")
Expand Down Expand Up @@ -726,6 +728,8 @@ def request_output_to_completion_response(
output,
False,
)
if final_res.get("error_msg", None) is not None and "Aborted" in final_res["error_msg"]:
finish_reason = "abort"

choice_data = CompletionResponseChoice(
token_ids=token_ids,
Expand Down
46 changes: 44 additions & 2 deletions fastdeploy/router/router.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,8 @@

import aiohttp
import uvicorn
from fastapi import FastAPI, HTTPException
from fastapi.responses import ORJSONResponse, Response, StreamingResponse
from fastapi import FastAPI, HTTPException, Request
from fastapi.responses import JSONResponse, ORJSONResponse, Response, StreamingResponse

from fastdeploy.router.utils import (
InstanceInfo,
Expand Down Expand Up @@ -503,6 +503,48 @@ async def health_generate():
return Response(status_code=200)


@app.post("/v1/abort_requests")
async def abort_requests(request: Request):
body = await request.json()
prefill_servers = app.state.router.prefill_servers
decode_servers = app.state.router.decode_servers
all_servers = prefill_servers + decode_servers

async with aiohttp.ClientSession() as session:
tasks = [session.post(f"{server.url()}/v1/abort_requests", json=body) for server in all_servers]
responses = await asyncio.gather(*tasks, return_exceptions=True)

# Aggregate results from Node D only
all_aborted = []
all_not_found = []
errors = []
decode_start = len(prefill_servers)
for i, (server, resp) in enumerate(zip(all_servers, responses)):
if i < decode_start:
continue
if isinstance(resp, Exception):
errors.append({"server": server.url(), "error": str(resp)})
elif resp.status == 200:
data = await resp.json()
result = data.get("result") or {}
all_aborted.extend(result.get("aborted", []))
all_not_found.extend(result.get("not_found", []))
else:
errors.append({"server": server.url(), "status": resp.status})

return JSONResponse(
content={
"request_id": f"router-{uuid4()}",
"status": "success" if not errors else "error",
"error_message": None if not errors else str(errors),
"result": {
"aborted": all_aborted,
"not_found": list(set(all_not_found)),
},
}
)


def launch_router(router_args: RouterArgs):
app.state.router_args = router_args
print(f"Starting router with args: {router_args}")
Expand Down
Loading
Loading