Skip to content

Commit 4215028

Browse files
guozhihao-224claude
andcommitted
refactor(experimental): reuse HTTP clients, add response models, and parallelize ops in inference service
Replace per-request httpx/requests client creation with shared long-lived clients across the inference service stack (controller, gateway, router, data proxy, InfBridge). This eliminates repeated TCP connection setup and TLS handshake overhead on every API call. Key changes: - Controller: shared httpx.Client/AsyncClient, idempotent destroy() - Gateway: shared AsyncClient via lifespan, _use_client() helper in streaming - Router: shared AsyncClient, parallel health checks via asyncio.gather - Data proxy: Pydantic response models, shared client, parallel callbacks - InfBridge: shared AsyncClient with proper aclose() lifecycle - Parallelize: proxy registration, set_version, pause/continue broadcasts - Add Pydantic BaseModel response types across all services for type safety Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
1 parent ae8c792 commit 4215028

8 files changed

Lines changed: 686 additions & 293 deletions

File tree

areal/experimental/inference_service/controller/controller.py

Lines changed: 112 additions & 60 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323
from threading import Lock
2424
from typing import TYPE_CHECKING, Any, cast
2525

26+
import httpx
2627
from openai.types.chat import ChatCompletion, ChatCompletionChunk
2728

2829
if TYPE_CHECKING:
@@ -168,6 +169,11 @@ def __init__(
168169
# Each entry: (guard_addr, role, worker_index) for /kill_forked_worker.
169170
self._forked_services: list[tuple[str, str, int]] = []
170171

172+
# Shared HTTP clients
173+
self._sync_client = httpx.Client(timeout=30.0)
174+
self._async_client = httpx.AsyncClient(timeout=config.request_timeout)
175+
self._destroyed = False
176+
171177
# Proxy compatibility (no-ops — gateway IS the proxy)
172178
self._proxy_started = False
173179
self.proxy_workers: list = []
@@ -259,8 +265,6 @@ async def _async_initialize(
259265
"""
260266
from dataclasses import asdict
261267

262-
import requests
263-
264268
from areal.api.cli_args import SchedulingSpec, SchedulingStrategy
265269
from areal.api.scheduler_api import Job
266270

@@ -431,10 +435,9 @@ def _build_launch_cmd(
431435
# Allocate rendezvous port on head node for distributed init
432436
dist_init_addr = None
433437
if nnodes_per_instance > 1:
434-
resp = requests.post(
438+
resp = self._sync_client.post(
435439
f"{head_guard_addr}/alloc_ports",
436440
json={"count": 1},
437-
timeout=30,
438441
)
439442
resp.raise_for_status()
440443
rendezvous_data = resp.json()
@@ -449,10 +452,9 @@ def _build_launch_cmd(
449452
guard_addr = f"http://{format_hostport(worker.ip, int(worker.worker_ports[0]))}"
450453

451454
# Allocate port for inference server on this node
452-
resp = requests.post(
455+
resp = self._sync_client.post(
453456
f"{guard_addr}/alloc_ports",
454457
json={"count": 1},
455-
timeout=30,
456458
)
457459
resp.raise_for_status()
458460
port_data = resp.json()
@@ -494,10 +496,9 @@ def _build_launch_cmd(
494496
"VLLM_ALLOW_RUNTIME_LORA_UPDATING": "True",
495497
}
496498

497-
resp = requests.post(
499+
resp = self._sync_client.post(
498500
f"{guard_addr}/fork",
499501
json=fork_payload,
500-
timeout=30,
501502
)
502503
resp.raise_for_status()
503504
self._forked_services.append(
@@ -659,41 +660,54 @@ def _wait_for_service(
659660
self, url: str, name: str, timeout: float | None = None
660661
) -> None:
661662
"""Wait for a service to become healthy."""
662-
import requests
663-
664663
timeout = timeout or self.config.setup_timeout
665664
deadline = time.monotonic() + timeout
666665
while time.monotonic() < deadline:
667666
try:
668-
resp = requests.get(url, timeout=2)
667+
resp = self._sync_client.get(url, timeout=2)
669668
if resp.status_code == 200:
670669
logger.info("%s is ready at %s", name, url)
671670
return
672-
except requests.RequestException:
671+
except httpx.HTTPError:
673672
pass
674673
time.sleep(0.1)
675674
raise TimeoutError(f"{name} did not become healthy at {url} within {timeout}s")
676675

677676
def _register_data_proxies_in_router(self) -> None:
678677
"""Register all data proxy workers in the router and store their worker IDs."""
679-
import requests
678+
if not self._data_proxy_addrs:
679+
return
680680

681-
for data_proxy_addr in self._data_proxy_addrs:
682-
resp = requests.post(
683-
f"{self._router_addr}/register",
684-
json={"worker_addr": data_proxy_addr},
685-
headers={"Authorization": f"Bearer {self.config.admin_api_key}"},
686-
timeout=5,
687-
)
681+
from concurrent.futures import ThreadPoolExecutor
682+
683+
admin_key = self.config.admin_api_key
684+
router_addr = self._router_addr
685+
686+
def _register_one(data_proxy_addr: str) -> tuple[str, str | None]:
687+
# Each thread gets its own httpx.Client because httpx.Client
688+
# is not thread-safe and must not be shared across threads.
689+
with httpx.Client() as client:
690+
resp = client.post(
691+
f"{router_addr}/register",
692+
json={"worker_addr": data_proxy_addr},
693+
headers={"Authorization": f"Bearer {admin_key}"},
694+
timeout=5,
695+
)
688696
resp.raise_for_status()
689697
worker_id = resp.json().get("worker_id")
690-
if worker_id:
691-
self._worker_ids[data_proxy_addr] = worker_id
692698
logger.info(
693699
"Registered data proxy %s in router (worker_id=%s)",
694700
data_proxy_addr,
695701
worker_id,
696702
)
703+
return data_proxy_addr, worker_id
704+
705+
with ThreadPoolExecutor(max_workers=len(self._data_proxy_addrs)) as pool:
706+
results = list(pool.map(_register_one, self._data_proxy_addrs))
707+
708+
for data_proxy_addr, worker_id in results:
709+
if worker_id:
710+
self._worker_ids[data_proxy_addr] = worker_id
697711

698712
def register_model(
699713
self,
@@ -702,11 +716,9 @@ def register_model(
702716
api_key: str | None = None,
703717
data_proxy_addrs: list[str] | None = None,
704718
) -> None:
705-
import requests
706-
707719
if data_proxy_addrs is None:
708720
data_proxy_addrs = self._data_proxy_addrs
709-
resp = requests.post(
721+
resp = self._sync_client.post(
710722
f"{self._gateway_addr}/register_model",
711723
json={
712724
"model": model,
@@ -872,6 +884,9 @@ async def _handle_online_ready_callback(
872884

873885
def destroy(self) -> None:
874886
"""Tear down all services and release resources."""
887+
if self._destroyed:
888+
return
889+
self._destroyed = True
875890
self._stop_online_callback_server()
876891

877892
# Destroy workflow executor
@@ -893,6 +908,16 @@ def destroy(self) -> None:
893908
)
894909
self._forked_services.clear()
895910

911+
# Close shared HTTP clients after all kill requests have been sent
912+
self._sync_client.close()
913+
try:
914+
from areal.infra.utils.concurrent import run_async_task
915+
916+
run_async_task(self._async_client.aclose)
917+
except Exception:
918+
# Best-effort cleanup on the destroy path.
919+
pass
920+
896921
# RPCGuard's shutdown `finally` block automatically kills all
897922
# forked children, so explicit teardown above is best-effort.
898923
# Delete all RPCGuard workers via scheduler
@@ -939,8 +964,20 @@ def set_version(self, version: int) -> None:
939964

940965
async def _async_set_version(self, version: int) -> None:
941966
payload = {"version": version}
942-
for wid in self._worker_ids.values():
943-
await self._async_gateway_http_post(f"/set_version/{wid}", payload)
967+
results = await asyncio.gather(
968+
*[
969+
self._async_gateway_http_post(f"/set_version/{wid}", payload)
970+
for wid in self._worker_ids.values()
971+
],
972+
return_exceptions=True,
973+
)
974+
failed = [r for r in results if isinstance(r, Exception)]
975+
for r in failed:
976+
logger.error("Failed to set version on a worker: %s", r)
977+
if failed and len(failed) == len(results):
978+
raise RuntimeError(
979+
f"set_version({version}) failed on ALL {len(failed)} workers"
980+
)
944981

945982
def get_version(self) -> int:
946983
"""Return the local version (compatible with VersionProvider protocol)."""
@@ -1196,8 +1233,9 @@ async def chat_completion(
11961233
return self._stream_chat_completion(url, body, headers)
11971234

11981235
# Non-streaming path
1199-
timeout = aiohttp.ClientTimeout(total=self.config.request_timeout)
1200-
async with aiohttp.ClientSession(timeout=timeout) as session:
1236+
async with aiohttp.ClientSession(
1237+
timeout=aiohttp.ClientTimeout(total=self.config.request_timeout)
1238+
) as session:
12011239
async with session.post(url, json=body, headers=headers) as resp:
12021240
if resp.status != 200:
12031241
text = await resp.text()
@@ -1217,8 +1255,9 @@ async def _stream_chat_completion(
12171255
"""Parse SSE stream from the gateway into ChatCompletionChunk objects."""
12181256
import aiohttp
12191257

1220-
timeout = aiohttp.ClientTimeout(total=self.config.request_timeout)
1221-
session = aiohttp.ClientSession(timeout=timeout)
1258+
session = aiohttp.ClientSession(
1259+
timeout=aiohttp.ClientTimeout(total=self.config.request_timeout)
1260+
)
12221261
try:
12231262
resp = await session.post(url, json=body, headers=headers)
12241263
if resp.status != 200:
@@ -1270,8 +1309,20 @@ async def pause_generation(self, worker_id: str | None = None) -> None:
12701309
if worker_id is not None:
12711310
await self._async_gateway_http_post(f"/pause_generation/{worker_id}", {})
12721311
else:
1273-
for wid in self._worker_ids.values():
1274-
await self._async_gateway_http_post(f"/pause_generation/{wid}", {})
1312+
results = await asyncio.gather(
1313+
*[
1314+
self._async_gateway_http_post(f"/pause_generation/{wid}", {})
1315+
for wid in self._worker_ids.values()
1316+
],
1317+
return_exceptions=True,
1318+
)
1319+
failed = [r for r in results if isinstance(r, Exception)]
1320+
for r in failed:
1321+
logger.error("Failed to pause generation on a worker: %s", r)
1322+
if failed and len(failed) == len(results):
1323+
raise RuntimeError(
1324+
f"pause_generation failed on ALL {len(failed)} workers"
1325+
)
12751326

12761327
async def continue_generation(self, worker_id: str | None = None) -> None:
12771328
"""Continue generation on a specific worker, or all workers if worker_id is None."""
@@ -1280,8 +1331,20 @@ async def continue_generation(self, worker_id: str | None = None) -> None:
12801331
if worker_id is not None:
12811332
await self._async_gateway_http_post(f"/continue_generation/{worker_id}", {})
12821333
else:
1283-
for wid in self._worker_ids.values():
1284-
await self._async_gateway_http_post(f"/continue_generation/{wid}", {})
1334+
results = await asyncio.gather(
1335+
*[
1336+
self._async_gateway_http_post(f"/continue_generation/{wid}", {})
1337+
for wid in self._worker_ids.values()
1338+
],
1339+
return_exceptions=True,
1340+
)
1341+
failed = [r for r in results if isinstance(r, Exception)]
1342+
for r in failed:
1343+
logger.error("Failed to continue generation on a worker: %s", r)
1344+
if failed and len(failed) == len(results):
1345+
raise RuntimeError(
1346+
f"continue_generation failed on ALL {len(failed)} workers"
1347+
)
12851348

12861349
# -- Stats -------------------------------------------------------------
12871350

@@ -1505,12 +1568,9 @@ def _fork_on_guard(
15051568
Returns ``(host, port)`` of the forked service and records the entry
15061569
in ``_forked_services`` for cleanup.
15071570
"""
1508-
import requests
1509-
1510-
resp = requests.post(
1571+
resp = self._sync_client.post(
15111572
f"{guard_addr}/alloc_ports",
15121573
json={"count": 1},
1513-
timeout=30,
15141574
)
15151575
resp.raise_for_status()
15161576
port_data = resp.json()
@@ -1519,14 +1579,13 @@ def _fork_on_guard(
15191579

15201580
cmd = list(raw_cmd) + ["--host", host, "--port", str(port)]
15211581

1522-
resp = requests.post(
1582+
resp = self._sync_client.post(
15231583
f"{guard_addr}/fork",
15241584
json={
15251585
"role": role,
15261586
"worker_index": worker_index,
15271587
"raw_cmd": cmd,
15281588
},
1529-
timeout=30,
15301589
)
15311590
resp.raise_for_status()
15321591

@@ -1540,10 +1599,8 @@ def _fork_on_guard(
15401599
def _kill_forked_service(
15411600
self, guard_addr: str, role: str, worker_index: int
15421601
) -> None:
1543-
import requests
1544-
15451602
try:
1546-
resp = requests.post(
1603+
resp = self._sync_client.post(
15471604
f"{guard_addr}/kill_forked_worker",
15481605
json={"role": role, "worker_index": worker_index},
15491606
timeout=10,
@@ -1557,7 +1614,7 @@ def _kill_forked_service(
15571614
worker_index,
15581615
resp.text,
15591616
)
1560-
except requests.RequestException as exc:
1617+
except httpx.HTTPError as exc:
15611618
logger.error(
15621619
"Error killing forked service %s/%d: %s", role, worker_index, exc
15631620
)
@@ -1571,11 +1628,9 @@ def _gateway_http_post(self, endpoint: str, payload: dict[str, Any]) -> None:
15711628
Raises ``RuntimeError`` on HTTP errors or connection failures so that
15721629
callers (e.g. ``pause()`` / ``resume()``) can detect and handle them.
15731630
"""
1574-
import requests
1575-
15761631
url = f"{self._gateway_addr}{endpoint}"
15771632
try:
1578-
resp = requests.post(
1633+
resp = self._sync_client.post(
15791634
url,
15801635
json=payload,
15811636
headers={"Authorization": f"Bearer {self.config.admin_api_key}"},
@@ -1585,7 +1640,7 @@ def _gateway_http_post(self, endpoint: str, payload: dict[str, Any]) -> None:
15851640
raise RuntimeError(
15861641
f"Gateway {endpoint} returned {resp.status_code}: {resp.text}"
15871642
)
1588-
except requests.RequestException as exc:
1643+
except httpx.HTTPError as exc:
15891644
raise RuntimeError(f"Failed to POST {endpoint}: {exc}") from exc
15901645

15911646
async def _async_gateway_http_post(
@@ -1597,19 +1652,16 @@ async def _async_gateway_http_post(
15971652
callers (e.g. ``pause_generation()`` / ``continue_generation()``) can
15981653
detect and handle them.
15991654
"""
1600-
import httpx
1601-
16021655
url = f"{self._gateway_addr}{endpoint}"
16031656
try:
1604-
async with httpx.AsyncClient(timeout=self.config.request_timeout) as client:
1605-
resp = await client.post(
1606-
url,
1607-
json=payload,
1608-
headers={"Authorization": f"Bearer {self.config.admin_api_key}"},
1657+
resp = await self._async_client.post(
1658+
url,
1659+
json=payload,
1660+
headers={"Authorization": f"Bearer {self.config.admin_api_key}"},
1661+
)
1662+
if resp.status_code >= 400:
1663+
raise RuntimeError(
1664+
f"Gateway {endpoint} returned {resp.status_code}: {resp.text}"
16091665
)
1610-
if resp.status_code >= 400:
1611-
raise RuntimeError(
1612-
f"Gateway {endpoint} returned {resp.status_code}: {resp.text}"
1613-
)
16141666
except httpx.HTTPError as exc:
16151667
raise RuntimeError(f"Failed to POST {endpoint}: {exc}") from exc

0 commit comments

Comments
 (0)