Skip to content

Commit 1b4ef68

Browse files
authored
feat(search): support vector index controls in hybrid search (#190)
1 parent e7ce493 commit 1b4ef68

5 files changed

Lines changed: 205 additions & 10 deletions

File tree

docs/sql.md

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -124,7 +124,10 @@ Positional arguments:
124124

125125
Named parameters:
126126
- `k` (BIGINT, default `10`): Number of results to return. Must be > 0.
127+
- `nprobs` (BIGINT, optional): Number of IVF partitions to probe when using a vector index. Must be > 0. Only affects IVF-based vector indices.
128+
- `refine_factor` (BIGINT, optional): Refine factor for the vector branch. Must be > 0.
127129
- `prefilter` (BOOLEAN, default `false`): If `true`, filters are applied before top-k selection.
130+
- `use_index` (BOOLEAN, default `true`): If `true`, allow ANN index usage for the vector branch when available. If `false`, the vector branch runs exact KNN.
128131
- `alpha` (FLOAT, default `0.5`): Vector/text mixing weight. Larger values weigh vector similarity more heavily.
129132
- `oversample_factor` (INTEGER, default `4`): Oversample factor for candidate generation. If provided, must be > 0.
130133

rust/ffi/search.rs

Lines changed: 26 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -191,9 +191,12 @@ pub unsafe extern "C" fn lance_create_hybrid_stream_ir(
191191
text_column: *const c_char,
192192
text_query: *const c_char,
193193
k: u64,
194+
nprobes: u64,
195+
refine_factor: u64,
194196
filter_ir: *const u8,
195197
filter_ir_len: usize,
196198
prefilter: u8,
199+
use_index: u8,
197200
alpha: f32,
198201
oversample_factor: u32,
199202
) -> *mut c_void {
@@ -205,9 +208,12 @@ pub unsafe extern "C" fn lance_create_hybrid_stream_ir(
205208
text_column,
206209
text_query,
207210
k,
211+
nprobes,
212+
refine_factor,
208213
filter_ir,
209214
filter_ir_len,
210215
prefilter,
216+
use_index,
211217
alpha,
212218
oversample_factor,
213219
) {
@@ -231,9 +237,12 @@ fn create_hybrid_stream_ir_inner(
231237
text_column: *const c_char,
232238
text_query: *const c_char,
233239
k: u64,
240+
nprobes: u64,
241+
refine_factor: u64,
234242
filter_ir: *const u8,
235243
filter_ir_len: usize,
236244
prefilter: u8,
245+
use_index: u8,
237246
alpha: f32,
238247
oversample_factor: u32,
239248
) -> FfiResult<StreamHandle> {
@@ -265,8 +274,11 @@ fn create_hybrid_stream_ir_inner(
265274
text_column,
266275
text_query,
267276
k_usize,
277+
nprobes,
278+
refine_factor,
268279
filter,
269280
prefilter,
281+
use_index,
270282
alpha,
271283
oversample_factor,
272284
)?;
@@ -282,8 +294,11 @@ fn create_hybrid_batch(
282294
text_column: &str,
283295
text_query: &str,
284296
k: usize,
297+
nprobes: u64,
298+
refine_factor: u64,
285299
filter: Option<datafusion_expr::Expr>,
286300
prefilter: u8,
301+
use_index: u8,
287302
alpha: f32,
288303
oversample_factor: u32,
289304
) -> FfiResult<RecordBatch> {
@@ -303,7 +318,17 @@ fn create_hybrid_batch(
303318
format!("hybrid vector nearest: {err}"),
304319
)
305320
})?;
306-
vector_scan.use_index(false);
321+
if nprobes != 0 {
322+
let nprobes_usize = nonzero_u64_to_usize(nprobes, "nprobes")?;
323+
vector_scan.nprobes(nprobes_usize);
324+
}
325+
if refine_factor != 0 {
326+
let refine_factor_u32: u32 = refine_factor.try_into().map_err(|_| {
327+
FfiError::new(ErrorCode::InvalidArgument, "refine_factor must fit in u32")
328+
})?;
329+
vector_scan.refine(refine_factor_u32);
330+
}
331+
vector_scan.use_index(use_index != 0);
307332
vector_scan.with_row_id();
308333
vector_scan.disable_scoring_autoprojection();
309334
vector_scan

src/include/lance_ffi.hpp

Lines changed: 6 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -273,13 +273,12 @@ void *lance_create_fts_stream_ir(void *dataset, const char *text_column,
273273
uint8_t prefilter);
274274

275275
void *lance_get_hybrid_schema(void *dataset);
276-
void *lance_create_hybrid_stream_ir(void *dataset, const char *vector_column,
277-
const float *query_values, size_t query_len,
278-
const char *text_column,
279-
const char *text_query, uint64_t k,
280-
const uint8_t *filter_ir,
281-
size_t filter_ir_len, uint8_t prefilter,
282-
float alpha, uint32_t oversample_factor);
276+
void *lance_create_hybrid_stream_ir(
277+
void *dataset, const char *vector_column, const float *query_values,
278+
size_t query_len, const char *text_column, const char *text_query,
279+
uint64_t k, uint64_t nprobes, uint64_t refine_factor,
280+
const uint8_t *filter_ir, size_t filter_ir_len, uint8_t prefilter,
281+
uint8_t use_index, float alpha, uint32_t oversample_factor);
283282

284283
// Index DDL / metadata
285284
int32_t lance_dataset_create_index(void *dataset, const char *index_name,

src/lance_search.cpp

Lines changed: 50 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -814,6 +814,9 @@ struct LanceSearchBindData : public TableFunctionData {
814814
string vector_column;
815815
vector<float> vector_query;
816816
string text_query;
817+
uint64_t nprobes = 0;
818+
uint64_t refine_factor = 0;
819+
bool use_index = true;
817820
float alpha = 0.5F;
818821
uint32_t oversample_factor = 4;
819822

@@ -882,8 +885,9 @@ static bool LanceSearchLoadNextBatch(LanceSearchLocalState &local_state,
882885
bind_data.dataset, bind_data.vector_column.c_str(),
883886
bind_data.vector_query.data(), bind_data.vector_query.size(),
884887
bind_data.text_column.c_str(), bind_data.text_query.c_str(),
885-
bind_data.k, ir, NumericCast<size_t>(ir_len),
886-
bind_data.prefilter ? 1 : 0, bind_data.alpha,
888+
bind_data.k, bind_data.nprobes, bind_data.refine_factor, ir,
889+
NumericCast<size_t>(ir_len), bind_data.prefilter ? 1 : 0,
890+
bind_data.use_index ? 1 : 0, bind_data.alpha,
887891
bind_data.oversample_factor);
888892
};
889893

@@ -1073,13 +1077,51 @@ LanceHybridBind(ClientContext &context, TableFunctionBindInput &input,
10731077
}
10741078
result->k = NumericCast<uint64_t>(k_val);
10751079

1080+
bool has_nprobes = false;
1081+
int64_t nprobes_val = 0;
1082+
auto nprobes_named = input.named_parameters.find("nprobs");
1083+
if (nprobes_named != input.named_parameters.end() &&
1084+
!nprobes_named->second.IsNull()) {
1085+
has_nprobes = true;
1086+
nprobes_val = nprobes_named->second.DefaultCastAs(LogicalType::BIGINT)
1087+
.GetValue<int64_t>();
1088+
}
1089+
if (has_nprobes && nprobes_val <= 0) {
1090+
throw InvalidInputException("lance_hybrid_search requires nprobs > 0");
1091+
}
1092+
result->nprobes = has_nprobes ? NumericCast<uint64_t>(nprobes_val) : 0;
1093+
1094+
bool has_refine_factor = false;
1095+
int64_t refine_factor_val = 0;
1096+
auto refine_factor_named = input.named_parameters.find("refine_factor");
1097+
if (refine_factor_named != input.named_parameters.end() &&
1098+
!refine_factor_named->second.IsNull()) {
1099+
has_refine_factor = true;
1100+
refine_factor_val =
1101+
refine_factor_named->second.DefaultCastAs(LogicalType::BIGINT)
1102+
.GetValue<int64_t>();
1103+
}
1104+
if (has_refine_factor && refine_factor_val <= 0) {
1105+
throw InvalidInputException(
1106+
"lance_hybrid_search requires refine_factor > 0");
1107+
}
1108+
result->refine_factor =
1109+
has_refine_factor ? NumericCast<uint64_t>(refine_factor_val) : 0;
1110+
10761111
auto prefilter_named = input.named_parameters.find("prefilter");
10771112
if (prefilter_named != input.named_parameters.end() &&
10781113
!prefilter_named->second.IsNull()) {
10791114
result->prefilter =
10801115
prefilter_named->second.DefaultCastAs(LogicalType::BOOLEAN)
10811116
.GetValue<bool>();
10821117
}
1118+
auto use_index_named = input.named_parameters.find("use_index");
1119+
if (use_index_named != input.named_parameters.end() &&
1120+
!use_index_named->second.IsNull()) {
1121+
result->use_index =
1122+
use_index_named->second.DefaultCastAs(LogicalType::BOOLEAN)
1123+
.GetValue<bool>();
1124+
}
10831125

10841126
auto alpha_named = input.named_parameters.find("alpha");
10851127
if (alpha_named != input.named_parameters.end() &&
@@ -1266,6 +1308,9 @@ LanceSearchBindToString(const LanceSearchBindData &bind_data) {
12661308
result["Lance Text Column"] = bind_data.text_column;
12671309
result["Lance Vector Query Dim"] = to_string(bind_data.vector_query.size());
12681310
result["Lance Text Query"] = bind_data.text_query;
1311+
result["Lance Nprobes"] = to_string(bind_data.nprobes);
1312+
result["Lance Refine Factor"] = to_string(bind_data.refine_factor);
1313+
result["Lance Use Index"] = bind_data.use_index ? "true" : "false";
12691314
result["Lance Alpha"] = to_string(bind_data.alpha);
12701315
result["Lance Oversample Factor"] = to_string(bind_data.oversample_factor);
12711316
}
@@ -1320,7 +1365,10 @@ static void RegisterLanceFtsSearch(ExtensionLoader &loader) {
13201365
static void RegisterLanceHybridSearch(ExtensionLoader &loader) {
13211366
auto configure = [](TableFunction &fun) {
13221367
fun.named_parameters["k"] = LogicalType::BIGINT;
1368+
fun.named_parameters["nprobs"] = LogicalType::BIGINT;
1369+
fun.named_parameters["refine_factor"] = LogicalType::BIGINT;
13231370
fun.named_parameters["prefilter"] = LogicalType::BOOLEAN;
1371+
fun.named_parameters["use_index"] = LogicalType::BOOLEAN;
13241372
fun.named_parameters["alpha"] = LogicalType::FLOAT;
13251373
fun.named_parameters["oversample_factor"] = LogicalType::INTEGER;
13261374
fun.projection_pushdown = true;

test/sql/search_functions.test

Lines changed: 120 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,32 @@ SELECT * FROM lance_vector_search('test/data/test_data.lance', 'vec', [1.0]::FLO
3434
----
3535
Invalid Input Error: lance_vector_search requires refine_factor > 0
3636

37+
# Hybrid search rejects non-positive nprobs
38+
statement error
39+
SELECT * FROM lance_hybrid_search(
40+
'test/data/search_test_data.lance',
41+
'vec',
42+
[0.0, 0.0, 0.0, 0.0]::FLOAT[4],
43+
'text',
44+
'puppy',
45+
nprobs = 0
46+
)
47+
----
48+
Invalid Input Error: lance_hybrid_search requires nprobs > 0
49+
50+
# Hybrid search rejects non-positive refine_factor
51+
statement error
52+
SELECT * FROM lance_hybrid_search(
53+
'test/data/search_test_data.lance',
54+
'vec',
55+
[0.0, 0.0, 0.0, 0.0]::FLOAT[4],
56+
'text',
57+
'puppy',
58+
refine_factor = 0
59+
)
60+
----
61+
Invalid Input Error: lance_hybrid_search requires refine_factor > 0
62+
3763
# Sanity: dataset is readable
3864
query I
3965
SELECT count(*) FROM 'test/data/search_test_data.lance'
@@ -308,6 +334,100 @@ ORDER BY _hybrid_score DESC
308334
3
309335
2
310336

337+
# Hybrid search accepts vector search controls in flat mode
338+
query I
339+
SELECT id
340+
FROM lance_hybrid_search(
341+
'test/data/search_test_data.lance',
342+
'vec',
343+
[0.0, 0.0, 0.0, 0.0]::FLOAT[4],
344+
'text',
345+
'puppy',
346+
k = 3,
347+
nprobs = 2,
348+
refine_factor = 2,
349+
use_index = false,
350+
alpha = 0.5,
351+
oversample_factor = 4
352+
)
353+
ORDER BY _hybrid_score DESC
354+
----
355+
1
356+
3
357+
2
358+
359+
statement ok
360+
COPY (SELECT * FROM 'test/data/search_test_data.lance')
361+
TO 'test/.tmp/search_hybrid_indexed.lance' (FORMAT lance, mode 'overwrite');
362+
363+
statement ok
364+
CREATE INDEX vec_idx ON 'test/.tmp/search_hybrid_indexed.lance' (vec)
365+
USING IVF_FLAT WITH (num_partitions=1, metric_type='l2');
366+
367+
statement ok
368+
CREATE INDEX text_idx ON 'test/.tmp/search_hybrid_indexed.lance' (text)
369+
USING INVERTED;
370+
371+
# Hybrid search exposes vector search controls in EXPLAIN when using an index
372+
query II
373+
EXPLAIN (FORMAT JSON)
374+
SELECT id
375+
FROM lance_hybrid_search(
376+
'test/.tmp/search_hybrid_indexed.lance',
377+
'vec',
378+
[0.0, 0.0, 0.0, 0.0]::FLOAT[4],
379+
'text',
380+
'puppy',
381+
k = 3,
382+
nprobs = 1,
383+
refine_factor = 2,
384+
use_index = true,
385+
alpha = 0.5,
386+
oversample_factor = 4
387+
);
388+
----
389+
physical_plan <REGEX>:[\s\S]*"Lance Search Mode": "hybrid"[\s\S]*"Lance Nprobes": "1"[\s\S]*"Lance Refine Factor": "2"[\s\S]*"Lance Use Index": "true"[\s\S]*
390+
391+
query II
392+
EXPLAIN (FORMAT JSON)
393+
SELECT id
394+
FROM lance_hybrid_search(
395+
'test/.tmp/search_hybrid_indexed.lance',
396+
'vec',
397+
[0.0, 0.0, 0.0, 0.0]::FLOAT[4],
398+
'text',
399+
'puppy',
400+
k = 3,
401+
nprobs = 1,
402+
refine_factor = 2,
403+
use_index = false,
404+
alpha = 0.5,
405+
oversample_factor = 4
406+
);
407+
----
408+
physical_plan <REGEX>:[\s\S]*"Lance Search Mode": "hybrid"[\s\S]*"Lance Use Index": "false"[\s\S]*
409+
410+
query I
411+
SELECT id
412+
FROM lance_hybrid_search(
413+
'test/.tmp/search_hybrid_indexed.lance',
414+
'vec',
415+
[0.0, 0.0, 0.0, 0.0]::FLOAT[4],
416+
'text',
417+
'puppy',
418+
k = 3,
419+
nprobs = 1,
420+
refine_factor = 2,
421+
use_index = true,
422+
alpha = 0.5,
423+
oversample_factor = 4
424+
)
425+
ORDER BY _hybrid_score DESC
426+
----
427+
1
428+
3
429+
2
430+
311431
# ===================================================================
312432
# Table name resolution: search functions support catalog table names
313433
# (Users can ATTACH a Lance directory and use table names instead of

0 commit comments

Comments
 (0)