2323from threading import Lock
2424from typing import TYPE_CHECKING , Any , cast
2525
26+ import httpx
2627from openai .types .chat import ChatCompletion , ChatCompletionChunk
2728
2829if TYPE_CHECKING :
@@ -168,6 +169,11 @@ def __init__(
168169 # Each entry: (guard_addr, role, worker_index) for /kill_forked_worker.
169170 self ._forked_services : list [tuple [str , str , int ]] = []
170171
172+ # Shared HTTP clients
173+ self ._sync_client = httpx .Client (timeout = 30.0 )
174+ self ._async_client = httpx .AsyncClient (timeout = config .request_timeout )
175+ self ._destroyed = False
176+
171177 # Proxy compatibility (no-ops — gateway IS the proxy)
172178 self ._proxy_started = False
173179 self .proxy_workers : list = []
@@ -259,8 +265,6 @@ async def _async_initialize(
259265 """
260266 from dataclasses import asdict
261267
262- import requests
263-
264268 from areal .api .cli_args import SchedulingSpec , SchedulingStrategy
265269 from areal .api .scheduler_api import Job
266270
@@ -431,10 +435,9 @@ def _build_launch_cmd(
431435 # Allocate rendezvous port on head node for distributed init
432436 dist_init_addr = None
433437 if nnodes_per_instance > 1 :
434- resp = requests .post (
438+ resp = self . _sync_client .post (
435439 f"{ head_guard_addr } /alloc_ports" ,
436440 json = {"count" : 1 },
437- timeout = 30 ,
438441 )
439442 resp .raise_for_status ()
440443 rendezvous_data = resp .json ()
@@ -449,10 +452,9 @@ def _build_launch_cmd(
449452 guard_addr = f"http://{ format_hostport (worker .ip , int (worker .worker_ports [0 ]))} "
450453
451454 # Allocate port for inference server on this node
452- resp = requests .post (
455+ resp = self . _sync_client .post (
453456 f"{ guard_addr } /alloc_ports" ,
454457 json = {"count" : 1 },
455- timeout = 30 ,
456458 )
457459 resp .raise_for_status ()
458460 port_data = resp .json ()
@@ -494,10 +496,9 @@ def _build_launch_cmd(
494496 "VLLM_ALLOW_RUNTIME_LORA_UPDATING" : "True" ,
495497 }
496498
497- resp = requests .post (
499+ resp = self . _sync_client .post (
498500 f"{ guard_addr } /fork" ,
499501 json = fork_payload ,
500- timeout = 30 ,
501502 )
502503 resp .raise_for_status ()
503504 self ._forked_services .append (
@@ -659,41 +660,54 @@ def _wait_for_service(
659660 self , url : str , name : str , timeout : float | None = None
660661 ) -> None :
661662 """Wait for a service to become healthy."""
662- import requests
663-
664663 timeout = timeout or self .config .setup_timeout
665664 deadline = time .monotonic () + timeout
666665 while time .monotonic () < deadline :
667666 try :
668- resp = requests .get (url , timeout = 2 )
667+ resp = self . _sync_client .get (url , timeout = 2 )
669668 if resp .status_code == 200 :
670669 logger .info ("%s is ready at %s" , name , url )
671670 return
672- except requests . RequestException :
671+ except httpx . HTTPError :
673672 pass
674673 time .sleep (0.1 )
675674 raise TimeoutError (f"{ name } did not become healthy at { url } within { timeout } s" )
676675
677676 def _register_data_proxies_in_router (self ) -> None :
678677 """Register all data proxy workers in the router and store their worker IDs."""
679- import requests
678+ if not self ._data_proxy_addrs :
679+ return
680680
681- for data_proxy_addr in self ._data_proxy_addrs :
682- resp = requests .post (
683- f"{ self ._router_addr } /register" ,
684- json = {"worker_addr" : data_proxy_addr },
685- headers = {"Authorization" : f"Bearer { self .config .admin_api_key } " },
686- timeout = 5 ,
687- )
681+ from concurrent .futures import ThreadPoolExecutor
682+
683+ admin_key = self .config .admin_api_key
684+ router_addr = self ._router_addr
685+
686+ def _register_one (data_proxy_addr : str ) -> tuple [str , str | None ]:
687+ # Each thread gets its own httpx.Client because httpx.Client
688+ # is not thread-safe and must not be shared across threads.
689+ with httpx .Client () as client :
690+ resp = client .post (
691+ f"{ router_addr } /register" ,
692+ json = {"worker_addr" : data_proxy_addr },
693+ headers = {"Authorization" : f"Bearer { admin_key } " },
694+ timeout = 5 ,
695+ )
688696 resp .raise_for_status ()
689697 worker_id = resp .json ().get ("worker_id" )
690- if worker_id :
691- self ._worker_ids [data_proxy_addr ] = worker_id
692698 logger .info (
693699 "Registered data proxy %s in router (worker_id=%s)" ,
694700 data_proxy_addr ,
695701 worker_id ,
696702 )
703+ return data_proxy_addr , worker_id
704+
705+ with ThreadPoolExecutor (max_workers = len (self ._data_proxy_addrs )) as pool :
706+ results = list (pool .map (_register_one , self ._data_proxy_addrs ))
707+
708+ for data_proxy_addr , worker_id in results :
709+ if worker_id :
710+ self ._worker_ids [data_proxy_addr ] = worker_id
697711
698712 def register_model (
699713 self ,
@@ -702,11 +716,9 @@ def register_model(
702716 api_key : str | None = None ,
703717 data_proxy_addrs : list [str ] | None = None ,
704718 ) -> None :
705- import requests
706-
707719 if data_proxy_addrs is None :
708720 data_proxy_addrs = self ._data_proxy_addrs
709- resp = requests .post (
721+ resp = self . _sync_client .post (
710722 f"{ self ._gateway_addr } /register_model" ,
711723 json = {
712724 "model" : model ,
@@ -872,6 +884,9 @@ async def _handle_online_ready_callback(
872884
873885 def destroy (self ) -> None :
874886 """Tear down all services and release resources."""
887+ if self ._destroyed :
888+ return
889+ self ._destroyed = True
875890 self ._stop_online_callback_server ()
876891
877892 # Destroy workflow executor
@@ -893,6 +908,16 @@ def destroy(self) -> None:
893908 )
894909 self ._forked_services .clear ()
895910
911+ # Close shared HTTP clients after all kill requests have been sent
912+ self ._sync_client .close ()
913+ try :
914+ from areal .infra .utils .concurrent import run_async_task
915+
916+ run_async_task (self ._async_client .aclose )
917+ except Exception :
918+ # Best-effort cleanup on the destroy path.
919+ pass
920+
896921 # RPCGuard's shutdown `finally` block automatically kills all
897922 # forked children, so explicit teardown above is best-effort.
898923 # Delete all RPCGuard workers via scheduler
@@ -939,8 +964,20 @@ def set_version(self, version: int) -> None:
939964
940965 async def _async_set_version (self , version : int ) -> None :
941966 payload = {"version" : version }
942- for wid in self ._worker_ids .values ():
943- await self ._async_gateway_http_post (f"/set_version/{ wid } " , payload )
967+ results = await asyncio .gather (
968+ * [
969+ self ._async_gateway_http_post (f"/set_version/{ wid } " , payload )
970+ for wid in self ._worker_ids .values ()
971+ ],
972+ return_exceptions = True ,
973+ )
974+ failed = [r for r in results if isinstance (r , Exception )]
975+ for r in failed :
976+ logger .error ("Failed to set version on a worker: %s" , r )
977+ if failed and len (failed ) == len (results ):
978+ raise RuntimeError (
979+ f"set_version({ version } ) failed on ALL { len (failed )} workers"
980+ )
944981
945982 def get_version (self ) -> int :
946983 """Return the local version (compatible with VersionProvider protocol)."""
@@ -1196,8 +1233,9 @@ async def chat_completion(
11961233 return self ._stream_chat_completion (url , body , headers )
11971234
11981235 # Non-streaming path
1199- timeout = aiohttp .ClientTimeout (total = self .config .request_timeout )
1200- async with aiohttp .ClientSession (timeout = timeout ) as session :
1236+ async with aiohttp .ClientSession (
1237+ timeout = aiohttp .ClientTimeout (total = self .config .request_timeout )
1238+ ) as session :
12011239 async with session .post (url , json = body , headers = headers ) as resp :
12021240 if resp .status != 200 :
12031241 text = await resp .text ()
@@ -1217,8 +1255,9 @@ async def _stream_chat_completion(
12171255 """Parse SSE stream from the gateway into ChatCompletionChunk objects."""
12181256 import aiohttp
12191257
1220- timeout = aiohttp .ClientTimeout (total = self .config .request_timeout )
1221- session = aiohttp .ClientSession (timeout = timeout )
1258+ session = aiohttp .ClientSession (
1259+ timeout = aiohttp .ClientTimeout (total = self .config .request_timeout )
1260+ )
12221261 try :
12231262 resp = await session .post (url , json = body , headers = headers )
12241263 if resp .status != 200 :
@@ -1270,8 +1309,20 @@ async def pause_generation(self, worker_id: str | None = None) -> None:
12701309 if worker_id is not None :
12711310 await self ._async_gateway_http_post (f"/pause_generation/{ worker_id } " , {})
12721311 else :
1273- for wid in self ._worker_ids .values ():
1274- await self ._async_gateway_http_post (f"/pause_generation/{ wid } " , {})
1312+ results = await asyncio .gather (
1313+ * [
1314+ self ._async_gateway_http_post (f"/pause_generation/{ wid } " , {})
1315+ for wid in self ._worker_ids .values ()
1316+ ],
1317+ return_exceptions = True ,
1318+ )
1319+ failed = [r for r in results if isinstance (r , Exception )]
1320+ for r in failed :
1321+ logger .error ("Failed to pause generation on a worker: %s" , r )
1322+ if failed and len (failed ) == len (results ):
1323+ raise RuntimeError (
1324+ f"pause_generation failed on ALL { len (failed )} workers"
1325+ )
12751326
12761327 async def continue_generation (self , worker_id : str | None = None ) -> None :
12771328 """Continue generation on a specific worker, or all workers if worker_id is None."""
@@ -1280,8 +1331,20 @@ async def continue_generation(self, worker_id: str | None = None) -> None:
12801331 if worker_id is not None :
12811332 await self ._async_gateway_http_post (f"/continue_generation/{ worker_id } " , {})
12821333 else :
1283- for wid in self ._worker_ids .values ():
1284- await self ._async_gateway_http_post (f"/continue_generation/{ wid } " , {})
1334+ results = await asyncio .gather (
1335+ * [
1336+ self ._async_gateway_http_post (f"/continue_generation/{ wid } " , {})
1337+ for wid in self ._worker_ids .values ()
1338+ ],
1339+ return_exceptions = True ,
1340+ )
1341+ failed = [r for r in results if isinstance (r , Exception )]
1342+ for r in failed :
1343+ logger .error ("Failed to continue generation on a worker: %s" , r )
1344+ if failed and len (failed ) == len (results ):
1345+ raise RuntimeError (
1346+ f"continue_generation failed on ALL { len (failed )} workers"
1347+ )
12851348
12861349 # -- Stats -------------------------------------------------------------
12871350
@@ -1505,12 +1568,9 @@ def _fork_on_guard(
15051568 Returns ``(host, port)`` of the forked service and records the entry
15061569 in ``_forked_services`` for cleanup.
15071570 """
1508- import requests
1509-
1510- resp = requests .post (
1571+ resp = self ._sync_client .post (
15111572 f"{ guard_addr } /alloc_ports" ,
15121573 json = {"count" : 1 },
1513- timeout = 30 ,
15141574 )
15151575 resp .raise_for_status ()
15161576 port_data = resp .json ()
@@ -1519,14 +1579,13 @@ def _fork_on_guard(
15191579
15201580 cmd = list (raw_cmd ) + ["--host" , host , "--port" , str (port )]
15211581
1522- resp = requests .post (
1582+ resp = self . _sync_client .post (
15231583 f"{ guard_addr } /fork" ,
15241584 json = {
15251585 "role" : role ,
15261586 "worker_index" : worker_index ,
15271587 "raw_cmd" : cmd ,
15281588 },
1529- timeout = 30 ,
15301589 )
15311590 resp .raise_for_status ()
15321591
@@ -1540,10 +1599,8 @@ def _fork_on_guard(
15401599 def _kill_forked_service (
15411600 self , guard_addr : str , role : str , worker_index : int
15421601 ) -> None :
1543- import requests
1544-
15451602 try :
1546- resp = requests .post (
1603+ resp = self . _sync_client .post (
15471604 f"{ guard_addr } /kill_forked_worker" ,
15481605 json = {"role" : role , "worker_index" : worker_index },
15491606 timeout = 10 ,
@@ -1557,7 +1614,7 @@ def _kill_forked_service(
15571614 worker_index ,
15581615 resp .text ,
15591616 )
1560- except requests . RequestException as exc :
1617+ except httpx . HTTPError as exc :
15611618 logger .error (
15621619 "Error killing forked service %s/%d: %s" , role , worker_index , exc
15631620 )
@@ -1571,11 +1628,9 @@ def _gateway_http_post(self, endpoint: str, payload: dict[str, Any]) -> None:
15711628 Raises ``RuntimeError`` on HTTP errors or connection failures so that
15721629 callers (e.g. ``pause()`` / ``resume()``) can detect and handle them.
15731630 """
1574- import requests
1575-
15761631 url = f"{ self ._gateway_addr } { endpoint } "
15771632 try :
1578- resp = requests .post (
1633+ resp = self . _sync_client .post (
15791634 url ,
15801635 json = payload ,
15811636 headers = {"Authorization" : f"Bearer { self .config .admin_api_key } " },
@@ -1585,7 +1640,7 @@ def _gateway_http_post(self, endpoint: str, payload: dict[str, Any]) -> None:
15851640 raise RuntimeError (
15861641 f"Gateway { endpoint } returned { resp .status_code } : { resp .text } "
15871642 )
1588- except requests . RequestException as exc :
1643+ except httpx . HTTPError as exc :
15891644 raise RuntimeError (f"Failed to POST { endpoint } : { exc } " ) from exc
15901645
15911646 async def _async_gateway_http_post (
@@ -1597,19 +1652,16 @@ async def _async_gateway_http_post(
15971652 callers (e.g. ``pause_generation()`` / ``continue_generation()``) can
15981653 detect and handle them.
15991654 """
1600- import httpx
1601-
16021655 url = f"{ self ._gateway_addr } { endpoint } "
16031656 try :
1604- async with httpx .AsyncClient (timeout = self .config .request_timeout ) as client :
1605- resp = await client .post (
1606- url ,
1607- json = payload ,
1608- headers = {"Authorization" : f"Bearer { self .config .admin_api_key } " },
1657+ resp = await self ._async_client .post (
1658+ url ,
1659+ json = payload ,
1660+ headers = {"Authorization" : f"Bearer { self .config .admin_api_key } " },
1661+ )
1662+ if resp .status_code >= 400 :
1663+ raise RuntimeError (
1664+ f"Gateway { endpoint } returned { resp .status_code } : { resp .text } "
16091665 )
1610- if resp .status_code >= 400 :
1611- raise RuntimeError (
1612- f"Gateway { endpoint } returned { resp .status_code } : { resp .text } "
1613- )
16141666 except httpx .HTTPError as exc :
16151667 raise RuntimeError (f"Failed to POST { endpoint } : { exc } " ) from exc
0 commit comments