Skip to content

Commit 46a5d9e

Browse files
committed
fix: resolve 25 test regressions from streaming retain pipeline (#722)
The 3-phase retain pipeline (914ba79) introduced several regressions: 1. **Per-content tags lost** — streaming pipeline used `contents[0].tags` for ALL chunks, breaking tag-based visibility. Fixed by tracking chunk-to-content mapping so each chunk uses its source content's tags. 2. **Multi-document batches broken** — batches with per-content `document_id` values were merged into a single document. Fixed by grouping by document_id and processing each group independently. 3. **Migration ID collision** — `d6e7f8a9b0c1` was used by both `drop_documents_metadata` and `case_insensitive_entities_trgm_index`. Renamed trgm migration to `e8f9a0b1c2d3`, fixed chain, added missing schema prefix on DROP INDEX. 4. **Graph entity inheritance** — `get_graph_data` queried entities for observation IDs only, but observations inherit entities from source memories. Fixed by querying `all_relevant_ids`. 5. **Docstring false positives** — link_utils.py docstrings triggered the SQL schema safety test's unqualified table reference check. 6. **Config test count** — `retain_chunk_batch_size` added to `_CONFIGURABLE_FIELDS` without updating the test assertion.
1 parent 1a1fb35 commit 46a5d9e

File tree

6 files changed

+78
-24
lines changed

6 files changed

+78
-24
lines changed

hindsight-api-slim/hindsight_api/alembic/versions/a4b5c6d7e8f9_fix_per_bank_vector_index_type.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
"""Fix per-bank vector indexes to match configured extension
22
33
Revision ID: a4b5c6d7e8f9
4-
Revises: c2d3e4f5g6h7, c5d6e7f8a9b0
4+
Revises: e8f9a0b1c2d3
55
Create Date: 2026-04-01
66
77
Migration d5e6f7a8b9c0 hardcoded HNSW when creating per-bank partial vector
@@ -21,7 +21,7 @@
2121
from sqlalchemy import text
2222

2323
revision: str = "a4b5c6d7e8f9"
24-
down_revision: str | Sequence[str] | None = "d6e7f8a9b0c1"
24+
down_revision: str | Sequence[str] | None = "e8f9a0b1c2d3"
2525
branch_labels: str | Sequence[str] | None = None
2626
depends_on: str | Sequence[str] | None = None
2727

hindsight-api-slim/hindsight_api/alembic/versions/d6e7f8a9b0c1_case_insensitive_entities_trgm_index.py renamed to hindsight-api-slim/hindsight_api/alembic/versions/e8f9a0b1c2d3_case_insensitive_entities_trgm_index.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -4,17 +4,17 @@
44
"Alice" and "alice" to have different trigram sets. This recreates it on
55
LOWER(canonical_name) so the % operator matches case-insensitively.
66
7-
Revision ID: d6e7f8a9b0c1
8-
Revises: c5d6e7f8a9b0
7+
Revision ID: e8f9a0b1c2d3
8+
Revises: d6e7f8a9b0c1
99
Create Date: 2026-03-31
1010
"""
1111

1212
from collections.abc import Sequence
1313

1414
from alembic import context, op
1515

16-
revision: str = "d6e7f8a9b0c1"
17-
down_revision: str | Sequence[str] | None = "c5d6e7f8a9b0"
16+
revision: str = "e8f9a0b1c2d3"
17+
down_revision: str | Sequence[str] | None = "d6e7f8a9b0c1"
1818
branch_labels: str | Sequence[str] | None = None
1919
depends_on: str | Sequence[str] | None = None
2020

@@ -27,7 +27,7 @@ def _get_schema_prefix() -> str:
2727
def upgrade() -> None:
2828
schema = _get_schema_prefix()
2929
# Drop the old case-sensitive trigram index
30-
op.execute(f"DROP INDEX IF EXISTS entities_canonical_name_trgm_idx")
30+
op.execute(f"DROP INDEX IF EXISTS {schema}entities_canonical_name_trgm_idx")
3131
# Create case-insensitive trigram index on LOWER(canonical_name)
3232
op.execute(
3333
f"CREATE INDEX IF NOT EXISTS entities_canonical_name_lower_trgm_idx "
@@ -36,8 +36,8 @@ def upgrade() -> None:
3636

3737

3838
def downgrade() -> None:
39-
op.execute(f"DROP INDEX IF EXISTS entities_canonical_name_lower_trgm_idx")
4039
schema = _get_schema_prefix()
40+
op.execute(f"DROP INDEX IF EXISTS {schema}entities_canonical_name_lower_trgm_idx")
4141
# Restore original case-sensitive index
4242
op.execute(
4343
f"CREATE INDEX IF NOT EXISTS entities_canonical_name_trgm_idx "

hindsight-api-slim/hindsight_api/engine/memory_engine.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -4338,15 +4338,16 @@ async def get_graph_data(
43384338
link for link in links if link["from_unit_id"] in unit_id_set and link["to_unit_id"] in unit_id_set
43394339
]
43404340

4341-
# Get entity information — only for visible units
4342-
if unit_ids:
4341+
# Get entity information — for visible units AND their source memories
4342+
# (observations inherit entities from source memories)
4343+
if all_relevant_ids:
43434344
unit_entities = await conn.fetch(f"""
43444345
SELECT ue.unit_id, e.canonical_name
43454346
FROM {fq_table("unit_entities")} ue
43464347
JOIN {fq_table("entities")} e ON ue.entity_id = e.id
43474348
WHERE ue.unit_id = ANY($1::uuid[])
43484349
ORDER BY ue.unit_id
4349-
""", unit_ids)
4350+
""", all_relevant_ids)
43504351
else:
43514352
unit_entities = []
43524353

hindsight-api-slim/hindsight_api/engine/retain/link_utils.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -58,7 +58,7 @@ async def _bulk_insert_links(
5858
chunk_size: int = 5000,
5959
skip_exists_check: bool = False,
6060
) -> None:
61-
"""Insert links into memory_links using sorted bulk INSERT FROM unnest().
61+
"""Bulk-insert links using sorted INSERT FROM unnest().
6262
6363
Sorting by (from_unit_id, to_unit_id) ensures all concurrent transactions
6464
acquire index locks in the same order, eliminating circular-wait deadlocks.
@@ -944,7 +944,7 @@ async def create_semantic_links_batch(
944944

945945
async def insert_entity_links_batch(conn, links: list[EntityLink], bank_id: str, chunk_size: int = 5000):
946946
"""
947-
Insert entity links into memory_links via sorted bulk INSERT FROM unnest().
947+
Bulk-insert entity links via sorted INSERT FROM unnest().
948948
949949
Args:
950950
conn: Database connection

hindsight-api-slim/hindsight_api/engine/retain/orchestrator.py

Lines changed: 62 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -434,6 +434,53 @@ async def retain_batch(
434434
# Convert dicts to RetainContent objects
435435
contents = _build_contents(contents_dicts, document_tags)
436436

437+
# When contents have multiple distinct per-content document_ids and no
438+
# batch-level document_id, group by doc_id and process each group
439+
# independently so each document is tracked separately.
440+
if not document_id:
441+
per_content_doc_ids = [item.get("document_id") for item in contents_dicts]
442+
unique_doc_ids = {d for d in per_content_doc_ids if d}
443+
if len(unique_doc_ids) > 1:
444+
# Group contents by document_id, preserving original order
445+
groups: dict[str, tuple[list[RetainContentDict], list[RetainContent]]] = {}
446+
original_indices: dict[str, list[int]] = {}
447+
for idx, (cd, c) in enumerate(zip(contents_dicts, contents)):
448+
doc_key = cd.get("document_id") or str(uuid.uuid4())
449+
if doc_key not in groups:
450+
groups[doc_key] = ([], [])
451+
original_indices[doc_key] = []
452+
groups[doc_key][0].append(cd)
453+
groups[doc_key][1].append(c)
454+
original_indices[doc_key].append(idx)
455+
456+
# Process each group and merge results back in original order
457+
result_unit_ids: list[list[str]] = [[] for _ in contents_dicts]
458+
total_usage = TokenUsage()
459+
for doc_key, (group_dicts, group_contents) in groups.items():
460+
group_ids, group_usage = await retain_batch(
461+
pool=pool,
462+
embeddings_model=embeddings_model,
463+
llm_config=llm_config,
464+
entity_resolver=entity_resolver,
465+
format_date_fn=format_date_fn,
466+
bank_id=bank_id,
467+
contents_dicts=group_dicts,
468+
config=config,
469+
document_id=doc_key,
470+
is_first_batch=is_first_batch,
471+
fact_type_override=fact_type_override,
472+
document_tags=document_tags,
473+
operation_id=operation_id,
474+
schema=schema,
475+
outbox_callback=outbox_callback,
476+
db_semaphore=db_semaphore,
477+
)
478+
for group_idx, orig_idx in enumerate(original_indices[doc_key]):
479+
if group_idx < len(group_ids):
480+
result_unit_ids[orig_idx] = group_ids[group_idx]
481+
total_usage = total_usage + group_usage
482+
return result_unit_ids, total_usage
483+
437484
# Resolve effective document_id early so both delta and streaming paths
438485
# can find existing chunks from a prior attempt. On retry, the generated
439486
# document_id is recovered from operation result_metadata.
@@ -508,10 +555,12 @@ async def retain_batch(
508555
# retain code paths.
509556
chunk_batch_size = getattr(config, "retain_chunk_batch_size", 100)
510557
chunk_size = getattr(config, "retain_chunk_size", 3000)
511-
all_pre_chunks = []
512-
for content in contents:
558+
all_pre_chunks: list[str] = []
559+
chunk_to_content: list[int] = [] # maps chunk index -> index into contents
560+
for content_idx, content in enumerate(contents):
513561
content_chunks = fact_extraction.chunk_text(content.content, chunk_size)
514562
all_pre_chunks.extend(content_chunks)
563+
chunk_to_content.extend([content_idx] * len(content_chunks))
515564

516565
total_pre_chunks = len(all_pre_chunks)
517566
num_batches = (total_pre_chunks + chunk_batch_size - 1) // chunk_batch_size if total_pre_chunks > 0 else 1
@@ -538,6 +587,7 @@ async def retain_batch(
538587
log_buffer=log_buffer,
539588
start_time=start_time,
540589
all_pre_chunks=all_pre_chunks,
590+
chunk_to_content=chunk_to_content,
541591
chunk_batch_size=chunk_batch_size,
542592
operation_id=operation_id,
543593
schema=schema,
@@ -676,6 +726,7 @@ async def _streaming_retain_batch(
676726
log_buffer: list[str],
677727
start_time: float,
678728
all_pre_chunks: list[str],
729+
chunk_to_content: list[int],
679730
chunk_batch_size: int,
680731
operation_id: str | None = None,
681732
schema: str | None = None,
@@ -704,8 +755,8 @@ async def _streaming_retain_batch(
704755
# operation result_metadata on retry).
705756
effective_doc_id = document_id
706757

707-
# Use the first content item as the template for metadata (context, event_date, etc.)
708-
template_content = contents[0] if contents else RetainContent(content="")
758+
# Default template for metadata (context, event_date, etc.) when content list is empty.
759+
_default_content = RetainContent(content="")
709760

710761
# Load existing chunk hashes BEFORE document tracking to detect recovery.
711762
# If chunks exist AND the document content hash matches, this is a retry of
@@ -774,14 +825,15 @@ async def _streaming_retain_batch(
774825
# it pushes the enriched result into the queue for the DB consumer.
775826
async def _llm_producer() -> None:
776827
async def _extract_one(global_idx: int, chunk_text: str) -> None:
828+
source = contents[chunk_to_content[global_idx]] if contents else _default_content
777829
content = RetainContent(
778830
content=chunk_text,
779-
context=template_content.context,
780-
event_date=template_content.event_date,
781-
metadata=template_content.metadata,
782-
entities=template_content.entities,
783-
tags=template_content.tags,
784-
observation_scopes=template_content.observation_scopes,
831+
context=source.context,
832+
event_date=source.event_date,
833+
metadata=source.metadata,
834+
entities=source.entities,
835+
tags=source.tags,
836+
observation_scopes=source.observation_scopes,
785837
)
786838
extracted, processed, chunk_meta, usage = await _extract_and_embed(
787839
[content],

hindsight-api-slim/tests/test_hierarchical_config.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -95,9 +95,10 @@ async def test_hierarchical_fields_categorization():
9595
assert "reflect_source_facts_max_tokens" in configurable
9696
assert "llm_gemini_safety_settings" in configurable
9797
assert "mcp_enabled_tools" in configurable
98+
assert "retain_chunk_batch_size" in configurable
9899

99100
# Verify count is correct
100-
assert len(configurable) == 21
101+
assert len(configurable) == 22
101102

102103
# Verify credential fields (NEVER exposed)
103104
assert "llm_api_key" in credentials

0 commit comments

Comments
 (0)