Skip to content

Commit e1709fd

Browse files
authored
Merge pull request #26 from bertheto/feat/image-upload-hardening
feat(security): harden image upload — extension allowlist, magic bytes, size cap
2 parents 66abcc6 + ac38af5 commit e1709fd

6 files changed

Lines changed: 510 additions & 64 deletions

File tree

src/config.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -60,6 +60,8 @@
6060
# Expose per-thread resources in MCP server (default: false for cleaner MCP client UI)
6161
# When enabled, each thread gets transcript, summary (if closed), and state resources
6262
EXPOSE_THREAD_RESOURCES = os.getenv("AGENTCHATBUS_EXPOSE_THREAD_RESOURCES", "false").lower() in {"1", "true", "yes"}
63+
# Admin token for settings endpoint (optional — if unset, PUT /api/settings is unprotected)
64+
ADMIN_TOKEN: str | None = os.getenv("AGENTCHATBUS_ADMIN_TOKEN")
6365

6466
def get_config_dict():
6567
return {

src/db/crud.py

Lines changed: 27 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -157,15 +157,15 @@ async def thread_create(
157157

158158
try:
159159
await db.execute(
160-
"INSERT INTO threads (id, topic, status, created_at, updated_at, metadata, system_prompt, template_id) VALUES (?, ?, 'discuss', ?, ?, ?, ?, ?)",
161-
(tid, topic, now, now, meta_json, system_prompt, template_id),
160+
"INSERT INTO threads (id, topic, status, created_at, metadata, system_prompt, template_id) VALUES (?, ?, 'discuss', ?, ?, ?, ?)",
161+
(tid, topic, now, meta_json, system_prompt, template_id),
162162
)
163163
await db.commit()
164164
await _emit_event(db, "thread.new", tid, {"thread_id": tid, "topic": topic})
165165
logger.info(f"Thread created: {tid} '{topic}'")
166166
return Thread(id=tid, topic=topic, status="discuss", created_at=_parse_dt(now),
167-
updated_at=_parse_dt(now), closed_at=None, summary=None, metadata=meta_json,
168-
system_prompt=system_prompt, template_id=template_id)
167+
closed_at=None, summary=None, metadata=meta_json, system_prompt=system_prompt,
168+
template_id=template_id)
169169
except sqlite3.IntegrityError as e:
170170
# UNIQUE constraint violation on threads.topic — another thread was created concurrently
171171
# Fetch and return the existing thread for idempotency
@@ -197,23 +197,21 @@ async def thread_list(
197197
status: Optional[str] = None,
198198
include_archived: bool = False,
199199
) -> list[Thread]:
200-
# Order by updated_at DESC (most recent activity first), fallback to created_at
201-
order_by = "ORDER BY COALESCE(updated_at, created_at) DESC"
202200
if status:
203201
async with db.execute(
204-
f"SELECT * FROM threads WHERE status = ? {order_by}",
202+
"SELECT * FROM threads WHERE status = ? ORDER BY created_at DESC",
205203
(status,),
206204
) as cur:
207205
rows = await cur.fetchall()
208206
return [_row_to_thread(r) for r in rows]
209207

210208
if include_archived:
211-
async with db.execute(f"SELECT * FROM threads {order_by}") as cur:
209+
async with db.execute("SELECT * FROM threads ORDER BY created_at DESC") as cur:
212210
rows = await cur.fetchall()
213211
return [_row_to_thread(r) for r in rows]
214212

215213
async with db.execute(
216-
f"SELECT * FROM threads WHERE status != 'archived' {order_by}"
214+
"SELECT * FROM threads WHERE status != 'archived' ORDER BY created_at DESC"
217215
) as cur:
218216
rows = await cur.fetchall()
219217
return [_row_to_thread(r) for r in rows]
@@ -368,14 +366,14 @@ async def _get_new_messages_since(
368366
def _row_to_thread(row: aiosqlite.Row) -> Thread:
369367
keys = row.keys()
370368
system_prompt = row["system_prompt"] if "system_prompt" in keys else None
371-
updated_at = _parse_dt(row["updated_at"]) if "updated_at" in keys and row["updated_at"] else None
372369
template_id = row["template_id"] if "template_id" in keys else None
370+
updated_at_raw = row["updated_at"] if "updated_at" in keys else None
373371
return Thread(
374372
id=row["id"],
375373
topic=row["topic"],
376374
status=row["status"],
377375
created_at=_parse_dt(row["created_at"]),
378-
updated_at=updated_at,
376+
updated_at=_parse_dt(updated_at_raw) if updated_at_raw else None,
379377
closed_at=_parse_dt(row["closed_at"]) if row["closed_at"] else None,
380378
summary=row["summary"],
381379
metadata=row["metadata"],
@@ -569,10 +567,8 @@ async def msg_post(
569567
"INSERT INTO messages (id, thread_id, author, role, content, seq, created_at, metadata, author_id, author_name) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?)",
570568
(mid, thread_id, actual_author, role, content, seq, now, meta_json, author_id, author_name),
571569
)
572-
# Update thread's updated_at to reflect latest activity
573570
await db.execute(
574-
"UPDATE threads SET updated_at = ? WHERE id = ?",
575-
(now, thread_id),
571+
"UPDATE threads SET updated_at = ? WHERE id = ?", (now, thread_id)
576572
)
577573
async with db.execute(
578574
"UPDATE reply_tokens SET status = 'consumed', consumed_at = ? "
@@ -597,16 +593,22 @@ async def msg_post(
597593
"msg_id": mid, "thread_id": thread_id, "author": author_name,
598594
"author_id": author_id, "role": role, "seq": seq, "content": content[:200], # truncate for event payload
599595
})
596+
_VALID_STOP_REASONS = {"convergence", "timeout", "error", "complete", "impasse"}
597+
600598
if metadata:
601-
if metadata.get("handoff_target"):
599+
handoff_target = metadata.get("handoff_target")
600+
if handoff_target:
602601
await _emit_event(db, "msg.handoff", thread_id, {
603602
"msg_id": mid, "thread_id": thread_id,
604-
"from_agent": author_name, "to_agent": metadata["handoff_target"],
603+
"from_agent": author_name, "to_agent": handoff_target,
605604
})
606-
if metadata.get("stop_reason"):
605+
stop_reason = metadata.get("stop_reason")
606+
if stop_reason:
607+
if stop_reason not in _VALID_STOP_REASONS:
608+
raise ValueError(f"Invalid stop_reason '{stop_reason}'. Must be one of: {', '.join(sorted(_VALID_STOP_REASONS))}")
607609
await _emit_event(db, "msg.stop", thread_id, {
608610
"msg_id": mid, "thread_id": thread_id,
609-
"agent": author_name, "reason": metadata["stop_reason"],
611+
"agent": author_name, "reason": stop_reason,
610612
})
611613
logger.debug(f"Message posted: seq={seq} author={author_name} thread={thread_id}")
612614
return Message(
@@ -848,6 +850,13 @@ async def agent_get(db: aiosqlite.Connection, agent_id: str) -> Optional[AgentIn
848850
return _row_to_agent(row) if row else None
849851

850852

853+
async def agent_verify_token(db: aiosqlite.Connection, agent_id: str, token: str) -> bool:
854+
"""Read-only token check — does not update last_seen or heartbeat."""
855+
async with db.execute("SELECT token FROM agents WHERE id = ?", (agent_id,)) as cur:
856+
row = await cur.fetchone()
857+
return row is not None and row["token"] == token
858+
859+
851860
async def agent_update(
852861
db: aiosqlite.Connection,
853862
agent_id: str,

src/main.py

Lines changed: 110 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -17,15 +17,15 @@
1717
from typing import Literal
1818

1919
import uvicorn
20-
from fastapi import FastAPI, Request, HTTPException
20+
from fastapi import FastAPI, Request, HTTPException, Header
2121
from starlette.responses import Response
2222
from fastapi.responses import HTMLResponse, StreamingResponse, JSONResponse, PlainTextResponse
2323
from fastapi.staticfiles import StaticFiles
2424
from pydantic import BaseModel, ConfigDict
2525
from mcp.server.sse import SseServerTransport
2626
from starlette.routing import Mount
2727

28-
from src.config import HOST, PORT, get_config_dict, save_config_dict
28+
from src.config import HOST, PORT, get_config_dict, save_config_dict, ADMIN_TOKEN
2929
from src.db.database import get_db, close_db
3030
from src.db import crud
3131
from src.db.crud import (
@@ -146,7 +146,7 @@ async def __call__(self, scope, receive, send):
146146
@app.get("/mcp/sse")
147147
async def mcp_sse_endpoint(request: Request):
148148
"""MCP SSE endpoint consumed by MCP clients (Claude Desktop, Cursor, …)."""
149-
from src.mcp_server import init_session_id, clear_connection_agent
149+
from src.mcp_server import init_session_id
150150

151151
# Initialize unique session ID for this SSE connection
152152
session_id = init_session_id()
@@ -168,9 +168,6 @@ async def mcp_sse_endpoint(request: Request):
168168
# Most are normal disconnects (anyio.ClosedResourceError, CancelledError…).
169169
# Log at DEBUG to avoid polluting the terminal.
170170
logger.debug("MCP SSE session ended: %s: %s", type(exc).__name__, exc)
171-
finally:
172-
# Ensure per-connection agent mapping is cleaned when SSE disconnects.
173-
clear_connection_agent(session_id)
174171
return _SseCompletedResponse()
175172

176173

@@ -254,6 +251,7 @@ async def api_threads(status: str | None = None, include_archived: bool = False)
254251

255252
@app.get("/api/threads/{thread_id}/messages")
256253
async def api_messages(thread_id: str, after_seq: int = 0, limit: int = 200, include_system_prompt: bool = False):
254+
limit = min(limit, 1000) # server-side hard cap — prevents memory exhaustion
257255
try:
258256
db = await asyncio.wait_for(get_db(), timeout=DB_TIMEOUT)
259257
t = await asyncio.wait_for(crud.thread_get(db, thread_id), timeout=DB_TIMEOUT)
@@ -284,6 +282,33 @@ async def api_messages(thread_id: str, after_seq: int = 0, limit: int = 200, inc
284282

285283
UPLOAD_DIR = Path(__file__).resolve().parent / "static" / "uploads"
286284

285+
# ── Image upload hardening (QW-01) ─────────────────────────────────────────
286+
# Max upload size: 5 MB. Prevents memory exhaustion / disk DoS.
287+
_MAX_IMAGE_BYTES = int(os.getenv("AGENTCHATBUS_MAX_IMAGE_BYTES", str(5 * 1024 * 1024)))
288+
289+
# Allowlist of safe extensions mapped to their expected magic-byte signatures.
290+
# Only files whose first bytes match the declared extension are accepted.
291+
_ALLOWED_IMAGE_EXTS: dict[str, list[bytes]] = {
292+
".jpg": [b"\xff\xd8\xff"],
293+
".jpeg": [b"\xff\xd8\xff"],
294+
".png": [b"\x89PNG\r\n\x1a\n"],
295+
".gif": [b"GIF87a", b"GIF89a"],
296+
".webp": [b"RIFF"],
297+
}
298+
299+
300+
def _ext_from_filename(filename: str) -> str:
301+
"""Return lowercase extension; map .jpe / .jfif → .jpg for uniformity."""
302+
ext = Path(filename).suffix.lower()
303+
return ".jpg" if ext in {".jpe", ".jfif"} else ext
304+
305+
306+
def _magic_bytes_ok(data: bytes, ext: str) -> bool:
307+
"""Return True if the first bytes of data match any known signature for ext."""
308+
signatures = _ALLOWED_IMAGE_EXTS.get(ext, [])
309+
return any(data[:len(sig)] == sig for sig in signatures)
310+
311+
287312
@app.post("/api/upload/image")
288313
async def api_upload_image(request: Request):
289314
"""Upload an image and return its URL."""
@@ -292,27 +317,33 @@ async def api_upload_image(request: Request):
292317
file = form.get("file")
293318
if not file or not file.filename:
294319
raise HTTPException(status_code=400, detail="No file provided")
295-
296-
# Validate file type
297-
if not file.content_type.startswith("image/"):
298-
raise HTTPException(status_code=400, detail="File must be an image")
299-
300-
# Create upload directory if it doesn't exist
320+
321+
ext = _ext_from_filename(file.filename)
322+
if ext not in _ALLOWED_IMAGE_EXTS:
323+
raise HTTPException(
324+
status_code=400,
325+
detail=f"Unsupported file type '{ext}'. Allowed: {', '.join(_ALLOWED_IMAGE_EXTS)}",
326+
)
327+
328+
# Read with size cap to prevent memory exhaustion
329+
contents = await file.read(_MAX_IMAGE_BYTES + 1)
330+
if len(contents) > _MAX_IMAGE_BYTES:
331+
raise HTTPException(
332+
status_code=413,
333+
detail=f"File too large. Maximum size is {_MAX_IMAGE_BYTES // (1024 * 1024)} MB",
334+
)
335+
336+
# Verify magic bytes — guards against renamed executables / polyglots
337+
if not _magic_bytes_ok(contents, ext):
338+
raise HTTPException(status_code=400, detail="File content does not match its extension")
339+
301340
UPLOAD_DIR.mkdir(parents=True, exist_ok=True)
302-
303-
# Generate unique filename
304-
ext = Path(file.filename).suffix or ".png"
305341
unique_name = f"{uuid.uuid4()}{ext}"
306342
file_path = UPLOAD_DIR / unique_name
307-
308-
# Save file
309-
contents = await file.read()
310343
with open(file_path, "wb") as f:
311344
f.write(contents)
312-
313-
# Return URL
314-
file_url = f"/static/uploads/{unique_name}"
315-
return {"url": file_url, "name": file.filename}
345+
346+
return {"url": f"/static/uploads/{unique_name}", "name": file.filename}
316347
except HTTPException:
317348
raise
318349
except Exception as e:
@@ -365,7 +396,9 @@ class SettingsUpdate(BaseModel):
365396
MSG_WAIT_TIMEOUT: int | None = None
366397

367398
@app.put("/api/settings")
368-
async def api_update_settings(body: SettingsUpdate):
399+
async def api_update_settings(body: SettingsUpdate, x_admin_token: str | None = Header(default=None)):
400+
if ADMIN_TOKEN and x_admin_token != ADMIN_TOKEN:
401+
raise HTTPException(status_code=401, detail="Invalid admin token")
369402
update_data = {k: v for k, v in body.model_dump().items() if v is not None}
370403
if update_data:
371404
save_config_dict(update_data)
@@ -389,6 +422,8 @@ class TemplateCreate(BaseModel):
389422
description: str | None = None
390423
system_prompt: str | None = None
391424
default_metadata: dict | None = None
425+
agent_id: str | None = None # optional — if provided with token, must match a registered agent
426+
token: str | None = None # optional — required only when agent_id is provided
392427

393428
class MessageCreate(BaseModel):
394429
model_config = ConfigDict(
@@ -439,15 +474,9 @@ async def api_get_template(template_id: str):
439474
raise HTTPException(status_code=503, detail="Database operation timeout")
440475
if t is None:
441476
raise HTTPException(status_code=404, detail="Template not found")
442-
default_metadata = None
443-
if t.default_metadata:
444-
try:
445-
default_metadata = json.loads(t.default_metadata)
446-
except (TypeError, ValueError):
447-
default_metadata = t.default_metadata
448477
return {
449478
"id": t.id, "name": t.name, "description": t.description,
450-
"system_prompt": t.system_prompt, "default_metadata": default_metadata,
479+
"system_prompt": t.system_prompt, "default_metadata": t.default_metadata,
451480
"is_builtin": t.is_builtin, "created_at": t.created_at.isoformat(),
452481
}
453482

@@ -456,6 +485,25 @@ async def api_get_template(template_id: str):
456485
async def api_create_template(body: TemplateCreate):
457486
try:
458487
db = await asyncio.wait_for(get_db(), timeout=DB_TIMEOUT)
488+
except asyncio.TimeoutError:
489+
raise HTTPException(status_code=503, detail="Database operation timeout")
490+
491+
# QW-06: if agent_id + token provided, verify they match a registered agent
492+
if body.agent_id and body.token:
493+
token_valid = await asyncio.wait_for(
494+
crud.agent_verify_token(db, body.agent_id, body.token), timeout=DB_TIMEOUT
495+
)
496+
if not token_valid:
497+
raise HTTPException(status_code=401, detail="Invalid agent_id or token")
498+
499+
# QW-07: apply content filter to system_prompt to block embedded secrets
500+
if body.system_prompt:
501+
from src.content_filter import check_content, ContentFilterError as _CFE
502+
blocked, pattern = check_content(body.system_prompt)
503+
if blocked:
504+
raise HTTPException(status_code=400, detail={"error": "system_prompt blocked by content filter", "pattern": pattern})
505+
506+
try:
459507
t = await asyncio.wait_for(
460508
crud.template_create(
461509
db,
@@ -507,6 +555,13 @@ async def api_sync_context(thread_id: str, body: SyncContextRequest | None = Non
507555

508556
@app.post("/api/threads", status_code=201)
509557
async def api_create_thread(body: ThreadCreate):
558+
# QW-07: apply content filter to system_prompt to block embedded secrets
559+
if body.system_prompt:
560+
from src.content_filter import check_content
561+
blocked, pattern = check_content(body.system_prompt)
562+
if blocked:
563+
raise HTTPException(status_code=400, detail={"error": "system_prompt blocked by content filter", "pattern": pattern})
564+
510565
try:
511566
db = await asyncio.wait_for(get_db(), timeout=DB_TIMEOUT)
512567
t = await asyncio.wait_for(
@@ -521,15 +576,33 @@ async def api_create_thread(body: ThreadCreate):
521576
"template_id": t.template_id, "created_at": t.created_at.isoformat()}
522577

523578
@app.post("/api/threads/{thread_id}/messages", status_code=201)
524-
async def api_post_message(thread_id: str, body: MessageCreate):
579+
async def api_post_message(thread_id: str, body: MessageCreate, x_agent_token: str | None = Header(default=None)):
525580
try:
526581
db = await asyncio.wait_for(get_db(), timeout=DB_TIMEOUT)
527582
t = await asyncio.wait_for(crud.thread_get(db, thread_id), timeout=DB_TIMEOUT)
528583
except asyncio.TimeoutError:
529584
raise HTTPException(status_code=503, detail="Database operation timeout")
530585
if t is None:
531586
raise HTTPException(status_code=404, detail="Thread not found")
532-
587+
588+
# Vecteur B: prevent role escalation from human/anonymous senders
589+
if body.role == "system" and body.author in ("human", ""):
590+
raise HTTPException(status_code=400, detail="role 'system' is not allowed for human messages")
591+
592+
# Vecteur C: if author matches a known agent_id, require a valid token
593+
try:
594+
known_agent = await asyncio.wait_for(crud.agent_get(db, body.author), timeout=DB_TIMEOUT)
595+
except asyncio.TimeoutError:
596+
known_agent = None
597+
if known_agent is not None:
598+
if not x_agent_token:
599+
raise HTTPException(status_code=401, detail="X-Agent-Token header required to post as a registered agent")
600+
token_valid = await asyncio.wait_for(
601+
crud.agent_verify_token(db, body.author, x_agent_token), timeout=DB_TIMEOUT
602+
)
603+
if not token_valid:
604+
raise HTTPException(status_code=401, detail="Invalid agent token")
605+
533606
msg_metadata = body.metadata or {}
534607
if body.mentions:
535608
msg_metadata["mentions"] = body.mentions
@@ -579,9 +652,10 @@ async def api_post_message(thread_id: str, body: MessageCreate):
579652
})
580653
except ContentFilterError as e:
581654
raise HTTPException(status_code=400, detail={"error": "Content blocked by filter", "pattern": e.pattern_name})
655+
except ValueError as e:
656+
raise HTTPException(status_code=400, detail=str(e))
582657
except asyncio.TimeoutError:
583658
raise HTTPException(status_code=503, detail="Database operation timeout")
584-
585659
except RateLimitExceeded as e:
586660
from fastapi.responses import JSONResponse
587661
return JSONResponse(
@@ -912,14 +986,14 @@ async def api_thread_export(thread_id: str):
912986
except asyncio.TimeoutError:
913987
raw_topic = thread_id
914988

915-
slug = re.sub(r"[^\w\-]", "-", raw_topic.lower())
916-
slug = re.sub(r"-+", "-", slug).strip("-") or "thread"
989+
slug = re.sub(r"[^\w\-]", "-", raw_topic.lower(), flags=re.ASCII)
990+
slug = re.sub(r"-+", "-", slug, flags=re.ASCII).strip("-")[:80] or "thread"
917991
filename = f"{slug}.md"
918992

919993
return PlainTextResponse(
920994
content=md,
921995
media_type="text/markdown; charset=utf-8",
922-
headers={"Content-Disposition": f'attachment; filename="{filename}"'},
996+
headers={"Content-Disposition": f"attachment; filename=\"{filename}\""},
923997
)
924998

925999

0 commit comments

Comments
 (0)