Skip to content

Commit 03f3f3c

Browse files
committed
fix
1 parent 9004f25 commit 03f3f3c

4 files changed

Lines changed: 44 additions & 21 deletions

File tree

areal/experimental/openai/proxy/proxy_gateway.py

Lines changed: 15 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -365,20 +365,22 @@ async def start_session(request: Request):
365365
"reuse" if requested_key else "new",
366366
)
367367

368+
# Reject if a refresh is already in flight for this key.
369+
# Must be checked BEFORE `in routes` since refresh pops the route.
370+
if requested_key and requested_key in _refreshing:
371+
return Response(
372+
status_code=429,
373+
content=json.dumps(
374+
{"detail": "A refresh is already in progress for this key."}
375+
).encode(),
376+
)
377+
368378
# ---- REFRESH PATH ----
369379
# Known key with an active route → end old session, wait for
370380
# the training pipeline to cycle, start a new session.
371381
if requested_key and requested_key in known_keys and requested_key in routes:
372-
# Reject concurrent refresh for the same key.
373-
if requested_key in _refreshing:
374-
return Response(
375-
status_code=429,
376-
content=json.dumps(
377-
{"detail": "A refresh is already in progress for this key."}
378-
).encode(),
379-
)
380-
381382
_refreshing.add(requested_key)
383+
ready_entry: _ReadyWorkerEntry | None = None
382384
try:
383385
old_route = routes.pop(requested_key)
384386
logger.info(
@@ -391,7 +393,7 @@ async def start_session(request: Request):
391393

392394
# Skip stale ready entries within the deadline.
393395
deadline = asyncio.get_running_loop().time() + refresh_timeout
394-
ready_entry: _ReadyWorkerEntry | None = None
396+
ready_entry = None
395397
while True:
396398
remaining = deadline - asyncio.get_running_loop().time()
397399
if remaining <= 0:
@@ -465,6 +467,9 @@ async def start_session(request: Request):
465467
except Exception:
466468
known_keys.pop(requested_key, None)
467469
_reject_future(old_route.pending_future, "Refresh failed unexpectedly")
470+
# Also settle the new worker's future if we consumed one
471+
if ready_entry is not None:
472+
_reject_future(ready_entry.future, "Refresh failed unexpectedly")
468473
raise
469474
finally:
470475
_refreshing.discard(requested_key)

areal/trainer/rl_trainer.py

Lines changed: 25 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -60,16 +60,17 @@ class _EmptyDataLoader:
6060
"""Minimal dataloader for online mode that yields empty dicts.
6161
6262
Compatible with ``cycle_dataloader()`` and ``len()`` expectations.
63-
Each "epoch" produces a single batch of ``batch_size`` empty dicts,
64-
so the training loop collects the correct number of trajectories
65-
before proceeding to a train step.
63+
``steps_per_epoch`` controls how many steps constitute one epoch,
64+
derived from ``total_train_steps // total_train_epochs`` to ensure
65+
epoch-frequency-gated components (Saver, RecoverHandler) behave correctly.
6666
"""
6767

68-
def __init__(self, batch_size: int = 1):
68+
def __init__(self, batch_size: int = 1, steps_per_epoch: int = 1):
6969
self.batch_size = batch_size
70+
self._steps_per_epoch = steps_per_epoch
7071

7172
def __len__(self) -> int:
72-
return 1 # 1 step per "epoch" for online mode
73+
return self._steps_per_epoch
7374

7475
def __iter__(self):
7576
while True:
@@ -123,9 +124,26 @@ def __init__(
123124
self.train_dataset = train_dataset
124125
self.valid_dataset = valid_dataset
125126
if train_dataset is None:
126-
# Online mode: use empty data generator
127+
# Online mode: require total_train_steps to compute steps_per_epoch.
128+
# Without this, __len__()=1 causes every step to be treated as an
129+
# epoch boundary, making Saver/RecoverHandler fire every step and
130+
# corrupting the LR schedule.
131+
if config.total_train_steps is None:
132+
raise ValueError(
133+
"total_train_steps must be set for online mode "
134+
"(train_dataset is None). Both total_train_epochs and "
135+
"total_train_steps are needed to compute steps_per_epoch."
136+
)
137+
steps_per_epoch = config.total_train_steps // config.total_train_epochs
138+
if steps_per_epoch < 1:
139+
raise ValueError(
140+
f"total_train_steps ({config.total_train_steps}) must be >= "
141+
f"total_train_epochs ({config.total_train_epochs}) so that "
142+
f"steps_per_epoch >= 1."
143+
)
127144
self.train_dataloader = _EmptyDataLoader(
128-
batch_size=config.train_dataset.batch_size
145+
batch_size=config.train_dataset.batch_size,
146+
steps_per_epoch=steps_per_epoch,
129147
)
130148
else:
131149
self.train_dataloader = self._create_dataloader(

examples/online_rl/config.yaml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@ trial_name: trial0
44
seed: 1
55
enable_offload: false
66
total_train_epochs: 10
7+
total_trian_steps: 100
78
tokenizer_path: ${actor.path}
89

910
cluster:

tests/experimental/openai/test_proxy_gateway.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -678,10 +678,9 @@ async def test_concurrent_refresh_same_key_returns_429(self):
678678
headers=_admin_headers(),
679679
json={"task_id": "t", "api_key": "k1"},
680680
)
681-
# The first refresh already popped the route, so the second
682-
# request falls through to round-robin (no mock → 500) or
683-
# hits the sentinel (429) if the route was still present.
684-
assert resp2.status_code in (429, 500)
681+
# The _refreshing guard now rejects concurrent refreshes for the
682+
# same key before checking `routes`, so this reliably returns 429.
683+
assert resp2.status_code == 429
685684

686685
# Clean up.
687686
refresh1.cancel()

0 commit comments

Comments
 (0)