Skip to content

Commit 3d00216

Browse files
authored
Support DataLoader in DynamicVamana.build (#186)
Includes: * Unit test * Python binding * Python test Also includes: * Unit test for Vamana.build * Test utilities for distance functors * Test fixture for discovering distances in reference results
1 parent b76d098 commit 3d00216

File tree

9 files changed

+432
-12
lines changed

9 files changed

+432
-12
lines changed

.gitignore

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@ wheelhouse/
1111
tags
1212
compile_commands.json
1313
.python-version
14+
.vscode
1415

1516
# Python related files
1617
__pycache__/

bindings/python/src/dynamic_vamana.cpp

Lines changed: 97 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,9 @@
3131
#include <pybind11/pybind11.h>
3232
#include <pybind11/stl.h>
3333

34+
// fmt
35+
#include <fmt/format.h>
36+
3437
// stl
3538
#include <span>
3639

@@ -88,6 +91,75 @@ void add_build_specialization(py::class_<svs::DynamicVamana>& index) {
8891
);
8992
}
9093

94+
/////
95+
///// Build from file (data loader)
96+
/////
97+
98+
template <typename Q, typename T, typename Dist, size_t N>
99+
svs::DynamicVamana dynamic_vamana_build_uncompressed(
100+
const svs::index::vamana::VamanaBuildParameters& parameters,
101+
svs::VectorDataLoader<T, N, RebindAllocator<T>> data_loader,
102+
std::span<const size_t> ids,
103+
svs::DistanceType distance_type,
104+
size_t num_threads
105+
) {
106+
return svs::DynamicVamana::build<Q>(
107+
parameters,
108+
std::move(data_loader),
109+
ids,
110+
distance_type,
111+
num_threads
112+
);
113+
}
114+
115+
using DynamicVamanaBuildFromFileDispatcher = svs::lib::Dispatcher<
116+
svs::DynamicVamana,
117+
const svs::index::vamana::VamanaBuildParameters&,
118+
UnspecializedVectorDataLoader,
119+
std::span<const size_t>,
120+
svs::DistanceType,
121+
size_t>;
122+
123+
DynamicVamanaBuildFromFileDispatcher dynamic_vamana_build_from_file_dispatcher() {
124+
auto dispatcher = DynamicVamanaBuildFromFileDispatcher{};
125+
// Register uncompressed specializations (Dynamic dimensionality only, similar to tests)
126+
for_standard_specializations([&]<typename Q, typename T, typename D, size_t N>() {
127+
// Only register when N is Dynamic (compile-time tag) - the pattern in static code
128+
// registers all; here we directly register.
129+
auto method = &dynamic_vamana_build_uncompressed<Q, T, D, N>;
130+
dispatcher.register_target(svs::lib::dispatcher_build_docs, method);
131+
});
132+
return dispatcher;
133+
}
134+
135+
svs::DynamicVamana dynamic_vamana_build_from_file(
136+
const svs::index::vamana::VamanaBuildParameters& parameters,
137+
UnspecializedVectorDataLoader data_loader,
138+
const py_contiguous_array_t<size_t>& py_ids,
139+
svs::DistanceType distance_type,
140+
size_t num_threads
141+
) {
142+
auto ids = std::span<const size_t>(py_ids.data(), py_ids.size());
143+
return dynamic_vamana_build_from_file_dispatcher().invoke(
144+
parameters, std::move(data_loader), ids, distance_type, num_threads
145+
);
146+
}
147+
148+
constexpr std::string_view DYNAMIC_VAMANA_BUILD_FROM_FILE_DOCSTRING_PROTO = R"(
149+
Construct a DynamicVamana index using a data loader, returning the index.
150+
151+
Args:
152+
parameters: Build parameters controlling graph construction.
153+
data_loader: Data loader (e.g., an VectorDataLoader instance).
154+
ids: Vector of ids to assign to each row in the dataset; must match dataset length and contain unique values.
155+
distance_type: The similarity function to use for this index.
156+
num_threads: Number of threads to use for index construction. Default: 1.
157+
158+
Specializations compiled into the binary are listed below.
159+
160+
{} # (Method listing auto-generated)
161+
)";
162+
91163
template <typename ElementType>
92164
void add_points(
93165
svs::DynamicVamana& index,
@@ -301,6 +373,31 @@ void wrap(py::module& m) {
301373
// Index building.
302374
add_build_specialization<float>(vamana);
303375

376+
// Build from file / data loader (dynamic docstring)
377+
{
378+
auto dispatcher = dynamic_vamana_build_from_file_dispatcher();
379+
std::string dynamic;
380+
for (size_t i = 0; i < dispatcher.size(); ++i) {
381+
fmt::format_to(
382+
std::back_inserter(dynamic),
383+
R"(Method {}:\n - data_loader: {}\n - distance: {}\n)",
384+
i,
385+
dispatcher.description(i, 1),
386+
dispatcher.description(i, 3)
387+
);
388+
}
389+
vamana.def_static(
390+
"build",
391+
&dynamic_vamana_build_from_file,
392+
py::arg("parameters"),
393+
py::arg("data_loader"),
394+
py::arg("ids"),
395+
py::arg("distance_type"),
396+
py::arg("num_threads") = 1,
397+
fmt::format(DYNAMIC_VAMANA_BUILD_FROM_FILE_DOCSTRING_PROTO, dynamic).c_str()
398+
);
399+
}
400+
304401
// Index modification.
305402
add_points_specialization<float>(vamana);
306403

bindings/python/tests/test_dynamic_vamana.py

Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,13 +14,15 @@
1414

1515
# unit under test
1616
import svs
17+
import numpy as np
1718

1819
# stdlib
1920
import unittest
2021
import os
2122
from tempfile import TemporaryDirectory
2223

2324
# helpers
25+
from .common import test_data_svs, test_data_dims, test_number_of_vectors, test_queries, test_groundtruth_l2
2426
from .dynamic import ReferenceDataset
2527

2628
class DynamicVamanaTester(unittest.TestCase):
@@ -162,3 +164,41 @@ def test_loop(self):
162164
)
163165
consolidate_count = 0
164166

167+
def test_build_from_loader(self):
168+
"""Test building DynamicVamana using a VectorDataLoader and explicit IDs."""
169+
170+
loader = svs.VectorDataLoader(test_data_svs, svs.DataType.float32, dims = test_data_dims)
171+
172+
# Sequential IDs
173+
ids = np.arange(test_number_of_vectors, dtype = np.uint64)
174+
175+
params = svs.VamanaBuildParameters(
176+
graph_max_degree = 64,
177+
window_size = 128,
178+
alpha = 1.2,
179+
)
180+
181+
index = svs.DynamicVamana.build(
182+
params,
183+
loader,
184+
ids,
185+
svs.DistanceType.L2,
186+
num_threads = 2,
187+
)
188+
189+
# Basic invariants
190+
self.assertEqual(index.size, test_number_of_vectors)
191+
self.assertEqual(index.dimensions, test_data_dims)
192+
self.assertTrue(index.has_id(0))
193+
self.assertTrue(index.has_id(test_number_of_vectors - 1))
194+
195+
queries = svs.read_vecs(test_queries)
196+
groundtruth = svs.read_vecs(test_groundtruth_l2)
197+
k = 10
198+
index.search_window_size = 20
199+
I, D = index.search(queries, k)
200+
self.assertEqual(I.shape[1], k)
201+
recall = svs.k_recall_at(groundtruth, I, k, k)
202+
# Recall in plausible range
203+
self.assertTrue(0.5 < recall <= 1.0)
204+

include/svs/orchestrators/dynamic_vamana.h

Lines changed: 42 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,9 @@
2121
#include "svs/orchestrators/manager.h"
2222
#include "svs/orchestrators/vamana.h"
2323

24+
// stdlib
25+
#include <type_traits>
26+
2427
namespace svs {
2528

2629
///
@@ -258,25 +261,55 @@ class DynamicVamana : public manager::IndexManager<DynamicVamanaInterface> {
258261
}
259262

260263
// Building
264+
///
265+
/// @brief Construct a DynamicVamana index from a data loader or dataset.
266+
///
267+
/// @tparam QueryTypes The set of query element types supported by the resulting index.
268+
/// @tparam DataLoader A data loader or dataset type.
269+
/// @tparam Distance Distance functor or ``svs::DistanceType`` enum.
270+
/// @tparam ThreadPoolProto Thread pool type or size_t).
271+
///
272+
/// @param parameters Build parameters controlling graph construction.
273+
/// @param data_loader Loader (or dataset) from which to obtain the data.
274+
/// @param ids External IDs to assign to each row; must be unique and have length ``data.size()``.
275+
/// @param distance Distance functor or enum.
276+
/// @param threadpool_proto Thread pool or number of threads to use.
277+
///
261278
template <
262279
manager::QueryTypeDefinition QueryTypes,
263-
data::ImmutableMemoryDataset Data,
280+
typename DataLoader,
264281
typename Distance,
265282
typename ThreadPoolProto>
266283
static DynamicVamana build(
267284
const index::vamana::VamanaBuildParameters& parameters,
268-
Data data,
285+
DataLoader&& data_loader,
269286
std::span<const size_t> ids,
270287
Distance distance,
271288
ThreadPoolProto threadpool_proto
272289
) {
273-
return make_dynamic_vamana<manager::as_typelist<QueryTypes>>(
274-
parameters,
275-
std::move(data),
276-
ids,
277-
std::move(distance),
278-
threads::as_threadpool(std::move(threadpool_proto))
279-
);
290+
auto threadpool = threads::as_threadpool(std::move(threadpool_proto));
291+
auto data = svs::detail::dispatch_load(std::forward<DataLoader>(data_loader), threadpool);
292+
// If given a DistanceType enum, dispatch to a concrete distance functor first.
293+
if constexpr (std::is_same_v<std::decay_t<Distance>, DistanceType>) {
294+
auto dispatcher = DistanceDispatcher(distance);
295+
return dispatcher([&](auto distance_function) {
296+
return make_dynamic_vamana<manager::as_typelist<QueryTypes>>(
297+
parameters,
298+
std::move(data),
299+
ids,
300+
std::move(distance_function),
301+
std::move(threadpool)
302+
);
303+
});
304+
} else {
305+
return make_dynamic_vamana<manager::as_typelist<QueryTypes>>(
306+
parameters,
307+
std::move(data),
308+
ids,
309+
std::move(distance),
310+
std::move(threadpool)
311+
);
312+
}
280313
}
281314

282315
// Assembly

tests/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -159,6 +159,7 @@ SET(INTEGRATION_TESTS
159159
${TEST_DIR}/svs/index/vamana/dynamic_index_2.cpp
160160
# Higher level constructs
161161
${TEST_DIR}/svs/orchestrators/vamana.cpp
162+
${TEST_DIR}/svs/orchestrators/dynamic_vamana.cpp
162163
# Integration Tests
163164
${TEST_DIR}/integration/exhaustive.cpp
164165
${TEST_DIR}/integration/vamana/index_search.cpp
Lines changed: 125 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,125 @@
1+
/*
2+
* Copyright 2025 Intel Corporation
3+
*
4+
* Licensed under the Apache License, Version 2.0 (the "License");
5+
* you may not use this file except in compliance with the License.
6+
* You may obtain a copy of the License at
7+
*
8+
* http://www.apache.org/licenses/LICENSE-2.0
9+
*
10+
* Unless required by applicable law or agreed to in writing, software
11+
* distributed under the License is distributed on an "AS IS" BASIS,
12+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
* See the License for the specific language governing permissions and
14+
* limitations under the License.
15+
*/
16+
17+
// Orchestrator under test
18+
#include "svs/orchestrators/dynamic_vamana.h"
19+
20+
// Core helpers
21+
#include "svs/core/recall.h"
22+
#include "svs/core/data/simple.h"
23+
24+
// Distance dispatcher
25+
#include "svs/core/distance.h"
26+
27+
// Test dataset utilities
28+
#include "tests/utils/test_dataset.h"
29+
#include "tests/utils/vamana_reference.h"
30+
#include "tests/utils/utils.h"
31+
32+
// Catch2
33+
#include "catch2/catch_test_macros.hpp"
34+
#include "catch2/catch_approx.hpp"
35+
36+
// STL
37+
#include <vector>
38+
#include <numeric>
39+
40+
namespace {
41+
42+
template <typename DataLoaderT, typename DistanceT>
43+
void test_build(
44+
DataLoaderT&& data_loader,
45+
DistanceT distance = DistanceT()
46+
) {
47+
auto expected_result = test_dataset::vamana::expected_build_results(
48+
distance, svsbenchmark::Uncompressed(svs::DataType::float32)
49+
);
50+
auto build_params = expected_result.build_parameters_.value();
51+
auto queries = svs::data::SimpleData<float>::load(test_dataset::query_file());
52+
auto groundtruth = test_dataset::load_groundtruth(distance);
53+
54+
// Prepare IDs (0 .. N-1)
55+
auto data = svs::data::SimpleData<float>::load(test_dataset::data_svs_file());
56+
const size_t n = data.size();
57+
std::vector<size_t> ids(n);
58+
std::iota(ids.begin(), ids.end(), 0);
59+
60+
size_t num_threads = 2;
61+
svs::DynamicVamana index = svs::DynamicVamana::build<float>(
62+
build_params,
63+
std::forward<DataLoaderT>(data_loader),
64+
ids,
65+
distance,
66+
num_threads
67+
);
68+
69+
// Basic invariants
70+
CATCH_REQUIRE(index.get_alpha() == Catch::Approx(build_params.alpha));
71+
CATCH_REQUIRE(index.get_construction_window_size() == build_params.window_size);
72+
CATCH_REQUIRE(index.get_prune_to() == build_params.prune_to);
73+
CATCH_REQUIRE(index.get_graph_max_degree() == build_params.graph_max_degree);
74+
CATCH_REQUIRE(index.get_num_threads() == num_threads);
75+
76+
// ID checks (spot sample)
77+
CATCH_REQUIRE(index.has_id(0));
78+
CATCH_REQUIRE(index.has_id(n / 2));
79+
CATCH_REQUIRE(index.has_id(n - 1));
80+
81+
const double epsilon = 0.01; // allow small deviation
82+
for (const auto& expected : expected_result.config_and_recall_) {
83+
auto these_queries = test_dataset::get_test_set(queries, expected.num_queries_);
84+
auto these_groundtruth =
85+
test_dataset::get_test_set(groundtruth, expected.num_queries_);
86+
index.set_search_parameters(expected.search_parameters_);
87+
auto results = index.search(these_queries, expected.num_neighbors_);
88+
double recall = svs::k_recall_at_n(
89+
these_groundtruth, results, expected.num_neighbors_, expected.recall_k_
90+
);
91+
CATCH_REQUIRE(recall > expected.recall_ - epsilon);
92+
CATCH_REQUIRE(recall < expected.recall_ + epsilon);
93+
}
94+
}
95+
96+
} // namespace
97+
98+
CATCH_TEST_CASE("DynamicVamana Build", "[managers][dynamic_vamana][build]") {
99+
for (auto distance_enum : test_dataset::vamana::available_build_distances()) {
100+
// SimpleData and distance functor.
101+
{
102+
std::string section_name = std::string("SimpleData ") + std::string(svs::name(distance_enum));
103+
CATCH_SECTION(section_name) {
104+
svs::DistanceDispatcher dispatcher(distance_enum);
105+
dispatcher([&](auto distance_functor) {
106+
test_build(
107+
svs::data::SimpleData<float>::load(test_dataset::data_svs_file()),
108+
distance_functor
109+
);
110+
});
111+
}
112+
}
113+
114+
// VectorDataLoader and distance enum.
115+
{
116+
std::string section_name = std::string("VectorDataLoader ") + std::string(svs::name(distance_enum));
117+
CATCH_SECTION(section_name) {
118+
test_build(
119+
svs::VectorDataLoader<float>(test_dataset::data_svs_file()),
120+
distance_enum
121+
);
122+
}
123+
}
124+
}
125+
}

0 commit comments

Comments
 (0)