diff --git a/docs/online_serving/README.md b/docs/online_serving/README.md index c87e9d51ec5..2b447476020 100644 --- a/docs/online_serving/README.md +++ b/docs/online_serving/README.md @@ -577,3 +577,4 @@ DeltaFunctionCall: - `/v1/pause` - Pause generation (causes denial of service). Inflight requests are aborted and cache is reset. - `/v1/resume` - Resume generation. - `/v1/is_paused` - Check if generation is paused. +- `/v1/abort_requests` - Abort inference requests to release GPU memory (KV Cache blocks) and compute resources. Accepts `req_ids` (list of request IDs) or `abort_all=true` (abort all requests). Returns the list of aborted requests with their generated token counts. diff --git a/docs/online_serving/router.md b/docs/online_serving/router.md index cbb0661e9fb..fc973de9a8a 100644 --- a/docs/online_serving/router.md +++ b/docs/online_serving/router.md @@ -151,6 +151,7 @@ The Router exposes a set of HTTP services to provide unified request scheduling, |----------|------|------| | POST | `/v1/chat/completions` | Provide scheduling services for inference requests based on the Chat Completions API | | POST | `/v1/completions` | Provide scheduling services for general text completion inference requests | +| POST | `/v1/abort_requests` | Abort inference requests to release GPU memory and compute resources. Accepts `req_ids` or `abort_all=true`. Returns aborted requests with their generated token counts | | POST | `/register` | Allow inference instances to register their metadata with the Router for scheduling | | GET | `/registered` | Query the list of currently registered inference instances | | GET | `/registered_number` | Query the number of currently registered inference instances | diff --git a/docs/zh/online_serving/README.md b/docs/zh/online_serving/README.md index 5c734daeb62..21f16d06e32 100644 --- a/docs/zh/online_serving/README.md +++ b/docs/zh/online_serving/README.md @@ -563,3 +563,4 @@ DeltaFunctionCall: /v1/pause - 暂停推理生成(会导致服务拒绝推理请求)。正在进行中的请求会被中止,缓存会被重置。 /v1/resume - 恢复推理生成。 /v1/is_paused - 检查推理生成是否已暂停。 +/v1/abort_requests - 中断推理请求,释放 GPU 显存(KV Cache blocks)和计算资源。支持传入 `req_ids`(请求 ID 列表)或 `abort_all=true`(中断所有请求)。返回已中断请求列表及其已生成的 token 数。 diff --git a/docs/zh/online_serving/router.md b/docs/zh/online_serving/router.md index f806dd64d3e..c5748daa7b8 100644 --- a/docs/zh/online_serving/router.md +++ b/docs/zh/online_serving/router.md @@ -152,6 +152,7 @@ Router 通过 HTTP 接口对外提供统一的调度服务,同时支持运行 |----------|------|------| | POST | `/v1/chat/completions` | 对外提供基于 Chat 接口的推理请求调度服务 | | POST | `/v1/completions` | 对外提供通用文本补全请求的调度服务 | +| POST | `/v1/abort_requests` | 中断推理请求,释放 GPU 显存和计算资源。支持传入 `req_ids` 或 `abort_all=true`,返回已中断请求列表及其已生成的 token 数 | | POST | `/register` | 推理实例向 Router 注册自身信息,用于参与调度 | | GET | `/registered` | 查询当前已注册的推理实例列表 | | GET | `/registered_number` | 查询当前已注册的推理实例数量 | diff --git a/fastdeploy/engine/common_engine.py b/fastdeploy/engine/common_engine.py index 28776b53ede..f3eaee11cf0 100644 --- a/fastdeploy/engine/common_engine.py +++ b/fastdeploy/engine/common_engine.py @@ -43,9 +43,11 @@ from fastdeploy.config import FDConfig from fastdeploy.engine.register_manager import RegisterManager from fastdeploy.engine.request import ( + CompletionOutput, ControlRequest, ControlResponse, Request, + RequestMetrics, RequestOutput, RequestStatus, RequestType, @@ -1481,6 +1483,139 @@ def _control_update_weights(self, control_request: ControlRequest) -> Optional[d return responses + def _control_abort_requests(self, control_req: ControlRequest): + if not envs.ENABLE_V1_KVCACHE_SCHEDULER: + raise Exception("abort_requests only supported in ENABLE_V1_KVCACHE_SCHEDULER") + args = control_req.get_args() + abort_all = args.get("abort_all", False) + req_ids = args.get("req_ids", []) + matched_input_ids = set() + now_reqs = list(set(self.resource_manager.requests.keys()) | set(self.scheduler.requests.keys())) + + # Step 1: Determine target request list + if abort_all: + # all requests in running + waiting + target_req_ids = now_reqs + else: + # filter out requests that actually exist + target_req_ids = [] + for rid in req_ids: + if rid in now_reqs: + target_req_ids.append(rid) + matched_input_ids.add(rid) + elif f"{rid}_0" in now_reqs: + target_req_ids.append(f"{rid}_0") + matched_input_ids.add(rid) + + if not target_req_ids: + return {"aborted": [], "not_found": req_ids if not abort_all else []} + + # Step 2: Collect partial results + aborted_info = [] + results = [] + for req_id in target_req_ids: + request = self.resource_manager.requests.get(req_id) + if request is None: + scheduled_req = self.scheduler.requests.get(req_id) + if scheduled_req is None: + continue + request = scheduled_req.raw + + partial_token_ids = list(request.output_token_ids) + + # Construct finished response with partial results + now = time.time() + abort_metrics = RequestMetrics( + arrival_time=request.metrics.arrival_time if request.metrics else now, + inference_start_time=request.metrics.inference_start_time if request.metrics else now, + engine_recv_latest_token_time=now, + engine_recv_first_token_time=request.metrics.engine_recv_first_token_time if request.metrics else now, + request_start_time=request.metrics.arrival_time if request.metrics else now, + ) + result = RequestOutput( + request_id=req_id, + finished=True, + outputs=CompletionOutput( + index=0, + send_idx=len(partial_token_ids), + token_ids=[self.data_processor.eos_token_ids[0]], + ), + metrics=abort_metrics, + error_code=200, + error_msg="Aborted", + ) + results.append(result) + aborted_info.append( + { + "request_id": req_id, + "output_token_count": len(partial_token_ids), + } + ) + + # Step 3: Execute abort — add all requests to waiting_abort_req_id_set + if envs.ENABLE_V1_KVCACHE_SCHEDULER: + for req_id in target_req_ids: + self.resource_manager.add_abort_req_ids(req_id) + time.sleep(0.0001) + if self.cfg.scheduler_config.splitwise_role != "prefill": + self._wait_abort_complete(target_req_ids) + + # Add results to scheduler, engine will have a thread calling get_results, + # then cleanup and call send_response to send to client. + # When client disconnects, send_response will automatically ignore + if self.cfg.scheduler_config.splitwise_role != "prefill": + try: + # self.send_response_server.send_response(req_id, [result]) + self.scheduler.put_results(results) + except Exception: + pass # client may have disconnected + + not_found = [rid for rid in req_ids if rid not in matched_input_ids] if not abort_all else [] + + return {"aborted": aborted_info, "not_found": not_found} + + def _wait_abort_complete(self, target_req_ids, stall_timeout=1): + """ + Wait for all abort requests to complete. + - Keep monitoring as long as remaining is not empty, which means cleanup is not done yet + - If no progress within stall_timeout seconds, force cleanup requests stuck in to_be_aborted_req_id_set, + reset progress state if any, then continue monitoring + """ + target_set = set(target_req_ids) + prev_remaining_count = len(target_set) + last_progress_time = time.time() + remaining = target_set & self.resource_manager.get_reqs_in_aborting() + while remaining: + remaining = target_set & self.resource_manager.get_reqs_in_aborting() + if not remaining: + self.llm_logger.info(f"all {len(target_set)} abort reqs cleaned") + return + + current_count = len(remaining) + if current_count < prev_remaining_count: + # progress made: recycle_abort_task was called + self.llm_logger.info(f"abort progress: {prev_remaining_count} -> {current_count}") + last_progress_time = time.time() + prev_remaining_count = current_count + + if time.time() - last_progress_time > stall_timeout: + # no progress timeout: only cleanup requests stuck in to_be_aborted (worker hasn't returned -9) + stuck = remaining & self.resource_manager.to_be_aborted_req_id_set + if stuck: + self.llm_logger.warning( + f"no abort progress for {stall_timeout}s, " + f"force cleanup {len(stuck)} stuck requests (in to_be_aborted)" + ) + for req_id in list(stuck): + self.llm_logger.warning(f"force cleanup stuck req_id:{req_id}") + self.resource_manager.recycle_abort_task(req_id) + # reset progress state + last_progress_time = time.time() + prev_remaining_count = current_count - len(stuck) + # else: remaining are all in waiting_abort_req_id_set, waiting for natural flow + + time.sleep(0.005) + def _parse_tags(self, control_request: ControlRequest): """ Parse tags from control request. diff --git a/fastdeploy/engine/sched/resource_manager_v1.py b/fastdeploy/engine/sched/resource_manager_v1.py index 0d91ea4d8bc..632c6672345 100644 --- a/fastdeploy/engine/sched/resource_manager_v1.py +++ b/fastdeploy/engine/sched/resource_manager_v1.py @@ -282,6 +282,7 @@ def recycle_abort_task(self, request_id): del self.requests[request_id] del self.req_dict[request_id] self.to_be_aborted_req_id_set.remove(request_id) + self.update_metrics() def _trigger_abort(self, request_id, scheduled_reqs): if request_id in self.requests: @@ -1207,6 +1208,9 @@ def download_bos_features(bos_client, features_urls): return None inputs["audio_features"] = result + def get_reqs_in_aborting(self): + return self.waiting_abort_req_id_set | self.to_be_aborted_req_id_set + def get_available_position(self) -> int: position = 0 while position < self.max_num_seqs: diff --git a/fastdeploy/entrypoints/openai/api_server.py b/fastdeploy/entrypoints/openai/api_server.py index a48850e2958..fcc57c86dff 100644 --- a/fastdeploy/entrypoints/openai/api_server.py +++ b/fastdeploy/entrypoints/openai/api_server.py @@ -480,6 +480,25 @@ async def update_weights(request: Request) -> Response: return control_response.to_api_json_response() +@app.post("/v1/abort_requests") +async def abort_requests(request: Request): + body = await request.json() + abort_all = body.get("abort_all", False) + req_ids = body.get("req_ids", None) + + # 参数校验 + if not abort_all and not req_ids: + return JSONResponse(status_code=400, content={"error": "must provide abort_all=true or req_ids"}) + + control_request = ControlRequest( + request_id=f"control-{uuid.uuid4()}", + method="abort_requests", + args={"abort_all": abort_all, "req_ids": req_ids or []}, + ) + control_response = await app.state.engine_client.run_control_method(control_request) + return control_response.to_api_json_response() + + def wrap_streaming_generator(original_generator: AsyncGenerator): """ Wrap an async generator to release the connection semaphore when the generator is finished. diff --git a/fastdeploy/entrypoints/openai/serving_chat.py b/fastdeploy/entrypoints/openai/serving_chat.py index 9d380b0db0c..d4fd32761a3 100644 --- a/fastdeploy/entrypoints/openai/serving_chat.py +++ b/fastdeploy/entrypoints/openai/serving_chat.py @@ -469,6 +469,9 @@ async def chat_completion_stream_generator( if res.get("error_msg") is not None and "Recover" in res["error_msg"]: choice.finish_reason = "recover_stop" + if res.get("error_msg") is not None and "Aborted" in res["error_msg"]: + choice.finish_reason = "abort" + inference_start_time[idx] = 0 if request.collect_metrics: @@ -802,6 +805,8 @@ async def _create_chat_completion_choice( if data.get("error_msg", None) is not None and "Recover" in data["error_msg"]: finish_reason = "recover_stop" + if data.get("error_msg", None) is not None and "Aborted" in data["error_msg"]: + finish_reason = "abort" return ChatCompletionResponseChoice( index=idx, message=message, diff --git a/fastdeploy/entrypoints/openai/serving_completion.py b/fastdeploy/entrypoints/openai/serving_completion.py index 4caf9fe210a..7c7ad1ae265 100644 --- a/fastdeploy/entrypoints/openai/serving_completion.py +++ b/fastdeploy/entrypoints/openai/serving_completion.py @@ -586,6 +586,8 @@ async def completion_stream_generator( output, tool_called[idx], ) + if res.get("error_msg") is not None and "Aborted" in res["error_msg"]: + choices[-1].finish_reason = "abort" inference_start_time[idx] = 0 send_idx = output.get("send_idx") @@ -726,6 +728,8 @@ def request_output_to_completion_response( output, False, ) + if final_res.get("error_msg", None) is not None and "Aborted" in final_res["error_msg"]: + finish_reason = "abort" choice_data = CompletionResponseChoice( token_ids=token_ids, diff --git a/fastdeploy/router/router.py b/fastdeploy/router/router.py index d64542b6ccc..960a64e7f58 100644 --- a/fastdeploy/router/router.py +++ b/fastdeploy/router/router.py @@ -17,8 +17,8 @@ import aiohttp import uvicorn -from fastapi import FastAPI, HTTPException -from fastapi.responses import ORJSONResponse, Response, StreamingResponse +from fastapi import FastAPI, HTTPException, Request +from fastapi.responses import JSONResponse, ORJSONResponse, Response, StreamingResponse from fastdeploy.router.utils import ( InstanceInfo, @@ -503,6 +503,48 @@ async def health_generate(): return Response(status_code=200) +@app.post("/v1/abort_requests") +async def abort_requests(request: Request): + body = await request.json() + prefill_servers = app.state.router.prefill_servers + decode_servers = app.state.router.decode_servers + all_servers = prefill_servers + decode_servers + + async with aiohttp.ClientSession() as session: + tasks = [session.post(f"{server.url()}/v1/abort_requests", json=body) for server in all_servers] + responses = await asyncio.gather(*tasks, return_exceptions=True) + + # Aggregate results from Node D only + all_aborted = [] + all_not_found = [] + errors = [] + decode_start = len(prefill_servers) + for i, (server, resp) in enumerate(zip(all_servers, responses)): + if i < decode_start: + continue + if isinstance(resp, Exception): + errors.append({"server": server.url(), "error": str(resp)}) + elif resp.status == 200: + data = await resp.json() + result = data.get("result") or {} + all_aborted.extend(result.get("aborted", [])) + all_not_found.extend(result.get("not_found", [])) + else: + errors.append({"server": server.url(), "status": resp.status}) + + return JSONResponse( + content={ + "request_id": f"router-{uuid4()}", + "status": "success" if not errors else "error", + "error_message": None if not errors else str(errors), + "result": { + "aborted": all_aborted, + "not_found": list(set(all_not_found)), + }, + } + ) + + def launch_router(router_args: RouterArgs): app.state.router_args = router_args print(f"Starting router with args: {router_args}") diff --git a/tests/engine/test_common_engine.py b/tests/engine/test_common_engine.py index 5a6241c4433..5132953023a 100644 --- a/tests/engine/test_common_engine.py +++ b/tests/engine/test_common_engine.py @@ -3504,3 +3504,215 @@ def _fake_sleep(s): # At least one sleep call was made, confirming the inner function executed self.assertGreaterEqual(call_count[0], 1) self._detach_finalizer(eng) + + # ── _control_abort_requests / _wait_abort_complete ─────────────── + + def _make_abort_engine(self, splitwise_role="mixed"): + """Create an engine wired up for abort tests.""" + extra = {} + if splitwise_role != "mixed": + extra["router"] = "0.0.0.0:9000" + cfg = self._make_cfg(splitwise_role=splitwise_role, num_gpu_blocks_override=4, **extra) + eng = self._make_engine(cfg) + eng.llm_logger = MagicMock() + + # data_processor with eos token + eng.data_processor = MagicMock() + eng.data_processor.eos_token_ids = [2] + + # resource_manager with requests dict and abort sets + eng.resource_manager = MagicMock() + eng.resource_manager.requests = {} + eng.resource_manager.waiting_abort_req_id_set = set() + eng.resource_manager.to_be_aborted_req_id_set = set() + eng.resource_manager.get_reqs_in_aborting = lambda: ( + eng.resource_manager.waiting_abort_req_id_set | eng.resource_manager.to_be_aborted_req_id_set + ) + + # scheduler with requests dict and put_results + eng.scheduler = MagicMock() + eng.scheduler.requests = {} + eng.scheduler.put_results = MagicMock() + + return eng + + def _make_fake_request(self, output_token_ids=None): + """Create a fake request object for abort tests.""" + req = MagicMock() + req.output_token_ids = output_token_ids or [10, 20, 30] + req.metrics = MagicMock() + req.metrics.arrival_time = 1000.0 + req.metrics.inference_start_time = 1000.1 + req.metrics.engine_recv_first_token_time = 1000.2 + return req + + def test_control_abort_requests_not_v1_raises(self): + """abort_requests raises when ENABLE_V1_KVCACHE_SCHEDULER is off.""" + eng = self._make_abort_engine() + control_req = ControlRequest("ctrl-1", "abort_requests", {"abort_all": True, "req_ids": []}) + with patch("fastdeploy.engine.common_engine.envs.ENABLE_V1_KVCACHE_SCHEDULER", 0): + with self.assertRaises(Exception) as ctx: + eng._control_abort_requests(control_req) + self.assertIn("only supported", str(ctx.exception)) + self._detach_finalizer(eng) + + def test_control_abort_requests_abort_all(self): + """abort_all=True aborts all requests in resource_manager + scheduler.""" + eng = self._make_abort_engine() + eng.resource_manager.requests = {"req-1_0": self._make_fake_request([10, 20])} + eng.scheduler.requests = {"req-2_0": MagicMock(raw=self._make_fake_request([30]))} + + control_req = ControlRequest("ctrl-1", "abort_requests", {"abort_all": True, "req_ids": []}) + + def clear_abort_sets(req_id): + # Simulate immediate abort completion + eng.resource_manager.waiting_abort_req_id_set.discard(req_id) + + eng.resource_manager.add_abort_req_ids = MagicMock(side_effect=clear_abort_sets) + + with patch("fastdeploy.engine.common_engine.envs.ENABLE_V1_KVCACHE_SCHEDULER", 1): + result = eng._control_abort_requests(control_req) + + self.assertEqual(len(result["aborted"]), 2) + self.assertEqual(result["not_found"], []) + ids = {a["request_id"] for a in result["aborted"]} + self.assertEqual(ids, {"req-1_0", "req-2_0"}) + # put_results should have been called (not prefill) + eng.scheduler.put_results.assert_called_once() + self._detach_finalizer(eng) + + def test_control_abort_requests_by_req_ids_with_suffix_match(self): + """req_ids match both exact and _0 suffix.""" + eng = self._make_abort_engine() + eng.resource_manager.requests = { + "req-A_0": self._make_fake_request([1, 2, 3]), + "req-B": self._make_fake_request([4, 5]), + } + + control_req = ControlRequest( + "ctrl-1", + "abort_requests", + { + "abort_all": False, + "req_ids": ["req-A", "req-B", "req-C"], + }, + ) + + def clear_abort_sets(req_id): + eng.resource_manager.waiting_abort_req_id_set.discard(req_id) + + eng.resource_manager.add_abort_req_ids = MagicMock(side_effect=clear_abort_sets) + + with patch("fastdeploy.engine.common_engine.envs.ENABLE_V1_KVCACHE_SCHEDULER", 1): + result = eng._control_abort_requests(control_req) + + aborted_ids = {a["request_id"] for a in result["aborted"]} + self.assertIn("req-A_0", aborted_ids) # matched via _0 suffix + self.assertIn("req-B", aborted_ids) # exact match + self.assertEqual(result["not_found"], ["req-C"]) + self._detach_finalizer(eng) + + def test_control_abort_requests_no_match(self): + """No requests found returns empty aborted and all in not_found.""" + eng = self._make_abort_engine() + control_req = ControlRequest( + "ctrl-1", + "abort_requests", + { + "abort_all": False, + "req_ids": ["nonexistent"], + }, + ) + + with patch("fastdeploy.engine.common_engine.envs.ENABLE_V1_KVCACHE_SCHEDULER", 1): + result = eng._control_abort_requests(control_req) + + self.assertEqual(result["aborted"], []) + self.assertEqual(result["not_found"], ["nonexistent"]) + self._detach_finalizer(eng) + + def test_control_abort_requests_prefill_skips_wait_and_put(self): + """Prefill role skips _wait_abort_complete and put_results.""" + eng = self._make_abort_engine(splitwise_role="prefill") + eng.resource_manager.requests = {"req-1_0": self._make_fake_request()} + + control_req = ControlRequest("ctrl-1", "abort_requests", {"abort_all": True, "req_ids": []}) + eng.resource_manager.add_abort_req_ids = MagicMock() + + with patch("fastdeploy.engine.common_engine.envs.ENABLE_V1_KVCACHE_SCHEDULER", 1): + result = eng._control_abort_requests(control_req) + + self.assertEqual(len(result["aborted"]), 1) + eng.scheduler.put_results.assert_not_called() + self._detach_finalizer(eng) + + def test_control_abort_requests_output_token_count(self): + """output_token_count reflects partial_token_ids length.""" + eng = self._make_abort_engine() + eng.resource_manager.requests = {"req-1_0": self._make_fake_request([10, 20, 30, 40, 50])} + + control_req = ControlRequest("ctrl-1", "abort_requests", {"abort_all": True, "req_ids": []}) + + def clear_abort_sets(req_id): + eng.resource_manager.waiting_abort_req_id_set.discard(req_id) + + eng.resource_manager.add_abort_req_ids = MagicMock(side_effect=clear_abort_sets) + + with patch("fastdeploy.engine.common_engine.envs.ENABLE_V1_KVCACHE_SCHEDULER", 1): + result = eng._control_abort_requests(control_req) + + self.assertEqual(result["aborted"][0]["output_token_count"], 5) + self._detach_finalizer(eng) + + def test_wait_abort_complete_immediate(self): + """_wait_abort_complete returns immediately when all requests already cleaned.""" + eng = self._make_abort_engine() + # Empty abort sets → remaining is empty → returns immediately + eng._wait_abort_complete(["req-1_0"]) + self._detach_finalizer(eng) + + def test_wait_abort_complete_progress(self): + """_wait_abort_complete exits when background thread cleans up.""" + eng = self._make_abort_engine() + eng.resource_manager.waiting_abort_req_id_set = {"req-1_0"} + + call_count = [0] + + def fake_sleep(s): + call_count[0] += 1 + # Simulate background thread cleaning up after first sleep + eng.resource_manager.waiting_abort_req_id_set.discard("req-1_0") + + with patch("fastdeploy.engine.common_engine.time.sleep", fake_sleep): + eng._wait_abort_complete(["req-1_0"]) + + self.assertGreaterEqual(call_count[0], 1) + self._detach_finalizer(eng) + + def test_wait_abort_complete_force_cleanup_stuck_in_to_be_aborted(self): + """Stall timeout triggers force cleanup for requests in to_be_aborted_req_id_set.""" + eng = self._make_abort_engine() + eng.resource_manager.to_be_aborted_req_id_set = {"req-1_0"} + + def mock_recycle(req_id): + eng.resource_manager.to_be_aborted_req_id_set.discard(req_id) + + eng.resource_manager.recycle_abort_task = MagicMock(side_effect=mock_recycle) + + # Make time.time() advance past stall_timeout + time_values = [100.0, 100.0, 102.0, 102.0, 102.0] + time_idx = [0] + + def fake_time(): + idx = min(time_idx[0], len(time_values) - 1) + time_idx[0] += 1 + return time_values[idx] + + with ( + patch("fastdeploy.engine.common_engine.time.time", fake_time), + patch("fastdeploy.engine.common_engine.time.sleep", lambda s: None), + ): + eng._wait_abort_complete(["req-1_0"], stall_timeout=1) + + eng.resource_manager.recycle_abort_task.assert_called_with("req-1_0") + self._detach_finalizer(eng) diff --git a/tests/entrypoints/openai/test_api_server.py b/tests/entrypoints/openai/test_api_server.py index 0cd57421701..8136dd3035c 100644 --- a/tests/entrypoints/openai/test_api_server.py +++ b/tests/entrypoints/openai/test_api_server.py @@ -809,3 +809,80 @@ def test_config_info(): api_server = _reload_api_server(args) api_server.llm_engine = None assert api_server.config_info().status_code == 500 + + +# ── /v1/abort_requests ────────────────────────────────────────────── + + +def _mock_abort_control_response(api_server, result, status_code=200): + mock_resp = MagicMock() + mock_resp.to_api_json_response.return_value = api_server.JSONResponse( + content={"request_id": "control-test", "status": "success", "error_message": None, "result": result}, + status_code=status_code, + ) + api_server.app.state.engine_client = MagicMock() + api_server.app.state.engine_client.run_control_method = AsyncMock(return_value=mock_resp) + + +@pytest.mark.asyncio +async def test_abort_requests_with_req_ids(): + args = _build_args() + api_server = _reload_api_server(args) + _mock_abort_control_response( + api_server, + { + "aborted": [{"request_id": "req-1_0", "output_token_count": 10}], + "not_found": ["req-999"], + }, + ) + req = MagicMock() + req.json = AsyncMock(return_value={"req_ids": ["req-1", "req-999"]}) + resp = await api_server.abort_requests(req) + assert resp.status_code == 200 + control_req = api_server.app.state.engine_client.run_control_method.await_args.args[0] + assert control_req.method == "abort_requests" + assert control_req.args["req_ids"] == ["req-1", "req-999"] + assert control_req.args["abort_all"] is False + + +@pytest.mark.asyncio +async def test_abort_requests_with_abort_all(): + args = _build_args() + api_server = _reload_api_server(args) + _mock_abort_control_response( + api_server, + { + "aborted": [ + {"request_id": "req-1_0", "output_token_count": 5}, + {"request_id": "req-2_0", "output_token_count": 12}, + ], + "not_found": [], + }, + ) + req = MagicMock() + req.json = AsyncMock(return_value={"abort_all": True}) + resp = await api_server.abort_requests(req) + assert resp.status_code == 200 + control_req = api_server.app.state.engine_client.run_control_method.await_args.args[0] + assert control_req.args["abort_all"] is True + assert control_req.args["req_ids"] == [] + + +@pytest.mark.asyncio +async def test_abort_requests_missing_params(): + args = _build_args() + api_server = _reload_api_server(args) + req = MagicMock() + req.json = AsyncMock(return_value={}) + resp = await api_server.abort_requests(req) + assert resp.status_code == 400 + + +@pytest.mark.asyncio +async def test_abort_requests_empty_req_ids(): + args = _build_args() + api_server = _reload_api_server(args) + req = MagicMock() + req.json = AsyncMock(return_value={"req_ids": []}) + resp = await api_server.abort_requests(req) + assert resp.status_code == 400 diff --git a/tests/router/test_router.py b/tests/router/test_router.py index 4b9476883fb..aa5be52f2f1 100644 --- a/tests/router/test_router.py +++ b/tests/router/test_router.py @@ -22,7 +22,7 @@ import unittest from types import SimpleNamespace -from unittest.mock import AsyncMock, patch +from unittest.mock import AsyncMock, MagicMock, patch from fastdeploy.router.router import Router, RouterArgs @@ -144,5 +144,170 @@ async def test_registered_number(self, mock_health): self.assertEqual(result["decode"], 0) +class TestRouterAbortRequests(unittest.IsolatedAsyncioTestCase): + """Tests for /v1/abort_requests route in router.py.""" + + def _make_mock_session(self, responses): + """Create a mock aiohttp.ClientSession where post() returns coroutines.""" + mock_session = MagicMock() + call_count = 0 + + def post_side_effect(*args, **kwargs): + nonlocal call_count + resp = responses[call_count] + call_count += 1 + if isinstance(resp, Exception): + raise resp + + async def _coro(): + return resp + + return _coro() + + mock_session.post = MagicMock(side_effect=post_side_effect) + mock_session.__aenter__ = AsyncMock(return_value=mock_session) + mock_session.__aexit__ = AsyncMock(return_value=False) + return mock_session + + @patch("fastdeploy.router.router.check_service_health_async", new_callable=AsyncMock, return_value=True) + async def test_abort_broadcasts_to_all_but_returns_decode_only(self, mock_health): + """P and D both receive the request, but only D results are aggregated.""" + from fastdeploy.router.router import abort_requests as abort_fn + from fastdeploy.router.router import app + + router = Router(_make_args(splitwise=True)) + await router.register_instance(_make_instance_dict(host_ip="10.0.0.1", port=8001, role="prefill")) + await router.register_instance(_make_instance_dict(host_ip="10.0.0.2", port=8002, role="decode")) + app.state.router = router + + prefill_resp = AsyncMock() + prefill_resp.status = 200 + prefill_resp.json = AsyncMock( + return_value={ + "request_id": "control-p", + "status": "success", + "error_message": None, + "result": {"aborted": [{"request_id": "req-1_0", "output_token_count": 0}], "not_found": []}, + } + ) + decode_resp = AsyncMock() + decode_resp.status = 200 + decode_resp.json = AsyncMock( + return_value={ + "request_id": "control-d", + "status": "success", + "error_message": None, + "result": {"aborted": [{"request_id": "req-1_0", "output_token_count": 15}], "not_found": []}, + } + ) + + mock_session = self._make_mock_session([prefill_resp, decode_resp]) + mock_request = AsyncMock() + mock_request.json = AsyncMock(return_value={"req_ids": ["req-1"]}) + + with patch("fastdeploy.router.router.aiohttp.ClientSession", return_value=mock_session): + resp = await abort_fn(mock_request) + + import json + + body = json.loads(resp.body) + self.assertEqual(len(body["result"]["aborted"]), 1) + self.assertEqual(body["result"]["aborted"][0]["output_token_count"], 15) + self.assertEqual(body["status"], "success") + self.assertEqual(mock_session.post.call_count, 2) + + @patch("fastdeploy.router.router.check_service_health_async", new_callable=AsyncMock, return_value=True) + async def test_abort_decode_error_returns_error_status(self, mock_health): + """When D node returns a non-200 status, status should be 'error'.""" + from fastdeploy.router.router import abort_requests as abort_fn + from fastdeploy.router.router import app + + router = Router(_make_args(splitwise=True)) + await router.register_instance(_make_instance_dict(host_ip="10.0.0.1", port=8001, role="prefill")) + await router.register_instance(_make_instance_dict(host_ip="10.0.0.2", port=8002, role="decode")) + app.state.router = router + + prefill_resp = AsyncMock() + prefill_resp.status = 200 + prefill_resp.json = AsyncMock( + return_value={ + "request_id": "control-p", + "status": "success", + "error_message": None, + "result": {"aborted": [], "not_found": []}, + } + ) + decode_resp = AsyncMock() + decode_resp.status = 500 + + mock_session = self._make_mock_session([prefill_resp, decode_resp]) + mock_request = AsyncMock() + mock_request.json = AsyncMock(return_value={"abort_all": True}) + + with patch("fastdeploy.router.router.aiohttp.ClientSession", return_value=mock_session): + resp = await abort_fn(mock_request) + + import json + + body = json.loads(resp.body) + self.assertEqual(body["status"], "error") + self.assertIsNotNone(body["error_message"]) + + @patch("fastdeploy.router.router.check_service_health_async", new_callable=AsyncMock, return_value=True) + async def test_abort_decode_exception_returns_error(self, mock_health): + """When D node connection fails (exception), error should be captured.""" + from fastdeploy.router.router import abort_requests as abort_fn + from fastdeploy.router.router import app + + router = Router(_make_args(splitwise=True)) + await router.register_instance(_make_instance_dict(host_ip="10.0.0.1", port=8001, role="prefill")) + await router.register_instance(_make_instance_dict(host_ip="10.0.0.2", port=8002, role="decode")) + app.state.router = router + + prefill_resp = AsyncMock() + prefill_resp.status = 200 + prefill_resp.json = AsyncMock( + return_value={ + "request_id": "control-p", + "status": "success", + "error_message": None, + "result": {"aborted": [], "not_found": []}, + } + ) + + # D node raises exception — but asyncio.gather(return_exceptions=True) captures it + # So we pass the exception as a response directly + mock_session = self._make_mock_session([prefill_resp, prefill_resp]) # placeholder + call_idx = [0] + + def post_with_exception(*args, **kwargs): + call_idx[0] += 1 + if call_idx[0] == 1: + # prefill: normal + async def _coro(): + return prefill_resp + + return _coro() + else: + # decode: raise (gather with return_exceptions=True will catch) + async def _coro_err(): + raise ConnectionError("refused") + + return _coro_err() + + mock_session.post = MagicMock(side_effect=post_with_exception) + mock_request = AsyncMock() + mock_request.json = AsyncMock(return_value={"abort_all": True}) + + with patch("fastdeploy.router.router.aiohttp.ClientSession", return_value=mock_session): + resp = await abort_fn(mock_request) + + import json + + body = json.loads(resp.body) + self.assertEqual(body["status"], "error") + self.assertIn("refused", body["error_message"]) + + if __name__ == "__main__": unittest.main()