Skip to content
2 changes: 1 addition & 1 deletion agent/Dockerfile
Original file line number Diff line number Diff line change
Expand Up @@ -62,4 +62,4 @@ USER nobody
EXPOSE 8090
EXPOSE 50061

CMD ["bash", "-lc", "uvicorn ai.web:create_app --factory --host 0.0.0.0 --port ${APP_PORT:-8090}"]
CMD ["bash", "-lc", "uvicorn ai.web:create_app --factory --host 0.0.0.0 --port ${APP_PORT:-8090} --timeout-graceful-shutdown ${UVICORN_GRACEFUL_SHUTDOWN_TIMEOUT:-310}"]
4 changes: 2 additions & 2 deletions agent/src/ai/agent.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,11 @@
import os
from dataclasses import dataclass, field
from pathlib import Path
from typing import Any, Literal

from pydantic_ai import Agent, RunContext
from pydantic_ai.models.test import TestModel

from ai.config import config
from ai.models import CanvasAnswer, CanvasQuestionRequest, CanvasSummary
from ai.patterns import get_decision_pattern as get_markdown_pattern
from ai.patterns import list_decision_patterns as list_markdown_patterns
Expand Down Expand Up @@ -91,7 +91,7 @@ def build_agent(model: str | Literal["test"] = "test") -> Agent[AgentDeps, Canva
)

def _tool_debug(message: str) -> None:
if os.getenv("REPL_WEB_DEBUG", "").strip().lower() in {"1", "true", "yes", "on"}:
if config.debug:
print(f"[web][agent] {message}", flush=True)

def _tool_error_entry(tool_name: str, error: Exception) -> dict[str, Any]:
Expand Down
66 changes: 66 additions & 0 deletions agent/src/ai/config.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,66 @@
import os


class Config:
def __init__(self) -> None:
self.ai_model: str = self._parse_str("AI_MODEL", default="test")
self.debug: bool = self._parse_bool("REPL_WEB_DEBUG")
self.cors_origins: str = self._parse_str("REPL_WEB_CORS_ORIGINS", default="*")

self.superplane_base_url: str = self._parse_str("SUPERPLANE_BASE_URL")
self.superplane_user_agent: str = self._parse_str("SUPERPLANE_USER_AGENT", default="curl/8.7.1")

self.drain_timeout: float = self._parse_float(
"DRAIN_TIMEOUT", lower=0, upper=1000, default=300.0,
)

self.db_host: str = self._parse_str("DB_HOST", default="db")
self.db_port: int = self._parse_int("DB_PORT", lower=1, upper=65535, default=5432)
self.db_name: str = self._parse_str("DB_NAME")
self.db_username: str = self._parse_str("DB_USERNAME")
self.db_password: str = self._parse_str("DB_PASSWORD")
self.db_sslmode: str = self._parse_str("DB_SSLMODE", default="disable")
self.application_name: str = self._parse_str("APPLICATION_NAME", default="superplane-agent")

self.jwt_secret: str = self._parse_str("JWT_SECRET")

self.grpc_host: str = self._parse_str("INTERNAL_GRPC_HOST", default="0.0.0.0")
self.grpc_port: int = self._parse_int("INTERNAL_GRPC_PORT", lower=1, upper=65535, default=50061)

self.pattern_dir: str = self._parse_str("AGENT_PATTERN_DIR")

@staticmethod
def _parse_float(env_name: str, *, lower: float, upper: float, default: float) -> float:
raw = os.getenv(env_name, "").strip()
if not raw:
return default
try:
value = float(raw)
except ValueError:
return default
return max(lower, min(value, upper))

@staticmethod
def _parse_int(env_name: str, *, lower: int, upper: int, default: int) -> int:
raw = os.getenv(env_name, "").strip()
if not raw:
return default
try:
value = int(raw)
except ValueError:
return default
return max(lower, min(value, upper))

@staticmethod
def _parse_bool(env_name: str, *, default: bool = False) -> bool:
raw = os.getenv(env_name, "").strip().lower()
if not raw:
return default
return raw in {"1", "true", "yes", "on"}

@staticmethod
def _parse_str(env_name: str, *, default: str = "") -> str:
return os.getenv(env_name, "").strip() or default


config = Config()
6 changes: 2 additions & 4 deletions agent/src/ai/grpc.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
import os
import threading
import uuid

Expand All @@ -8,6 +7,7 @@
import grpc
from google.protobuf.timestamp_pb2 import Timestamp

from ai.config import config
from ai.session_store import AgentChatNotFoundError, SessionStore, StoredAgentChat, StoredAgentChatMessage
from private import agents_pb2

Expand Down Expand Up @@ -133,9 +133,7 @@ def __init__(self, config: AgentServiceConfig, store: SessionStore) -> None:

