-
Notifications
You must be signed in to change notification settings - Fork 4.1k
sve optimization for HNSW::MinimaxHeap::pop_min() #4699
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from 2 commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -22,6 +22,12 @@ | |
| #include <type_traits> | ||
| #endif | ||
|
|
||
| #ifdef __ARM_FEATURE_SVE | ||
| #include <arm_sve.h> | ||
| #include <limits> | ||
| #include <type_traits> | ||
| #endif | ||
|
|
||
| namespace faiss { | ||
|
|
||
| /************************************************************** | ||
|
|
@@ -1306,6 +1312,72 @@ int HNSW::MinimaxHeap::pop_min(float* vmin_out) { | |
| return ret; | ||
| } | ||
|
|
||
| #elif __ARM_FEATURE_SVE | ||
|
|
||
| int HNSW::MinimaxHeap::pop_min(float* vmin_out) { | ||
| assert(k > 0); | ||
| static_assert( | ||
| std::is_same<storage_idx_t, int32_t>::value, | ||
| "This code expects storage_idx_t to be int32_t"); | ||
|
|
||
| const size_t lanes = svcntw(); | ||
| const svbool_t pg_all = svptrue_b32(); | ||
|
|
||
| prefetch_L2(ids.data()); | ||
| prefetch_L2(dis.data()); | ||
|
|
||
| int32_t min_idx = -1; | ||
| float min_dis = std::numeric_limits<float>::infinity(); | ||
|
|
||
| // Initialize vectors with -1 indices and infinity distances | ||
| svint32_t min_idx_vec = svdup_n_s32(-1); | ||
| svfloat32_t min_dis_vec = svdup_n_f32(min_dis); | ||
| svint32_t current_idx_vec = svindex_s32(0, 1); | ||
|
|
||
| size_t i = 0; | ||
| const size_t k_size = static_cast<size_t>(k); | ||
|
|
||
| while (i < k_size) { | ||
| svbool_t pg_iter = svwhilelt_b32_u64(i, k_size); | ||
|
|
||
| const size_t prefetch_iterations = 2; | ||
| size_t prefetch_idx = i + prefetch_iterations * lanes; | ||
| if (prefetch_idx < k_size) { | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. is this
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This is for avoiding the out-of-bound addresses. The prefetch_idx within the range [i, i + lanes] is safe in the loop, however we are prefetching i + 2 * lanes which might overflow the upper bound of the loop which might waste CPU cycles for prefetch. Let me know if you still think we should remove the check |
||
| prefetch_L2(ids.data() + prefetch_idx); | ||
| prefetch_L2(dis.data() + prefetch_idx); | ||
| } | ||
|
|
||
| svint32_t idx_val_vec = svld1_s32(pg_iter, ids.data() + i); | ||
| svbool_t pg_valid_idx_value = svcmpne_n_s32(pg_iter, idx_val_vec, -1); | ||
| svfloat32_t dis_val_vec = svld1_f32(pg_valid_idx_value, dis.data() + i); | ||
|
|
||
| svbool_t pg_less_equal_mask = svcmple_f32(pg_valid_idx_value, dis_val_vec, min_dis_vec); | ||
|
|
||
| min_idx_vec = svsel_s32(pg_less_equal_mask, current_idx_vec, min_idx_vec); | ||
| min_dis_vec = svsel_f32(pg_less_equal_mask, dis_val_vec, min_dis_vec); | ||
|
|
||
| current_idx_vec = svadd_n_s32_x(pg_all, current_idx_vec, lanes); | ||
| i += lanes; | ||
| } | ||
|
|
||
| svbool_t pg_valid_pre = svcmpne_n_s32(pg_all, min_idx_vec, -1); | ||
| if (!svptest_any(pg_all, pg_valid_pre)) { | ||
| // No valid elements found (all are -1) | ||
| return -1; | ||
| } | ||
|
|
||
| min_dis = svminv_f32(pg_valid_pre, min_dis_vec); | ||
| svbool_t pg_valid_for_idx = svcmpeq_n_f32(pg_valid_pre, min_dis_vec, min_dis); | ||
| min_idx = svmaxv_s32(pg_valid_for_idx, min_idx_vec); | ||
| if (vmin_out) { | ||
| *vmin_out = min_dis; | ||
| } | ||
| int ret = ids[min_idx]; | ||
| ids[min_idx] = -1; | ||
| --nvalid; | ||
| return ret; | ||
| } | ||
|
|
||
| #else | ||
|
|
||
| // baseline non-vectorized version | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
why 2? please add a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for the review! The "2" here is the best performance I got during benchmarking. The idea is to prefetch the data certain steps ahead (here is 2) which timing is not too early and not too late for cache access. I will add some comment for explaining the usage of here.