diff --git a/c/lib.cpp b/c/lib.cpp index e3914674..4aa8d40f 100644 --- a/c/lib.cpp +++ b/c/lib.cpp @@ -478,7 +478,7 @@ USEARCH_EXPORT void usearch_exact_search( // metric_punned_t metric(dimensions, metric_kind_to_cpp(metric_kind), scalar_kind_to_cpp(scalar_kind)); executor_default_t executor(threads); - static exact_search_t search; + exact_search_t search; exact_search_results_t result = search( // (byte_t const*)dataset, dataset_count, dataset_stride, // (byte_t const*)queries, queries_count, queries_stride, // diff --git a/golang/lib.go b/golang/lib.go index 8fa39a23..06b09800 100644 --- a/golang/lib.go +++ b/golang/lib.go @@ -689,7 +689,7 @@ func (index *Index) SearchUnsafe(query unsafe.Pointer, limit uint) (keys []Key, // - numThreads: Number of threads to use (0 for auto-detection) func ExactSearch(dataset []float32, queries []float32, datasetSize uint, queryCount uint, datasetStride uint, queryStride uint, vectorDimensions uint, metric Metric, - maxResults uint, numThreads uint, resultKeysStride uint, resultDistancesStride uint) (keys []Key, distances []float32, err error) { + maxResults uint, numThreads uint) (keys []Key, distances []float32, err error) { if len(dataset) == 0 || len(queries) == 0 { return nil, nil, errors.New("dataset and queries cannot be empty") @@ -703,9 +703,15 @@ func ExactSearch(dataset []float32, queries []float32, datasetSize uint, queryCo if (len(queries) % int(vectorDimensions)) != 0 { return nil, nil, errors.New("queries length must be a multiple of the dimensions") } + if maxResults == 0 { + return nil, nil, errors.New("maxResults must be greater than zero") + } + + keys = make([]Key, queryCount*maxResults) + distances = make([]float32, queryCount*maxResults) + resultKeysStride := uint32(maxResults * 8) // int64 - 8 bytes + resultDistancesStride := uint32(maxResults * 4) // float32 - 4 bytes - keys = make([]Key, maxResults) - distances = make([]float32, maxResults) var errorMessage *C.char C.usearch_exact_search(unsafe.Pointer(&dataset[0]), C.size_t(datasetSize), C.size_t(datasetStride), unsafe.Pointer(&queries[0]), C.size_t(queryCount), C.size_t(queryStride), C.usearch_scalar_f32_k, C.size_t(vectorDimensions), metric.CValue(), C.size_t(maxResults), C.size_t(numThreads), @@ -718,8 +724,6 @@ func ExactSearch(dataset []float32, queries []float32, datasetSize uint, queryCo return nil, nil, errors.New(C.GoString(errorMessage)) } - keys = keys[:maxResults] - distances = distances[:maxResults] return keys, distances, nil } @@ -736,17 +740,19 @@ func ExactSearch(dataset []float32, queries []float32, datasetSize uint, queryCo // For contiguous data, use vectorDimensions * sizeof(element_type). func ExactSearchUnsafe(dataset unsafe.Pointer, queries unsafe.Pointer, datasetSize uint, queryCount uint, datasetStride uint, queryStride uint, vectorDimensions uint, metric Metric, quantization Quantization, - maxResults uint, numThreads uint, resultKeysStride uint, resultDistancesStride uint) (keys []Key, distances []float32, err error) { + maxResults uint, numThreads uint) (keys []Key, distances []float32, err error) { if dataset == nil || queries == nil { return nil, nil, errors.New("dataset and queries pointers cannot be nil") } - if vectorDimensions == 0 || datasetSize == 0 || queryCount == 0 { - return nil, nil, errors.New("dimensions and sizes must be greater than zero") + if vectorDimensions == 0 || datasetSize == 0 || queryCount == 0 || maxResults == 0 { + return nil, nil, errors.New("dimensions, query count, max results and sizes must be greater than zero") } - keys = make([]Key, maxResults) - distances = make([]float32, maxResults) + keys = make([]Key, queryCount*maxResults) + distances = make([]float32, queryCount*maxResults) + resultKeysStride := uint32(maxResults * 8) // int64 - 8 bytes + resultDistancesStride := uint32(maxResults * 4) // float32 - 4 bytes var errorMessage *C.char C.usearch_exact_search(dataset, C.size_t(datasetSize), C.size_t(datasetStride), queries, C.size_t(queryCount), C.size_t(queryStride), quantization.CValue(), C.size_t(vectorDimensions), metric.CValue(), C.size_t(maxResults), C.size_t(numThreads), @@ -757,8 +763,6 @@ func ExactSearchUnsafe(dataset unsafe.Pointer, queries unsafe.Pointer, datasetSi return nil, nil, errors.New(C.GoString(errorMessage)) } - keys = keys[:maxResults] - distances = distances[:maxResults] return keys, distances, nil } @@ -850,7 +854,7 @@ func DistanceI8(vec1 []int8, vec2 []int8, vectorDimensions uint, metric Metric) // For contiguous int8 data, use vectorDimensions * 1 byte. func ExactSearchI8(dataset []int8, queries []int8, datasetSize uint, queryCount uint, datasetStride uint, queryStride uint, vectorDimensions uint, metric Metric, - maxResults uint, numThreads uint, resultKeysStride uint, resultDistancesStride uint) (keys []Key, distances []float32, err error) { + maxResults uint, numThreads uint) (keys []Key, distances []float32, err error) { if len(dataset) == 0 || len(queries) == 0 { return nil, nil, errors.New("dataset and queries cannot be empty") @@ -858,9 +862,13 @@ func ExactSearchI8(dataset []int8, queries []int8, datasetSize uint, queryCount if vectorDimensions == 0 { return nil, nil, errors.New("dimensions must be greater than zero") } - - keys = make([]Key, maxResults) - distances = make([]float32, maxResults) + if maxResults == 0 { + return nil, nil, errors.New("maxResults must be greater than zero") + } + keys = make([]Key, queryCount*maxResults) + distances = make([]float32, queryCount*maxResults) + resultKeysStride := uint32(maxResults * 8) // int64 - 8 bytes + resultDistancesStride := uint32(maxResults * 4) // float32 - 4 bytes var errorMessage *C.char C.usearch_exact_search(unsafe.Pointer(&dataset[0]), C.size_t(datasetSize), C.size_t(datasetStride), unsafe.Pointer(&queries[0]), C.size_t(queryCount), C.size_t(queryStride), C.usearch_scalar_i8_k, C.size_t(vectorDimensions), metric.CValue(), C.size_t(maxResults), C.size_t(numThreads), @@ -872,8 +880,6 @@ func ExactSearchI8(dataset []int8, queries []int8, datasetSize uint, queryCount if errorMessage != nil { return nil, nil, errors.New(C.GoString(errorMessage)) } - keys = keys[:maxResults] - distances = distances[:maxResults] return keys, distances, nil } diff --git a/golang/lib_test.go b/golang/lib_test.go index 6409dce6..279f1648 100644 --- a/golang/lib_test.go +++ b/golang/lib_test.go @@ -30,19 +30,19 @@ func createTestIndex(t *testing.T, dimensions uint, quantization Quantization) * } func generateTestVector(dimensions uint) []float32 { - vector := make([]float32, dimensions) - for i := uint(0); i < dimensions; i++ { - vector[i] = float32(i) + 0.1 - } - return vector + vector := make([]float32, dimensions) + for i := uint(0); i < dimensions; i++ { + vector[i] = float32(i) + 0.1 + } + return vector } func generateTestVectorI8(dimensions uint) []int8 { - vector := make([]int8, dimensions) - for i := uint(0); i < dimensions; i++ { - vector[i] = int8((i % 127) + 1) - } - return vector + vector := make([]int8, dimensions) + for i := uint(0); i < dimensions; i++ { + vector[i] = int8((i % 127) + 1) + } + return vector } func populateIndex(t *testing.T, index *Index, vectorCount int) [][]float32 { @@ -57,16 +57,16 @@ func populateIndex(t *testing.T, index *Index, vectorCount int) [][]float32 { t.Fatalf("Failed to get dimensions: %v", err) } - for i := 0; i < vectorCount; i++ { - vector := generateTestVector(dimensions) - vector[0] = float32(i) // Make each vector unique - vectors[i] = vector + for i := 0; i < vectorCount; i++ { + vector := generateTestVector(dimensions) + vector[0] = float32(i) // Make each vector unique + vectors[i] = vector - err = index.Add(Key(i), vector) - if err != nil { - t.Fatalf("Failed to add vector %d: %v", i, err) - } - } + err = index.Add(Key(i), vector) + if err != nil { + t.Fatalf("Failed to add vector %d: %v", i, err) + } + } return vectors } @@ -189,12 +189,12 @@ func TestBasicOperations(t *testing.T) { t.Fatalf("Failed to reserve capacity: %v", err) } - // Add a vector - vector := generateTestVector(defaultTestDimensions) - vector[0] = 42.0 - vector[1] = 24.0 + // Add a vector + vector := generateTestVector(defaultTestDimensions) + vector[0] = 42.0 + vector[1] = 24.0 - err := index.Add(100, vector) + err := index.Add(100, vector) if err != nil { t.Fatalf("Failed to add vector: %v", err) } @@ -524,13 +524,13 @@ func TestQuantizationTypes(t *testing.T) { if err := index.Reserve(1); err != nil { t.Fatalf("Failed to reserve capacity: %v", err) } - vector := generateTestVector(32) - err := index.Add(1, vector) + vector := generateTestVector(32) + err := index.Add(1, vector) if err != nil { t.Fatalf("F32 Add failed: %v", err) } - keys, _, err := index.Search(vector, 1) + keys, _, err := index.Search(vector, 1) if err != nil { t.Fatalf("F32 Search failed: %v", err) } @@ -551,17 +551,17 @@ func TestQuantizationTypes(t *testing.T) { if err := index.Reserve(1); err != nil { t.Fatalf("Failed to reserve capacity: %v", err) } - vector := make([]float64, 32) - for i := range vector { - vector[i] = float64(i) + 0.5 - } + vector := make([]float64, 32) + for i := range vector { + vector[i] = float64(i) + 0.5 + } - err := index.AddUnsafe(1, unsafe.Pointer(&vector[0])) + err := index.AddUnsafe(1, unsafe.Pointer(&vector[0])) if err != nil { t.Fatalf("F64 AddUnsafe failed: %v", err) } - keys, _, err := index.SearchUnsafe(unsafe.Pointer(&vector[0]), 1) + keys, _, err := index.SearchUnsafe(unsafe.Pointer(&vector[0]), 1) if err != nil { t.Fatalf("F64 SearchUnsafe failed: %v", err) } @@ -582,13 +582,13 @@ func TestQuantizationTypes(t *testing.T) { if err := index.Reserve(1); err != nil { t.Fatalf("Failed to reserve capacity: %v", err) } - vector := generateTestVectorI8(32) - err := index.AddI8(1, vector) + vector := generateTestVectorI8(32) + err := index.AddI8(1, vector) if err != nil { t.Fatalf("I8 Add failed: %v", err) } - keys, _, err := index.SearchI8(vector, 1) + keys, _, err := index.SearchI8(vector, 1) if err != nil { t.Fatalf("I8 Search failed: %v", err) } @@ -614,8 +614,8 @@ func TestUnsafeOperations(t *testing.T) { if err := index.Reserve(1); err != nil { t.Fatalf("Failed to reserve capacity: %v", err) } - vector := generateTestVector(64) - ptr := unsafe.Pointer(&vector[0]) + vector := generateTestVector(64) + ptr := unsafe.Pointer(&vector[0]) // Test AddUnsafe err := index.AddUnsafe(100, ptr) @@ -671,12 +671,12 @@ func TestConcurrentInsertions(t *testing.T) { _ = index.ChangeThreadsAdd(uint(runtime.NumCPU())) for i := 0; i < totalVectors; i++ { - vector := generateTestVector(64) - vector[0] = float32(i) - if err := index.Add(Key(i), vector); err != nil { - t.Fatalf("Insertion failed at %d: %v", i, err) - } - } + vector := generateTestVector(64) + vector[0] = float32(i) + if err := index.Add(Key(i), vector); err != nil { + t.Fatalf("Insertion failed at %d: %v", i, err) + } + } // Verify final count finalSize, err := index.Len() @@ -763,17 +763,22 @@ func TestExactSearch(t *testing.T) { const datasetSize = 100 const querySize = 10 const vectorDims = 32 + const maxResults = 5 dataset := make([]float32, datasetSize*vectorDims) queries := make([]float32, querySize*vectorDims) // Fill with test data - for i := 0; i < len(dataset); i++ { - dataset[i] = float32(i%100) + 0.1 + for i := 0; i < datasetSize; i++ { + for j := 0; j < vectorDims; j++ { + dataset[i*vectorDims+j] = float32(i%100+j) + 0.1 + } } - for i := 0; i < len(queries); i++ { - queries[i] = float32(i%50) + 0.1 + for i := 0; i < querySize; i++ { + for j := 0; j < vectorDims; j++ { + queries[i*vectorDims+j] = float32(j%50) + 0.1 + } } keys, distances, err := ExactSearch( @@ -781,35 +786,47 @@ func TestExactSearch(t *testing.T) { datasetSize, querySize, vectorDims*4, vectorDims*4, // Stride in bytes for float32 vectorDims, Cosine, - 5, 0, // maxResults=5, numThreads=0 (auto) - 8, 4, // resultKeysStride, resultDistancesStride + maxResults, 0, // maxResults=5, numThreads=0 (auto) ) if err != nil { t.Fatalf("ExactSearch failed: %v", err) } - if len(keys) != 5 || len(distances) != 5 { - t.Fatalf("Expected 5 results from ExactSearch, got %d keys and %d distances", + if len(keys) != maxResults*querySize || len(distances) != maxResults*querySize { + t.Fatalf("Expected 5*10 results from ExactSearch, got %d keys and %d distances", len(keys), len(distances)) } + + for i := 0; i < querySize; i++ { + for j := 0; j < maxResults; j++ { + if keys[j] != keys[i*maxResults+j] || distances[j] != distances[i*maxResults+j] { + t.Fatalf("Expected same results from ExactSearch for all keys and distances") + } + } + } }) t.Run("I8 exact search", func(t *testing.T) { const datasetSize = 50 const querySize = 5 const vectorDims = 16 + const maxResults = 3 dataset := make([]int8, datasetSize*vectorDims) queries := make([]int8, querySize*vectorDims) // Fill with test data - for i := 0; i < len(dataset); i++ { - dataset[i] = int8((i % 100) + 1) + for i := 0; i < datasetSize; i++ { + for j := 0; j < vectorDims; j++ { + dataset[i*vectorDims+j] = int8(i%100+j) + 1 + } } - for i := 0; i < len(queries); i++ { - queries[i] = int8((i % 50) + 1) + for i := 0; i < querySize; i++ { + for j := 0; j < vectorDims; j++ { + queries[i*vectorDims+j] = int8(j%50) + 1 + } } keys, distances, err := ExactSearchI8( @@ -817,18 +834,83 @@ func TestExactSearch(t *testing.T) { datasetSize, querySize, vectorDims, vectorDims, // Stride in bytes for int8 vectorDims, L2sq, - 3, 0, // maxResults=3, numThreads=0 (auto) - 8, 4, // resultKeysStride, resultDistancesStride + maxResults, 0, // maxResults=3, numThreads=0 (auto) + ) + + if err != nil { + t.Fatalf("ExactSearchI8 failed: %v", err) + } + + if len(keys) != maxResults*querySize || len(distances) != maxResults*querySize { + t.Fatalf("Expected 3*querySize results from ExactSearchI8, got %d keys and %d distances", + len(keys), len(distances)) + } + + for i := 0; i < querySize; i++ { + for j := 0; j < maxResults; j++ { + if keys[j] != keys[i*maxResults+j] || distances[j] != distances[i*maxResults+j] { + t.Fatalf("Expected same results from ExactSearch for all keys and distances") + } + } + } + }) + + t.Run("unsafe exact search", func(t *testing.T) { + const datasetSize = 10 + const querySize = 10 + const vectorDims = 3 + const maxResults = 1 + + dataset := []float32{0.57402676, 0.416747, 0.7048512, + 0.031865682, 0.81882423, 0.57315916, + 0.2874403, 0.045098174, 0.95673627, + 0.006364229, 0.71774554, 0.6962764, + 0.33764744, 0.44205195, 0.831014, + 0.3366346, 0.829091, 0.4464138, + 0.11070566, 0.96180826, 0.2503381, + 0.538731, 0.2840365, 0.7931533, + 0.7719648, 0.20657142, 0.6011644, + 0.21957317, 0.94966024, 0.22345713, + } + + queries := []float32{0.57402676, 0.416747, 0.7048512, + 0.031865682, 0.81882423, 0.57315916, + 0.2874403, 0.045098174, 0.95673627, + 0.006364229, 0.71774554, 0.6962764, + 0.33764744, 0.44205195, 0.831014, + 0.3366346, 0.829091, 0.4464138, + 0.11070566, 0.96180826, 0.2503381, + 0.538731, 0.2840365, 0.7931533, + 0.7719648, 0.20657142, 0.6011644, + 0.21957317, 0.94966024, 0.22345713, + } + + keys, distances, err := ExactSearchUnsafe( + unsafe.Pointer(&dataset[0]), unsafe.Pointer(&queries[0]), + datasetSize, querySize, + vectorDims, vectorDims, // Stride in bytes for int8 + vectorDims, L2sq, F32, + maxResults, 0, // maxResults=3, numThreads=0 (auto) ) if err != nil { t.Fatalf("ExactSearchI8 failed: %v", err) } - if len(keys) != 3 || len(distances) != 3 { - t.Fatalf("Expected 3 results from ExactSearchI8, got %d keys and %d distances", + if len(keys) != maxResults*querySize || len(distances) != maxResults*querySize { + t.Fatalf("Expected 3*querySize results from ExactSearchI8, got %d keys and %d distances", len(keys), len(distances)) } + + fmt.Printf("keys %v\n", keys) + fmt.Printf("distances %v\n", distances) + + for i := 0; i < querySize; i++ { + if keys[i] != Key(i) || distances[i] != 0 { + t.Fatalf("Expected same results from ExactSearch for all keys and distances") + } + } + }) }