Skip to content

Commit 876deb2

Browse files
committed
fixes
1 parent 9f10db6 commit 876deb2

10 files changed

Lines changed: 180 additions & 260 deletions

File tree

hindsight-api/hindsight_api/engine/memory_engine.py

Lines changed: 18 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -4159,10 +4159,10 @@ async def _invalidate_facts_from_mental_models(
41594159
fact_ids: list[str],
41604160
) -> int:
41614161
"""
4162-
Remove fact IDs from mental model observations when memories are deleted.
4162+
Remove fact IDs from mental model source_memory_ids when memories are deleted.
41634163
4164-
Uses JSONB path operations to find and update mental models that reference
4165-
the deleted fact IDs in their observations.
4164+
Mental models are now stored in memory_units with fact_type='mental_model'
4165+
and have a source_memory_ids column (UUID[]) tracking their source memories.
41664166
41674167
Args:
41684168
conn: Database connection
@@ -4175,45 +4175,29 @@ async def _invalidate_facts_from_mental_models(
41754175
if not fact_ids:
41764176
return 0
41774177

4178-
# Convert fact_ids to a jsonb array for efficient comparison
4179-
import json
4178+
# Convert string IDs to UUIDs for the array comparison
4179+
import uuid as uuid_module
41804180

4181-
fact_ids_json = json.dumps(fact_ids)
4181+
fact_uuids = [uuid_module.UUID(fid) for fid in fact_ids]
41824182

4183-
# Update mental models by removing the deleted fact IDs from all observations
4184-
# This uses jsonb_set to update each observation's fact_ids array
4183+
# Update mental models (memory_units with fact_type='mental_model')
4184+
# by removing the deleted fact IDs from source_memory_ids
4185+
# Use array subtraction: source_memory_ids - deleted_ids
41854186
result = await conn.execute(
41864187
f"""
4187-
UPDATE {fq_table("mental_models")}
4188-
SET observations = jsonb_set(
4189-
observations,
4190-
'{{observations}}',
4191-
(
4192-
SELECT COALESCE(jsonb_agg(
4193-
jsonb_set(
4194-
observation,
4195-
'{{fact_ids}}',
4196-
(
4197-
SELECT COALESCE(jsonb_agg(fid), '[]'::jsonb)
4198-
FROM jsonb_array_elements_text(observation->'fact_ids') AS fid
4199-
WHERE NOT (fid::text = ANY($2::text[]))
4200-
)
4201-
)
4202-
), '[]'::jsonb)
4203-
FROM jsonb_array_elements(observations->'observations') AS observation
4204-
)
4188+
UPDATE {fq_table("memory_units")}
4189+
SET source_memory_ids = (
4190+
SELECT COALESCE(array_agg(elem), ARRAY[]::uuid[])
4191+
FROM unnest(source_memory_ids) AS elem
4192+
WHERE elem != ALL($2::uuid[])
42054193
),
4206-
last_updated = NOW()
4194+
updated_at = NOW()
42074195
WHERE bank_id = $1
4208-
AND EXISTS (
4209-
SELECT 1
4210-
FROM jsonb_array_elements(observations->'observations') AS observation,
4211-
jsonb_array_elements_text(observation->'fact_ids') AS fid
4212-
WHERE fid::text = ANY($2::text[])
4213-
)
4196+
AND fact_type = 'mental_model'
4197+
AND source_memory_ids && $2::uuid[]
42144198
""",
42154199
bank_id,
4216-
fact_ids,
4200+
fact_uuids,
42174201
)
42184202

42194203
# Parse the result to get number of updated rows

hindsight-api/hindsight_api/engine/reflect/prompts.py

Lines changed: 29 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -16,33 +16,42 @@ def _extract_directive_rules(directives: list[dict[str, Any]]) -> list[str]:
1616
Extract directive rules as a list of strings.
1717
1818
Args:
19-
directives: List of directive mental models with observations
19+
directives: List of directives with name and content
2020
2121
Returns:
2222
List of directive rule strings
2323
"""
2424
rules = []
2525
for directive in directives:
2626
directive_name = directive.get("name", "")
27-
observations = directive.get("observations", [])
28-
if observations:
29-
for obs in observations:
30-
# Support both Pydantic Observation objects and dicts
31-
if hasattr(obs, "title"):
32-
title = obs.title
33-
content = obs.content
34-
else:
35-
title = obs.get("title", "")
36-
content = obs.get("content", "")
37-
if title and content:
38-
rules.append(f"**{title}**: {content}")
39-
elif content:
40-
rules.append(content)
41-
elif directive_name:
42-
# Fallback to description if no observations
43-
desc = directive.get("description", "")
44-
if desc:
45-
rules.append(f"**{directive_name}**: {desc}")
27+
# New format: directives have direct content field
28+
content = directive.get("content", "")
29+
if content:
30+
if directive_name:
31+
rules.append(f"**{directive_name}**: {content}")
32+
else:
33+
rules.append(content)
34+
else:
35+
# Legacy format: check for observations
36+
observations = directive.get("observations", [])
37+
if observations:
38+
for obs in observations:
39+
# Support both Pydantic Observation objects and dicts
40+
if hasattr(obs, "title"):
41+
title = obs.title
42+
obs_content = obs.content
43+
else:
44+
title = obs.get("title", "")
45+
obs_content = obs.get("content", "")
46+
if title and obs_content:
47+
rules.append(f"**{title}**: {obs_content}")
48+
elif obs_content:
49+
rules.append(obs_content)
50+
elif directive_name:
51+
# Fallback to description
52+
desc = directive.get("description", "")
53+
if desc:
54+
rules.append(f"**{directive_name}**: {desc}")
4655
return rules
4756

4857

hindsight-api/tests/test_consolidation.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
"""
66

77
import uuid
8+
from unittest.mock import patch
89

910
import pytest
1011

@@ -17,6 +18,18 @@
1718
)
1819

1920

21+
@pytest.fixture(autouse=True)
22+
def enable_consolidation():
23+
"""Enable consolidation for all tests in this module."""
24+
from hindsight_api.config import get_config
25+
26+
config = get_config()
27+
original_value = config.enable_consolidation
28+
config.enable_consolidation = True
29+
yield
30+
config.enable_consolidation = original_value
31+
32+
2033
class TestConsolidationIntegration:
2134
"""Integration tests for consolidation with real database.
2235

hindsight-api/tests/test_llm_tools.py

Lines changed: 12 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -241,24 +241,27 @@ def test_get_reflect_tools_default(self):
241241
tools = get_reflect_tools()
242242

243243
tool_names = [t["function"]["name"] for t in tools]
244-
assert "list_mental_models" in tool_names
245-
assert "get_mental_model" in tool_names
244+
assert "search_reflections" in tool_names
245+
assert "search_mental_models" in tool_names
246246
assert "recall" in tool_names
247-
assert "learn" in tool_names
248247
assert "expand" in tool_names
249248
assert "done" in tool_names
250249

251-
def test_get_reflect_tools_without_learn(self):
252-
"""Test getting reflect tools without learn."""
250+
def test_get_reflect_tools_with_directives(self):
251+
"""Test getting reflect tools with directive rules."""
253252
from hindsight_api.engine.reflect.tools_schema import get_reflect_tools
254253

255-
tools = get_reflect_tools(enable_learn=False)
254+
tools = get_reflect_tools(directive_rules=["Always respond in French"])
256255

257256
tool_names = [t["function"]["name"] for t in tools]
258-
assert "learn" not in tool_names
259257
assert "recall" in tool_names
260258
assert "done" in tool_names
261259

260+
# Done tool should have directive_compliance field when directives are present
261+
done_tool = next(t for t in tools if t["function"]["name"] == "done")
262+
params = done_tool["function"]["parameters"]["properties"]
263+
assert "directive_compliance" in params
264+
262265
def test_get_reflect_tools_answer_mode(self):
263266
"""Test getting reflect tools with answer output mode."""
264267
from hindsight_api.engine.reflect.tools_schema import get_reflect_tools
@@ -270,7 +273,8 @@ def test_get_reflect_tools_answer_mode(self):
270273

271274
assert "answer" in params
272275
assert "memory_ids" in params
273-
assert "model_ids" in params
276+
assert "mental_model_ids" in params
277+
assert "reflection_ids" in params
274278

275279

276280
class TestLLMToolCallResult:

hindsight-api/tests/test_main_module.py

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -363,7 +363,6 @@ def capture_uvicorn_run(**kwargs):
363363
RetainContext,
364364
RecallContext,
365365
ReflectContext,
366-
RefreshMentalModelContext,
367366
)
368367

369368

@@ -395,6 +394,3 @@ async def validate_recall(self, ctx: RecallContext) -> ValidationResult:
395394

396395
async def validate_reflect(self, ctx: ReflectContext) -> ValidationResult:
397396
return ValidationResult.accept()
398-
399-
async def validate_refresh_mental_model(self, ctx: RefreshMentalModelContext) -> ValidationResult:
400-
return ValidationResult.accept()

0 commit comments

Comments
 (0)