Skip to content

Commit f1fc2d3

Browse files
authored
feat(proxy): add proxy gateway and online RL training mode (#947)
* feat(proxy): add proxy gateway and online RL training mode Add a lightweight proxy gateway that routes external user requests to session-pinned backend proxy workers, enabling online RL training where real users drive the rollout loop instead of scripted agents. Key changes: - New proxy gateway server with session routing, API-key-based auth, readiness queue for online-mode worker coordination, and session refresh with API key reuse - New _OnlineAgent that registers workers as ready and blocks until an external user completes a full session lifecycle - OpenAIProxyWorkflow gains mode="online" alongside inline/subproc - RolloutController.start_proxy_gateway() for gateway lifecycle - RLTrainer._init_proxy_for_online() with _EmptyDataLoader for online mode where data comes from external users - Fix export_trajectories race condition: session_id is now required in the request body (resolved via admin auth) so export always targets the correct session even after API key remapping - Worker start_session accepts caller-provided api_key for refresh/ reuse with conflict detection (409 for active, cleanup for finished) - Example scripts and documentation for the online RL workflow - 36 new tests covering gateway, worker, and online agent * fix
1 parent 8c64dcd commit f1fc2d3

27 files changed

Lines changed: 3506 additions & 89 deletions

areal/api/cli_args.py

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1378,17 +1378,18 @@ class OpenAIProxyConfig:
13781378
default="inline",
13791379
metadata={
13801380
"help": (
1381-
"OpenAI proxy mode: 'inline' (in-process) or 'subproc' (subprocess). "
1381+
"OpenAI proxy mode: 'inline' (in-process), 'subproc' (subprocess), "
1382+
"or 'online' (external user sessions for online RL training). "
13821383
"`inline` mode runs the provided agent workflow directly in the same process. "
1383-
"It can use the provided `base_url` and `http_client` to reduce overhead. "
1384-
"`subproc` mode launches a separate process to run the agent with `OPENAI_BASE_URL` environment variable, "
1385-
"which offers more flexible deployment options at the cost of larger overhead."
1384+
"`subproc` mode launches a separate process to run the agent. "
1385+
"`online` mode waits for external users to complete sessions via "
1386+
"the proxy gateway URL, enabling online RL training."
13861387
),
1387-
"choices": ["inline", "subproc"],
1388+
"choices": ["inline", "subproc", "online"],
13881389
},
13891390
)
13901391
tool_call_parser: str = field(
1391-
default="qwen3",
1392+
default="qwen",
13921393
metadata={"help": "Parser for tool calls in model output."},
13931394
)
13941395
reasoning_parser: str = field(

areal/engine/sglang_remote.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -271,6 +271,9 @@ def get_version(self) -> int:
271271
"""Get the current weight version."""
272272
return self._engine.get_version()
273273

274+
def set_proxy_gateway_addr(self, addr: str) -> None:
275+
return self._engine.set_proxy_gateway_addr(addr)
276+
274277
async def agenerate(self, req: ModelRequest) -> ModelResponse:
275278
"""Asynchronously generate a response for the given request."""
276279
return await self._engine.agenerate(req)

areal/engine/vllm_remote.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -315,6 +315,9 @@ def get_version(self) -> int:
315315
"""Get the current weight version."""
316316
return self._engine.get_version()
317317

318+
def set_proxy_gateway_addr(self, addr: str) -> None:
319+
self._engine.set_proxy_gateway_addr(addr)
320+
318321
async def agenerate(self, req: ModelRequest) -> ModelResponse:
319322
"""Asynchronously generate a response for the given request."""
320323
return await self._engine.agenerate(req)

areal/experimental/openai/client.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1042,7 +1042,7 @@ def __init__(
10421042
self,
10431043
engine: TRolloutEngine,
10441044
tokenizer: "PreTrainedTokenizerFast",
1045-
tool_call_parser: str = "qwen3",
1045+
tool_call_parser: str = "qwen",
10461046
reasoning_parser: str = "qwen3",
10471047
engine_max_tokens: int | None = None,
10481048
chat_template_type: str = "hf",

areal/experimental/openai/proxy/client_session.py

Lines changed: 17 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -139,8 +139,14 @@ async def export_interactions(
139139
"""Export interactions for this session via HTTP.
140140
141141
This method should be called after the session context exits
142-
(i.e., after `__aexit__` has ended the RL session), since
143-
`/export_trajectories` waits for the session to finish.
142+
(i.e., after ``__aexit__`` has ended the RL session), since
143+
``/export_trajectories`` waits for the session to finish.
144+
145+
The request always includes the explicit ``session_id`` so that
146+
the server resolves the correct session regardless of any
147+
API-key-to-session remapping that may have occurred during a
148+
refresh cycle. Admin auth is used because the session key is
149+
not guaranteed to still map to this session.
144150
145151
Parameters
146152
----------
@@ -153,17 +159,22 @@ async def export_interactions(
153159
-------
154160
dict[str, InteractionWithTokenLogpReward]
155161
Dictionary mapping interaction IDs to their data
156-
"""
157162
158-
if self._session_api_key is None:
159-
raise ValueError("Session API key is not set")
163+
Raises
164+
------
165+
ValueError
166+
If ``session_id`` has not been set on this client.
167+
"""
168+
if self.session_id is None:
169+
raise ValueError("session_id must be set before exporting interactions")
160170

161171
url = f"{self.base_url}{EXPORT_TRAJECTORIES_PATHNAME}"
162172
payload = {
173+
"session_id": self.session_id,
163174
"discount": discount,
164175
"style": style,
165176
}
166-
headers = self._session_auth_headers()
177+
headers = self._admin_auth_headers()
167178
async with self._session.post(url, json=payload, headers=headers) as resp:
168179
resp.raise_for_status()
169180
data = await resp.json()
Lines changed: 86 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,86 @@
1+
"""Internal agent for online training mode.
2+
3+
``_OnlineAgent`` is NOT intended for direct use. It is automatically
4+
created by ``OpenAIProxyWorkflow`` when ``mode="online"``. It registers
5+
the assigned proxy worker as "ready" on the proxy gateway, then blocks
6+
until an external user completes a full session lifecycle.
7+
"""
8+
9+
from __future__ import annotations
10+
11+
from typing import Any
12+
13+
import aiohttp
14+
15+
from areal.infra import workflow_context
16+
17+
from .proxy_gateway import CompletedSessionInfo
18+
from .server import INTERNAL_WAIT_FOR_SESSION_PATHNAME
19+
20+
21+
class _OnlineAgent:
22+
"""Internal agent that waits for external user sessions.
23+
24+
Registers the assigned proxy worker as "ready" on the proxy
25+
gateway, then blocks until an external user completes a full session
26+
lifecycle (start_session → interact → set_reward → end_session) on
27+
that worker.
28+
29+
Parameters
30+
----------
31+
proxy_gateway_addr : str
32+
HTTP address of the proxy gateway server.
33+
admin_api_key : str
34+
Admin API key for authenticating with the proxy gateway.
35+
timeout : float
36+
Maximum seconds to wait for a session completion.
37+
"""
38+
39+
def __init__(
40+
self,
41+
proxy_gateway_addr: str,
42+
admin_api_key: str,
43+
timeout: float = 3600.0,
44+
):
45+
self.proxy_gateway_addr = proxy_gateway_addr
46+
self.admin_api_key = admin_api_key
47+
self.timeout = timeout
48+
49+
async def run(
50+
self, data: dict[str, Any], **extra_kwargs: Any
51+
) -> CompletedSessionInfo:
52+
"""Wait for an external user to complete a session.
53+
54+
Parameters
55+
----------
56+
data : dict
57+
Ignored in online mode (dataloader yields empty dicts).
58+
extra_kwargs : dict
59+
Provided by ``OpenAIProxyWorkflow``:
60+
61+
- ``base_url``: proxy worker address
62+
- ``api_key``: admin API key
63+
- ``http_client``: ``httpx.AsyncClient`` (unused here)
64+
65+
Returns
66+
-------
67+
CompletedSessionInfo
68+
Session credentials for trajectory export.
69+
"""
70+
base_url = extra_kwargs["base_url"] # proxy worker addr
71+
72+
url = f"{self.proxy_gateway_addr}/{INTERNAL_WAIT_FOR_SESSION_PATHNAME}"
73+
headers = {"Authorization": f"Bearer {self.admin_api_key}"}
74+
payload = {"worker_addr": base_url}
75+
76+
timeout = aiohttp.ClientTimeout(total=self.timeout)
77+
session = await workflow_context.get_aiohttp_session()
78+
async with session.post(
79+
url,
80+
json=payload,
81+
headers=headers,
82+
timeout=timeout,
83+
) as resp:
84+
resp.raise_for_status()
85+
result = await resp.json()
86+
return CompletedSessionInfo(**result)

0 commit comments

Comments
 (0)