Skip to content

Commit 5cdc714

Browse files
authored
fix(recall): reject empty queries with 400 and fix SQL parameter gap (#632)
* fix(recall): reject empty queries with 400 and fix SQL parameter gap causing IndeterminateDatatypeError When query text contains only punctuation/symbols (no word characters after normalization), the BM25 arms are skipped but the old code still placed `limit` at \$3 in the params list. If tags or tag_groups were also set, their params (\$4+) were referenced in the SQL while \$3 was a gap, causing PostgreSQL to raise IndeterminateDatatypeError. Fix the parameter layout so `limit` is only appended to params when tokens are present (i.e. when BM25 arms actually use LIMIT \$3), and shift tags_param_idx from 4 to 3 in the no-tokens path. Also add a field_validator on RecallRequest.query that rejects empty-after- normalization queries at the API layer with a 400 before they reach the DB. * refactor: extract tokenize_query helper and reuse in RecallRequest validator
1 parent 78aa7c5 commit 5cdc714

File tree

2 files changed

+32
-11
lines changed

2 files changed

+32
-11
lines changed

hindsight-api-slim/hindsight_api/api/http.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -169,6 +169,15 @@ class RecallRequest(BaseModel):
169169
"Each group is a leaf {tags, match} or compound {and: [...]}, {or: [...]}, {not: ...}.",
170170
)
171171

172+
@field_validator("query")
173+
@classmethod
174+
def validate_query_not_empty(cls, v: str) -> str:
175+
from ..engine.search.retrieval import tokenize_query
176+
177+
if not tokenize_query(v):
178+
raise ValueError("query must contain at least one word character after normalization")
179+
return v
180+
172181
@model_validator(mode="after")
173182
def validate_tags_exclusive(self) -> "RecallRequest":
174183
if self.tags is not None and self.tag_groups is not None:

hindsight-api-slim/hindsight_api/engine/search/retrieval.py

Lines changed: 23 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010

1111
import asyncio
1212
import logging
13+
import re
1314
from dataclasses import dataclass, field
1415
from datetime import UTC, datetime
1516
from typing import Optional
@@ -26,6 +27,15 @@
2627
logger = logging.getLogger(__name__)
2728

2829

30+
def tokenize_query(query_text: str) -> list[str]:
31+
"""Normalize query text and split into BM25 tokens.
32+
33+
Strips punctuation, lowercases, and splits on whitespace.
34+
Returns an empty list when the query contains no word characters.
35+
"""
36+
return re.sub(r"[^\w\s]", " ", query_text.lower()).split()
37+
38+
2939
@dataclass
3040
class ParallelRetrievalResult:
3141
"""Result from parallel retrieval across all methods."""
@@ -129,12 +139,9 @@ async def retrieve_semantic_bm25_combined(
129139
Returns:
130140
Dict mapping fact_type -> (semantic_results, bm25_results)
131141
"""
132-
import re
133-
134142
result_dict: dict[str, tuple[list[RetrievalResult], list[RetrievalResult]]] = {ft: ([], []) for ft in fact_types}
135143

136-
sanitized_text = re.sub(r"[^\w\s]", " ", query_text.lower())
137-
tokens = [token for token in sanitized_text.split() if token]
144+
tokens = tokenize_query(query_text)
138145

139146
# Over-fetch for HNSW approximation; semantic results trimmed to limit in Python.
140147
hnsw_fetch = max(limit * 5, 100)
@@ -148,11 +155,15 @@ async def retrieve_semantic_bm25_combined(
148155
# --- Parameter layout ---
149156
# $1 = query_emb_str (semantic arms)
150157
# $2 = bank_id
151-
# $3 = limit (BM25 LIMIT; semantic uses inlined hnsw_fetch literal)
152-
# $4 = bm25_text (only when tokens present)
153-
# $N = tags (N=4 when no tokens, N=5 when tokens present)
154-
# $M+ = tag_groups params (one per leaf, starting after tags param)
155-
tags_param_idx = 5 if tokens else 4
158+
# When tokens present:
159+
# $3 = limit (BM25 LIMIT; semantic uses inlined hnsw_fetch literal)
160+
# $4 = bm25_text
161+
# $5 = tags (if present)
162+
# $6+ = tag_groups params (one per leaf)
163+
# When no tokens ($3 is skipped — not included in params to avoid type inference gap):
164+
# $3 = tags (if present)
165+
# $4+ = tag_groups params (one per leaf)
166+
tags_param_idx = 5 if tokens else 3
156167
tags_clause = build_tags_where_clause_simple(tags, tags_param_idx, match=tags_match)
157168

158169
# tag_groups params start immediately after the tags param slot
@@ -222,9 +233,10 @@ async def retrieve_semantic_bm25_combined(
222233

223234
query = "\nUNION ALL\n".join(arms)
224235

225-
params: list = [query_emb_str, bank_id, limit]
236+
params: list = [query_emb_str, bank_id]
226237
if tokens:
227-
params.append(bm25_text_param)
238+
params.append(limit) # $3: BM25 LIMIT (only referenced when tokens are present)
239+
params.append(bm25_text_param) # $4
228240
if tags:
229241
params.append(tags)
230242
params.extend(groups_params)

0 commit comments

Comments
 (0)