Skip to content
2 changes: 1 addition & 1 deletion agent/Dockerfile
Original file line number Diff line number Diff line change
Expand Up @@ -62,4 +62,4 @@ USER nobody
EXPOSE 8090
EXPOSE 50061

CMD ["bash", "-lc", "uvicorn ai.web:create_app --factory --host 0.0.0.0 --port ${APP_PORT:-8090}"]
CMD ["bash", "-lc", "uvicorn ai.web:create_app --factory --host 0.0.0.0 --port ${APP_PORT:-8090} --timeout-graceful-shutdown ${UVICORN_GRACEFUL_SHUTDOWN_TIMEOUT:-310}"]
128 changes: 113 additions & 15 deletions agent/src/ai/repl_web.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,95 @@
from ai.text import normalize_optional


_DEFAULT_DRAIN_TIMEOUT = 300.0
_DRAIN_LOG_INTERVAL = 5.0


def _resolve_drain_timeout() -> float:
raw = os.getenv("DRAIN_TIMEOUT", "").strip()
if not raw:
return _DEFAULT_DRAIN_TIMEOUT
try:
return max(float(raw), 0.0)
except ValueError:
return _DEFAULT_DRAIN_TIMEOUT


class ActiveStreamTracker:
"""Tracks in-flight SSE streams so the service can drain them before shutting down."""

def __init__(self, drain_timeout: float | None = None) -> None:
self._active_count = 0
self._lock = asyncio.Lock()
self._drained = asyncio.Event()
self._drained.set()
self._shutting_down = False
self._drain_timeout = (
drain_timeout if drain_timeout is not None else _resolve_drain_timeout()
)

@property
def is_shutting_down(self) -> bool:
return self._shutting_down

@property
def active_count(self) -> int:
return self._active_count

@asynccontextmanager
async def track_stream(self) -> AsyncIterator[None]:
async with self._lock:
self._active_count += 1
self._drained.clear()
try:
yield
finally:
async with self._lock:
self._active_count -= 1
if self._active_count == 0:
self._drained.set()

def begin_shutdown(self) -> None:
self._shutting_down = True

async def wait_for_drain(self) -> None:
if self._active_count == 0:
print("[web] graceful shutdown: no active streams, proceeding with cleanup", flush=True)
return

print(
f"[web] graceful shutdown: waiting for {self._active_count} active stream(s) to finish"
f" (timeout={self._drain_timeout}s)...",
flush=True,
)
try:
await asyncio.wait_for(self._drain_with_logging(), timeout=self._drain_timeout)
print(
"[web] graceful shutdown: all streams finished, proceeding with cleanup",
flush=True,
)
except TimeoutError:
remaining = self._active_count
print(
f"[web] graceful shutdown: drain timeout reached"
f" with {remaining} stream(s) still active,"
" forcing shutdown",
flush=True,
)

async def _drain_with_logging(self) -> None:
while not self._drained.is_set():
try:
await asyncio.wait_for(self._drained.wait(), timeout=_DRAIN_LOG_INTERVAL)
except TimeoutError:
remaining = self._active_count
print(
f"[web] graceful shutdown: waiting for"
f" {remaining} active stream(s) to finish...",
flush=True,
)


@dataclass(frozen=True)
class WebServerConfig:
host: str = "127.0.0.1"
Expand Down Expand Up @@ -413,13 +502,17 @@ def _create_app() -> FastAPI:
@asynccontextmanager
async def lifespan(app: FastAPI) -> AsyncIterator[None]:
store = SessionStore()
tracker = ActiveStreamTracker()
app.state.session_store = store
app.state.stream_tracker = tracker
grpc_server = InternalAgentServer.from_env(store)
grpc_server.start()
app.state.internal_agent_server = grpc_server
try:
yield
finally:
tracker.begin_shutdown()
await tracker.wait_for_drain()
grpc_server.stop()
store.close()

Expand All @@ -438,6 +531,10 @@ async def lifespan(app: FastAPI) -> AsyncIterator[None]:

@app.post("/agents/chats/{chat_id}/stream")
async def stream_repl(chat_id: str, payload: ReplStreamRequest, request: Request) -> StreamingResponse:
tracker: ActiveStreamTracker = request.app.state.stream_tracker
if tracker.is_shutting_down:
raise HTTPException(status_code=503, detail="Service is shutting down")

if payload.model != "test" and _resolve_bearer_token(request) is None:
raise HTTPException(status_code=401, detail="Authorization header is required")