@classmethod
def from_env(cls, store: SessionStore) -> "InternalAgentServer":
host = os.getenv("INTERNAL_GRPC_HOST", "0.0.0.0").strip() or "0.0.0.0"
port = int(os.getenv("INTERNAL_GRPC_PORT", "50061"))
return cls(AgentServiceConfig(host=host, port=port), store)
return cls(AgentServiceConfig(host=config.grpc_host, port=config.grpc_port), store)

def start(self) -> None:
if self._thread is not None and self._thread.is_alive():
Expand Down
5 changes: 2 additions & 3 deletions agent/src/ai/jwt.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,7 @@
import os
import jwt
from dataclasses import dataclass


from ai.config import config
from ai.text import normalize_optional


Expand All @@ -20,7 +19,7 @@ def __init__(self, jwt_secret: str, audience: str = "superplane_api") -> None:

@classmethod
def from_env(cls) -> "JwtValidator":
jwt_secret = normalize_optional(os.getenv("JWT_SECRET"))
jwt_secret = normalize_optional(config.jwt_secret)
if jwt_secret is None:
raise ValueError("Missing required setting: JWT_SECRET")
return cls(jwt_secret=jwt_secret)
Expand Down
11 changes: 4 additions & 7 deletions agent/src/ai/patterns.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,11 @@
from __future__ import annotations

import os
import re
from dataclasses import dataclass
from pathlib import Path

from ai.config import config

_TOKEN_RE = re.compile(r"[a-z0-9]+")
_KEYWORDS_PREFIX = "keywords:"

Expand All @@ -19,12 +20,8 @@ class DecisionPattern:


def _resolve_pattern_dir() -> Path:
env_value = os.getenv("AGENT_PATTERN_DIR", "").strip()
if env_value:
env_dir = Path(env_value).expanduser()
return env_dir

# Default to <repo>/agent/patterns
if config.pattern_dir:
return Path(config.pattern_dir).expanduser()
return Path(__file__).resolve().parents[2] / "patterns"


Expand Down
83 changes: 49 additions & 34 deletions agent/src/ai/repl_web.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
import asyncio
import json
import os
import threading
import time
from collections.abc import AsyncIterator
Expand Down Expand Up @@ -29,12 +28,14 @@
from pydantic_ai.run import AgentRunResultEvent

from ai.agent import AgentDeps, build_agent
from ai.config import config
from ai.grpc import InternalAgentServer
from ai.models import CanvasAnswer
from ai.jwt import JwtClaims, JwtValidator
from ai.persisted_run_recorder import PersistedRunRecorder
from ai.session_store import AgentChatNotFoundError, SessionStore, StoredAgentChat
from ai.proposal_configuration_coerce import coerce_canvas_answer_proposal
from ai.stream_tracker import ActiveStreamTracker
from ai.superplane_client import SuperplaneClient, SuperplaneClientConfig
from ai.text import normalize_optional

Expand All @@ -48,15 +49,15 @@ class WebServerConfig:
class ReplStreamRequest(BaseModel):
question: str = Field(min_length=1, max_length=2000)
model: str = Field(
default=(os.getenv("AI_MODEL", "test").strip() or "test"),
default=config.ai_model,
min_length=1,
max_length=200,
)
base_url: str | None = None


def _debug_enabled() -> bool:
return os.getenv("REPL_WEB_DEBUG", "").strip().lower() in {"1", "true", "yes", "on"}
return config.debug


def _debug_log(message: str, **fields: Any) -> None:
Expand All @@ -69,8 +70,8 @@ def _debug_log(message: str, **fields: Any) -> None:
print(f"[web] {message}", flush=True)


