-
Notifications
You must be signed in to change notification settings - Fork 813
Expand file tree
/
Copy pathbank_utils.py
More file actions
342 lines (271 loc) · 10.5 KB
/
bank_utils.py
File metadata and controls
342 lines (271 loc) · 10.5 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
"""
bank profile utilities for disposition and mission management.
"""
import json
import logging
import re
import uuid
from typing import TypedDict
from pydantic import BaseModel, Field
from ...config import get_config
from ..db_utils import acquire_with_retry
from ..memory_engine import fq_table, get_current_schema
from ..response_models import DispositionTraits
logger = logging.getLogger(__name__)
# Fact types that get per-bank partial vector indexes, mapped to their 4-char index suffix.
_BANK_INDEX_FACT_TYPES: dict[str, str] = {
"world": "worl",
"experience": "expr",
"observation": "obsv",
}
def _bank_index_name(ft: str, internal_id: str) -> str:
"""Deterministic, schema-safe vector index name for a (bank, fact_type) pair.
Uses the first 16 hex chars of internal_id (8 bytes of entropy) — unique
enough in practice, fits comfortably within PostgreSQL's 63-char identifier limit.
"""
uid = str(internal_id).replace("-", "")[:16]
return f"idx_mu_emb_{_BANK_INDEX_FACT_TYPES[ft]}_{uid}"
def _vector_index_clause() -> str:
"""Return the USING clause for vector index creation based on the configured extension."""
ext = get_config().vector_extension
if ext == "pgvectorscale":
return "USING diskann (embedding vector_cosine_ops) WITH (num_neighbors = 50)"
elif ext == "vchord":
return "USING vchordrq (embedding vector_l2_ops)"
else: # pgvector (default)
return "USING hnsw (embedding vector_cosine_ops)"
async def create_bank_vector_indexes(conn, bank_id: str, internal_id: str) -> None:
"""Create per-(bank, fact_type) partial vector indexes for a newly created bank.
Respects the HINDSIGHT_API_VECTOR_EXTENSION config to use the appropriate
index type (HNSW for pgvector, DiskANN for pgvectorscale, vchordrq for vchord).
Called immediately after the bank row is first inserted. Safe on empty banks
(index build is instant). Idempotent via CREATE INDEX IF NOT EXISTS.
bank_id is escaped for SQL literal safety (apostrophes doubled).
"""
table = fq_table("memory_units")
escaped = bank_id.replace("'", "''")
using_clause = _vector_index_clause()
for ft in _BANK_INDEX_FACT_TYPES:
idx = _bank_index_name(ft, internal_id)
await conn.execute(
f"CREATE INDEX IF NOT EXISTS {idx} "
f"ON {table} {using_clause} "
f"WHERE fact_type = '{ft}' AND bank_id = '{escaped}'"
)
async def drop_bank_vector_indexes(conn, internal_id: str) -> None:
"""Drop per-(bank, fact_type) partial vector indexes for a bank being deleted.
Called before the bank row is deleted so internal_id is still known.
Idempotent via DROP INDEX IF EXISTS.
"""
schema = get_current_schema()
for ft in _BANK_INDEX_FACT_TYPES:
idx = _bank_index_name(ft, internal_id)
await conn.execute(f"DROP INDEX IF EXISTS {schema}.{idx}")
DEFAULT_DISPOSITION = {
"skepticism": 3,
"literalism": 3,
"empathy": 3,
}
class BankProfile(TypedDict):
"""Type for bank profile data."""
name: str
disposition: DispositionTraits
mission: str
class MissionMergeResponse(BaseModel):
"""LLM response for mission merge."""
mission: str = Field(description="Merged mission in first person perspective")
async def get_bank_profile(pool, bank_id: str) -> BankProfile:
"""
Get bank profile (name, disposition + mission).
Auto-creates bank with default values if not exists.
Args:
pool: Database connection pool
bank_id: bank IDentifier
Returns:
BankProfile with name, typed DispositionTraits, and mission
"""
async with acquire_with_retry(pool) as conn:
# Try to get existing bank
row = await conn.fetchrow(
f"""
SELECT name, disposition, mission
FROM {fq_table("banks")} WHERE bank_id = $1
""",
bank_id,
)
if row:
# asyncpg returns JSONB as a string, so parse it
disposition_data = row["disposition"]
if isinstance(disposition_data, str):
disposition_data = json.loads(disposition_data)
return BankProfile(
name=row["name"],
disposition=DispositionTraits(**disposition_data),
mission=row["mission"] or "",
)
# Bank doesn't exist, create with defaults.
# Generate internal_id here so we control the value and can use it
# immediately for vector index creation without a RETURNING round-trip.
internal_id = uuid.uuid4()
inserted = await conn.fetchval(
f"""
INSERT INTO {fq_table("banks")} (bank_id, name, disposition, mission, internal_id)
VALUES ($1, $2, $3::jsonb, $4, $5)
ON CONFLICT (bank_id) DO NOTHING
RETURNING bank_id
""",
bank_id,
bank_id, # Default name is the bank_id
json.dumps(DEFAULT_DISPOSITION),
"",
internal_id,
)
if inserted:
# Fresh insert — create per-bank vector indexes (instant on empty bank)
await create_bank_vector_indexes(conn, bank_id, str(internal_id))
return BankProfile(name=bank_id, disposition=DispositionTraits(**DEFAULT_DISPOSITION), mission="")
async def update_bank_disposition(pool, bank_id: str, disposition: dict[str, int]) -> None:
"""
Update bank disposition traits.
Args:
pool: Database connection pool
bank_id: bank IDentifier
disposition: Dict with skepticism, literalism, empathy (all 1-5)
"""
# Ensure bank exists first
await get_bank_profile(pool, bank_id)
async with acquire_with_retry(pool) as conn:
await conn.execute(
f"""
UPDATE {fq_table("banks")}
SET disposition = $2::jsonb,
updated_at = NOW()
WHERE bank_id = $1
""",
bank_id,
json.dumps(disposition),
)
async def set_bank_mission(pool, bank_id: str, mission: str) -> None:
"""
Set bank mission (replacing any existing mission).
Args:
pool: Database connection pool
bank_id: bank IDentifier
mission: The mission text
"""
# Ensure bank exists first
await get_bank_profile(pool, bank_id)
async with acquire_with_retry(pool) as conn:
await conn.execute(
f"""
UPDATE {fq_table("banks")}
SET mission = $2,
updated_at = NOW()
WHERE bank_id = $1
""",
bank_id,
mission,
)
async def merge_bank_mission(pool, llm_config, bank_id: str, new_info: str) -> dict:
"""
Merge new mission information with existing mission using LLM.
Normalizes to first person ("I") and resolves conflicts.
Args:
pool: Database connection pool
llm_config: LLM configuration for mission merging
bank_id: bank IDentifier
new_info: New mission information to add/merge
Returns:
Dict with 'mission' (str) key
"""
# Get current profile
profile = await get_bank_profile(pool, bank_id)
current_mission = profile["mission"]
# Use LLM to merge missions
result = await _llm_merge_mission(llm_config, current_mission, new_info)
merged_mission = result["mission"]
# Update in database
async with acquire_with_retry(pool) as conn:
await conn.execute(
f"""
UPDATE {fq_table("banks")}
SET mission = $2,
updated_at = NOW()
WHERE bank_id = $1
""",
bank_id,
merged_mission,
)
return {"mission": merged_mission}
async def _llm_merge_mission(llm_config, current: str, new_info: str) -> dict:
"""
Use LLM to intelligently merge mission information.
Args:
llm_config: LLM configuration to use
current: Current mission text
new_info: New information to merge
Returns:
Dict with 'mission' (str) key
"""
prompt = f"""You are helping maintain an agent's mission statement.
Current mission: {current if current else "(empty)"}
New information to add: {new_info}
Instructions:
1. Merge the new information with the current mission
2. If there are conflicts, the NEW information overwrites the old
3. Keep additions that don't conflict
4. Output in FIRST PERSON ("I") perspective
5. Be concise - keep it under 500 characters
6. Return ONLY the merged mission text, no explanations
Merged mission:"""
try:
messages = [{"role": "user", "content": prompt}]
content = await llm_config.call(
messages=messages, scope="bank_mission", temperature=0.3, max_completion_tokens=8192
)
logger.info(f"LLM response for mission merge (first 500 chars): {content[:500]}")
merged = content.strip()
if not merged or merged.lower() in ["(empty)", "none", "n/a"]:
merged = new_info if new_info else ""
return {"mission": merged}
except Exception as e:
logger.error(f"Error merging mission with LLM: {e}")
# Fallback: just append new info
if current:
merged = f"{current} {new_info}".strip()
else:
merged = new_info
return {"mission": merged}
async def list_banks(pool) -> list:
"""
List all banks in the system.
Args:
pool: Database connection pool
Returns:
List of dicts with bank_id, name, disposition, mission, created_at, updated_at
"""
async with acquire_with_retry(pool) as conn:
rows = await conn.fetch(
f"""
SELECT bank_id, name, disposition, mission, created_at, updated_at
FROM {fq_table("banks")}
ORDER BY updated_at DESC
"""
)
result = []
for row in rows:
# asyncpg returns JSONB as a string, so parse it
disposition_data = row["disposition"]
if isinstance(disposition_data, str):
disposition_data = json.loads(disposition_data)
result.append(
{
"bank_id": row["bank_id"],
"name": row["name"],
"disposition": disposition_data,
"mission": row["mission"] or "",
"created_at": row["created_at"].isoformat() if row["created_at"] else None,
"updated_at": row["updated_at"].isoformat() if row["updated_at"] else None,
}
)
return result