From f95b278d07e578c538621148bd1abbcbb1d3f12a Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Mihai=20Capot=C4=83?= Date: Fri, 19 Sep 2025 22:59:09 -0700 Subject: [PATCH] Support DataLoader in DynamicVamana.build Includes: * Unit test * Python binding * Python test Also includes: * Unit test for Vamana.build * Test utilities for distance functors * Test fixture for discovering distances in reference results --- .gitignore | 1 + bindings/python/src/dynamic_vamana.cpp | 97 ++++++++++++++ bindings/python/tests/test_dynamic_vamana.py | 40 ++++++ include/svs/orchestrators/dynamic_vamana.h | 51 ++++++-- tests/CMakeLists.txt | 1 + tests/svs/orchestrators/dynamic_vamana.cpp | 125 +++++++++++++++++++ tests/svs/orchestrators/vamana.cpp | 94 +++++++++++++- tests/utils/test_dataset.h | 7 ++ tests/utils/vamana_reference.h | 28 +++++ 9 files changed, 432 insertions(+), 12 deletions(-) create mode 100644 tests/svs/orchestrators/dynamic_vamana.cpp diff --git a/.gitignore b/.gitignore index efaf738f6..c74567654 100644 --- a/.gitignore +++ b/.gitignore @@ -11,6 +11,7 @@ wheelhouse/ tags compile_commands.json .python-version +.vscode # Python related files __pycache__/ diff --git a/bindings/python/src/dynamic_vamana.cpp b/bindings/python/src/dynamic_vamana.cpp index e75c62500..a4c3b4984 100644 --- a/bindings/python/src/dynamic_vamana.cpp +++ b/bindings/python/src/dynamic_vamana.cpp @@ -31,6 +31,9 @@ #include #include +// fmt +#include + // stl #include @@ -88,6 +91,75 @@ void add_build_specialization(py::class_& index) { ); } +///// +///// Build from file (data loader) +///// + +template +svs::DynamicVamana dynamic_vamana_build_uncompressed( + const svs::index::vamana::VamanaBuildParameters& parameters, + svs::VectorDataLoader> data_loader, + std::span ids, + svs::DistanceType distance_type, + size_t num_threads +) { + return svs::DynamicVamana::build( + parameters, + std::move(data_loader), + ids, + distance_type, + num_threads + ); +} + +using DynamicVamanaBuildFromFileDispatcher = svs::lib::Dispatcher< + svs::DynamicVamana, + const svs::index::vamana::VamanaBuildParameters&, + UnspecializedVectorDataLoader, + std::span, + svs::DistanceType, + size_t>; + +DynamicVamanaBuildFromFileDispatcher dynamic_vamana_build_from_file_dispatcher() { + auto dispatcher = DynamicVamanaBuildFromFileDispatcher{}; + // Register uncompressed specializations (Dynamic dimensionality only, similar to tests) + for_standard_specializations([&]() { + // Only register when N is Dynamic (compile-time tag) - the pattern in static code + // registers all; here we directly register. + auto method = &dynamic_vamana_build_uncompressed; + dispatcher.register_target(svs::lib::dispatcher_build_docs, method); + }); + return dispatcher; +} + +svs::DynamicVamana dynamic_vamana_build_from_file( + const svs::index::vamana::VamanaBuildParameters& parameters, + UnspecializedVectorDataLoader data_loader, + const py_contiguous_array_t& py_ids, + svs::DistanceType distance_type, + size_t num_threads +) { + auto ids = std::span(py_ids.data(), py_ids.size()); + return dynamic_vamana_build_from_file_dispatcher().invoke( + parameters, std::move(data_loader), ids, distance_type, num_threads + ); +} + +constexpr std::string_view DYNAMIC_VAMANA_BUILD_FROM_FILE_DOCSTRING_PROTO = R"( +Construct a DynamicVamana index using a data loader, returning the index. + +Args: + parameters: Build parameters controlling graph construction. + data_loader: Data loader (e.g., an VectorDataLoader instance). + ids: Vector of ids to assign to each row in the dataset; must match dataset length and contain unique values. + distance_type: The similarity function to use for this index. + num_threads: Number of threads to use for index construction. Default: 1. + +Specializations compiled into the binary are listed below. + +{} # (Method listing auto-generated) +)"; + template void add_points( svs::DynamicVamana& index, @@ -301,6 +373,31 @@ void wrap(py::module& m) { // Index building. add_build_specialization(vamana); + // Build from file / data loader (dynamic docstring) + { + auto dispatcher = dynamic_vamana_build_from_file_dispatcher(); + std::string dynamic; + for (size_t i = 0; i < dispatcher.size(); ++i) { + fmt::format_to( + std::back_inserter(dynamic), + R"(Method {}:\n - data_loader: {}\n - distance: {}\n)", + i, + dispatcher.description(i, 1), + dispatcher.description(i, 3) + ); + } + vamana.def_static( + "build", + &dynamic_vamana_build_from_file, + py::arg("parameters"), + py::arg("data_loader"), + py::arg("ids"), + py::arg("distance_type"), + py::arg("num_threads") = 1, + fmt::format(DYNAMIC_VAMANA_BUILD_FROM_FILE_DOCSTRING_PROTO, dynamic).c_str() + ); + } + // Index modification. add_points_specialization(vamana); diff --git a/bindings/python/tests/test_dynamic_vamana.py b/bindings/python/tests/test_dynamic_vamana.py index 1c278d2a4..5580e0452 100644 --- a/bindings/python/tests/test_dynamic_vamana.py +++ b/bindings/python/tests/test_dynamic_vamana.py @@ -14,6 +14,7 @@ # unit under test import svs +import numpy as np # stdlib import unittest @@ -21,6 +22,7 @@ from tempfile import TemporaryDirectory # helpers +from .common import test_data_svs, test_data_dims, test_number_of_vectors, test_queries, test_groundtruth_l2 from .dynamic import ReferenceDataset class DynamicVamanaTester(unittest.TestCase): @@ -162,3 +164,41 @@ def test_loop(self): ) consolidate_count = 0 + def test_build_from_loader(self): + """Test building DynamicVamana using a VectorDataLoader and explicit IDs.""" + + loader = svs.VectorDataLoader(test_data_svs, svs.DataType.float32, dims = test_data_dims) + + # Sequential IDs + ids = np.arange(test_number_of_vectors, dtype = np.uint64) + + params = svs.VamanaBuildParameters( + graph_max_degree = 64, + window_size = 128, + alpha = 1.2, + ) + + index = svs.DynamicVamana.build( + params, + loader, + ids, + svs.DistanceType.L2, + num_threads = 2, + ) + + # Basic invariants + self.assertEqual(index.size, test_number_of_vectors) + self.assertEqual(index.dimensions, test_data_dims) + self.assertTrue(index.has_id(0)) + self.assertTrue(index.has_id(test_number_of_vectors - 1)) + + queries = svs.read_vecs(test_queries) + groundtruth = svs.read_vecs(test_groundtruth_l2) + k = 10 + index.search_window_size = 20 + I, D = index.search(queries, k) + self.assertEqual(I.shape[1], k) + recall = svs.k_recall_at(groundtruth, I, k, k) + # Recall in plausible range + self.assertTrue(0.5 < recall <= 1.0) + diff --git a/include/svs/orchestrators/dynamic_vamana.h b/include/svs/orchestrators/dynamic_vamana.h index a0d725781..afba65bbe 100644 --- a/include/svs/orchestrators/dynamic_vamana.h +++ b/include/svs/orchestrators/dynamic_vamana.h @@ -21,6 +21,9 @@ #include "svs/orchestrators/manager.h" #include "svs/orchestrators/vamana.h" +// stdlib +#include + namespace svs { /// @@ -258,25 +261,55 @@ class DynamicVamana : public manager::IndexManager { } // Building + /// + /// @brief Construct a DynamicVamana index from a data loader or dataset. + /// + /// @tparam QueryTypes The set of query element types supported by the resulting index. + /// @tparam DataLoader A data loader or dataset type. + /// @tparam Distance Distance functor or ``svs::DistanceType`` enum. + /// @tparam ThreadPoolProto Thread pool type or size_t). + /// + /// @param parameters Build parameters controlling graph construction. + /// @param data_loader Loader (or dataset) from which to obtain the data. + /// @param ids External IDs to assign to each row; must be unique and have length ``data.size()``. + /// @param distance Distance functor or enum. + /// @param threadpool_proto Thread pool or number of threads to use. + /// template < manager::QueryTypeDefinition QueryTypes, - data::ImmutableMemoryDataset Data, + typename DataLoader, typename Distance, typename ThreadPoolProto> static DynamicVamana build( const index::vamana::VamanaBuildParameters& parameters, - Data data, + DataLoader&& data_loader, std::span ids, Distance distance, ThreadPoolProto threadpool_proto ) { - return make_dynamic_vamana>( - parameters, - std::move(data), - ids, - std::move(distance), - threads::as_threadpool(std::move(threadpool_proto)) - ); + auto threadpool = threads::as_threadpool(std::move(threadpool_proto)); + auto data = svs::detail::dispatch_load(std::forward(data_loader), threadpool); + // If given a DistanceType enum, dispatch to a concrete distance functor first. + if constexpr (std::is_same_v, DistanceType>) { + auto dispatcher = DistanceDispatcher(distance); + return dispatcher([&](auto distance_function) { + return make_dynamic_vamana>( + parameters, + std::move(data), + ids, + std::move(distance_function), + std::move(threadpool) + ); + }); + } else { + return make_dynamic_vamana>( + parameters, + std::move(data), + ids, + std::move(distance), + std::move(threadpool) + ); + } } // Assembly diff --git a/tests/CMakeLists.txt b/tests/CMakeLists.txt index 4954d18a8..0d07d4f05 100644 --- a/tests/CMakeLists.txt +++ b/tests/CMakeLists.txt @@ -159,6 +159,7 @@ SET(INTEGRATION_TESTS ${TEST_DIR}/svs/index/vamana/dynamic_index_2.cpp # Higher level constructs ${TEST_DIR}/svs/orchestrators/vamana.cpp + ${TEST_DIR}/svs/orchestrators/dynamic_vamana.cpp # Integration Tests ${TEST_DIR}/integration/exhaustive.cpp ${TEST_DIR}/integration/vamana/index_search.cpp diff --git a/tests/svs/orchestrators/dynamic_vamana.cpp b/tests/svs/orchestrators/dynamic_vamana.cpp new file mode 100644 index 000000000..e25dea80e --- /dev/null +++ b/tests/svs/orchestrators/dynamic_vamana.cpp @@ -0,0 +1,125 @@ +/* + * Copyright 2025 Intel Corporation + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +// Orchestrator under test +#include "svs/orchestrators/dynamic_vamana.h" + +// Core helpers +#include "svs/core/recall.h" +#include "svs/core/data/simple.h" + +// Distance dispatcher +#include "svs/core/distance.h" + +// Test dataset utilities +#include "tests/utils/test_dataset.h" +#include "tests/utils/vamana_reference.h" +#include "tests/utils/utils.h" + +// Catch2 +#include "catch2/catch_test_macros.hpp" +#include "catch2/catch_approx.hpp" + +// STL +#include +#include + +namespace { + +template +void test_build( + DataLoaderT&& data_loader, + DistanceT distance = DistanceT() +) { + auto expected_result = test_dataset::vamana::expected_build_results( + distance, svsbenchmark::Uncompressed(svs::DataType::float32) + ); + auto build_params = expected_result.build_parameters_.value(); + auto queries = svs::data::SimpleData::load(test_dataset::query_file()); + auto groundtruth = test_dataset::load_groundtruth(distance); + + // Prepare IDs (0 .. N-1) + auto data = svs::data::SimpleData::load(test_dataset::data_svs_file()); + const size_t n = data.size(); + std::vector ids(n); + std::iota(ids.begin(), ids.end(), 0); + + size_t num_threads = 2; + svs::DynamicVamana index = svs::DynamicVamana::build( + build_params, + std::forward(data_loader), + ids, + distance, + num_threads + ); + + // Basic invariants + CATCH_REQUIRE(index.get_alpha() == Catch::Approx(build_params.alpha)); + CATCH_REQUIRE(index.get_construction_window_size() == build_params.window_size); + CATCH_REQUIRE(index.get_prune_to() == build_params.prune_to); + CATCH_REQUIRE(index.get_graph_max_degree() == build_params.graph_max_degree); + CATCH_REQUIRE(index.get_num_threads() == num_threads); + + // ID checks (spot sample) + CATCH_REQUIRE(index.has_id(0)); + CATCH_REQUIRE(index.has_id(n / 2)); + CATCH_REQUIRE(index.has_id(n - 1)); + + const double epsilon = 0.01; // allow small deviation + for (const auto& expected : expected_result.config_and_recall_) { + auto these_queries = test_dataset::get_test_set(queries, expected.num_queries_); + auto these_groundtruth = + test_dataset::get_test_set(groundtruth, expected.num_queries_); + index.set_search_parameters(expected.search_parameters_); + auto results = index.search(these_queries, expected.num_neighbors_); + double recall = svs::k_recall_at_n( + these_groundtruth, results, expected.num_neighbors_, expected.recall_k_ + ); + CATCH_REQUIRE(recall > expected.recall_ - epsilon); + CATCH_REQUIRE(recall < expected.recall_ + epsilon); + } +} + +} // namespace + +CATCH_TEST_CASE("DynamicVamana Build", "[managers][dynamic_vamana][build]") { + for (auto distance_enum : test_dataset::vamana::available_build_distances()) { + // SimpleData and distance functor. + { + std::string section_name = std::string("SimpleData ") + std::string(svs::name(distance_enum)); + CATCH_SECTION(section_name) { + svs::DistanceDispatcher dispatcher(distance_enum); + dispatcher([&](auto distance_functor) { + test_build( + svs::data::SimpleData::load(test_dataset::data_svs_file()), + distance_functor + ); + }); + } + } + + // VectorDataLoader and distance enum. + { + std::string section_name = std::string("VectorDataLoader ") + std::string(svs::name(distance_enum)); + CATCH_SECTION(section_name) { + test_build( + svs::VectorDataLoader(test_dataset::data_svs_file()), + distance_enum + ); + } + } + } +} diff --git a/tests/svs/orchestrators/vamana.cpp b/tests/svs/orchestrators/vamana.cpp index 4a5b61a7a..cf78d2d2f 100644 --- a/tests/svs/orchestrators/vamana.cpp +++ b/tests/svs/orchestrators/vamana.cpp @@ -14,12 +14,100 @@ * limitations under the License. */ -// SVS +// Orchestrator under test #include "svs/orchestrators/vamana.h" +// Core helpers +#include "svs/core/recall.h" +#include "svs/core/data/simple.h" + +// Distance dispatcher +#include "svs/core/distance.h" + +// Test dataset utilities +#include "tests/utils/test_dataset.h" +#include "tests/utils/vamana_reference.h" +#include "tests/utils/utils.h" + // Catch2 #include "catch2/catch_test_macros.hpp" +#include "catch2/catch_approx.hpp" + +// STL +#include +#include + +namespace { + +template +void test_build( + DataLoaderT&& data_loader, + DistanceT distance = DistanceT() +) { + auto expected_result = test_dataset::vamana::expected_build_results( + distance, svsbenchmark::Uncompressed(svs::DataType::float32) + ); + auto build_params = expected_result.build_parameters_.value(); + auto queries = svs::data::SimpleData::load(test_dataset::query_file()); + auto groundtruth = test_dataset::load_groundtruth(distance); + + size_t num_threads = 2; + svs::Vamana index = svs::Vamana::build( + build_params, + std::forward(data_loader), + distance, + num_threads + ); + + // Basic invariants + CATCH_REQUIRE(index.get_alpha() == Catch::Approx(build_params.alpha)); + CATCH_REQUIRE(index.get_construction_window_size() == build_params.window_size); + CATCH_REQUIRE(index.get_prune_to() == build_params.prune_to); + CATCH_REQUIRE(index.get_graph_max_degree() == build_params.graph_max_degree); + CATCH_REQUIRE(index.get_num_threads() == num_threads); + + const double epsilon = 0.01; // allow small deviation + for (const auto& expected : expected_result.config_and_recall_) { + auto these_queries = test_dataset::get_test_set(queries, expected.num_queries_); + auto these_groundtruth = + test_dataset::get_test_set(groundtruth, expected.num_queries_); + index.set_search_parameters(expected.search_parameters_); + auto results = index.search(these_queries, expected.num_neighbors_); + double recall = svs::k_recall_at_n( + these_groundtruth, results, expected.num_neighbors_, expected.recall_k_ + ); + CATCH_REQUIRE(recall > expected.recall_ - epsilon); + CATCH_REQUIRE(recall < expected.recall_ + epsilon); + } +} + +} // namespace + +CATCH_TEST_CASE("Vamana Build", "[managers][vamana][build]") { + for (auto distance_enum : test_dataset::vamana::available_build_distances()) { + // SimpleData and distance functor. + { + std::string section_name = std::string("SimpleData ") + std::string(svs::name(distance_enum)); + CATCH_SECTION(section_name) { + svs::DistanceDispatcher dispatcher(distance_enum); + dispatcher([&](auto distance_functor) { + test_build( + svs::data::SimpleData::load(test_dataset::data_svs_file()), + distance_functor + ); + }); + } + } -CATCH_TEST_CASE("Vamana Index", "[managers][vamana]") { - // Todo? + // VectorDataLoader and distance enum. + { + std::string section_name = std::string("VectorDataLoader ") + std::string(svs::name(distance_enum)); + CATCH_SECTION(section_name) { + test_build( + svs::VectorDataLoader(test_dataset::data_svs_file()), + distance_enum + ); + } + } + } } diff --git a/tests/utils/test_dataset.h b/tests/utils/test_dataset.h index aa55f316b..274c6c3ce 100644 --- a/tests/utils/test_dataset.h +++ b/tests/utils/test_dataset.h @@ -76,6 +76,13 @@ svs::graphs::SimpleBlockedGraph graph_blocked(); /// Helper to load the ground-truth for a given file. svs::data::SimpleData load_groundtruth(svs::DistanceType distance); +// Overload for distance functor types (DistanceL2, DistanceIP, DistanceCosineSimilarity, etc.). +// This enables test code that passes a concrete functor instance instead of the enum. +template +inline svs::data::SimpleData load_groundtruth(DistanceFunctor /*distance*/) { + return load_groundtruth(svs::distance_type_v); +} + /// /// @brief Return a reference to the last `queries_in_test_set` entries in `queries`. /// diff --git a/tests/utils/vamana_reference.h b/tests/utils/vamana_reference.h index 21b812d24..78e57bf0b 100644 --- a/tests/utils/vamana_reference.h +++ b/tests/utils/vamana_reference.h @@ -29,6 +29,7 @@ // stl #include +#include #include #include @@ -70,6 +71,13 @@ expected_build_results(svs::DistanceType distance, const T& dataset) { return result; } +template +svsbenchmark::vamana::ExpectedResult +expected_build_results(DistanceFunctor, const T& dataset) { + // Delegate to the DistanceType overload using the deduced distance type + return expected_build_results(svs::distance_type_v, dataset); +} + /// Return the only reference search for the requested parameters. /// Throws ANNException if the number of dataset is not equal to one. template @@ -126,4 +134,24 @@ load_dynamic_test_index(const Distance& distance) { ); } +// Return the set of distances that have reference build expectations for the +// uncompressed float32 dataset. Cached after first computation. +inline const std::set& available_build_distances() { + static const std::set distances = []() { + std::set ds; + const auto dataset = svsbenchmark::Uncompressed(svs::DataType::float32); + const auto& table = parse_expected(); + auto all_results = svs::lib::load>( + svs::lib::node_view_at(table, "vamana_test_build"), std::nullopt + ); + for (const auto& r : all_results) { + if (r.dataset_.match(dataset)) { + ds.insert(r.distance_); + } + } + return ds; + }(); + return distances; +} + } // namespace test_dataset::vamana