Skip to content
Open
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
72 changes: 72 additions & 0 deletions faiss/impl/HNSW.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,12 @@
#include <type_traits>
#endif

#ifdef __ARM_FEATURE_SVE
#include <arm_sve.h>
#include <limits>
#include <type_traits>
#endif

namespace faiss {

/**************************************************************
Expand Down Expand Up @@ -1306,6 +1312,72 @@ int HNSW::MinimaxHeap::pop_min(float* vmin_out) {
return ret;
}

#elif __ARM_FEATURE_SVE

int HNSW::MinimaxHeap::pop_min(float* vmin_out) {
assert(k > 0);
static_assert(
std::is_same<storage_idx_t, int32_t>::value,
"This code expects storage_idx_t to be int32_t");

const size_t lanes = svcntw();
const svbool_t pg_all = svptrue_b32();

prefetch_L2(ids.data());
prefetch_L2(dis.data());

int32_t min_idx = -1;
float min_dis = std::numeric_limits<float>::infinity();

// Initialize vectors with -1 indices and infinity distances
svint32_t min_idx_vec = svdup_n_s32(-1);
svfloat32_t min_dis_vec = svdup_n_f32(min_dis);
svint32_t current_idx_vec = svindex_s32(0, 1);

size_t i = 0;
const size_t k_size = static_cast<size_t>(k);

while (i < k_size) {
svbool_t pg_iter = svwhilelt_b32_u64(i, k_size);

const size_t prefetch_iterations = 2;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why 2? please add a comment

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the review! The "2" here is the best performance I got during benchmarking. The idea is to prefetch the data certain steps ahead (here is 2) which timing is not too early and not too late for cache access. I will add some comment for explaining the usage of here.

size_t prefetch_idx = i + prefetch_iterations * lanes;
if (prefetch_idx < k_size) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is this if really needed?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is for avoiding the out-of-bound addresses. The prefetch_idx within the range [i, i + lanes] is safe in the loop, however we are prefetching i + 2 * lanes which might overflow the upper bound of the loop which might waste CPU cycles for prefetch. Let me know if you still think we should remove the check

prefetch_L2(ids.data() + prefetch_idx);
prefetch_L2(dis.data() + prefetch_idx);
}

svint32_t idx_val_vec = svld1_s32(pg_iter, ids.data() + i);
svbool_t pg_valid_idx_value = svcmpne_n_s32(pg_iter, idx_val_vec, -1);
svfloat32_t dis_val_vec = svld1_f32(pg_valid_idx_value, dis.data() + i);

svbool_t pg_less_equal_mask = svcmple_f32(pg_valid_idx_value, dis_val_vec, min_dis_vec);

min_idx_vec = svsel_s32(pg_less_equal_mask, current_idx_vec, min_idx_vec);
min_dis_vec = svsel_f32(pg_less_equal_mask, dis_val_vec, min_dis_vec);

current_idx_vec = svadd_n_s32_x(pg_all, current_idx_vec, lanes);
i += lanes;
}

svbool_t pg_valid_pre = svcmpne_n_s32(pg_all, min_idx_vec, -1);
if (!svptest_any(pg_all, pg_valid_pre)) {
// No valid elements found (all are -1)
return -1;
}

min_dis = svminv_f32(pg_valid_pre, min_dis_vec);
svbool_t pg_valid_for_idx = svcmpeq_n_f32(pg_valid_pre, min_dis_vec, min_dis);
min_idx = svmaxv_s32(pg_valid_for_idx, min_idx_vec);
if (vmin_out) {
*vmin_out = min_dis;
}
int ret = ids[min_idx];
ids[min_idx] = -1;
--nvalid;
return ret;
}

#else

// baseline non-vectorized version
Expand Down