diff --git a/faiss/impl/HNSW.cpp b/faiss/impl/HNSW.cpp index e5a43f2ece..c26905bd18 100644 --- a/faiss/impl/HNSW.cpp +++ b/faiss/impl/HNSW.cpp @@ -22,6 +22,12 @@ #include #endif +#ifdef __ARM_FEATURE_SVE +#include +#include +#include +#endif + namespace faiss { /************************************************************** @@ -1306,6 +1312,75 @@ int HNSW::MinimaxHeap::pop_min(float* vmin_out) { return ret; } +#elif __ARM_FEATURE_SVE + +int HNSW::MinimaxHeap::pop_min(float* vmin_out) { + assert(k > 0); + static_assert( + std::is_same::value, + "This code expects storage_idx_t to be int32_t"); + + const size_t lanes = svcntw(); + const svbool_t pg_all = svptrue_b32(); + + prefetch_L2(ids.data()); + prefetch_L2(dis.data()); + + int32_t min_idx = -1; + float min_dis = std::numeric_limits::infinity(); + + // Initialize vectors with -1 indices and infinity distances + svint32_t min_idx_vec = svdup_n_s32(-1); + svfloat32_t min_dis_vec = svdup_n_f32(min_dis); + svint32_t current_idx_vec = svindex_s32(0, 1); + + size_t i = 0; + const size_t k_size = static_cast(k); + + while (i < k_size) { + svbool_t pg_iter = svwhilelt_b32_u64(i, k_size); + + // Prefetch data 2 vector iterations ahead to hide memory latency. + // Value of 2 was chosen based on performance benchmarking results, + // which showed optimal performance compared to other values (1, 3, 4). + const size_t prefetch_iterations = 2; + size_t prefetch_idx = i + prefetch_iterations * lanes; + if (prefetch_idx < k_size) { + prefetch_L2(ids.data() + prefetch_idx); + prefetch_L2(dis.data() + prefetch_idx); + } + + svint32_t idx_val_vec = svld1_s32(pg_iter, ids.data() + i); + svbool_t pg_valid_idx_value = svcmpne_n_s32(pg_iter, idx_val_vec, -1); + svfloat32_t dis_val_vec = svld1_f32(pg_valid_idx_value, dis.data() + i); + + svbool_t pg_less_equal_mask = svcmple_f32(pg_valid_idx_value, dis_val_vec, min_dis_vec); + + min_idx_vec = svsel_s32(pg_less_equal_mask, current_idx_vec, min_idx_vec); + min_dis_vec = svsel_f32(pg_less_equal_mask, dis_val_vec, min_dis_vec); + + current_idx_vec = svadd_n_s32_x(pg_all, current_idx_vec, lanes); + i += lanes; + } + + svbool_t pg_valid_pre = svcmpne_n_s32(pg_all, min_idx_vec, -1); + if (!svptest_any(pg_all, pg_valid_pre)) { + // No valid elements found (all are -1) + return -1; + } + + min_dis = svminv_f32(pg_valid_pre, min_dis_vec); + svbool_t pg_valid_for_idx = svcmpeq_n_f32(pg_valid_pre, min_dis_vec, min_dis); + min_idx = svmaxv_s32(pg_valid_for_idx, min_idx_vec); + if (vmin_out) { + *vmin_out = min_dis; + } + int ret = ids[min_idx]; + ids[min_idx] = -1; + --nvalid; + return ret; +} + #else // baseline non-vectorized version