Skip to content

Commit 3aac487

Browse files
committed
fix(retrieval): Apply proof_count boost to link_expansion retrieval
1 parent ed4f994 commit 3aac487

1 file changed

Lines changed: 13 additions & 13 deletions

File tree

hindsight-api-slim/hindsight_api/engine/search/link_expansion_retrieval.py

Lines changed: 13 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -59,7 +59,7 @@ async def _find_semantic_seeds(
5959
rows = await conn.fetch(
6060
f"""
6161
SELECT id, text, context, event_date, occurred_start, occurred_end,
62-
mentioned_at, fact_type, document_id, chunk_id, tags,
62+
mentioned_at, fact_type, document_id, chunk_id, tags, proof_count,
6363
1 - (embedding <=> $1::vector) AS similarity
6464
FROM {fq_table("memory_units")}
6565
WHERE bank_id = $2
@@ -274,7 +274,7 @@ async def _expand_combined(
274274
-- Score = COUNT(DISTINCT shared entities), mapped to [0,1] via tanh.
275275
SELECT mu.id, mu.text, mu.context, mu.event_date, mu.occurred_start,
276276
mu.occurred_end, mu.mentioned_at,
277-
mu.fact_type, mu.document_id, mu.chunk_id, mu.tags,
277+
mu.fact_type, mu.document_id, mu.chunk_id, mu.tags, mu.proof_count,
278278
COUNT(DISTINCT ue_seed.entity_id)::float AS score,
279279
'entity'::text AS source
280280
FROM {ue} ue_seed
@@ -298,14 +298,14 @@ async def _expand_combined(
298298
SELECT
299299
id, text, context, event_date, occurred_start,
300300
occurred_end, mentioned_at,
301-
fact_type, document_id, chunk_id, tags,
301+
fact_type, document_id, chunk_id, tags, proof_count,
302302
MAX(weight) AS score,
303303
'semantic'::text AS source
304304
FROM (
305305
SELECT
306306
mu.id, mu.text, mu.context, mu.event_date, mu.occurred_start,
307307
mu.occurred_end, mu.mentioned_at,
308-
mu.fact_type, mu.document_id, mu.chunk_id, mu.tags,
308+
mu.fact_type, mu.document_id, mu.chunk_id, mu.tags, mu.proof_count,
309309
ml.weight
310310
FROM {ml} ml
311311
JOIN {mu} mu ON mu.id = ml.to_unit_id
@@ -317,7 +317,7 @@ async def _expand_combined(
317317
SELECT
318318
mu.id, mu.text, mu.context, mu.event_date, mu.occurred_start,
319319
mu.occurred_end, mu.mentioned_at,
320-
mu.fact_type, mu.document_id, mu.chunk_id, mu.tags,
320+
mu.fact_type, mu.document_id, mu.chunk_id, mu.tags, mu.proof_count,
321321
ml.weight
322322
FROM {ml} ml
323323
JOIN {mu} mu ON mu.id = ml.from_unit_id
@@ -328,7 +328,7 @@ async def _expand_combined(
328328
) sem_raw
329329
GROUP BY id, text, context, event_date, occurred_start,
330330
occurred_end, mentioned_at,
331-
fact_type, document_id, chunk_id, tags
331+
fact_type, document_id, chunk_id, tags, proof_count
332332
ORDER BY score DESC
333333
LIMIT $3
334334
),
@@ -339,7 +339,7 @@ async def _expand_combined(
339339
SELECT DISTINCT ON (mu.id)
340340
mu.id, mu.text, mu.context, mu.event_date, mu.occurred_start,
341341
mu.occurred_end, mu.mentioned_at,
342-
mu.fact_type, mu.document_id, mu.chunk_id, mu.tags,
342+
mu.fact_type, mu.document_id, mu.chunk_id, mu.tags, mu.proof_count,
343343
ml.weight AS score,
344344
'causal'::text AS source
345345
FROM {ml} ml
@@ -429,7 +429,7 @@ async def _expand_observations(
429429
SELECT
430430
mu.id, mu.text, mu.context, mu.event_date, mu.occurred_start,
431431
mu.occurred_end, mu.mentioned_at,
432-
mu.fact_type, mu.document_id, mu.chunk_id, mu.tags,
432+
mu.fact_type, mu.document_id, mu.chunk_id, mu.tags, mu.proof_count,
433433
(SELECT COUNT(DISTINCT s) FROM unnest(mu.source_memory_ids) s WHERE s = ANY(ca.source_ids))::float AS score
434434
FROM {fq_table("memory_units")} mu, connected_array ca
435435
WHERE mu.fact_type = 'observation'
@@ -453,35 +453,35 @@ async def _expand_observations(
453453
SELECT
454454
id, text, context, event_date, occurred_start,
455455
occurred_end, mentioned_at,
456-
fact_type, document_id, chunk_id, tags,
456+
fact_type, document_id, chunk_id, tags, proof_count,
457457
MAX(weight) AS score,
458458
'semantic'::text AS source
459459
FROM (
460460
SELECT mu.id, mu.text, mu.context, mu.event_date, mu.occurred_start,
461461
mu.occurred_end, mu.mentioned_at, mu.fact_type, mu.document_id,
462-
mu.chunk_id, mu.tags, ml.weight
462+
mu.chunk_id, mu.tags, mu.proof_count, ml.weight
463463
FROM {ml} ml JOIN {mu} mu ON mu.id = ml.to_unit_id
464464
WHERE ml.from_unit_id = ANY($1::uuid[])
465465
AND ml.link_type = 'semantic' AND mu.fact_type = 'observation'
466466
AND mu.id != ALL($1::uuid[])
467467
UNION ALL
468468
SELECT mu.id, mu.text, mu.context, mu.event_date, mu.occurred_start,
469469
mu.occurred_end, mu.mentioned_at, mu.fact_type, mu.document_id,
470-
mu.chunk_id, mu.tags, ml.weight
470+
mu.chunk_id, mu.tags, mu.proof_count, ml.weight
471471
FROM {ml} ml JOIN {mu} mu ON mu.id = ml.from_unit_id
472472
WHERE ml.to_unit_id = ANY($1::uuid[])
473473
AND ml.link_type = 'semantic' AND mu.fact_type = 'observation'
474474
AND mu.id != ALL($1::uuid[])
475475
) sem_raw
476476
GROUP BY id, text, context, event_date, occurred_start, occurred_end,
477-
mentioned_at, fact_type, document_id, chunk_id, tags
477+
mentioned_at, fact_type, document_id, chunk_id, tags, proof_count
478478
ORDER BY score DESC LIMIT $2
479479
),
480480
causal_expanded AS (
481481
SELECT DISTINCT ON (mu.id)
482482
mu.id, mu.text, mu.context, mu.event_date, mu.occurred_start,
483483
mu.occurred_end, mu.mentioned_at, mu.fact_type, mu.document_id,
484-
mu.chunk_id, mu.tags, ml.weight AS score, 'causal'::text AS source
484+
mu.chunk_id, mu.tags, mu.proof_count, ml.weight AS score, 'causal'::text AS source
485485
FROM {ml} ml JOIN {mu} mu ON ml.to_unit_id = mu.id
486486
WHERE ml.from_unit_id = ANY($1::uuid[])
487487
AND ml.link_type IN ('causes', 'caused_by', 'enables', 'prevents')

0 commit comments

Comments
 (0)