Skip to content

Commit

Permalink
Add Brute Force KNN to Python and Rust API's (#59)
Browse files Browse the repository at this point in the history
Authors:
  - Ben Frederickson (https://github.com/benfred)
  - Corey J. Nolet (https://github.com/cjnolet)

Approvers:
  - Corey J. Nolet (https://github.com/cjnolet)

URL: #59
  • Loading branch information
benfred authored Apr 4, 2024
1 parent d62e1ce commit d23de1b
Show file tree
Hide file tree
Showing 19 changed files with 715 additions and 111 deletions.
3 changes: 2 additions & 1 deletion cpp/include/cuvs/distance/distance_types.h
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#pragma once

#ifdef __cplusplus
extern "C" {
Expand Down Expand Up @@ -67,4 +68,4 @@ enum DistanceType {

#ifdef __cplusplus
}
#endif
#endif
40 changes: 20 additions & 20 deletions cpp/include/cuvs/neighbors/brute_force.h
Original file line number Diff line number Diff line change
Expand Up @@ -36,24 +36,24 @@ extern "C" {
typedef struct {
uintptr_t addr;
DLDataType dtype;
} bruteForceIndex;
} cuvsBruteForceIndex;

typedef bruteForceIndex* cuvsBruteForceIndex_t;
typedef cuvsBruteForceIndex* cuvsBruteForceIndex_t;

/**
* @brief Allocate BRUTEFORCE index
*
* @param[in] index cuvsBruteForceIndex_t to allocate
* @return cuvsError_t
*/
cuvsError_t bruteForceIndexCreate(cuvsBruteForceIndex_t* index);
cuvsError_t cuvsBruteForceIndexCreate(cuvsBruteForceIndex_t* index);

/**
* @brief De-allocate BRUTEFORCE index
*
* @param[in] index cuvsBruteForceIndex_t to de-allocate
*/
cuvsError_t bruteForceIndexDestroy(cuvsBruteForceIndex_t index);
cuvsError_t cuvsBruteForceIndexDestroy(cuvsBruteForceIndex_t index);
/**
* @}
*/
Expand Down Expand Up @@ -83,13 +83,13 @@ cuvsError_t bruteForceIndexDestroy(cuvsBruteForceIndex_t index);
*
* // Create BRUTEFORCE index
* cuvsBruteForceIndex_t index;
* cuvsError_t index_create_status = bruteForceIndexCreate(&index);
* cuvsError_t index_create_status = cuvsBruteForceIndexCreate(&index);
*
* // Build the BRUTEFORCE Index
* cuvsError_t build_status = bruteForceBuild(res, &dataset_tensor, L2Expanded, 0.f, index);
* cuvsError_t build_status = cuvsBruteForceBuild(res, &dataset_tensor, L2Expanded, 0.f, index);
*
* // de-allocate `index` and `res`
* cuvsError_t index_destroy_status = bruteForceIndexDestroy(index);
* cuvsError_t index_destroy_status = cuvsBruteForceIndexDestroy(index);
* cuvsError_t res_destroy_status = cuvsResourcesDestroy(res);
* @endcode
*
Expand All @@ -100,11 +100,11 @@ cuvsError_t bruteForceIndexDestroy(cuvsBruteForceIndex_t index);
* @param[out] index cuvsBruteForceIndex_t Newly built BRUTEFORCE index
* @return cuvsError_t
*/
cuvsError_t bruteForceBuild(cuvsResources_t res,
DLManagedTensor* dataset,
enum DistanceType metric,
float metric_arg,
cuvsBruteForceIndex_t index);
cuvsError_t cuvsBruteForceBuild(cuvsResources_t res,
DLManagedTensor* dataset,
enum DistanceType metric,
float metric_arg,
cuvsBruteForceIndex_t index);
/**
* @}
*/
Expand Down Expand Up @@ -136,24 +136,24 @@ cuvsError_t bruteForceBuild(cuvsResources_t res,
* DLManagedTensor queries;
* DLManagedTensor neighbors;
*
* // Search the `index` built using `bruteForceBuild`
* cuvsError_t search_status = bruteForceSearch(res, index, &queries, &neighbors, &distances);
* // Search the `index` built using `cuvsBruteForceBuild`
* cuvsError_t search_status = cuvsBruteForceSearch(res, index, &queries, &neighbors, &distances);
*
* // de-allocate `res`
* cuvsError_t res_destroy_status = cuvsResourcesDestroy(res);
* @endcode
*
* @param[in] res cuvsResources_t opaque C handle
* @param[in] index bruteForceIndex which has been returned by `bruteForceBuild`
* @param[in] index cuvsBruteForceIndex which has been returned by `cuvsBruteForceBuild`
* @param[in] queries DLManagedTensor* queries dataset to search
* @param[out] neighbors DLManagedTensor* output `k` neighbors for queries
* @param[out] distances DLManagedTensor* output `k` distances for queries
*/
cuvsError_t bruteForceSearch(cuvsResources_t res,
cuvsBruteForceIndex_t index,
DLManagedTensor* queries,
DLManagedTensor* neighbors,
DLManagedTensor* distances);
cuvsError_t cuvsBruteForceSearch(cuvsResources_t res,
cuvsBruteForceIndex_t index,
DLManagedTensor* queries,
DLManagedTensor* neighbors,
DLManagedTensor* distances);
/**
* @}
*/
Expand Down
56 changes: 21 additions & 35 deletions cpp/src/neighbors/brute_force_c.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
#include <raft/core/resources.hpp>

#include <cuvs/core/c_api.h>
#include <cuvs/core/exceptions.hpp>
#include <cuvs/core/interop.hpp>
#include <cuvs/neighbors/brute_force.h>
#include <cuvs/neighbors/brute_force.hpp>
Expand All @@ -49,7 +50,7 @@ void* _build(cuvsResources_t res,

template <typename T>
void _search(cuvsResources_t res,
bruteForceIndex index,
cuvsBruteForceIndex index,
DLManagedTensor* queries_tensor,
DLManagedTensor* neighbors_tensor,
DLManagedTensor* distances_tensor)
Expand All @@ -70,19 +71,14 @@ void _search(cuvsResources_t res,

} // namespace

extern "C" cuvsError_t bruteForceIndexCreate(cuvsBruteForceIndex_t* index)
extern "C" cuvsError_t cuvsBruteForceIndexCreate(cuvsBruteForceIndex_t* index)
{
try {
*index = new bruteForceIndex{};
return CUVS_SUCCESS;
} catch (...) {
return CUVS_ERROR;
}
return cuvs::core::translate_exceptions([=] { *index = new cuvsBruteForceIndex{}; });
}

extern "C" cuvsError_t bruteForceIndexDestroy(cuvsBruteForceIndex_t index_c_ptr)
extern "C" cuvsError_t cuvsBruteForceIndexDestroy(cuvsBruteForceIndex_t index_c_ptr)
{
try {
return cuvs::core::translate_exceptions([=] {
auto index = *index_c_ptr;

if (index.dtype.code == kDLFloat) {
Expand All @@ -96,19 +92,16 @@ extern "C" cuvsError_t bruteForceIndexDestroy(cuvsBruteForceIndex_t index_c_ptr)
delete index_ptr;
}
delete index_c_ptr;
return CUVS_SUCCESS;
} catch (...) {
return CUVS_ERROR;
}
});
}

extern "C" cuvsError_t bruteForceBuild(cuvsResources_t res,
DLManagedTensor* dataset_tensor,
enum DistanceType metric,
float metric_arg,
cuvsBruteForceIndex_t index)
extern "C" cuvsError_t cuvsBruteForceBuild(cuvsResources_t res,
DLManagedTensor* dataset_tensor,
enum DistanceType metric,
float metric_arg,
cuvsBruteForceIndex_t index)
{
try {
return cuvs::core::translate_exceptions([=] {
auto dataset = dataset_tensor->dl_tensor;

if (dataset.dtype.code == kDLFloat && dataset.dtype.bits == 32) {
Expand All @@ -120,19 +113,16 @@ extern "C" cuvsError_t bruteForceBuild(cuvsResources_t res,
dataset.dtype.code,
dataset.dtype.bits);
}
return CUVS_SUCCESS;
} catch (...) {
return CUVS_ERROR;
}
});
}

extern "C" cuvsError_t bruteForceSearch(cuvsResources_t res,
cuvsBruteForceIndex_t index_c_ptr,
DLManagedTensor* queries_tensor,
DLManagedTensor* neighbors_tensor,
DLManagedTensor* distances_tensor)
extern "C" cuvsError_t cuvsBruteForceSearch(cuvsResources_t res,
cuvsBruteForceIndex_t index_c_ptr,
DLManagedTensor* queries_tensor,
DLManagedTensor* neighbors_tensor,
DLManagedTensor* distances_tensor)
{
try {
return cuvs::core::translate_exceptions([=] {
auto queries = queries_tensor->dl_tensor;
auto neighbors = neighbors_tensor->dl_tensor;
auto distances = distances_tensor->dl_tensor;
Expand All @@ -159,9 +149,5 @@ extern "C" cuvsError_t bruteForceSearch(cuvsResources_t res,
queries.dtype.code,
queries.dtype.bits);
}

return CUVS_SUCCESS;
} catch (...) {
return CUVS_ERROR;
}
});
}
59 changes: 14 additions & 45 deletions cpp/src/neighbors/ivf_flat_c.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
#include <raft/core/resources.hpp>

#include <cuvs/core/c_api.h>
#include <cuvs/core/exceptions.hpp>
#include <cuvs/core/interop.hpp>
#include <cuvs/neighbors/ivf_flat.h>
#include <cuvs/neighbors/ivf_flat.hpp>
Expand Down Expand Up @@ -86,17 +87,12 @@ void _search(cuvsResources_t res,

extern "C" cuvsError_t ivfFlatIndexCreate(cuvsIvfFlatIndex_t* index)
{
try {
*index = new ivfFlatIndex{};
return CUVS_SUCCESS;
} catch (...) {
return CUVS_ERROR;
}
return cuvs::core::translate_exceptions([=] { *index = new ivfFlatIndex{}; });
}

extern "C" cuvsError_t ivfFlatIndexDestroy(cuvsIvfFlatIndex_t index_c_ptr)
{
try {
return cuvs::core::translate_exceptions([=] {
auto index = *index_c_ptr;

if (index.dtype.code == kDLFloat) {
Expand All @@ -113,18 +109,15 @@ extern "C" cuvsError_t ivfFlatIndexDestroy(cuvsIvfFlatIndex_t index_c_ptr)
delete index_ptr;
}
delete index_c_ptr;
return CUVS_SUCCESS;
} catch (...) {
return CUVS_ERROR;
}
});
}

extern "C" cuvsError_t ivfFlatBuild(cuvsResources_t res,
cuvsIvfFlatIndexParams_t params,
DLManagedTensor* dataset_tensor,
cuvsIvfFlatIndex_t index)
{
try {
return cuvs::core::translate_exceptions([=] {
auto dataset = dataset_tensor->dl_tensor;

if (dataset.dtype.code == kDLFloat && dataset.dtype.bits == 32) {
Expand All @@ -144,10 +137,7 @@ extern "C" cuvsError_t ivfFlatBuild(cuvsResources_t res,
dataset.dtype.code,
dataset.dtype.bits);
}
return CUVS_SUCCESS;
} catch (...) {
return CUVS_ERROR;
}
});
}

extern "C" cuvsError_t ivfFlatSearch(cuvsResources_t res,
Expand All @@ -157,7 +147,7 @@ extern "C" cuvsError_t ivfFlatSearch(cuvsResources_t res,
DLManagedTensor* neighbors_tensor,
DLManagedTensor* distances_tensor)
{
try {
return cuvs::core::translate_exceptions([=] {
auto queries = queries_tensor->dl_tensor;
auto neighbors = neighbors_tensor->dl_tensor;
auto distances = distances_tensor->dl_tensor;
Expand Down Expand Up @@ -191,16 +181,12 @@ extern "C" cuvsError_t ivfFlatSearch(cuvsResources_t res,
queries.dtype.code,
queries.dtype.bits);
}

return CUVS_SUCCESS;
} catch (...) {
return CUVS_ERROR;
}
});
}

extern "C" cuvsError_t cuvsIvfFlatIndexParamsCreate(cuvsIvfFlatIndexParams_t* params)
{
try {
return cuvs::core::translate_exceptions([=] {
*params = new ivfFlatIndexParams{.metric = L2Expanded,
.metric_arg = 2.0f,
.add_data_on_build = true,
Expand All @@ -209,38 +195,21 @@ extern "C" cuvsError_t cuvsIvfFlatIndexParamsCreate(cuvsIvfFlatIndexParams_t* pa
.kmeans_trainset_fraction = 0.5,
.adaptive_centers = false,
.conservative_memory_allocation = false};
return CUVS_SUCCESS;
} catch (...) {
return CUVS_ERROR;
}
});
}

extern "C" cuvsError_t cuvsIvfFlatIndexParamsDestroy(cuvsIvfFlatIndexParams_t params)
{
try {
delete params;
return CUVS_SUCCESS;
} catch (...) {
return CUVS_ERROR;
}
return cuvs::core::translate_exceptions([=] { delete params; });
}

extern "C" cuvsError_t cuvsIvfFlatSearchParamsCreate(cuvsIvfFlatSearchParams_t* params)
{
try {
*params = new ivfFlatSearchParams{.n_probes = 20};
return CUVS_SUCCESS;
} catch (...) {
return CUVS_ERROR;
}
return cuvs::core::translate_exceptions(
[=] { *params = new ivfFlatSearchParams{.n_probes = 20}; });
}

extern "C" cuvsError_t cuvsIvfFlatSearchParamsDestroy(cuvsIvfFlatSearchParams_t params)
{
try {
delete params;
return CUVS_SUCCESS;
} catch (...) {
return CUVS_ERROR;
}
return cuvs::core::translate_exceptions([=] { delete params; });
}
8 changes: 4 additions & 4 deletions cpp/test/neighbors/run_brute_force_c.c
Original file line number Diff line number Diff line change
Expand Up @@ -44,10 +44,10 @@ void run_brute_force(int64_t n_rows,

// create index
cuvsBruteForceIndex_t index;
bruteForceIndexCreate(&index);
cuvsBruteForceIndexCreate(&index);

// build index
bruteForceBuild(res, &dataset_tensor, metric, 0.0f, index);
cuvsBruteForceBuild(res, &dataset_tensor, metric, 0.0f, index);

// create queries DLTensor
DLManagedTensor queries_tensor;
Expand Down Expand Up @@ -86,9 +86,9 @@ void run_brute_force(int64_t n_rows,
distances_tensor.dl_tensor.strides = NULL;

// search index
bruteForceSearch(res, index, &queries_tensor, &neighbors_tensor, &distances_tensor);
cuvsBruteForceSearch(res, index, &queries_tensor, &neighbors_tensor, &distances_tensor);

// de-allocate index and res
bruteForceIndexDestroy(index);
cuvsBruteForceIndexDestroy(index);
cuvsResourcesDestroy(res);
}
Loading

0 comments on commit d23de1b

Please sign in to comment.