def _resolve_required(value: str | None, env_name: str) -> str:
resolved = normalize_optional(value) or normalize_optional(os.getenv(env_name))
def _resolve_required(value: str | None, fallback: str | None, env_name: str) -> str:
resolved = normalize_optional(value) or normalize_optional(fallback)
if resolved is None:
raise ValueError(f"Missing required setting: {env_name}")
return resolved
Expand Down Expand Up @@ -143,7 +144,7 @@ def _resolve_agent_context(chat_id: str, request: Request) -> tuple[JwtClaims, S


def _build_deps(payload: ReplStreamRequest, request: Request, claims: JwtClaims, canvas_id: str) -> AgentDeps:
base_url = _resolve_required(payload.base_url, "SUPERPLANE_BASE_URL")
base_url = _resolve_required(payload.base_url, config.superplane_base_url, "SUPERPLANE_BASE_URL")
api_token = _resolve_required_bearer_token(request)
client = SuperplaneClient(
SuperplaneClientConfig(
Expand Down Expand Up @@ -413,19 +414,22 @@ def _create_app() -> FastAPI:
@asynccontextmanager
async def lifespan(app: FastAPI) -> AsyncIterator[None]:
store = SessionStore()
tracker = ActiveStreamTracker()
app.state.session_store = store
app.state.stream_tracker = tracker
grpc_server = InternalAgentServer.from_env(store)
grpc_server.start()
app.state.internal_agent_server = grpc_server
try:
yield
finally:
tracker.begin_shutdown()
await tracker.wait_for_drain()
grpc_server.stop()
store.close()

app = FastAPI(lifespan=lifespan)
cors_origins_raw = os.getenv("REPL_WEB_CORS_ORIGINS", "*")
cors_origins = [origin.strip() for origin in cors_origins_raw.split(",") if origin.strip()]
cors_origins = [origin.strip() for origin in config.cors_origins.split(",") if origin.strip()]
if not cors_origins:
cors_origins = ["*"]

Expand All @@ -438,42 +442,53 @@ async def lifespan(app: FastAPI) -> AsyncIterator[None]:

@app.post("/agents/chats/{chat_id}/stream")
async def stream_repl(chat_id: str, payload: ReplStreamRequest, request: Request) -> StreamingResponse:
tracker: ActiveStreamTracker = request.app.state.stream_tracker
if tracker.is_shutting_down:
raise HTTPException(status_code=503, detail="Service is shutting down")

if payload.model != "test" and _resolve_bearer_token(request) is None:
raise HTTPException(status_code=401, detail="Authorization header is required")

_debug_log(
"incoming stream request",
chat_id=chat_id,
model=payload.model,
has_base_url=bool(normalize_optional(payload.base_url) or normalize_optional(os.getenv("SUPERPLANE_BASE_URL"))),
has_base_url=bool(normalize_optional(payload.base_url) or normalize_optional(config.superplane_base_url)),
has_token=bool(_resolve_bearer_token(request)),
)

async def event_generator() -> AsyncIterator[str]:
try:
async for event in _stream_agent_run(chat_id, payload, request):
if await request.is_disconnected():
_debug_log("client disconnected", chat_id=chat_id)
break
yield _encode_sse_event(event)
except Exception as error:
_debug_log("stream failed", chat_id=chat_id, error=str(error))
yield _encode_sse_event(
{
"type": "run_failed",
"error": str(error),
}
)
yield _encode_sse_event({"type": "done"})

return StreamingResponse(
event_generator(),
media_type="text/event-stream",
headers={
"cache-control": "no-cache",
"connection": "keep-alive",
},
)
await tracker.acquire()
try:
async def event_generator() -> AsyncIterator[str]:
try:
async for event in _stream_agent_run(chat_id, payload, request):
if await request.is_disconnected():
_debug_log("client disconnected", chat_id=chat_id)
break
yield _encode_sse_event(event)
except Exception as error:
_debug_log("stream failed", chat_id=chat_id, error=str(error))
yield _encode_sse_event(
{
"type": "run_failed",
"error": str(error),
}
)
yield _encode_sse_event({"type": "done"})
finally:
await tracker.release()

return StreamingResponse(
event_generator(),
media_type="text/event-stream",
headers={
"cache-control": "no-cache",
"connection": "keep-alive",
},
)
except BaseException:
await tracker.release()
raise

return app

Expand Down
31 changes: 12 additions & 19 deletions agent/src/ai/session_store.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
import json
import os
import threading
import uuid
from collections.abc import Iterator
Expand All @@ -10,6 +9,8 @@

import psycopg
from psycopg.rows import dict_row

from ai.config import config
from psycopg.types.json import Jsonb
from pydantic_ai.messages import (
ModelMessage,
Expand Down Expand Up @@ -164,20 +165,12 @@ class SessionStoreConfig:

@classmethod
def from_env(cls) -> "SessionStoreConfig":
host = (os.getenv("DB_HOST") or "db").strip()
port = int((os.getenv("DB_PORT") or "5432").strip())
dbname = (os.getenv("DB_NAME") or "").strip()
user = (os.getenv("DB_USERNAME") or "").strip()
password = (os.getenv("DB_PASSWORD") or "").strip()
sslmode = (os.getenv("DB_SSLMODE") or "disable").strip() or "disable"
application_name = (os.getenv("APPLICATION_NAME") or "superplane-agent").strip() or "superplane-agent"

missing_fields = [
name
for name, value in (
("DB_NAME", dbname),
("DB_USERNAME", user),
("DB_PASSWORD", password),
("DB_NAME", config.db_name),
("DB_USERNAME", config.db_username),
("DB_PASSWORD", config.db_password),
)
if not value
]
Expand All @@ -186,13 +179,13 @@ def from_env(cls) -> "SessionStoreConfig":
raise ValueError(f"Missing required agent database settings: {joined}")

return cls(
host=host,
port=port,
dbname=dbname,
user=user,
password=password,
sslmode=sslmode,
application_name=application_name,
host=config.db_host,
port=config.db_port,
dbname=config.db_name,
user=config.db_username,
password=config.db_password,
sslmode=config.db_sslmode,
application_name=config.application_name,
)


Expand Down
Loading
Loading