1313import aiosqlite
1414
1515from src .db .models import Thread , Message , AgentInfo , Event
16- from src .config import AGENT_HEARTBEAT_TIMEOUT , RATE_LIMIT_MSG_PER_MINUTE , RATE_LIMIT_ENABLED
17- from src .config import AGENT_HEARTBEAT_TIMEOUT , CONTENT_FILTER_ENABLED
16+ from src .config import (
17+ AGENT_HEARTBEAT_TIMEOUT ,
18+ RATE_LIMIT_MSG_PER_MINUTE ,
19+ RATE_LIMIT_ENABLED ,
20+ CONTENT_FILTER_ENABLED ,
21+ REPLY_TOKEN_LEASE_SECONDS ,
22+ SEQ_TOLERANCE ,
23+ SEQ_MISMATCH_MAX_MESSAGES ,
24+ )
1825from src .content_filter import check_content , ContentFilterError
1926
2027logger = logging .getLogger (__name__ )
@@ -30,6 +37,46 @@ def __init__(self, limit: int, window: int, retry_after: int, scope: str) -> Non
3037 super ().__init__ (f"Rate limit exceeded: { limit } messages/{ window } s" )
3138
3239
40+ class MissingSyncFieldsError (Exception ):
41+ """Raised when strict sync fields are absent from msg_post."""
42+
43+ def __init__ (self , missing_fields : list [str ]) -> None :
44+ self .missing_fields = missing_fields
45+ super ().__init__ (f"Missing required sync fields: { ', ' .join (missing_fields )} " )
46+
47+
48+ class SeqMismatchError (Exception ):
49+ """Raised when too many unseen messages exist since expected seq."""
50+
51+ def __init__ (self , expected_last_seq : int , current_seq : int , new_messages : list [dict ]) -> None :
52+ self .expected_last_seq = expected_last_seq
53+ self .current_seq = current_seq
54+ self .new_messages = new_messages
55+ super ().__init__ (
56+ f"SEQ_MISMATCH: expected_last_seq={ expected_last_seq } , current_seq={ current_seq } "
57+ )
58+
59+
60+ class ReplyTokenInvalidError (Exception ):
61+ def __init__ (self , token : str ) -> None :
62+ self .token = token
63+ super ().__init__ ("TOKEN_INVALID" )
64+
65+
66+ class ReplyTokenExpiredError (Exception ):
67+ def __init__ (self , token : str , expires_at : str ) -> None :
68+ self .token = token
69+ self .expires_at = expires_at
70+ super ().__init__ ("TOKEN_EXPIRED" )
71+
72+
73+ class ReplyTokenReplayError (Exception ):
74+ def __init__ (self , token : str , consumed_at : Optional [str ]) -> None :
75+ self .token = token
76+ self .consumed_at = consumed_at
77+ super ().__init__ ("TOKEN_REPLAY" )
78+
79+
3380GLOBAL_SYSTEM_PROMPT = """**SYSTEM DIRECTIVE: ACTIVE AGENT COLLABORATION WORKSPACE**
3481
3582Welcome to this Thread. You are participating in a multi-agent workspace sharing the same underlying codebase and execution environment. You MUST collaborate proactively and keep progress moving.
@@ -233,6 +280,69 @@ async def thread_latest_seq(db: aiosqlite.Connection, thread_id: str) -> int:
233280 return row ["max_seq" ] or 0
234281
235282
283+ async def _expire_old_reply_tokens (db : aiosqlite .Connection ) -> None :
284+ now = _now ()
285+ await db .execute (
286+ "UPDATE reply_tokens SET status = 'expired' "
287+ "WHERE status = 'issued' AND expires_at <= ?" ,
288+ (now ,),
289+ )
290+ await db .commit ()
291+
292+
293+ async def issue_reply_token (
294+ db : aiosqlite .Connection ,
295+ thread_id : str ,
296+ agent_id : Optional [str ] = None ,
297+ ) -> dict :
298+ """Issue a short-lived reply token bound to a thread (and optionally an agent)."""
299+ await _expire_old_reply_tokens (db )
300+ token = secrets .token_urlsafe (24 )
301+ issued_at = _now ()
302+ expires_at = (datetime .now (timezone .utc ) + timedelta (seconds = REPLY_TOKEN_LEASE_SECONDS )).isoformat ()
303+ await db .execute (
304+ "INSERT INTO reply_tokens (token, thread_id, agent_id, issued_at, expires_at, consumed_at, status) "
305+ "VALUES (?, ?, ?, ?, ?, NULL, 'issued')" ,
306+ (token , thread_id , agent_id , issued_at , expires_at ),
307+ )
308+ await db .commit ()
309+ current_seq = await thread_latest_seq (db , thread_id )
310+ return {
311+ "reply_token" : token ,
312+ "current_seq" : current_seq ,
313+ "reply_window" : {
314+ "expires_at" : expires_at ,
315+ "max_new_messages" : SEQ_TOLERANCE ,
316+ },
317+ }
318+
319+
320+ async def _get_new_messages_since (
321+ db : aiosqlite .Connection ,
322+ thread_id : str ,
323+ expected_last_seq : int ,
324+ limit : int = SEQ_MISMATCH_MAX_MESSAGES ,
325+ ) -> list [dict ]:
326+ msgs = await msg_list (
327+ db ,
328+ thread_id = thread_id ,
329+ after_seq = expected_last_seq ,
330+ limit = limit ,
331+ include_system_prompt = False ,
332+ )
333+ return [
334+ {
335+ "msg_id" : m .id ,
336+ "seq" : m .seq ,
337+ "author" : m .author ,
338+ "role" : m .role ,
339+ "content" : m .content ,
340+ "created_at" : m .created_at .isoformat (),
341+ }
342+ for m in msgs
343+ ]
344+
345+
236346def _row_to_thread (row : aiosqlite .Row ) -> Thread :
237347 system_prompt = row ["system_prompt" ] if "system_prompt" in row .keys () else None
238348 updated_at = _parse_dt (row ["updated_at" ]) if "updated_at" in row .keys () and row ["updated_at" ] else None
@@ -258,6 +368,8 @@ async def msg_post(
258368 thread_id : str ,
259369 author : str ,
260370 content : str ,
371+ expected_last_seq : int ,
372+ reply_token : str ,
261373 role : str = "user" ,
262374 metadata : Optional [dict ] = None ,
263375) -> Message :
@@ -310,6 +422,41 @@ async def msg_post(
310422 scope = scope ,
311423 )
312424
425+ missing_fields : list [str ] = []
426+ if expected_last_seq is None :
427+ missing_fields .append ("expected_last_seq" )
428+ if not reply_token :
429+ missing_fields .append ("reply_token" )
430+ if missing_fields :
431+ raise MissingSyncFieldsError (missing_fields )
432+
433+ await _expire_old_reply_tokens (db )
434+ async with db .execute (
435+ "SELECT token, thread_id, agent_id, expires_at, consumed_at, status "
436+ "FROM reply_tokens WHERE token = ?" ,
437+ (reply_token ,),
438+ ) as cur :
439+ token_row = await cur .fetchone ()
440+
441+ if token_row is None :
442+ raise ReplyTokenInvalidError (reply_token )
443+ if token_row ["thread_id" ] != thread_id :
444+ raise ReplyTokenInvalidError (reply_token )
445+ if token_row ["status" ] == "consumed" :
446+ raise ReplyTokenReplayError (reply_token , token_row ["consumed_at" ])
447+ if token_row ["status" ] == "expired" :
448+ raise ReplyTokenExpiredError (reply_token , token_row ["expires_at" ])
449+
450+ token_agent_id = token_row ["agent_id" ]
451+ if token_agent_id and author_id and token_agent_id != author_id :
452+ raise ReplyTokenInvalidError (reply_token )
453+
454+ current_seq = await thread_latest_seq (db , thread_id )
455+ new_messages_count = current_seq - expected_last_seq
456+ if new_messages_count > SEQ_TOLERANCE :
457+ new_messages = await _get_new_messages_since (db , thread_id , expected_last_seq )
458+ raise SeqMismatchError (expected_last_seq , current_seq , new_messages )
459+
313460 mid = str (uuid .uuid4 ())
314461 now = _now ()
315462 seq = await next_seq (db )
@@ -323,6 +470,22 @@ async def msg_post(
323470 "UPDATE threads SET updated_at = ? WHERE id = ?" ,
324471 (now , thread_id ),
325472 )
473+ async with db .execute (
474+ "UPDATE reply_tokens SET status = 'consumed', consumed_at = ? "
475+ "WHERE token = ? AND status = 'issued'" ,
476+ (now , reply_token ),
477+ ) as cur :
478+ consumed = cur .rowcount
479+ if consumed == 0 :
480+ await db .rollback ()
481+ async with db .execute (
482+ "SELECT consumed_at FROM reply_tokens WHERE token = ?" ,
483+ (reply_token ,),
484+ ) as cur :
485+ row = await cur .fetchone ()
486+ consumed_at = row ["consumed_at" ] if row else None
487+ raise ReplyTokenReplayError (reply_token , consumed_at )
488+
326489 await db .commit ()
327490 if author_id :
328491 await agent_msg_post (db , author_id )
0 commit comments