Expand All @@ -450,21 +547,22 @@ async def stream_repl(chat_id: str, payload: ReplStreamRequest, request: Request
)

async def event_generator() -> AsyncIterator[str]:
try:
async for event in _stream_agent_run(chat_id, payload, request):
if await request.is_disconnected():
_debug_log("client disconnected", chat_id=chat_id)
break
yield _encode_sse_event(event)
except Exception as error:
_debug_log("stream failed", chat_id=chat_id, error=str(error))
yield _encode_sse_event(
{
"type": "run_failed",
"error": str(error),
}
)
yield _encode_sse_event({"type": "done"})
async with tracker.track_stream():
try:
async for event in _stream_agent_run(chat_id, payload, request):
if await request.is_disconnected():
_debug_log("client disconnected", chat_id=chat_id)
break
yield _encode_sse_event(event)
except Exception as error:
_debug_log("stream failed", chat_id=chat_id, error=str(error))
yield _encode_sse_event(
{
"type": "run_failed",
"error": str(error),
}
)
yield _encode_sse_event({"type": "done"})

return StreamingResponse(
event_generator(),
Expand Down
78 changes: 78 additions & 0 deletions agent/tests/test_repl_web.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,84 @@ def _stub_agent_persistence(monkeypatch: pytest.MonkeyPatch) -> None:
monkeypatch.setattr(repl_web.InternalAgentServer, "from_env", MagicMock(return_value=fake_grpc_server))


class TestActiveStreamTracker:
def test_track_stream_increments_and_decrements(self) -> None:
async def run() -> None:
tracker = repl_web.ActiveStreamTracker(drain_timeout=5.0)
assert tracker.active_count == 0

async with tracker.track_stream():
assert tracker.active_count == 1
async with tracker.track_stream():
assert tracker.active_count == 2
assert tracker.active_count == 1
assert tracker.active_count == 0

asyncio.run(run())

def test_begin_shutdown_sets_flag(self) -> None:
tracker = repl_web.ActiveStreamTracker(drain_timeout=5.0)
assert not tracker.is_shutting_down
tracker.begin_shutdown()
assert tracker.is_shutting_down

def test_wait_for_drain_returns_immediately_when_no_streams(self) -> None:
async def run() -> None:
tracker = repl_web.ActiveStreamTracker(drain_timeout=5.0)
await tracker.wait_for_drain()

asyncio.run(run())

def test_wait_for_drain_waits_for_active_streams(self) -> None:
async def run() -> None:
tracker = repl_web.ActiveStreamTracker(drain_timeout=5.0)
finished_order: list[str] = []

async def simulate_stream(delay: float, label: str) -> None:
async with tracker.track_stream():
await asyncio.sleep(delay)
finished_order.append(label)

stream1 = asyncio.create_task(simulate_stream(0.1, "fast"))
stream2 = asyncio.create_task(simulate_stream(0.3, "slow"))

await asyncio.sleep(0.01)
assert tracker.active_count == 2

tracker.begin_shutdown()
await tracker.wait_for_drain()

assert tracker.active_count == 0
assert "fast" in finished_order
assert "slow" in finished_order

asyncio.run(run())

def test_wait_for_drain_respects_timeout(self) -> None:
async def run() -> None:
tracker = repl_web.ActiveStreamTracker(drain_timeout=0.2)

async def stuck_stream() -> None:
async with tracker.track_stream():
await asyncio.sleep(10.0)

task = asyncio.create_task(stuck_stream())
await asyncio.sleep(0.01)
assert tracker.active_count == 1

tracker.begin_shutdown()
await tracker.wait_for_drain()

assert tracker.active_count == 1
task.cancel()
try:
await task
except asyncio.CancelledError:
pass

asyncio.run(run())


def test_parse_stream_event_accepts_valid_sse_line() -> None:
event = _parse_stream_event(b'data: {"type":"model_delta","content":"hello"}\n')
assert event == {"type": "model_delta", "content": "hello"}
Expand Down
2 changes: 2 additions & 0 deletions docker-compose.dev.yml
Original file line number Diff line number Diff line change
Expand Up @@ -139,9 +139,11 @@ services:
--factory
--host 0.0.0.0
--port 8090
--timeout-graceful-shutdown 310
--reload
--reload-dir /app/agent/src
--reload-dir /app/agent/evals"
stop_grace_period: 330s
tty: true
stdin_open: true
ports:
Expand Down
Loading