From 291b2563abc6f13f3fbc1118a7849cfc729a056d Mon Sep 17 00:00:00 2001 From: matthijs Date: Thu, 28 Aug 2025 00:59:30 -0700 Subject: [PATCH 1/5] SIMDConfig object, [faiss] Adding support for AVX2 and AVX512(F, BW, DQ, Vl, DL) detection Summary: * Added support to detect SIMD instruction set for both `AVX2` and `AVX512F, AVX512VL` related levels * Added hardware specific unit tests (eg: checks when unit tests are ran on x86 arch then relevant SIMD levels are returned, also respective instructions are executed) * Reason for explicitly running computation and not relying on `__builtin_cpu_supports("avx512f")` [link](https://stackoverflow.com/questions/48677575/does-gccs-builtin-cpu-supports-check-for-os-support) * Also, fixes the bug in existing `AVX2` detection * Incorrect CPUID Bit Check: Function uses `ebx & (1 << 16)` to check for `AVX2` support. This is incorrect because bit 16 in `ebx` is actually used for `AVX-512F`, not `AVX2`. * Correct Bit for AVX2: Correct bit for detecting AVX2 is bit 5 in `ebx` when `eax = 7` and `ecx = 0`. This is based on Intel's documentation for the CPUID instruction. * Another bug observed in constructor for SIMDConfig (if env variable is set, the codepath still follows detection via code) * Improving SIMDConfig to take parameters to its constructor to support and enable injection mechanism for better testing* Adding more unit tests for other Hardware * Added variable with SIMDConfig to track all possible supported SIMD Levels Differential Revision: D72937710 Reviewed By: mdouze --- faiss/utils/simd_levels.cpp | 172 ++++++++++++++++++++++ faiss/utils/simd_levels.h | 82 +++++++++++ tests/test_simd_levels.cpp | 280 ++++++++++++++++++++++++++++++++++++ 3 files changed, 534 insertions(+) create mode 100644 faiss/utils/simd_levels.cpp create mode 100644 faiss/utils/simd_levels.h create mode 100644 tests/test_simd_levels.cpp diff --git a/faiss/utils/simd_levels.cpp b/faiss/utils/simd_levels.cpp new file mode 100644 index 0000000000..887225ee3b --- /dev/null +++ b/faiss/utils/simd_levels.cpp @@ -0,0 +1,172 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * + * This source code is licensed under the MIT license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include + +#include +#include + +namespace faiss { + +SIMDLevel SIMDConfig::level = SIMDLevel::NONE; +std::unordered_set& SIMDConfig::supported_simd_levels() { + static std::unordered_set levels; + return levels; +} + +// it is there to make sure the constructor runs +static SIMDConfig dummy_config; + +SIMDConfig::SIMDConfig(const char** faiss_simd_level_env) { + // added to support dependency injection + const char* env_var = faiss_simd_level_env ? *faiss_simd_level_env + : getenv("FAISS_SIMD_LEVEL"); + + // check environment variable for SIMD level is explicitly set + if (!env_var) { + level = auto_detect_simd_level(); + } else { + auto matched_level = to_simd_level(env_var); + if (matched_level.has_value()) { + set_level(matched_level.value()); + supported_simd_levels().clear(); + supported_simd_levels().insert(matched_level.value()); + } else { + fprintf(stderr, + "FAISS_SIMD_LEVEL is set to %s, which is unknown\n", + env_var); + exit(1); + } + } + supported_simd_levels().insert(SIMDLevel::NONE); +} + +void SIMDConfig::set_level(SIMDLevel l) { + level = l; +} + +SIMDLevel SIMDConfig::get_level() { + return level; +} + +std::string SIMDConfig::get_level_name() { + return to_string(level).value_or(""); +} + +bool SIMDConfig::is_simd_level_available(SIMDLevel l) { + return supported_simd_levels().find(l) != supported_simd_levels().end(); +} + +SIMDLevel SIMDConfig::auto_detect_simd_level() { + SIMDLevel level = SIMDLevel::NONE; + +#if defined(__x86_64__) && \ + (defined(COMPILE_SIMD_AVX2) || defined(COMPILE_SIMD_AVX512)) + unsigned int eax, ebx, ecx, edx; + + eax = 1; + ecx = 0; + asm volatile("cpuid" + : "=a"(eax), "=b"(ebx), "=c"(ecx), "=d"(edx) + : "a"(eax), "c"(ecx)); + + bool has_avx = (ecx & (1 << 28)) != 0; + + bool has_xsave_osxsave = + (ecx & ((1 << 26) | (1 << 27))) == ((1 << 26) | (1 << 27)); + + bool avx_supported = false; + if (has_avx && has_xsave_osxsave) { + unsigned int xcr0; + asm volatile("xgetbv" : "=a"(xcr0), "=d"(edx) : "c"(0)); + avx_supported = (xcr0 & 6) == 6; + } + + if (avx_supported) { + eax = 7; + ecx = 0; + asm volatile("cpuid" + : "=a"(eax), "=b"(ebx), "=c"(ecx), "=d"(edx) + : "a"(eax), "c"(ecx)); + + unsigned int xcr0; + asm volatile("xgetbv" : "=a"(xcr0), "=d"(edx) : "c"(0)); + +#if defined(COMPILE_SIMD_AVX2) || defined(COMPILE_SIMD_AVX512) + bool has_avx2 = (ebx & (1 << 5)) != 0; + if (has_avx2) { + SIMDConfig::supported_simd_levels().insert(SIMDLevel::AVX2); + level = SIMDLevel::AVX2; + } + +#if defined(COMPILE_SIMD_AVX512) + bool cpu_has_avx512f = (ebx & (1 << 16)) != 0; + bool os_supports_avx512 = (xcr0 & 0xE0) == 0xE0; + bool has_avx512f = cpu_has_avx512f && os_supports_avx512; + if (has_avx512f) { + bool has_avx512cd = (ebx & (1 << 28)) != 0; + bool has_avx512vl = (ebx & (1 << 31)) != 0; + bool has_avx512dq = (ebx & (1 << 17)) != 0; + bool has_avx512bw = (ebx & (1 << 30)) != 0; + if (has_avx512bw && has_avx512cd && has_avx512vl && has_avx512dq) { + level = SIMDLevel::AVX512; + supported_simd_levels().insert(SIMDLevel::AVX512); + } + } +#endif // defined(COMPILE_SIMD_AVX512) +#endif // defined(COMPILE_SIMD_AVX2)|| defined(COMPILE_SIMD_AVX512) + } +#endif // defined(__x86_64__) && (defined(COMPILE_SIMD_AVX2) || + // defined(COMPILE_SIMD_AVX512)) + +#if defined(__aarch64__) && defined(__ARM_NEON) && \ + defined(COMPILE_SIMD_ARM_NEON) + // ARM NEON is standard on aarch64, so we can assume it's available + supported_simd_levels().insert(SIMDLevel::ARM_NEON); + level = SIMDLevel::ARM_NEON; + + // TODO: Add ARM SVE detection when needed + // For now, we default to ARM_NEON as it's universally supported on aarch64 +#endif + + return level; +} + +std::optional to_string(SIMDLevel level) { + switch (level) { + case SIMDLevel::NONE: + return "NONE"; + case SIMDLevel::AVX2: + return "AVX2"; + case SIMDLevel::AVX512: + return "AVX512"; + case SIMDLevel::ARM_NEON: + return "ARM_NEON"; + default: + return std::nullopt; + } + return std::nullopt; +} + +std::optional to_simd_level(const std::string& level_str) { + if (level_str == "NONE") { + return SIMDLevel::NONE; + } + if (level_str == "AVX2") { + return SIMDLevel::AVX2; + } + if (level_str == "AVX512") { + return SIMDLevel::AVX512; + } + if (level_str == "ARM_NEON") { + return SIMDLevel::ARM_NEON; + } + + return std::nullopt; +} + +} // namespace faiss diff --git a/faiss/utils/simd_levels.h b/faiss/utils/simd_levels.h new file mode 100644 index 0000000000..ad3d0b289d --- /dev/null +++ b/faiss/utils/simd_levels.h @@ -0,0 +1,82 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * + * This source code is licensed under the MIT license found in the + * LICENSE file in the root directory of this source tree. + */ + +#pragma once + +#include +#include +#include + +namespace faiss { + +#define COMPILE_SIMD_NONE + +enum class SIMDLevel { + NONE, + // x86 + AVX2, + AVX512, + // arm & aarch64 + ARM_NEON, + + COUNT +}; + +std::optional to_string(SIMDLevel level); + +std::optional to_simd_level(const std::string& level_str); + +/* Current SIMD configuration. This static class manages the current SIMD level + * and intializes it from the cpuid and the FAISS_SIMD_LEVEL + * environment variable */ +struct SIMDConfig { + static SIMDLevel level; + static std::unordered_set& supported_simd_levels(); + + typedef SIMDLevel (*DetectSIMDLevelFunc)(); + static SIMDLevel auto_detect_simd_level(); + + SIMDConfig(const char** faiss_simd_level_env = nullptr); + + static void set_level(SIMDLevel level); + static SIMDLevel get_level(); + static std::string get_level_name(); + + static bool is_simd_level_available(SIMDLevel level); +}; + +/*********************** x86 SIMD */ + +#ifdef COMPILE_SIMD_AVX2 +#define DISPATCH_SIMDLevel_AVX2(f, ...) \ + case SIMDLevel::AVX2: \ + return f(__VA_ARGS__) +#else +#define DISPATCH_SIMDLevel_AVX2(f, ...) +#endif + +#ifdef COMPILE_SIMD_AVX512 +#define DISPATCH_SIMDLevel_AVX512(f, ...) \ + case SIMDLevel::AVX512F: \ + return f(__VA_ARGS__) +#else +#define DISPATCH_SIMDLevel_AVX512(f, ...) +#endif + +/* dispatch function f to f */ + +#define DISPATCH_SIMDLevel(f, ...) \ + switch (SIMDConfig::level) { \ + case SIMDLevel::NONE: \ + return f(__VA_ARGS__); \ + DISPATCH_SIMDLevel_AVX2(f, __VA_ARGS__); \ + DISPATCH_SIMDLevel_AVX512(f, __VA_ARGS__); \ + default: \ + FAISS_ASSERT(!"Invalid SIMD level"); \ + } + +} // namespace faiss diff --git a/tests/test_simd_levels.cpp b/tests/test_simd_levels.cpp new file mode 100644 index 0000000000..4dac2e9877 --- /dev/null +++ b/tests/test_simd_levels.cpp @@ -0,0 +1,280 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * + * This source code is licensed under the MIT license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include +#include +#include + +#ifdef __x86_64__ +#include +#endif + +#include + +static jmp_buf jmpbuf; +static void sigill_handler(int sig) { + longjmp(jmpbuf, 1); +} + +bool try_execute(void (*func)()) { + signal(SIGILL, sigill_handler); + if (setjmp(jmpbuf) == 0) { + func(); + signal(SIGILL, SIG_DFL); + return true; + } else { + signal(SIGILL, SIG_DFL); + return false; + } +} + +#ifdef __x86_64__ +std::vector run_avx2_computation() { + alignas(32) int result[8]; + alignas(32) int input1[8] = {1, 2, 3, 4, 5, 6, 7, 8}; + alignas(32) int input2[8] = {8, 7, 6, 5, 4, 3, 2, 1}; + + __m256i vec1 = _mm256_load_si256(reinterpret_cast<__m256i*>(input1)); + __m256i vec2 = _mm256_load_si256(reinterpret_cast<__m256i*>(input2)); + __m256i vec_result = _mm256_add_epi32(vec1, vec2); + _mm256_store_si256(reinterpret_cast<__m256i*>(result), vec_result); + + return {result, result + 8}; +} + +std::vector run_avx512f_computation() { + alignas(64) long long result[8]; + alignas(64) long long input1[8] = {1, 2, 3, 4, 5, 6, 7, 8}; + alignas(64) long long input2[8] = {8, 7, 6, 5, 4, 3, 2, 1}; + + __m512i vec1 = _mm512_load_si512(reinterpret_cast(input1)); + __m512i vec2 = _mm512_load_si512(reinterpret_cast(input2)); + __m512i vec_result = _mm512_add_epi64(vec1, vec2); + _mm512_store_si512(reinterpret_cast<__m512i*>(result), vec_result); + + return {result, result + 8}; +} + +std::vector run_avx512cd_computation() { + run_avx512f_computation(); + + __m512i indices = _mm512_set_epi32( + 15, 14, 13, 12, 11, 10, 9, 8, 7, 6, 5, 4, 3, 2, 1, 0); + __m512i conflict_mask = _mm512_conflict_epi32(indices); + + alignas(64) int mask_array[16]; + _mm512_store_epi32(mask_array, conflict_mask); + + return std::vector(); +} + +std::vector run_avx512vl_computation() { + run_avx512f_computation(); + + __m256i vec1 = _mm256_set_epi32(7, 6, 5, 4, 3, 2, 1, 0); + __m256i vec2 = _mm256_set_epi32(0, 1, 2, 3, 4, 5, 6, 7); + __m256i result = _mm256_add_epi32(vec1, vec2); + alignas(32) int result_array[8]; + _mm256_store_si256(reinterpret_cast<__m256i*>(result_array), result); + + return std::vector(result_array, result_array + 8); +} + +std::vector run_avx512dq_computation() { + run_avx512f_computation(); + + __m512i vec1 = _mm512_set_epi64(7, 6, 5, 4, 3, 2, 1, 0); + __m512i vec2 = _mm512_set_epi64(0, 1, 2, 3, 4, 5, 6, 7); + __m512i result = _mm512_add_epi64(vec1, vec2); + + alignas(64) long long result_array[8]; + _mm512_store_si512(result_array, result); + + return std::vector(result_array, result_array + 8); +} + +std::vector run_avx512bw_computation() { + run_avx512f_computation(); + + std::vector input1(64, 0); + __m512i vec1 = + _mm512_loadu_si512(reinterpret_cast(input1.data())); + std::vector input2(64, 7); + __m512i vec2 = + _mm512_loadu_si512(reinterpret_cast(input2.data())); + __m512i result = _mm512_add_epi8(vec1, vec2); + + alignas(64) int8_t result_array[64]; + _mm512_storeu_si512(reinterpret_cast<__m512i*>(result_array), result); + + return std::vector(result_array, result_array + 64); +} +#endif // __x86_64__ + +std::pair> try_execute(std::vector (*func)()) { + signal(SIGILL, sigill_handler); + if (setjmp(jmpbuf) == 0) { + auto result = func(); + signal(SIGILL, SIG_DFL); + return std::make_pair(true, result); + } else { + signal(SIGILL, SIG_DFL); + return std::make_pair(false, std::vector()); + } +} + +TEST(SIMDConfig, simd_level_auto_detect_architecture_only) { + faiss::SIMDLevel detected_level = + faiss::SIMDConfig::auto_detect_simd_level(); + +#if defined(__x86_64__) && \ + (defined(__AVX2__) || \ + (defined(__AVX512F__) && defined(__AVX512CD__) && \ + defined(__AVX512VL__) && defined(__AVX512BW__) && \ + defined(__AVX512DQ__))) + EXPECT_TRUE( + detected_level == faiss::SIMDLevel::AVX2 || + detected_level == faiss::SIMDLevel::AVX512); +#elif defined(__aarch64__) && defined(__ARM_NEON) + EXPECT_TRUE(detected_level == faiss::SIMDLevel::ARM_NEON); +#else + EXPECT_EQ(detected_level, faiss::SIMDLevel::NONE); +#endif +} + +#ifdef __x86_64__ +TEST(SIMDConfig, successful_avx2_execution_on_x86arch) { + faiss::SIMDConfig simd_config(nullptr); + + if (simd_config.is_simd_level_available(faiss::SIMDLevel::AVX2)) { + auto actual_result = try_execute(run_avx2_computation); + EXPECT_TRUE(actual_result.first); + auto expected_result_vector = std::vector(8, 9); + EXPECT_EQ(actual_result.second, expected_result_vector); + } +} + +TEST(SIMDConfig, on_avx512f_supported_we_should_avx2_support_as_well) { + faiss::SIMDConfig simd_config(nullptr); + + if (simd_config.is_simd_level_available(faiss::SIMDLevel::AVX512)) { + EXPECT_TRUE( + simd_config.is_simd_level_available(faiss::SIMDLevel::AVX2)); + } +} + +TEST(SIMDConfig, successful_avx512f_execution_on_x86arch) { + faiss::SIMDConfig simd_config(nullptr); + + if (simd_config.is_simd_level_available(faiss::SIMDLevel::AVX512)) { + auto actual_result = try_execute(run_avx512f_computation); + EXPECT_TRUE(actual_result.first); + auto expected_result_vector = std::vector(8, 9); + EXPECT_EQ(actual_result.second, expected_result_vector); + } +} + +TEST(SIMDConfig, successful_avx512cd_execution_on_x86arch) { + faiss::SIMDConfig simd_config(nullptr); + + if (simd_config.is_simd_level_available(faiss::SIMDLevel::AVX512)) { + auto actual = try_execute(run_avx512cd_computation); + EXPECT_TRUE(actual.first); + } +} + +TEST(SIMDConfig, successful_avx512vl_execution_on_x86arch) { + faiss::SIMDConfig simd_config(nullptr); + + if (simd_config.is_simd_level_available(faiss::SIMDLevel::AVX512)) { + auto actual = try_execute(run_avx512vl_computation); + EXPECT_TRUE(actual.first); + EXPECT_EQ(actual.second, std::vector(8, 7)); + } +} + +TEST(SIMDConfig, successful_avx512dq_execution_on_x86arch) { + faiss::SIMDConfig simd_config(nullptr); + + if (simd_config.is_simd_level_available(faiss::SIMDLevel::AVX512)) { + EXPECT_TRUE( + simd_config.is_simd_level_available(faiss::SIMDLevel::AVX512)); + auto actual = try_execute(run_avx512dq_computation); + EXPECT_TRUE(actual.first); + EXPECT_EQ(actual.second, std::vector(8, 7)); + } +} + +TEST(SIMDConfig, successful_avx512bw_execution_on_x86arch) { + faiss::SIMDConfig simd_config(nullptr); + + if (simd_config.is_simd_level_available(faiss::SIMDLevel::AVX512)) { + EXPECT_TRUE( + simd_config.is_simd_level_available(faiss::SIMDLevel::AVX512)); + auto actual = try_execute(run_avx512bw_computation); + EXPECT_TRUE(actual.first); + EXPECT_EQ(actual.second, std::vector(64, 7)); + } +} +#endif // __x86_64__ + +TEST(SIMDConfig, override_simd_level) { + const char* faiss_env_var_neon = "ARM_NEON"; + faiss::SIMDConfig simd_neon_config(&faiss_env_var_neon); + EXPECT_EQ(simd_neon_config.level, faiss::SIMDLevel::ARM_NEON); + + EXPECT_EQ(simd_neon_config.supported_simd_levels().size(), 2); + EXPECT_TRUE(simd_neon_config.is_simd_level_available( + faiss::SIMDLevel::ARM_NEON)); + + const char* faiss_env_var_avx512 = "AVX512"; + faiss::SIMDConfig simd_avx512_config(&faiss_env_var_avx512); + EXPECT_EQ(simd_avx512_config.level, faiss::SIMDLevel::AVX512); + EXPECT_EQ(simd_avx512_config.supported_simd_levels().size(), 2); + EXPECT_TRUE(simd_avx512_config.is_simd_level_available( + faiss::SIMDLevel::AVX512)); +} + +TEST(SIMDConfig, simd_config_get_level_name) { + const char* faiss_env_var_neon = "ARM_NEON"; + faiss::SIMDConfig simd_neon_config(&faiss_env_var_neon); + EXPECT_EQ(simd_neon_config.level, faiss::SIMDLevel::ARM_NEON); + EXPECT_TRUE(simd_neon_config.is_simd_level_available( + faiss::SIMDLevel::ARM_NEON)); + EXPECT_EQ(faiss_env_var_neon, simd_neon_config.get_level_name()); + + const char* faiss_env_var_avx512 = "AVX512"; + faiss::SIMDConfig simd_avx512_config(&faiss_env_var_avx512); + EXPECT_EQ(simd_avx512_config.level, faiss::SIMDLevel::AVX512); + EXPECT_TRUE(simd_avx512_config.is_simd_level_available( + faiss::SIMDLevel::AVX512)); + EXPECT_EQ(faiss_env_var_avx512, simd_avx512_config.get_level_name()); +} + +TEST(SIMDLevel, get_level_name_from_enum) { + EXPECT_EQ("NONE", to_string(faiss::SIMDLevel::NONE).value_or("")); + EXPECT_EQ("AVX2", to_string(faiss::SIMDLevel::AVX2).value_or("")); + EXPECT_EQ("AVX512", to_string(faiss::SIMDLevel::AVX512).value_or("")); + EXPECT_EQ("ARM_NEON", to_string(faiss::SIMDLevel::ARM_NEON).value_or("")); + + int actual_num_simd_levels = static_cast(faiss::SIMDLevel::COUNT); + EXPECT_EQ(4, actual_num_simd_levels); + // Check that all SIMD levels have a name (except for COUNT which is not a + // real SIMD level) + for (int i = 0; i < actual_num_simd_levels - 1; ++i) { + faiss::SIMDLevel simd_level = static_cast(i); + EXPECT_TRUE(faiss::to_string(simd_level).has_value()); + } +} + +TEST(SIMDLevel, to_simd_level_from_string) { + EXPECT_EQ(faiss::SIMDLevel::NONE, faiss::to_simd_level("NONE")); + EXPECT_EQ(faiss::SIMDLevel::AVX2, faiss::to_simd_level("AVX2")); + EXPECT_EQ(faiss::SIMDLevel::AVX512, faiss::to_simd_level("AVX512")); + EXPECT_EQ(faiss::SIMDLevel::ARM_NEON, faiss::to_simd_level("ARM_NEON")); + EXPECT_FALSE(faiss::to_simd_level("INVALID").has_value()); +} From 3c61326ea8981a46e8776c05bd924f6b211070ff Mon Sep 17 00:00:00 2001 From: matthijs Date: Thu, 28 Aug 2025 00:59:30 -0700 Subject: [PATCH 2/5] dynamic dispatch distances_simd Summary: `fvec_madd` is the first function to test dispatching to AVX and AVX512 distances_simd.cpp is split into specialized files distances_avx2.cpp distances_avx512.cpp that are compiled with appropriate flags. Differential Revision: D72937708 Reviewed By: mnorris11 --- faiss/utils/distances.h | 98 + faiss/utils/distances_simd.cpp | 3589 +---------------- faiss/utils/extra_distances-inl.h | 7 - faiss/utils/simd_impl/distances_aarch64.cpp | 137 + faiss/utils/simd_impl/distances_arm_sve.cpp | 496 +++ faiss/utils/simd_impl/distances_autovec-inl.h | 153 + faiss/utils/simd_impl/distances_avx.cpp | 99 + faiss/utils/simd_impl/distances_avx2.cpp | 1178 ++++++ faiss/utils/simd_impl/distances_avx512.cpp | 1092 +++++ faiss/utils/simd_impl/distances_sse-inl.h | 385 ++ faiss/utils/simd_levels.cpp | 3 +- faiss/utils/simd_levels.h | 2 +- tests/test_distances_simd.cpp | 532 ++- tests/test_simd_levels.cpp | 156 +- 14 files changed, 4197 insertions(+), 3730 deletions(-) create mode 100644 faiss/utils/simd_impl/distances_aarch64.cpp create mode 100644 faiss/utils/simd_impl/distances_arm_sve.cpp create mode 100644 faiss/utils/simd_impl/distances_autovec-inl.h create mode 100644 faiss/utils/simd_impl/distances_avx.cpp create mode 100644 faiss/utils/simd_impl/distances_avx2.cpp create mode 100644 faiss/utils/simd_impl/distances_avx512.cpp create mode 100644 faiss/utils/simd_impl/distances_sse-inl.h diff --git a/faiss/utils/distances.h b/faiss/utils/distances.h index 80d2cfc699..3531b10845 100644 --- a/faiss/utils/distances.h +++ b/faiss/utils/distances.h @@ -15,6 +15,7 @@ #include #include +#include namespace faiss { @@ -27,15 +28,27 @@ struct IDSelector; /// Squared L2 distance between two vectors float fvec_L2sqr(const float* x, const float* y, size_t d); +template +float fvec_L2sqr(const float* x, const float* y, size_t d); + /// inner product float fvec_inner_product(const float* x, const float* y, size_t d); +template +float fvec_inner_product(const float* x, const float* y, size_t d); + /// L1 distance float fvec_L1(const float* x, const float* y, size_t d); +template +float fvec_L1(const float* x, const float* y, size_t d); + /// infinity distance float fvec_Linf(const float* x, const float* y, size_t d); +template +float fvec_Linf(const float* x, const float* y, size_t d); + /// Special version of inner product that computes 4 distances /// between x and yi, which is performance oriented. void fvec_inner_product_batch_4( @@ -50,6 +63,19 @@ void fvec_inner_product_batch_4( float& dis2, float& dis3); +template +void fvec_inner_product_batch_4( + const float* x, + const float* y0, + const float* y1, + const float* y2, + const float* y3, + const size_t d, + float& dis0, + float& dis1, + float& dis2, + float& dis3); + /// Special version of L2sqr that computes 4 distances /// between x and yi, which is performance oriented. void fvec_L2sqr_batch_4( @@ -64,6 +90,19 @@ void fvec_L2sqr_batch_4( float& dis2, float& dis3); +template +void fvec_L2sqr_batch_4( + const float* x, + const float* y0, + const float* y1, + const float* y2, + const float* y3, + const size_t d, + float& dis0, + float& dis1, + float& dis2, + float& dis3); + /** Compute pairwise distances between sets of vectors * * @param d dimension of the vectors @@ -93,6 +132,14 @@ void fvec_inner_products_ny( size_t d, size_t ny); +template +void fvec_inner_products_ny( + float* ip, /* output inner product */ + const float* x, + const float* y, + size_t d, + size_t ny); + /* compute ny square L2 distance between x and a set of contiguous y vectors */ void fvec_L2sqr_ny( float* dis, @@ -101,6 +148,14 @@ void fvec_L2sqr_ny( size_t d, size_t ny); +template +void fvec_L2sqr_ny( + float* dis, + const float* x, + const float* y, + size_t d, + size_t ny); + /* compute ny square L2 distance between x and a set of transposed contiguous y vectors. squared lengths of y should be provided as well */ void fvec_L2sqr_ny_transposed( @@ -112,6 +167,16 @@ void fvec_L2sqr_ny_transposed( size_t d_offset, size_t ny); +template +void fvec_L2sqr_ny_transposed( + float* dis, + const float* x, + const float* y, + const float* y_sqlen, + size_t d, + size_t d_offset, + size_t ny); + /* compute ny square L2 distance between x and a set of contiguous y vectors and return the index of the nearest vector. return 0 if ny == 0. */ @@ -122,6 +187,14 @@ size_t fvec_L2sqr_ny_nearest( size_t d, size_t ny); +template +size_t fvec_L2sqr_ny_nearest( + float* distances_tmp_buffer, + const float* x, + const float* y, + size_t d, + size_t ny); + /* compute ny square L2 distance between x and a set of transposed contiguous y vectors and return the index of the nearest vector. squared lengths of y should be provided as well @@ -135,9 +208,22 @@ size_t fvec_L2sqr_ny_nearest_y_transposed( size_t d_offset, size_t ny); +template +size_t fvec_L2sqr_ny_nearest_y_transposed( + float* distances_tmp_buffer, + const float* x, + const float* y, + const float* y_sqlen, + size_t d, + size_t d_offset, + size_t ny); + /** squared norm of a vector */ float fvec_norm_L2sqr(const float* x, size_t d); +template +float fvec_norm_L2sqr(const float* x, size_t d); + /** compute the L2 norms for a set of vectors * * @param norms output norms, size nx @@ -473,6 +559,10 @@ void compute_PQ_dis_tables_dsub2( */ void fvec_madd(size_t n, const float* a, float bf, const float* b, float* c); +/* same statically */ +template +void fvec_madd(size_t n, const float* a, float bf, const float* b, float* c); + /** same as fvec_madd, also return index of the min of the result table * @return index of the min of table c */ @@ -483,4 +573,12 @@ int fvec_madd_and_argmin( const float* b, float* c); +template +int fvec_madd_and_argmin( + size_t n, + const float* a, + float bf, + const float* b, + float* c); + } // namespace faiss diff --git a/faiss/utils/distances_simd.cpp b/faiss/utils/distances_simd.cpp index c6ff8b57cb..ab174a5a54 100644 --- a/faiss/utils/distances_simd.cpp +++ b/faiss/utils/distances_simd.cpp @@ -10,7 +10,6 @@ #include #include -#include #include #include #include @@ -19,85 +18,28 @@ #include #include -#ifdef __SSE3__ -#include -#endif - -#if defined(__AVX512F__) -#include -#elif defined(__AVX2__) -#include -#endif - -#ifdef __ARM_FEATURE_SVE -#include -#endif - -#ifdef __aarch64__ -#include -#endif +#define AUTOVEC_LEVEL SIMDLevel::NONE +#include namespace faiss { -#ifdef __AVX__ -#define USE_AVX -#endif - -/********************************************************* - * Optimized distance computations - *********************************************************/ - -/* Functions to compute: - - L2 distance between 2 vectors - - inner product between 2 vectors - - L2 norm of a vector - - The functions should probably not be invoked when a large number of - vectors are be processed in batch (in which case Matrix multiply - is faster), but may be useful for comparing vectors isolated in - memory. - - Works with any vectors of any dimension, even unaligned (in which - case they are slower). - +/******* +Functions with SIMDLevel::NONE */ -/********************************************************* - * Reference implementations - */ - -float fvec_L1_ref(const float* x, const float* y, size_t d) { - size_t i; - float res = 0; - for (i = 0; i < d; i++) { - const float tmp = x[i] - y[i]; - res += fabs(tmp); - } - return res; -} - -float fvec_Linf_ref(const float* x, const float* y, size_t d) { - size_t i; - float res = 0; - for (i = 0; i < d; i++) { - res = fmax(res, fabs(x[i] - y[i])); - } - return res; -} - -void fvec_L2sqr_ny_ref( - float* dis, - const float* x, - const float* y, - size_t d, - size_t ny) { - for (size_t i = 0; i < ny; i++) { - dis[i] = fvec_L2sqr(x, y, d); - y += d; - } +template <> +void fvec_madd( + size_t n, + const float* a, + float bf, + const float* b, + float* c) { + for (size_t i = 0; i < n; i++) + c[i] = a[i] + bf * b[i]; } -void fvec_L2sqr_ny_y_transposed_ref( +template <> +void fvec_L2sqr_ny_transposed( float* dis, const float* x, const float* y, @@ -120,13 +62,50 @@ void fvec_L2sqr_ny_y_transposed_ref( } } -size_t fvec_L2sqr_ny_nearest_ref( +template <> +void fvec_inner_products_ny( + float* ip, + const float* x, + const float* y, + size_t d, + size_t ny) { +// BLAS slower for the use cases here +#if 0 +{ + FINTEGER di = d; + FINTEGER nyi = ny; + float one = 1.0, zero = 0.0; + FINTEGER onei = 1; + sgemv_ ("T", &di, &nyi, &one, y, &di, x, &onei, &zero, ip, &onei); +} +#endif + for (size_t i = 0; i < ny; i++) { + ip[i] = fvec_inner_product(x, y, d); + y += d; + } +} + +template <> +void fvec_L2sqr_ny( + float* dis, + const float* x, + const float* y, + size_t d, + size_t ny) { + for (size_t i = 0; i < ny; i++) { + dis[i] = fvec_L2sqr(x, y, d); + y += d; + } +} + +template <> +size_t fvec_L2sqr_ny_nearest( float* distances_tmp_buffer, const float* x, const float* y, size_t d, size_t ny) { - fvec_L2sqr_ny(distances_tmp_buffer, x, y, d, ny); + fvec_L2sqr_ny(distances_tmp_buffer, x, y, d, ny); size_t nearest_idx = 0; float min_dis = HUGE_VALF; @@ -141,7 +120,8 @@ size_t fvec_L2sqr_ny_nearest_ref( return nearest_idx; } -size_t fvec_L2sqr_ny_nearest_y_transposed_ref( +template <> +size_t fvec_L2sqr_ny_nearest_y_transposed( float* distances_tmp_buffer, const float* x, const float* y, @@ -149,7 +129,7 @@ size_t fvec_L2sqr_ny_nearest_y_transposed_ref( size_t d, size_t d_offset, size_t ny) { - fvec_L2sqr_ny_y_transposed_ref( + fvec_L2sqr_ny_transposed( distances_tmp_buffer, x, y, y_sqlen, d, d_offset, ny); size_t nearest_idx = 0; @@ -165,73 +145,54 @@ size_t fvec_L2sqr_ny_nearest_y_transposed_ref( return nearest_idx; } -void fvec_inner_products_ny_ref( - float* ip, - const float* x, - const float* y, - size_t d, - size_t ny) { - // BLAS slower for the use cases here -#if 0 - { - FINTEGER di = d; - FINTEGER nyi = ny; - float one = 1.0, zero = 0.0; - FINTEGER onei = 1; - sgemv_ ("T", &di, &nyi, &one, y, &di, x, &onei, &zero, ip, &onei); - } -#endif - for (size_t i = 0; i < ny; i++) { - ip[i] = fvec_inner_product(x, y, d); - y += d; +template <> +int fvec_madd_and_argmin( + size_t n, + const float* a, + float bf, + const float* b, + float* c) { + float vmin = 1e20; + int imin = -1; + + for (size_t i = 0; i < n; i++) { + c[i] = a[i] + bf * b[i]; + if (c[i] < vmin) { + vmin = c[i]; + imin = i; + } } + return imin; } /********************************************************* - * Autovectorized implementations + * dispatching functions */ -FAISS_PRAGMA_IMPRECISE_FUNCTION_BEGIN -float fvec_inner_product(const float* x, const float* y, size_t d) { - float res = 0.F; - FAISS_PRAGMA_IMPRECISE_LOOP - for (size_t i = 0; i != d; ++i) { - res += x[i] * y[i]; - } - return res; +float fvec_L1(const float* x, const float* y, size_t d) { + DISPATCH_SIMDLevel(fvec_L1, x, y, d); } -FAISS_PRAGMA_IMPRECISE_FUNCTION_END -FAISS_PRAGMA_IMPRECISE_FUNCTION_BEGIN -float fvec_norm_L2sqr(const float* x, size_t d) { - // the double in the _ref is suspected to be a typo. Some of the manual - // implementations this replaces used float. - float res = 0; - FAISS_PRAGMA_IMPRECISE_LOOP - for (size_t i = 0; i != d; ++i) { - res += x[i] * x[i]; - } +float fvec_Linf(const float* x, const float* y, size_t d) { + DISPATCH_SIMDLevel(fvec_Linf, x, y, d); +} - return res; +// dispatching functions + +float fvec_norm_L2sqr(const float* x, size_t d) { + DISPATCH_SIMDLevel(fvec_norm_L2sqr, x, d); } -FAISS_PRAGMA_IMPRECISE_FUNCTION_END -FAISS_PRAGMA_IMPRECISE_FUNCTION_BEGIN float fvec_L2sqr(const float* x, const float* y, size_t d) { - size_t i; - float res = 0; - FAISS_PRAGMA_IMPRECISE_LOOP - for (i = 0; i < d; i++) { - const float tmp = x[i] - y[i]; - res += tmp * tmp; - } - return res; + DISPATCH_SIMDLevel(fvec_L2sqr, x, y, d); +} + +float fvec_inner_product(const float* x, const float* y, size_t d) { + DISPATCH_SIMDLevel(fvec_inner_product, x, y, d); } -FAISS_PRAGMA_IMPRECISE_FUNCTION_END /// Special version of inner product that computes 4 distances /// between x and yi -FAISS_PRAGMA_IMPRECISE_FUNCTION_BEGIN void fvec_inner_product_batch_4( const float* __restrict x, const float* __restrict y0, @@ -243,28 +204,22 @@ void fvec_inner_product_batch_4( float& dis1, float& dis2, float& dis3) { - float d0 = 0; - float d1 = 0; - float d2 = 0; - float d3 = 0; - FAISS_PRAGMA_IMPRECISE_LOOP - for (size_t i = 0; i < d; ++i) { - d0 += x[i] * y0[i]; - d1 += x[i] * y1[i]; - d2 += x[i] * y2[i]; - d3 += x[i] * y3[i]; - } - - dis0 = d0; - dis1 = d1; - dis2 = d2; - dis3 = d3; + DISPATCH_SIMDLevel( + fvec_inner_product_batch_4, + x, + y0, + y1, + y2, + y3, + d, + dis0, + dis1, + dis2, + dis3); } -FAISS_PRAGMA_IMPRECISE_FUNCTION_END /// Special version of L2sqr that computes 4 distances /// between x and yi, which is performance oriented. -FAISS_PRAGMA_IMPRECISE_FUNCTION_BEGIN void fvec_L2sqr_batch_4( const float* x, const float* y0, @@ -276,3326 +231,72 @@ void fvec_L2sqr_batch_4( float& dis1, float& dis2, float& dis3) { - float d0 = 0; - float d1 = 0; - float d2 = 0; - float d3 = 0; - FAISS_PRAGMA_IMPRECISE_LOOP - for (size_t i = 0; i < d; ++i) { - const float q0 = x[i] - y0[i]; - const float q1 = x[i] - y1[i]; - const float q2 = x[i] - y2[i]; - const float q3 = x[i] - y3[i]; - d0 += q0 * q0; - d1 += q1 * q1; - d2 += q2 * q2; - d3 += q3 * q3; - } - - dis0 = d0; - dis1 = d1; - dis2 = d2; - dis3 = d3; -} -FAISS_PRAGMA_IMPRECISE_FUNCTION_END - -/********************************************************* - * SSE and AVX implementations - */ - -#ifdef __SSE3__ - -// reads 0 <= d < 4 floats as __m128 -static inline __m128 masked_read(int d, const float* x) { - assert(0 <= d && d < 4); - ALIGNED(16) float buf[4] = {0, 0, 0, 0}; - switch (d) { - case 3: - buf[2] = x[2]; - [[fallthrough]]; - case 2: - buf[1] = x[1]; - [[fallthrough]]; - case 1: - buf[0] = x[0]; - } - return _mm_load_ps(buf); - // cannot use AVX2 _mm_mask_set1_epi32 -} - -namespace { - -/// helper function -inline float horizontal_sum(const __m128 v) { - // say, v is [x0, x1, x2, x3] - - // v0 is [x2, x3, ..., ...] - const __m128 v0 = _mm_shuffle_ps(v, v, _MM_SHUFFLE(0, 0, 3, 2)); - // v1 is [x0 + x2, x1 + x3, ..., ...] - const __m128 v1 = _mm_add_ps(v, v0); - // v2 is [x1 + x3, ..., .... ,...] - __m128 v2 = _mm_shuffle_ps(v1, v1, _MM_SHUFFLE(0, 0, 0, 1)); - // v3 is [x0 + x1 + x2 + x3, ..., ..., ...] - const __m128 v3 = _mm_add_ps(v1, v2); - // return v3[0] - return _mm_cvtss_f32(v3); -} - -#ifdef __AVX2__ -/// helper function for AVX2 -inline float horizontal_sum(const __m256 v) { - // add high and low parts - const __m128 v0 = - _mm_add_ps(_mm256_castps256_ps128(v), _mm256_extractf128_ps(v, 1)); - // perform horizontal sum on v0 - return horizontal_sum(v0); -} -#endif - -#ifdef __AVX512F__ -/// helper function for AVX512 -inline float horizontal_sum(const __m512 v) { - // performs better than adding the high and low parts - return _mm512_reduce_add_ps(v); -} -#endif - -/// Function that does a component-wise operation between x and y -/// to compute L2 distances. ElementOp can then be used in the fvec_op_ny -/// functions below -struct ElementOpL2 { - static float op(float x, float y) { - float tmp = x - y; - return tmp * tmp; - } - - static __m128 op(__m128 x, __m128 y) { - __m128 tmp = _mm_sub_ps(x, y); - return _mm_mul_ps(tmp, tmp); - } - -#ifdef __AVX2__ - static __m256 op(__m256 x, __m256 y) { - __m256 tmp = _mm256_sub_ps(x, y); - return _mm256_mul_ps(tmp, tmp); - } -#endif - -#ifdef __AVX512F__ - static __m512 op(__m512 x, __m512 y) { - __m512 tmp = _mm512_sub_ps(x, y); - return _mm512_mul_ps(tmp, tmp); - } -#endif -}; - -/// Function that does a component-wise operation between x and y -/// to compute inner products -struct ElementOpIP { - static float op(float x, float y) { - return x * y; - } - - static __m128 op(__m128 x, __m128 y) { - return _mm_mul_ps(x, y); - } - -#ifdef __AVX2__ - static __m256 op(__m256 x, __m256 y) { - return _mm256_mul_ps(x, y); - } -#endif - -#ifdef __AVX512F__ - static __m512 op(__m512 x, __m512 y) { - return _mm512_mul_ps(x, y); - } -#endif -}; - -template -void fvec_op_ny_D1(float* dis, const float* x, const float* y, size_t ny) { - float x0s = x[0]; - __m128 x0 = _mm_set_ps(x0s, x0s, x0s, x0s); - - size_t i; - for (i = 0; i + 3 < ny; i += 4) { - __m128 accu = ElementOp::op(x0, _mm_loadu_ps(y)); - y += 4; - dis[i] = _mm_cvtss_f32(accu); - __m128 tmp = _mm_shuffle_ps(accu, accu, 1); - dis[i + 1] = _mm_cvtss_f32(tmp); - tmp = _mm_shuffle_ps(accu, accu, 2); - dis[i + 2] = _mm_cvtss_f32(tmp); - tmp = _mm_shuffle_ps(accu, accu, 3); - dis[i + 3] = _mm_cvtss_f32(tmp); - } - while (i < ny) { // handle non-multiple-of-4 case - dis[i++] = ElementOp::op(x0s, *y++); - } -} - -template -void fvec_op_ny_D2(float* dis, const float* x, const float* y, size_t ny) { - __m128 x0 = _mm_set_ps(x[1], x[0], x[1], x[0]); - - size_t i; - for (i = 0; i + 1 < ny; i += 2) { - __m128 accu = ElementOp::op(x0, _mm_loadu_ps(y)); - y += 4; - accu = _mm_hadd_ps(accu, accu); - dis[i] = _mm_cvtss_f32(accu); - accu = _mm_shuffle_ps(accu, accu, 3); - dis[i + 1] = _mm_cvtss_f32(accu); - } - if (i < ny) { // handle odd case - dis[i] = ElementOp::op(x[0], y[0]) + ElementOp::op(x[1], y[1]); - } + DISPATCH_SIMDLevel( + fvec_L2sqr_batch_4, x, y0, y1, y2, y3, d, dis0, dis1, dis2, dis3); } -#if defined(__AVX512F__) - -template <> -void fvec_op_ny_D2( +void fvec_L2sqr_ny_transposed( float* dis, const float* x, const float* y, + const float* y_sqlen, + size_t d, + size_t d_offset, size_t ny) { - const size_t ny16 = ny / 16; - size_t i = 0; - - if (ny16 > 0) { - // process 16 D2-vectors per loop. - _mm_prefetch((const char*)y, _MM_HINT_T0); - _mm_prefetch((const char*)(y + 32), _MM_HINT_T0); - - const __m512 m0 = _mm512_set1_ps(x[0]); - const __m512 m1 = _mm512_set1_ps(x[1]); - - for (i = 0; i < ny16 * 16; i += 16) { - _mm_prefetch((const char*)(y + 64), _MM_HINT_T0); - - // load 16x2 matrix and transpose it in registers. - // the typical bottleneck is memory access, so - // let's trade instructions for the bandwidth. - - __m512 v0; - __m512 v1; - - transpose_16x2( - _mm512_loadu_ps(y + 0 * 16), - _mm512_loadu_ps(y + 1 * 16), - v0, - v1); - - // compute distances (dot product) - __m512 distances = _mm512_mul_ps(m0, v0); - distances = _mm512_fmadd_ps(m1, v1, distances); - - // store - _mm512_storeu_ps(dis + i, distances); - - y += 32; // move to the next set of 16x2 elements - } - } - - if (i < ny) { - // process leftovers - float x0 = x[0]; - float x1 = x[1]; - - for (; i < ny; i++) { - float distance = x0 * y[0] + x1 * y[1]; - y += 2; - dis[i] = distance; - } - } + DISPATCH_SIMDLevel( + fvec_L2sqr_ny_transposed, dis, x, y, y_sqlen, d, d_offset, ny); } -template <> -void fvec_op_ny_D2( - float* dis, +void fvec_inner_products_ny( + float* ip, /* output inner product */ const float* x, const float* y, + size_t d, size_t ny) { - const size_t ny16 = ny / 16; - size_t i = 0; - - if (ny16 > 0) { - // process 16 D2-vectors per loop. - _mm_prefetch((const char*)y, _MM_HINT_T0); - _mm_prefetch((const char*)(y + 32), _MM_HINT_T0); - - const __m512 m0 = _mm512_set1_ps(x[0]); - const __m512 m1 = _mm512_set1_ps(x[1]); - - for (i = 0; i < ny16 * 16; i += 16) { - _mm_prefetch((const char*)(y + 64), _MM_HINT_T0); - - // load 16x2 matrix and transpose it in registers. - // the typical bottleneck is memory access, so - // let's trade instructions for the bandwidth. - - __m512 v0; - __m512 v1; - - transpose_16x2( - _mm512_loadu_ps(y + 0 * 16), - _mm512_loadu_ps(y + 1 * 16), - v0, - v1); - - // compute differences - const __m512 d0 = _mm512_sub_ps(m0, v0); - const __m512 d1 = _mm512_sub_ps(m1, v1); - - // compute squares of differences - __m512 distances = _mm512_mul_ps(d0, d0); - distances = _mm512_fmadd_ps(d1, d1, distances); - - // store - _mm512_storeu_ps(dis + i, distances); - - y += 32; // move to the next set of 16x2 elements - } - } - - if (i < ny) { - // process leftovers - float x0 = x[0]; - float x1 = x[1]; - - for (; i < ny; i++) { - float sub0 = x0 - y[0]; - float sub1 = x1 - y[1]; - float distance = sub0 * sub0 + sub1 * sub1; - - y += 2; - dis[i] = distance; - } - } + DISPATCH_SIMDLevel(fvec_inner_products_ny, ip, x, y, d, ny); } -#elif defined(__AVX2__) - -template <> -void fvec_op_ny_D2( +void fvec_L2sqr_ny( float* dis, const float* x, const float* y, + size_t d, size_t ny) { - const size_t ny8 = ny / 8; - size_t i = 0; - - if (ny8 > 0) { - // process 8 D2-vectors per loop. - _mm_prefetch((const char*)y, _MM_HINT_T0); - _mm_prefetch((const char*)(y + 16), _MM_HINT_T0); - - const __m256 m0 = _mm256_set1_ps(x[0]); - const __m256 m1 = _mm256_set1_ps(x[1]); - - for (i = 0; i < ny8 * 8; i += 8) { - _mm_prefetch((const char*)(y + 32), _MM_HINT_T0); - - // load 8x2 matrix and transpose it in registers. - // the typical bottleneck is memory access, so - // let's trade instructions for the bandwidth. - - __m256 v0; - __m256 v1; - - transpose_8x2( - _mm256_loadu_ps(y + 0 * 8), - _mm256_loadu_ps(y + 1 * 8), - v0, - v1); - - // compute distances - __m256 distances = _mm256_mul_ps(m0, v0); - distances = _mm256_fmadd_ps(m1, v1, distances); - - // store - _mm256_storeu_ps(dis + i, distances); - - y += 16; - } - } - - if (i < ny) { - // process leftovers - float x0 = x[0]; - float x1 = x[1]; - - for (; i < ny; i++) { - float distance = x0 * y[0] + x1 * y[1]; - y += 2; - dis[i] = distance; - } - } + DISPATCH_SIMDLevel(fvec_L2sqr_ny, dis, x, y, d, ny); } -template <> -void fvec_op_ny_D2( - float* dis, +size_t fvec_L2sqr_ny_nearest( + float* distances_tmp_buffer, const float* x, const float* y, + size_t d, size_t ny) { - const size_t ny8 = ny / 8; - size_t i = 0; - - if (ny8 > 0) { - // process 8 D2-vectors per loop. - _mm_prefetch((const char*)y, _MM_HINT_T0); - _mm_prefetch((const char*)(y + 16), _MM_HINT_T0); - - const __m256 m0 = _mm256_set1_ps(x[0]); - const __m256 m1 = _mm256_set1_ps(x[1]); - - for (i = 0; i < ny8 * 8; i += 8) { - _mm_prefetch((const char*)(y + 32), _MM_HINT_T0); - - // load 8x2 matrix and transpose it in registers. - // the typical bottleneck is memory access, so - // let's trade instructions for the bandwidth. - - __m256 v0; - __m256 v1; - - transpose_8x2( - _mm256_loadu_ps(y + 0 * 8), - _mm256_loadu_ps(y + 1 * 8), - v0, - v1); + DISPATCH_SIMDLevel( + fvec_L2sqr_ny_nearest, distances_tmp_buffer, x, y, d, ny); +} - // compute differences - const __m256 d0 = _mm256_sub_ps(m0, v0); - const __m256 d1 = _mm256_sub_ps(m1, v1); +size_t fvec_L2sqr_ny_nearest_y_transposed( + float* distances_tmp_buffer, + const float* x, + const float* y, + const float* y_sqlen, + size_t d, + size_t d_offset, + size_t ny) { + DISPATCH_SIMDLevel( + fvec_L2sqr_ny_nearest_y_transposed, + distances_tmp_buffer, + x, + y, + y_sqlen, + d, + d_offset, + ny); +} - // compute squares of differences - __m256 distances = _mm256_mul_ps(d0, d0); - distances = _mm256_fmadd_ps(d1, d1, distances); - - // store - _mm256_storeu_ps(dis + i, distances); - - y += 16; - } - } - - if (i < ny) { - // process leftovers - float x0 = x[0]; - float x1 = x[1]; - - for (; i < ny; i++) { - float sub0 = x0 - y[0]; - float sub1 = x1 - y[1]; - float distance = sub0 * sub0 + sub1 * sub1; - - y += 2; - dis[i] = distance; - } - } -} - -#endif - -template -void fvec_op_ny_D4(float* dis, const float* x, const float* y, size_t ny) { - __m128 x0 = _mm_loadu_ps(x); - - for (size_t i = 0; i < ny; i++) { - __m128 accu = ElementOp::op(x0, _mm_loadu_ps(y)); - y += 4; - dis[i] = horizontal_sum(accu); - } -} - -#if defined(__AVX512F__) - -template <> -void fvec_op_ny_D4( - float* dis, - const float* x, - const float* y, - size_t ny) { - const size_t ny16 = ny / 16; - size_t i = 0; - - if (ny16 > 0) { - // process 16 D4-vectors per loop. - const __m512 m0 = _mm512_set1_ps(x[0]); - const __m512 m1 = _mm512_set1_ps(x[1]); - const __m512 m2 = _mm512_set1_ps(x[2]); - const __m512 m3 = _mm512_set1_ps(x[3]); - - for (i = 0; i < ny16 * 16; i += 16) { - // load 16x4 matrix and transpose it in registers. - // the typical bottleneck is memory access, so - // let's trade instructions for the bandwidth. - - __m512 v0; - __m512 v1; - __m512 v2; - __m512 v3; - - transpose_16x4( - _mm512_loadu_ps(y + 0 * 16), - _mm512_loadu_ps(y + 1 * 16), - _mm512_loadu_ps(y + 2 * 16), - _mm512_loadu_ps(y + 3 * 16), - v0, - v1, - v2, - v3); - - // compute distances - __m512 distances = _mm512_mul_ps(m0, v0); - distances = _mm512_fmadd_ps(m1, v1, distances); - distances = _mm512_fmadd_ps(m2, v2, distances); - distances = _mm512_fmadd_ps(m3, v3, distances); - - // store - _mm512_storeu_ps(dis + i, distances); - - y += 64; // move to the next set of 16x4 elements - } - } - - if (i < ny) { - // process leftovers - __m128 x0 = _mm_loadu_ps(x); - - for (; i < ny; i++) { - __m128 accu = ElementOpIP::op(x0, _mm_loadu_ps(y)); - y += 4; - dis[i] = horizontal_sum(accu); - } - } -} - -template <> -void fvec_op_ny_D4( - float* dis, - const float* x, - const float* y, - size_t ny) { - const size_t ny16 = ny / 16; - size_t i = 0; - - if (ny16 > 0) { - // process 16 D4-vectors per loop. - const __m512 m0 = _mm512_set1_ps(x[0]); - const __m512 m1 = _mm512_set1_ps(x[1]); - const __m512 m2 = _mm512_set1_ps(x[2]); - const __m512 m3 = _mm512_set1_ps(x[3]); - - for (i = 0; i < ny16 * 16; i += 16) { - // load 16x4 matrix and transpose it in registers. - // the typical bottleneck is memory access, so - // let's trade instructions for the bandwidth. - - __m512 v0; - __m512 v1; - __m512 v2; - __m512 v3; - - transpose_16x4( - _mm512_loadu_ps(y + 0 * 16), - _mm512_loadu_ps(y + 1 * 16), - _mm512_loadu_ps(y + 2 * 16), - _mm512_loadu_ps(y + 3 * 16), - v0, - v1, - v2, - v3); - - // compute differences - const __m512 d0 = _mm512_sub_ps(m0, v0); - const __m512 d1 = _mm512_sub_ps(m1, v1); - const __m512 d2 = _mm512_sub_ps(m2, v2); - const __m512 d3 = _mm512_sub_ps(m3, v3); - - // compute squares of differences - __m512 distances = _mm512_mul_ps(d0, d0); - distances = _mm512_fmadd_ps(d1, d1, distances); - distances = _mm512_fmadd_ps(d2, d2, distances); - distances = _mm512_fmadd_ps(d3, d3, distances); - - // store - _mm512_storeu_ps(dis + i, distances); - - y += 64; // move to the next set of 16x4 elements - } - } - - if (i < ny) { - // process leftovers - __m128 x0 = _mm_loadu_ps(x); - - for (; i < ny; i++) { - __m128 accu = ElementOpL2::op(x0, _mm_loadu_ps(y)); - y += 4; - dis[i] = horizontal_sum(accu); - } - } -} - -#elif defined(__AVX2__) - -template <> -void fvec_op_ny_D4( - float* dis, - const float* x, - const float* y, - size_t ny) { - const size_t ny8 = ny / 8; - size_t i = 0; - - if (ny8 > 0) { - // process 8 D4-vectors per loop. - const __m256 m0 = _mm256_set1_ps(x[0]); - const __m256 m1 = _mm256_set1_ps(x[1]); - const __m256 m2 = _mm256_set1_ps(x[2]); - const __m256 m3 = _mm256_set1_ps(x[3]); - - for (i = 0; i < ny8 * 8; i += 8) { - // load 8x4 matrix and transpose it in registers. - // the typical bottleneck is memory access, so - // let's trade instructions for the bandwidth. - - __m256 v0; - __m256 v1; - __m256 v2; - __m256 v3; - - transpose_8x4( - _mm256_loadu_ps(y + 0 * 8), - _mm256_loadu_ps(y + 1 * 8), - _mm256_loadu_ps(y + 2 * 8), - _mm256_loadu_ps(y + 3 * 8), - v0, - v1, - v2, - v3); - - // compute distances - __m256 distances = _mm256_mul_ps(m0, v0); - distances = _mm256_fmadd_ps(m1, v1, distances); - distances = _mm256_fmadd_ps(m2, v2, distances); - distances = _mm256_fmadd_ps(m3, v3, distances); - - // store - _mm256_storeu_ps(dis + i, distances); - - y += 32; - } - } - - if (i < ny) { - // process leftovers - __m128 x0 = _mm_loadu_ps(x); - - for (; i < ny; i++) { - __m128 accu = ElementOpIP::op(x0, _mm_loadu_ps(y)); - y += 4; - dis[i] = horizontal_sum(accu); - } - } -} - -template <> -void fvec_op_ny_D4( - float* dis, - const float* x, - const float* y, - size_t ny) { - const size_t ny8 = ny / 8; - size_t i = 0; - - if (ny8 > 0) { - // process 8 D4-vectors per loop. - const __m256 m0 = _mm256_set1_ps(x[0]); - const __m256 m1 = _mm256_set1_ps(x[1]); - const __m256 m2 = _mm256_set1_ps(x[2]); - const __m256 m3 = _mm256_set1_ps(x[3]); - - for (i = 0; i < ny8 * 8; i += 8) { - // load 8x4 matrix and transpose it in registers. - // the typical bottleneck is memory access, so - // let's trade instructions for the bandwidth. - - __m256 v0; - __m256 v1; - __m256 v2; - __m256 v3; - - transpose_8x4( - _mm256_loadu_ps(y + 0 * 8), - _mm256_loadu_ps(y + 1 * 8), - _mm256_loadu_ps(y + 2 * 8), - _mm256_loadu_ps(y + 3 * 8), - v0, - v1, - v2, - v3); - - // compute differences - const __m256 d0 = _mm256_sub_ps(m0, v0); - const __m256 d1 = _mm256_sub_ps(m1, v1); - const __m256 d2 = _mm256_sub_ps(m2, v2); - const __m256 d3 = _mm256_sub_ps(m3, v3); - - // compute squares of differences - __m256 distances = _mm256_mul_ps(d0, d0); - distances = _mm256_fmadd_ps(d1, d1, distances); - distances = _mm256_fmadd_ps(d2, d2, distances); - distances = _mm256_fmadd_ps(d3, d3, distances); - - // store - _mm256_storeu_ps(dis + i, distances); - - y += 32; - } - } - - if (i < ny) { - // process leftovers - __m128 x0 = _mm_loadu_ps(x); - - for (; i < ny; i++) { - __m128 accu = ElementOpL2::op(x0, _mm_loadu_ps(y)); - y += 4; - dis[i] = horizontal_sum(accu); - } - } -} - -#endif - -template -void fvec_op_ny_D8(float* dis, const float* x, const float* y, size_t ny) { - __m128 x0 = _mm_loadu_ps(x); - __m128 x1 = _mm_loadu_ps(x + 4); - - for (size_t i = 0; i < ny; i++) { - __m128 accu = ElementOp::op(x0, _mm_loadu_ps(y)); - y += 4; - accu = _mm_add_ps(accu, ElementOp::op(x1, _mm_loadu_ps(y))); - y += 4; - accu = _mm_hadd_ps(accu, accu); - accu = _mm_hadd_ps(accu, accu); - dis[i] = _mm_cvtss_f32(accu); - } -} - -#if defined(__AVX512F__) - -template <> -void fvec_op_ny_D8( - float* dis, - const float* x, - const float* y, - size_t ny) { - const size_t ny16 = ny / 16; - size_t i = 0; - - if (ny16 > 0) { - // process 16 D16-vectors per loop. - const __m512 m0 = _mm512_set1_ps(x[0]); - const __m512 m1 = _mm512_set1_ps(x[1]); - const __m512 m2 = _mm512_set1_ps(x[2]); - const __m512 m3 = _mm512_set1_ps(x[3]); - const __m512 m4 = _mm512_set1_ps(x[4]); - const __m512 m5 = _mm512_set1_ps(x[5]); - const __m512 m6 = _mm512_set1_ps(x[6]); - const __m512 m7 = _mm512_set1_ps(x[7]); - - for (i = 0; i < ny16 * 16; i += 16) { - // load 16x8 matrix and transpose it in registers. - // the typical bottleneck is memory access, so - // let's trade instructions for the bandwidth. - - __m512 v0; - __m512 v1; - __m512 v2; - __m512 v3; - __m512 v4; - __m512 v5; - __m512 v6; - __m512 v7; - - transpose_16x8( - _mm512_loadu_ps(y + 0 * 16), - _mm512_loadu_ps(y + 1 * 16), - _mm512_loadu_ps(y + 2 * 16), - _mm512_loadu_ps(y + 3 * 16), - _mm512_loadu_ps(y + 4 * 16), - _mm512_loadu_ps(y + 5 * 16), - _mm512_loadu_ps(y + 6 * 16), - _mm512_loadu_ps(y + 7 * 16), - v0, - v1, - v2, - v3, - v4, - v5, - v6, - v7); - - // compute distances - __m512 distances = _mm512_mul_ps(m0, v0); - distances = _mm512_fmadd_ps(m1, v1, distances); - distances = _mm512_fmadd_ps(m2, v2, distances); - distances = _mm512_fmadd_ps(m3, v3, distances); - distances = _mm512_fmadd_ps(m4, v4, distances); - distances = _mm512_fmadd_ps(m5, v5, distances); - distances = _mm512_fmadd_ps(m6, v6, distances); - distances = _mm512_fmadd_ps(m7, v7, distances); - - // store - _mm512_storeu_ps(dis + i, distances); - - y += 128; // 16 floats * 8 rows - } - } - - if (i < ny) { - // process leftovers - __m256 x0 = _mm256_loadu_ps(x); - - for (; i < ny; i++) { - __m256 accu = ElementOpIP::op(x0, _mm256_loadu_ps(y)); - y += 8; - dis[i] = horizontal_sum(accu); - } - } -} - -template <> -void fvec_op_ny_D8( - float* dis, - const float* x, - const float* y, - size_t ny) { - const size_t ny16 = ny / 16; - size_t i = 0; - - if (ny16 > 0) { - // process 16 D16-vectors per loop. - const __m512 m0 = _mm512_set1_ps(x[0]); - const __m512 m1 = _mm512_set1_ps(x[1]); - const __m512 m2 = _mm512_set1_ps(x[2]); - const __m512 m3 = _mm512_set1_ps(x[3]); - const __m512 m4 = _mm512_set1_ps(x[4]); - const __m512 m5 = _mm512_set1_ps(x[5]); - const __m512 m6 = _mm512_set1_ps(x[6]); - const __m512 m7 = _mm512_set1_ps(x[7]); - - for (i = 0; i < ny16 * 16; i += 16) { - // load 16x8 matrix and transpose it in registers. - // the typical bottleneck is memory access, so - // let's trade instructions for the bandwidth. - - __m512 v0; - __m512 v1; - __m512 v2; - __m512 v3; - __m512 v4; - __m512 v5; - __m512 v6; - __m512 v7; - - transpose_16x8( - _mm512_loadu_ps(y + 0 * 16), - _mm512_loadu_ps(y + 1 * 16), - _mm512_loadu_ps(y + 2 * 16), - _mm512_loadu_ps(y + 3 * 16), - _mm512_loadu_ps(y + 4 * 16), - _mm512_loadu_ps(y + 5 * 16), - _mm512_loadu_ps(y + 6 * 16), - _mm512_loadu_ps(y + 7 * 16), - v0, - v1, - v2, - v3, - v4, - v5, - v6, - v7); - - // compute differences - const __m512 d0 = _mm512_sub_ps(m0, v0); - const __m512 d1 = _mm512_sub_ps(m1, v1); - const __m512 d2 = _mm512_sub_ps(m2, v2); - const __m512 d3 = _mm512_sub_ps(m3, v3); - const __m512 d4 = _mm512_sub_ps(m4, v4); - const __m512 d5 = _mm512_sub_ps(m5, v5); - const __m512 d6 = _mm512_sub_ps(m6, v6); - const __m512 d7 = _mm512_sub_ps(m7, v7); - - // compute squares of differences - __m512 distances = _mm512_mul_ps(d0, d0); - distances = _mm512_fmadd_ps(d1, d1, distances); - distances = _mm512_fmadd_ps(d2, d2, distances); - distances = _mm512_fmadd_ps(d3, d3, distances); - distances = _mm512_fmadd_ps(d4, d4, distances); - distances = _mm512_fmadd_ps(d5, d5, distances); - distances = _mm512_fmadd_ps(d6, d6, distances); - distances = _mm512_fmadd_ps(d7, d7, distances); - - // store - _mm512_storeu_ps(dis + i, distances); - - y += 128; // 16 floats * 8 rows - } - } - - if (i < ny) { - // process leftovers - __m256 x0 = _mm256_loadu_ps(x); - - for (; i < ny; i++) { - __m256 accu = ElementOpL2::op(x0, _mm256_loadu_ps(y)); - y += 8; - dis[i] = horizontal_sum(accu); - } - } -} - -#elif defined(__AVX2__) - -template <> -void fvec_op_ny_D8( - float* dis, - const float* x, - const float* y, - size_t ny) { - const size_t ny8 = ny / 8; - size_t i = 0; - - if (ny8 > 0) { - // process 8 D8-vectors per loop. - const __m256 m0 = _mm256_set1_ps(x[0]); - const __m256 m1 = _mm256_set1_ps(x[1]); - const __m256 m2 = _mm256_set1_ps(x[2]); - const __m256 m3 = _mm256_set1_ps(x[3]); - const __m256 m4 = _mm256_set1_ps(x[4]); - const __m256 m5 = _mm256_set1_ps(x[5]); - const __m256 m6 = _mm256_set1_ps(x[6]); - const __m256 m7 = _mm256_set1_ps(x[7]); - - for (i = 0; i < ny8 * 8; i += 8) { - // load 8x8 matrix and transpose it in registers. - // the typical bottleneck is memory access, so - // let's trade instructions for the bandwidth. - - __m256 v0; - __m256 v1; - __m256 v2; - __m256 v3; - __m256 v4; - __m256 v5; - __m256 v6; - __m256 v7; - - transpose_8x8( - _mm256_loadu_ps(y + 0 * 8), - _mm256_loadu_ps(y + 1 * 8), - _mm256_loadu_ps(y + 2 * 8), - _mm256_loadu_ps(y + 3 * 8), - _mm256_loadu_ps(y + 4 * 8), - _mm256_loadu_ps(y + 5 * 8), - _mm256_loadu_ps(y + 6 * 8), - _mm256_loadu_ps(y + 7 * 8), - v0, - v1, - v2, - v3, - v4, - v5, - v6, - v7); - - // compute distances - __m256 distances = _mm256_mul_ps(m0, v0); - distances = _mm256_fmadd_ps(m1, v1, distances); - distances = _mm256_fmadd_ps(m2, v2, distances); - distances = _mm256_fmadd_ps(m3, v3, distances); - distances = _mm256_fmadd_ps(m4, v4, distances); - distances = _mm256_fmadd_ps(m5, v5, distances); - distances = _mm256_fmadd_ps(m6, v6, distances); - distances = _mm256_fmadd_ps(m7, v7, distances); - - // store - _mm256_storeu_ps(dis + i, distances); - - y += 64; - } - } - - if (i < ny) { - // process leftovers - __m256 x0 = _mm256_loadu_ps(x); - - for (; i < ny; i++) { - __m256 accu = ElementOpIP::op(x0, _mm256_loadu_ps(y)); - y += 8; - dis[i] = horizontal_sum(accu); - } - } -} - -template <> -void fvec_op_ny_D8( - float* dis, - const float* x, - const float* y, - size_t ny) { - const size_t ny8 = ny / 8; - size_t i = 0; - - if (ny8 > 0) { - // process 8 D8-vectors per loop. - const __m256 m0 = _mm256_set1_ps(x[0]); - const __m256 m1 = _mm256_set1_ps(x[1]); - const __m256 m2 = _mm256_set1_ps(x[2]); - const __m256 m3 = _mm256_set1_ps(x[3]); - const __m256 m4 = _mm256_set1_ps(x[4]); - const __m256 m5 = _mm256_set1_ps(x[5]); - const __m256 m6 = _mm256_set1_ps(x[6]); - const __m256 m7 = _mm256_set1_ps(x[7]); - - for (i = 0; i < ny8 * 8; i += 8) { - // load 8x8 matrix and transpose it in registers. - // the typical bottleneck is memory access, so - // let's trade instructions for the bandwidth. - - __m256 v0; - __m256 v1; - __m256 v2; - __m256 v3; - __m256 v4; - __m256 v5; - __m256 v6; - __m256 v7; - - transpose_8x8( - _mm256_loadu_ps(y + 0 * 8), - _mm256_loadu_ps(y + 1 * 8), - _mm256_loadu_ps(y + 2 * 8), - _mm256_loadu_ps(y + 3 * 8), - _mm256_loadu_ps(y + 4 * 8), - _mm256_loadu_ps(y + 5 * 8), - _mm256_loadu_ps(y + 6 * 8), - _mm256_loadu_ps(y + 7 * 8), - v0, - v1, - v2, - v3, - v4, - v5, - v6, - v7); - - // compute differences - const __m256 d0 = _mm256_sub_ps(m0, v0); - const __m256 d1 = _mm256_sub_ps(m1, v1); - const __m256 d2 = _mm256_sub_ps(m2, v2); - const __m256 d3 = _mm256_sub_ps(m3, v3); - const __m256 d4 = _mm256_sub_ps(m4, v4); - const __m256 d5 = _mm256_sub_ps(m5, v5); - const __m256 d6 = _mm256_sub_ps(m6, v6); - const __m256 d7 = _mm256_sub_ps(m7, v7); - - // compute squares of differences - __m256 distances = _mm256_mul_ps(d0, d0); - distances = _mm256_fmadd_ps(d1, d1, distances); - distances = _mm256_fmadd_ps(d2, d2, distances); - distances = _mm256_fmadd_ps(d3, d3, distances); - distances = _mm256_fmadd_ps(d4, d4, distances); - distances = _mm256_fmadd_ps(d5, d5, distances); - distances = _mm256_fmadd_ps(d6, d6, distances); - distances = _mm256_fmadd_ps(d7, d7, distances); - - // store - _mm256_storeu_ps(dis + i, distances); - - y += 64; - } - } - - if (i < ny) { - // process leftovers - __m256 x0 = _mm256_loadu_ps(x); - - for (; i < ny; i++) { - __m256 accu = ElementOpL2::op(x0, _mm256_loadu_ps(y)); - y += 8; - dis[i] = horizontal_sum(accu); - } - } -} - -#endif - -template -void fvec_op_ny_D12(float* dis, const float* x, const float* y, size_t ny) { - __m128 x0 = _mm_loadu_ps(x); - __m128 x1 = _mm_loadu_ps(x + 4); - __m128 x2 = _mm_loadu_ps(x + 8); - - for (size_t i = 0; i < ny; i++) { - __m128 accu = ElementOp::op(x0, _mm_loadu_ps(y)); - y += 4; - accu = _mm_add_ps(accu, ElementOp::op(x1, _mm_loadu_ps(y))); - y += 4; - accu = _mm_add_ps(accu, ElementOp::op(x2, _mm_loadu_ps(y))); - y += 4; - dis[i] = horizontal_sum(accu); - } -} - -} // anonymous namespace - -void fvec_L2sqr_ny( - float* dis, - const float* x, - const float* y, - size_t d, - size_t ny) { - // optimized for a few special cases - -#define DISPATCH(dval) \ - case dval: \ - fvec_op_ny_D##dval(dis, x, y, ny); \ - return; - - switch (d) { - DISPATCH(1) - DISPATCH(2) - DISPATCH(4) - DISPATCH(8) - DISPATCH(12) - default: - fvec_L2sqr_ny_ref(dis, x, y, d, ny); - return; - } -#undef DISPATCH -} - -void fvec_inner_products_ny( - float* dis, - const float* x, - const float* y, - size_t d, - size_t ny) { -#define DISPATCH(dval) \ - case dval: \ - fvec_op_ny_D##dval(dis, x, y, ny); \ - return; - - switch (d) { - DISPATCH(1) - DISPATCH(2) - DISPATCH(4) - DISPATCH(8) - DISPATCH(12) - default: - fvec_inner_products_ny_ref(dis, x, y, d, ny); - return; - } -#undef DISPATCH -} - -#if defined(__AVX512F__) - -template -void fvec_L2sqr_ny_y_transposed_D( - float* distances, - const float* x, - const float* y, - const float* y_sqlen, - const size_t d_offset, - size_t ny) { - // current index being processed - size_t i = 0; - - // squared length of x - float x_sqlen = 0; - for (size_t j = 0; j < DIM; j++) { - x_sqlen += x[j] * x[j]; - } - - // process 16 vectors per loop - const size_t ny16 = ny / 16; - - if (ny16 > 0) { - // m[i] = (2 * x[i], ... 2 * x[i]) - __m512 m[DIM]; - for (size_t j = 0; j < DIM; j++) { - m[j] = _mm512_set1_ps(x[j]); - m[j] = _mm512_add_ps(m[j], m[j]); // m[j] = 2 * x[j] - } - - __m512 x_sqlen_ymm = _mm512_set1_ps(x_sqlen); - - for (; i < ny16 * 16; i += 16) { - // Load vectors for 16 dimensions - __m512 v[DIM]; - for (size_t j = 0; j < DIM; j++) { - v[j] = _mm512_loadu_ps(y + j * d_offset); - } - - // Compute dot products - __m512 dp = _mm512_fnmadd_ps(m[0], v[0], x_sqlen_ymm); - for (size_t j = 1; j < DIM; j++) { - dp = _mm512_fnmadd_ps(m[j], v[j], dp); - } - - // Compute y^2 - (2 * x, y) + x^2 - __m512 distances_v = _mm512_add_ps(_mm512_loadu_ps(y_sqlen), dp); - - _mm512_storeu_ps(distances + i, distances_v); - - // Scroll y and y_sqlen forward - y += 16; - y_sqlen += 16; - } - } - - if (i < ny) { - // Process leftovers - for (; i < ny; i++) { - float dp = 0; - for (size_t j = 0; j < DIM; j++) { - dp += x[j] * y[j * d_offset]; - } - - // Compute y^2 - 2 * (x, y), which is sufficient for looking for the - // lowest distance. - const float distance = y_sqlen[0] - 2 * dp + x_sqlen; - distances[i] = distance; - - y += 1; - y_sqlen += 1; - } - } -} - -#elif defined(__AVX2__) - -template -void fvec_L2sqr_ny_y_transposed_D( - float* distances, - const float* x, - const float* y, - const float* y_sqlen, - const size_t d_offset, - size_t ny) { - // current index being processed - size_t i = 0; - - // squared length of x - float x_sqlen = 0; - for (size_t j = 0; j < DIM; j++) { - x_sqlen += x[j] * x[j]; - } - - // process 8 vectors per loop. - const size_t ny8 = ny / 8; - - if (ny8 > 0) { - // m[i] = (2 * x[i], ... 2 * x[i]) - __m256 m[DIM]; - for (size_t j = 0; j < DIM; j++) { - m[j] = _mm256_set1_ps(x[j]); - m[j] = _mm256_add_ps(m[j], m[j]); - } - - __m256 x_sqlen_ymm = _mm256_set1_ps(x_sqlen); - - for (; i < ny8 * 8; i += 8) { - // collect dim 0 for 8 D4-vectors. - const __m256 v0 = _mm256_loadu_ps(y + 0 * d_offset); - - // compute dot products - // this is x^2 - 2x[0]*y[0] - __m256 dp = _mm256_fnmadd_ps(m[0], v0, x_sqlen_ymm); - - for (size_t j = 1; j < DIM; j++) { - // collect dim j for 8 D4-vectors. - const __m256 vj = _mm256_loadu_ps(y + j * d_offset); - dp = _mm256_fnmadd_ps(m[j], vj, dp); - } - - // we've got x^2 - (2x, y) at this point - - // y^2 - (2x, y) + x^2 - __m256 distances_v = _mm256_add_ps(_mm256_loadu_ps(y_sqlen), dp); - - _mm256_storeu_ps(distances + i, distances_v); - - // scroll y and y_sqlen forward. - y += 8; - y_sqlen += 8; - } - } - - if (i < ny) { - // process leftovers - for (; i < ny; i++) { - float dp = 0; - for (size_t j = 0; j < DIM; j++) { - dp += x[j] * y[j * d_offset]; - } - - // compute y^2 - 2 * (x, y), which is sufficient for looking for the - // lowest distance. - const float distance = y_sqlen[0] - 2 * dp + x_sqlen; - distances[i] = distance; - - y += 1; - y_sqlen += 1; - } - } -} - -#endif - -void fvec_L2sqr_ny_transposed( - float* dis, - const float* x, - const float* y, - const float* y_sqlen, - size_t d, - size_t d_offset, - size_t ny) { - // optimized for a few special cases - -#ifdef __AVX2__ -#define DISPATCH(dval) \ - case dval: \ - return fvec_L2sqr_ny_y_transposed_D( \ - dis, x, y, y_sqlen, d_offset, ny); - - switch (d) { - DISPATCH(1) - DISPATCH(2) - DISPATCH(4) - DISPATCH(8) - default: - return fvec_L2sqr_ny_y_transposed_ref( - dis, x, y, y_sqlen, d, d_offset, ny); - } -#undef DISPATCH -#else - // non-AVX2 case - return fvec_L2sqr_ny_y_transposed_ref(dis, x, y, y_sqlen, d, d_offset, ny); -#endif -} - -#if defined(__AVX512F__) - -size_t fvec_L2sqr_ny_nearest_D2( - float* distances_tmp_buffer, - const float* x, - const float* y, - size_t ny) { - // this implementation does not use distances_tmp_buffer. - - size_t i = 0; - float current_min_distance = HUGE_VALF; - size_t current_min_index = 0; - - const size_t ny16 = ny / 16; - if (ny16 > 0) { - _mm_prefetch((const char*)y, _MM_HINT_T0); - _mm_prefetch((const char*)(y + 32), _MM_HINT_T0); - - __m512 min_distances = _mm512_set1_ps(HUGE_VALF); - __m512i min_indices = _mm512_set1_epi32(0); - - __m512i current_indices = _mm512_setr_epi32( - 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15); - const __m512i indices_increment = _mm512_set1_epi32(16); - - const __m512 m0 = _mm512_set1_ps(x[0]); - const __m512 m1 = _mm512_set1_ps(x[1]); - - for (; i < ny16 * 16; i += 16) { - _mm_prefetch((const char*)(y + 64), _MM_HINT_T0); - - __m512 v0; - __m512 v1; - - transpose_16x2( - _mm512_loadu_ps(y + 0 * 16), - _mm512_loadu_ps(y + 1 * 16), - v0, - v1); - - const __m512 d0 = _mm512_sub_ps(m0, v0); - const __m512 d1 = _mm512_sub_ps(m1, v1); - - __m512 distances = _mm512_mul_ps(d0, d0); - distances = _mm512_fmadd_ps(d1, d1, distances); - - __mmask16 comparison = - _mm512_cmp_ps_mask(distances, min_distances, _CMP_LT_OS); - - min_distances = _mm512_min_ps(distances, min_distances); - min_indices = _mm512_mask_blend_epi32( - comparison, min_indices, current_indices); - - current_indices = - _mm512_add_epi32(current_indices, indices_increment); - - y += 32; - } - - alignas(64) float min_distances_scalar[16]; - alignas(64) uint32_t min_indices_scalar[16]; - _mm512_store_ps(min_distances_scalar, min_distances); - _mm512_store_epi32(min_indices_scalar, min_indices); - - for (size_t j = 0; j < 16; j++) { - if (current_min_distance > min_distances_scalar[j]) { - current_min_distance = min_distances_scalar[j]; - current_min_index = min_indices_scalar[j]; - } - } - } - - if (i < ny) { - float x0 = x[0]; - float x1 = x[1]; - - for (; i < ny; i++) { - float sub0 = x0 - y[0]; - float sub1 = x1 - y[1]; - float distance = sub0 * sub0 + sub1 * sub1; - - y += 2; - - if (current_min_distance > distance) { - current_min_distance = distance; - current_min_index = i; - } - } - } - - return current_min_index; -} - -size_t fvec_L2sqr_ny_nearest_D4( - float* distances_tmp_buffer, - const float* x, - const float* y, - size_t ny) { - // this implementation does not use distances_tmp_buffer. - - size_t i = 0; - float current_min_distance = HUGE_VALF; - size_t current_min_index = 0; - - const size_t ny16 = ny / 16; - - if (ny16 > 0) { - __m512 min_distances = _mm512_set1_ps(HUGE_VALF); - __m512i min_indices = _mm512_set1_epi32(0); - - __m512i current_indices = _mm512_setr_epi32( - 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15); - const __m512i indices_increment = _mm512_set1_epi32(16); - - const __m512 m0 = _mm512_set1_ps(x[0]); - const __m512 m1 = _mm512_set1_ps(x[1]); - const __m512 m2 = _mm512_set1_ps(x[2]); - const __m512 m3 = _mm512_set1_ps(x[3]); - - for (; i < ny16 * 16; i += 16) { - __m512 v0; - __m512 v1; - __m512 v2; - __m512 v3; - - transpose_16x4( - _mm512_loadu_ps(y + 0 * 16), - _mm512_loadu_ps(y + 1 * 16), - _mm512_loadu_ps(y + 2 * 16), - _mm512_loadu_ps(y + 3 * 16), - v0, - v1, - v2, - v3); - - const __m512 d0 = _mm512_sub_ps(m0, v0); - const __m512 d1 = _mm512_sub_ps(m1, v1); - const __m512 d2 = _mm512_sub_ps(m2, v2); - const __m512 d3 = _mm512_sub_ps(m3, v3); - - __m512 distances = _mm512_mul_ps(d0, d0); - distances = _mm512_fmadd_ps(d1, d1, distances); - distances = _mm512_fmadd_ps(d2, d2, distances); - distances = _mm512_fmadd_ps(d3, d3, distances); - - __mmask16 comparison = - _mm512_cmp_ps_mask(distances, min_distances, _CMP_LT_OS); - - min_distances = _mm512_min_ps(distances, min_distances); - min_indices = _mm512_mask_blend_epi32( - comparison, min_indices, current_indices); - - current_indices = - _mm512_add_epi32(current_indices, indices_increment); - - y += 64; - } - - alignas(64) float min_distances_scalar[16]; - alignas(64) uint32_t min_indices_scalar[16]; - _mm512_store_ps(min_distances_scalar, min_distances); - _mm512_store_epi32(min_indices_scalar, min_indices); - - for (size_t j = 0; j < 16; j++) { - if (current_min_distance > min_distances_scalar[j]) { - current_min_distance = min_distances_scalar[j]; - current_min_index = min_indices_scalar[j]; - } - } - } - - if (i < ny) { - __m128 x0 = _mm_loadu_ps(x); - - for (; i < ny; i++) { - __m128 accu = ElementOpL2::op(x0, _mm_loadu_ps(y)); - y += 4; - const float distance = horizontal_sum(accu); - - if (current_min_distance > distance) { - current_min_distance = distance; - current_min_index = i; - } - } - } - - return current_min_index; -} - -size_t fvec_L2sqr_ny_nearest_D8( - float* distances_tmp_buffer, - const float* x, - const float* y, - size_t ny) { - // this implementation does not use distances_tmp_buffer. - - size_t i = 0; - float current_min_distance = HUGE_VALF; - size_t current_min_index = 0; - - const size_t ny16 = ny / 16; - if (ny16 > 0) { - __m512 min_distances = _mm512_set1_ps(HUGE_VALF); - __m512i min_indices = _mm512_set1_epi32(0); - - __m512i current_indices = _mm512_setr_epi32( - 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15); - const __m512i indices_increment = _mm512_set1_epi32(16); - - const __m512 m0 = _mm512_set1_ps(x[0]); - const __m512 m1 = _mm512_set1_ps(x[1]); - const __m512 m2 = _mm512_set1_ps(x[2]); - const __m512 m3 = _mm512_set1_ps(x[3]); - - const __m512 m4 = _mm512_set1_ps(x[4]); - const __m512 m5 = _mm512_set1_ps(x[5]); - const __m512 m6 = _mm512_set1_ps(x[6]); - const __m512 m7 = _mm512_set1_ps(x[7]); - - for (; i < ny16 * 16; i += 16) { - __m512 v0; - __m512 v1; - __m512 v2; - __m512 v3; - __m512 v4; - __m512 v5; - __m512 v6; - __m512 v7; - - transpose_16x8( - _mm512_loadu_ps(y + 0 * 16), - _mm512_loadu_ps(y + 1 * 16), - _mm512_loadu_ps(y + 2 * 16), - _mm512_loadu_ps(y + 3 * 16), - _mm512_loadu_ps(y + 4 * 16), - _mm512_loadu_ps(y + 5 * 16), - _mm512_loadu_ps(y + 6 * 16), - _mm512_loadu_ps(y + 7 * 16), - v0, - v1, - v2, - v3, - v4, - v5, - v6, - v7); - - const __m512 d0 = _mm512_sub_ps(m0, v0); - const __m512 d1 = _mm512_sub_ps(m1, v1); - const __m512 d2 = _mm512_sub_ps(m2, v2); - const __m512 d3 = _mm512_sub_ps(m3, v3); - const __m512 d4 = _mm512_sub_ps(m4, v4); - const __m512 d5 = _mm512_sub_ps(m5, v5); - const __m512 d6 = _mm512_sub_ps(m6, v6); - const __m512 d7 = _mm512_sub_ps(m7, v7); - - __m512 distances = _mm512_mul_ps(d0, d0); - distances = _mm512_fmadd_ps(d1, d1, distances); - distances = _mm512_fmadd_ps(d2, d2, distances); - distances = _mm512_fmadd_ps(d3, d3, distances); - distances = _mm512_fmadd_ps(d4, d4, distances); - distances = _mm512_fmadd_ps(d5, d5, distances); - distances = _mm512_fmadd_ps(d6, d6, distances); - distances = _mm512_fmadd_ps(d7, d7, distances); - - __mmask16 comparison = - _mm512_cmp_ps_mask(distances, min_distances, _CMP_LT_OS); - - min_distances = _mm512_min_ps(distances, min_distances); - min_indices = _mm512_mask_blend_epi32( - comparison, min_indices, current_indices); - - current_indices = - _mm512_add_epi32(current_indices, indices_increment); - - y += 128; - } - - alignas(64) float min_distances_scalar[16]; - alignas(64) uint32_t min_indices_scalar[16]; - _mm512_store_ps(min_distances_scalar, min_distances); - _mm512_store_epi32(min_indices_scalar, min_indices); - - for (size_t j = 0; j < 16; j++) { - if (current_min_distance > min_distances_scalar[j]) { - current_min_distance = min_distances_scalar[j]; - current_min_index = min_indices_scalar[j]; - } - } - } - - if (i < ny) { - __m256 x0 = _mm256_loadu_ps(x); - - for (; i < ny; i++) { - __m256 accu = ElementOpL2::op(x0, _mm256_loadu_ps(y)); - y += 8; - const float distance = horizontal_sum(accu); - - if (current_min_distance > distance) { - current_min_distance = distance; - current_min_index = i; - } - } - } - - return current_min_index; -} - -#elif defined(__AVX2__) - -size_t fvec_L2sqr_ny_nearest_D2( - float* distances_tmp_buffer, - const float* x, - const float* y, - size_t ny) { - // this implementation does not use distances_tmp_buffer. - - // current index being processed - size_t i = 0; - - // min distance and the index of the closest vector so far - float current_min_distance = HUGE_VALF; - size_t current_min_index = 0; - - // process 8 D2-vectors per loop. - const size_t ny8 = ny / 8; - if (ny8 > 0) { - _mm_prefetch((const char*)y, _MM_HINT_T0); - _mm_prefetch((const char*)(y + 16), _MM_HINT_T0); - - // track min distance and the closest vector independently - // for each of 8 AVX2 components. - __m256 min_distances = _mm256_set1_ps(HUGE_VALF); - __m256i min_indices = _mm256_set1_epi32(0); - - __m256i current_indices = _mm256_setr_epi32(0, 1, 2, 3, 4, 5, 6, 7); - const __m256i indices_increment = _mm256_set1_epi32(8); - - // 1 value per register - const __m256 m0 = _mm256_set1_ps(x[0]); - const __m256 m1 = _mm256_set1_ps(x[1]); - - for (; i < ny8 * 8; i += 8) { - _mm_prefetch((const char*)(y + 32), _MM_HINT_T0); - - __m256 v0; - __m256 v1; - - transpose_8x2( - _mm256_loadu_ps(y + 0 * 8), - _mm256_loadu_ps(y + 1 * 8), - v0, - v1); - - // compute differences - const __m256 d0 = _mm256_sub_ps(m0, v0); - const __m256 d1 = _mm256_sub_ps(m1, v1); - - // compute squares of differences - __m256 distances = _mm256_mul_ps(d0, d0); - distances = _mm256_fmadd_ps(d1, d1, distances); - - // compare the new distances to the min distances - // for each of 8 AVX2 components. - __m256 comparison = - _mm256_cmp_ps(min_distances, distances, _CMP_LT_OS); - - // update min distances and indices with closest vectors if needed. - min_distances = _mm256_min_ps(distances, min_distances); - min_indices = _mm256_castps_si256(_mm256_blendv_ps( - _mm256_castsi256_ps(current_indices), - _mm256_castsi256_ps(min_indices), - comparison)); - - // update current indices values. Basically, +8 to each of the - // 8 AVX2 components. - current_indices = - _mm256_add_epi32(current_indices, indices_increment); - - // scroll y forward (8 vectors 2 DIM each). - y += 16; - } - - // dump values and find the minimum distance / minimum index - float min_distances_scalar[8]; - uint32_t min_indices_scalar[8]; - _mm256_storeu_ps(min_distances_scalar, min_distances); - _mm256_storeu_si256((__m256i*)(min_indices_scalar), min_indices); - - for (size_t j = 0; j < 8; j++) { - if (current_min_distance > min_distances_scalar[j]) { - current_min_distance = min_distances_scalar[j]; - current_min_index = min_indices_scalar[j]; - } - } - } - - if (i < ny) { - // process leftovers. - // the following code is not optimal, but it is rarely invoked. - float x0 = x[0]; - float x1 = x[1]; - - for (; i < ny; i++) { - float sub0 = x0 - y[0]; - float sub1 = x1 - y[1]; - float distance = sub0 * sub0 + sub1 * sub1; - - y += 2; - - if (current_min_distance > distance) { - current_min_distance = distance; - current_min_index = i; - } - } - } - - return current_min_index; -} - -size_t fvec_L2sqr_ny_nearest_D4( - float* distances_tmp_buffer, - const float* x, - const float* y, - size_t ny) { - // this implementation does not use distances_tmp_buffer. - - // current index being processed - size_t i = 0; - - // min distance and the index of the closest vector so far - float current_min_distance = HUGE_VALF; - size_t current_min_index = 0; - - // process 8 D4-vectors per loop. - const size_t ny8 = ny / 8; - - if (ny8 > 0) { - // track min distance and the closest vector independently - // for each of 8 AVX2 components. - __m256 min_distances = _mm256_set1_ps(HUGE_VALF); - __m256i min_indices = _mm256_set1_epi32(0); - - __m256i current_indices = _mm256_setr_epi32(0, 1, 2, 3, 4, 5, 6, 7); - const __m256i indices_increment = _mm256_set1_epi32(8); - - // 1 value per register - const __m256 m0 = _mm256_set1_ps(x[0]); - const __m256 m1 = _mm256_set1_ps(x[1]); - const __m256 m2 = _mm256_set1_ps(x[2]); - const __m256 m3 = _mm256_set1_ps(x[3]); - - for (; i < ny8 * 8; i += 8) { - __m256 v0; - __m256 v1; - __m256 v2; - __m256 v3; - - transpose_8x4( - _mm256_loadu_ps(y + 0 * 8), - _mm256_loadu_ps(y + 1 * 8), - _mm256_loadu_ps(y + 2 * 8), - _mm256_loadu_ps(y + 3 * 8), - v0, - v1, - v2, - v3); - - // compute differences - const __m256 d0 = _mm256_sub_ps(m0, v0); - const __m256 d1 = _mm256_sub_ps(m1, v1); - const __m256 d2 = _mm256_sub_ps(m2, v2); - const __m256 d3 = _mm256_sub_ps(m3, v3); - - // compute squares of differences - __m256 distances = _mm256_mul_ps(d0, d0); - distances = _mm256_fmadd_ps(d1, d1, distances); - distances = _mm256_fmadd_ps(d2, d2, distances); - distances = _mm256_fmadd_ps(d3, d3, distances); - - // compare the new distances to the min distances - // for each of 8 AVX2 components. - __m256 comparison = - _mm256_cmp_ps(min_distances, distances, _CMP_LT_OS); - - // update min distances and indices with closest vectors if needed. - min_distances = _mm256_min_ps(distances, min_distances); - min_indices = _mm256_castps_si256(_mm256_blendv_ps( - _mm256_castsi256_ps(current_indices), - _mm256_castsi256_ps(min_indices), - comparison)); - - // update current indices values. Basically, +8 to each of the - // 8 AVX2 components. - current_indices = - _mm256_add_epi32(current_indices, indices_increment); - - // scroll y forward (8 vectors 4 DIM each). - y += 32; - } - - // dump values and find the minimum distance / minimum index - float min_distances_scalar[8]; - uint32_t min_indices_scalar[8]; - _mm256_storeu_ps(min_distances_scalar, min_distances); - _mm256_storeu_si256((__m256i*)(min_indices_scalar), min_indices); - - for (size_t j = 0; j < 8; j++) { - if (current_min_distance > min_distances_scalar[j]) { - current_min_distance = min_distances_scalar[j]; - current_min_index = min_indices_scalar[j]; - } - } - } - - if (i < ny) { - // process leftovers - __m128 x0 = _mm_loadu_ps(x); - - for (; i < ny; i++) { - __m128 accu = ElementOpL2::op(x0, _mm_loadu_ps(y)); - y += 4; - const float distance = horizontal_sum(accu); - - if (current_min_distance > distance) { - current_min_distance = distance; - current_min_index = i; - } - } - } - - return current_min_index; -} - -size_t fvec_L2sqr_ny_nearest_D8( - float* distances_tmp_buffer, - const float* x, - const float* y, - size_t ny) { - // this implementation does not use distances_tmp_buffer. - - // current index being processed - size_t i = 0; - - // min distance and the index of the closest vector so far - float current_min_distance = HUGE_VALF; - size_t current_min_index = 0; - - // process 8 D8-vectors per loop. - const size_t ny8 = ny / 8; - if (ny8 > 0) { - // track min distance and the closest vector independently - // for each of 8 AVX2 components. - __m256 min_distances = _mm256_set1_ps(HUGE_VALF); - __m256i min_indices = _mm256_set1_epi32(0); - - __m256i current_indices = _mm256_setr_epi32(0, 1, 2, 3, 4, 5, 6, 7); - const __m256i indices_increment = _mm256_set1_epi32(8); - - // 1 value per register - const __m256 m0 = _mm256_set1_ps(x[0]); - const __m256 m1 = _mm256_set1_ps(x[1]); - const __m256 m2 = _mm256_set1_ps(x[2]); - const __m256 m3 = _mm256_set1_ps(x[3]); - - const __m256 m4 = _mm256_set1_ps(x[4]); - const __m256 m5 = _mm256_set1_ps(x[5]); - const __m256 m6 = _mm256_set1_ps(x[6]); - const __m256 m7 = _mm256_set1_ps(x[7]); - - for (; i < ny8 * 8; i += 8) { - __m256 v0; - __m256 v1; - __m256 v2; - __m256 v3; - __m256 v4; - __m256 v5; - __m256 v6; - __m256 v7; - - transpose_8x8( - _mm256_loadu_ps(y + 0 * 8), - _mm256_loadu_ps(y + 1 * 8), - _mm256_loadu_ps(y + 2 * 8), - _mm256_loadu_ps(y + 3 * 8), - _mm256_loadu_ps(y + 4 * 8), - _mm256_loadu_ps(y + 5 * 8), - _mm256_loadu_ps(y + 6 * 8), - _mm256_loadu_ps(y + 7 * 8), - v0, - v1, - v2, - v3, - v4, - v5, - v6, - v7); - - // compute differences - const __m256 d0 = _mm256_sub_ps(m0, v0); - const __m256 d1 = _mm256_sub_ps(m1, v1); - const __m256 d2 = _mm256_sub_ps(m2, v2); - const __m256 d3 = _mm256_sub_ps(m3, v3); - const __m256 d4 = _mm256_sub_ps(m4, v4); - const __m256 d5 = _mm256_sub_ps(m5, v5); - const __m256 d6 = _mm256_sub_ps(m6, v6); - const __m256 d7 = _mm256_sub_ps(m7, v7); - - // compute squares of differences - __m256 distances = _mm256_mul_ps(d0, d0); - distances = _mm256_fmadd_ps(d1, d1, distances); - distances = _mm256_fmadd_ps(d2, d2, distances); - distances = _mm256_fmadd_ps(d3, d3, distances); - distances = _mm256_fmadd_ps(d4, d4, distances); - distances = _mm256_fmadd_ps(d5, d5, distances); - distances = _mm256_fmadd_ps(d6, d6, distances); - distances = _mm256_fmadd_ps(d7, d7, distances); - - // compare the new distances to the min distances - // for each of 8 AVX2 components. - __m256 comparison = - _mm256_cmp_ps(min_distances, distances, _CMP_LT_OS); - - // update min distances and indices with closest vectors if needed. - min_distances = _mm256_min_ps(distances, min_distances); - min_indices = _mm256_castps_si256(_mm256_blendv_ps( - _mm256_castsi256_ps(current_indices), - _mm256_castsi256_ps(min_indices), - comparison)); - - // update current indices values. Basically, +8 to each of the - // 8 AVX2 components. - current_indices = - _mm256_add_epi32(current_indices, indices_increment); - - // scroll y forward (8 vectors 8 DIM each). - y += 64; - } - - // dump values and find the minimum distance / minimum index - float min_distances_scalar[8]; - uint32_t min_indices_scalar[8]; - _mm256_storeu_ps(min_distances_scalar, min_distances); - _mm256_storeu_si256((__m256i*)(min_indices_scalar), min_indices); - - for (size_t j = 0; j < 8; j++) { - if (current_min_distance > min_distances_scalar[j]) { - current_min_distance = min_distances_scalar[j]; - current_min_index = min_indices_scalar[j]; - } - } - } - - if (i < ny) { - // process leftovers - __m256 x0 = _mm256_loadu_ps(x); - - for (; i < ny; i++) { - __m256 accu = ElementOpL2::op(x0, _mm256_loadu_ps(y)); - y += 8; - const float distance = horizontal_sum(accu); - - if (current_min_distance > distance) { - current_min_distance = distance; - current_min_index = i; - } - } - } - - return current_min_index; -} - -#else -size_t fvec_L2sqr_ny_nearest_D2( - float* distances_tmp_buffer, - const float* x, - const float* y, - size_t ny) { - return fvec_L2sqr_ny_nearest_ref(distances_tmp_buffer, x, y, 2, ny); -} - -size_t fvec_L2sqr_ny_nearest_D4( - float* distances_tmp_buffer, - const float* x, - const float* y, - size_t ny) { - return fvec_L2sqr_ny_nearest_ref(distances_tmp_buffer, x, y, 4, ny); -} - -size_t fvec_L2sqr_ny_nearest_D8( - float* distances_tmp_buffer, - const float* x, - const float* y, - size_t ny) { - return fvec_L2sqr_ny_nearest_ref(distances_tmp_buffer, x, y, 8, ny); -} -#endif - -size_t fvec_L2sqr_ny_nearest( - float* distances_tmp_buffer, - const float* x, - const float* y, - size_t d, - size_t ny) { - // optimized for a few special cases -#define DISPATCH(dval) \ - case dval: \ - return fvec_L2sqr_ny_nearest_D##dval(distances_tmp_buffer, x, y, ny); - - switch (d) { - DISPATCH(2) - DISPATCH(4) - DISPATCH(8) - default: - return fvec_L2sqr_ny_nearest_ref(distances_tmp_buffer, x, y, d, ny); - } -#undef DISPATCH -} - -#if defined(__AVX512F__) - -template -size_t fvec_L2sqr_ny_nearest_y_transposed_D( - float* distances_tmp_buffer, - const float* x, - const float* y, - const float* y_sqlen, - const size_t d_offset, - size_t ny) { - // This implementation does not use distances_tmp_buffer. - - // Current index being processed - size_t i = 0; - - // Min distance and the index of the closest vector so far - float current_min_distance = HUGE_VALF; - size_t current_min_index = 0; - - // Process 16 vectors per loop - const size_t ny16 = ny / 16; - - if (ny16 > 0) { - // Track min distance and the closest vector independently - // for each of 16 AVX-512 components. - __m512 min_distances = _mm512_set1_ps(HUGE_VALF); - __m512i min_indices = _mm512_set1_epi32(0); - - __m512i current_indices = _mm512_setr_epi32( - 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15); - const __m512i indices_increment = _mm512_set1_epi32(16); - - // m[i] = (2 * x[i], ... 2 * x[i]) - __m512 m[DIM]; - for (size_t j = 0; j < DIM; j++) { - m[j] = _mm512_set1_ps(x[j]); - m[j] = _mm512_add_ps(m[j], m[j]); - } - - for (; i < ny16 * 16; i += 16) { - // Compute dot products - const __m512 v0 = _mm512_loadu_ps(y + 0 * d_offset); - __m512 dp = _mm512_mul_ps(m[0], v0); - for (size_t j = 1; j < DIM; j++) { - const __m512 vj = _mm512_loadu_ps(y + j * d_offset); - dp = _mm512_fmadd_ps(m[j], vj, dp); - } - - // Compute y^2 - (2 * x, y), which is sufficient for looking for the - // lowest distance. - // x^2 is the constant that can be avoided. - const __m512 distances = - _mm512_sub_ps(_mm512_loadu_ps(y_sqlen), dp); - - // Compare the new distances to the min distances - __mmask16 comparison = - _mm512_cmp_ps_mask(min_distances, distances, _CMP_LT_OS); - - // Update min distances and indices with closest vectors if needed - min_distances = - _mm512_mask_blend_ps(comparison, distances, min_distances); - min_indices = _mm512_castps_si512(_mm512_mask_blend_ps( - comparison, - _mm512_castsi512_ps(current_indices), - _mm512_castsi512_ps(min_indices))); - - // Update current indices values. Basically, +16 to each of the 16 - // AVX-512 components. - current_indices = - _mm512_add_epi32(current_indices, indices_increment); - - // Scroll y and y_sqlen forward. - y += 16; - y_sqlen += 16; - } - - // Dump values and find the minimum distance / minimum index - float min_distances_scalar[16]; - uint32_t min_indices_scalar[16]; - _mm512_storeu_ps(min_distances_scalar, min_distances); - _mm512_storeu_si512((__m512i*)(min_indices_scalar), min_indices); - - for (size_t j = 0; j < 16; j++) { - if (current_min_distance > min_distances_scalar[j]) { - current_min_distance = min_distances_scalar[j]; - current_min_index = min_indices_scalar[j]; - } - } - } - - if (i < ny) { - // Process leftovers - for (; i < ny; i++) { - float dp = 0; - for (size_t j = 0; j < DIM; j++) { - dp += x[j] * y[j * d_offset]; - } - - // Compute y^2 - 2 * (x, y), which is sufficient for looking for the - // lowest distance. - const float distance = y_sqlen[0] - 2 * dp; - - if (current_min_distance > distance) { - current_min_distance = distance; - current_min_index = i; - } - - y += 1; - y_sqlen += 1; - } - } - - return current_min_index; -} - -#elif defined(__AVX2__) - -template -size_t fvec_L2sqr_ny_nearest_y_transposed_D( - float* distances_tmp_buffer, - const float* x, - const float* y, - const float* y_sqlen, - const size_t d_offset, - size_t ny) { - // this implementation does not use distances_tmp_buffer. - - // current index being processed - size_t i = 0; - - // min distance and the index of the closest vector so far - float current_min_distance = HUGE_VALF; - size_t current_min_index = 0; - - // process 8 vectors per loop. - const size_t ny8 = ny / 8; - - if (ny8 > 0) { - // track min distance and the closest vector independently - // for each of 8 AVX2 components. - __m256 min_distances = _mm256_set1_ps(HUGE_VALF); - __m256i min_indices = _mm256_set1_epi32(0); - - __m256i current_indices = _mm256_setr_epi32(0, 1, 2, 3, 4, 5, 6, 7); - const __m256i indices_increment = _mm256_set1_epi32(8); - - // m[i] = (2 * x[i], ... 2 * x[i]) - __m256 m[DIM]; - for (size_t j = 0; j < DIM; j++) { - m[j] = _mm256_set1_ps(x[j]); - m[j] = _mm256_add_ps(m[j], m[j]); - } - - for (; i < ny8 * 8; i += 8) { - // collect dim 0 for 8 D4-vectors. - const __m256 v0 = _mm256_loadu_ps(y + 0 * d_offset); - // compute dot products - __m256 dp = _mm256_mul_ps(m[0], v0); - - for (size_t j = 1; j < DIM; j++) { - // collect dim j for 8 D4-vectors. - const __m256 vj = _mm256_loadu_ps(y + j * d_offset); - dp = _mm256_fmadd_ps(m[j], vj, dp); - } - - // compute y^2 - (2 * x, y), which is sufficient for looking for the - // lowest distance. - // x^2 is the constant that can be avoided. - const __m256 distances = - _mm256_sub_ps(_mm256_loadu_ps(y_sqlen), dp); - - // compare the new distances to the min distances - // for each of 8 AVX2 components. - const __m256 comparison = - _mm256_cmp_ps(min_distances, distances, _CMP_LT_OS); - - // update min distances and indices with closest vectors if needed. - min_distances = - _mm256_blendv_ps(distances, min_distances, comparison); - min_indices = _mm256_castps_si256(_mm256_blendv_ps( - _mm256_castsi256_ps(current_indices), - _mm256_castsi256_ps(min_indices), - comparison)); - - // update current indices values. Basically, +8 to each of the - // 8 AVX2 components. - current_indices = - _mm256_add_epi32(current_indices, indices_increment); - - // scroll y and y_sqlen forward. - y += 8; - y_sqlen += 8; - } - - // dump values and find the minimum distance / minimum index - float min_distances_scalar[8]; - uint32_t min_indices_scalar[8]; - _mm256_storeu_ps(min_distances_scalar, min_distances); - _mm256_storeu_si256((__m256i*)(min_indices_scalar), min_indices); - - for (size_t j = 0; j < 8; j++) { - if (current_min_distance > min_distances_scalar[j]) { - current_min_distance = min_distances_scalar[j]; - current_min_index = min_indices_scalar[j]; - } - } - } - - if (i < ny) { - // process leftovers - for (; i < ny; i++) { - float dp = 0; - for (size_t j = 0; j < DIM; j++) { - dp += x[j] * y[j * d_offset]; - } - - // compute y^2 - 2 * (x, y), which is sufficient for looking for the - // lowest distance. - const float distance = y_sqlen[0] - 2 * dp; - - if (current_min_distance > distance) { - current_min_distance = distance; - current_min_index = i; - } - - y += 1; - y_sqlen += 1; - } - } - - return current_min_index; -} - -#endif - -size_t fvec_L2sqr_ny_nearest_y_transposed( - float* distances_tmp_buffer, - const float* x, - const float* y, - const float* y_sqlen, - size_t d, - size_t d_offset, - size_t ny) { - // optimized for a few special cases -#ifdef __AVX2__ -#define DISPATCH(dval) \ - case dval: \ - return fvec_L2sqr_ny_nearest_y_transposed_D( \ - distances_tmp_buffer, x, y, y_sqlen, d_offset, ny); - - switch (d) { - DISPATCH(1) - DISPATCH(2) - DISPATCH(4) - DISPATCH(8) - default: - return fvec_L2sqr_ny_nearest_y_transposed_ref( - distances_tmp_buffer, x, y, y_sqlen, d, d_offset, ny); - } -#undef DISPATCH -#else - // non-AVX2 case - return fvec_L2sqr_ny_nearest_y_transposed_ref( - distances_tmp_buffer, x, y, y_sqlen, d, d_offset, ny); -#endif -} - -#endif - -#ifdef USE_AVX - -float fvec_L1(const float* x, const float* y, size_t d) { - __m256 msum1 = _mm256_setzero_ps(); - // signmask used for absolute value - __m256 signmask = _mm256_castsi256_ps(_mm256_set1_epi32(0x7fffffffUL)); - - while (d >= 8) { - __m256 mx = _mm256_loadu_ps(x); - x += 8; - __m256 my = _mm256_loadu_ps(y); - y += 8; - // subtract - const __m256 a_m_b = _mm256_sub_ps(mx, my); - // find sum of absolute value of distances (manhattan distance) - msum1 = _mm256_add_ps(msum1, _mm256_and_ps(signmask, a_m_b)); - d -= 8; - } - - __m128 msum2 = _mm256_extractf128_ps(msum1, 1); - msum2 = _mm_add_ps(msum2, _mm256_extractf128_ps(msum1, 0)); - __m128 signmask2 = _mm_castsi128_ps(_mm_set1_epi32(0x7fffffffUL)); - - if (d >= 4) { - __m128 mx = _mm_loadu_ps(x); - x += 4; - __m128 my = _mm_loadu_ps(y); - y += 4; - const __m128 a_m_b = _mm_sub_ps(mx, my); - msum2 = _mm_add_ps(msum2, _mm_and_ps(signmask2, a_m_b)); - d -= 4; - } - - if (d > 0) { - __m128 mx = masked_read(d, x); - __m128 my = masked_read(d, y); - __m128 a_m_b = _mm_sub_ps(mx, my); - msum2 = _mm_add_ps(msum2, _mm_and_ps(signmask2, a_m_b)); - } - - msum2 = _mm_hadd_ps(msum2, msum2); - msum2 = _mm_hadd_ps(msum2, msum2); - return _mm_cvtss_f32(msum2); -} - -float fvec_Linf(const float* x, const float* y, size_t d) { - __m256 msum1 = _mm256_setzero_ps(); - // signmask used for absolute value - __m256 signmask = _mm256_castsi256_ps(_mm256_set1_epi32(0x7fffffffUL)); - - while (d >= 8) { - __m256 mx = _mm256_loadu_ps(x); - x += 8; - __m256 my = _mm256_loadu_ps(y); - y += 8; - // subtract - const __m256 a_m_b = _mm256_sub_ps(mx, my); - // find max of absolute value of distances (chebyshev distance) - msum1 = _mm256_max_ps(msum1, _mm256_and_ps(signmask, a_m_b)); - d -= 8; - } - - __m128 msum2 = _mm256_extractf128_ps(msum1, 1); - msum2 = _mm_max_ps(msum2, _mm256_extractf128_ps(msum1, 0)); - __m128 signmask2 = _mm_castsi128_ps(_mm_set1_epi32(0x7fffffffUL)); - - if (d >= 4) { - __m128 mx = _mm_loadu_ps(x); - x += 4; - __m128 my = _mm_loadu_ps(y); - y += 4; - const __m128 a_m_b = _mm_sub_ps(mx, my); - msum2 = _mm_max_ps(msum2, _mm_and_ps(signmask2, a_m_b)); - d -= 4; - } - - if (d > 0) { - __m128 mx = masked_read(d, x); - __m128 my = masked_read(d, y); - __m128 a_m_b = _mm_sub_ps(mx, my); - msum2 = _mm_max_ps(msum2, _mm_and_ps(signmask2, a_m_b)); - } - - msum2 = _mm_max_ps(_mm_movehl_ps(msum2, msum2), msum2); - msum2 = _mm_max_ps(msum2, _mm_shuffle_ps(msum2, msum2, 1)); - return _mm_cvtss_f32(msum2); -} - -#elif defined(__SSE3__) // But not AVX - -float fvec_L1(const float* x, const float* y, size_t d) { - return fvec_L1_ref(x, y, d); -} - -float fvec_Linf(const float* x, const float* y, size_t d) { - return fvec_Linf_ref(x, y, d); -} - -#elif defined(__ARM_FEATURE_SVE) - -struct ElementOpIP { - static svfloat32_t op(svbool_t pg, svfloat32_t x, svfloat32_t y) { - return svmul_f32_x(pg, x, y); - } - static svfloat32_t merge( - svbool_t pg, - svfloat32_t z, - svfloat32_t x, - svfloat32_t y) { - return svmla_f32_x(pg, z, x, y); - } -}; - -template -void fvec_op_ny_sve_d1(float* dis, const float* x, const float* y, size_t ny) { - const size_t lanes = svcntw(); - const size_t lanes2 = lanes * 2; - const size_t lanes3 = lanes * 3; - const size_t lanes4 = lanes * 4; - const svbool_t pg = svptrue_b32(); - const svfloat32_t x0 = svdup_n_f32(x[0]); - size_t i = 0; - for (; i + lanes4 < ny; i += lanes4) { - svfloat32_t y0 = svld1_f32(pg, y); - svfloat32_t y1 = svld1_f32(pg, y + lanes); - svfloat32_t y2 = svld1_f32(pg, y + lanes2); - svfloat32_t y3 = svld1_f32(pg, y + lanes3); - y0 = ElementOp::op(pg, x0, y0); - y1 = ElementOp::op(pg, x0, y1); - y2 = ElementOp::op(pg, x0, y2); - y3 = ElementOp::op(pg, x0, y3); - svst1_f32(pg, dis, y0); - svst1_f32(pg, dis + lanes, y1); - svst1_f32(pg, dis + lanes2, y2); - svst1_f32(pg, dis + lanes3, y3); - y += lanes4; - dis += lanes4; - } - const svbool_t pg0 = svwhilelt_b32_u64(i, ny); - const svbool_t pg1 = svwhilelt_b32_u64(i + lanes, ny); - const svbool_t pg2 = svwhilelt_b32_u64(i + lanes2, ny); - const svbool_t pg3 = svwhilelt_b32_u64(i + lanes3, ny); - svfloat32_t y0 = svld1_f32(pg0, y); - svfloat32_t y1 = svld1_f32(pg1, y + lanes); - svfloat32_t y2 = svld1_f32(pg2, y + lanes2); - svfloat32_t y3 = svld1_f32(pg3, y + lanes3); - y0 = ElementOp::op(pg0, x0, y0); - y1 = ElementOp::op(pg1, x0, y1); - y2 = ElementOp::op(pg2, x0, y2); - y3 = ElementOp::op(pg3, x0, y3); - svst1_f32(pg0, dis, y0); - svst1_f32(pg1, dis + lanes, y1); - svst1_f32(pg2, dis + lanes2, y2); - svst1_f32(pg3, dis + lanes3, y3); -} - -template -void fvec_op_ny_sve_d2(float* dis, const float* x, const float* y, size_t ny) { - const size_t lanes = svcntw(); - const size_t lanes2 = lanes * 2; - const size_t lanes4 = lanes * 4; - const svbool_t pg = svptrue_b32(); - const svfloat32_t x0 = svdup_n_f32(x[0]); - const svfloat32_t x1 = svdup_n_f32(x[1]); - size_t i = 0; - for (; i + lanes2 < ny; i += lanes2) { - const svfloat32x2_t y0 = svld2_f32(pg, y); - const svfloat32x2_t y1 = svld2_f32(pg, y + lanes2); - svfloat32_t y00 = svget2_f32(y0, 0); - const svfloat32_t y01 = svget2_f32(y0, 1); - svfloat32_t y10 = svget2_f32(y1, 0); - const svfloat32_t y11 = svget2_f32(y1, 1); - y00 = ElementOp::op(pg, x0, y00); - y10 = ElementOp::op(pg, x0, y10); - y00 = ElementOp::merge(pg, y00, x1, y01); - y10 = ElementOp::merge(pg, y10, x1, y11); - svst1_f32(pg, dis, y00); - svst1_f32(pg, dis + lanes, y10); - y += lanes4; - dis += lanes2; - } - const svbool_t pg0 = svwhilelt_b32_u64(i, ny); - const svbool_t pg1 = svwhilelt_b32_u64(i + lanes, ny); - const svfloat32x2_t y0 = svld2_f32(pg0, y); - const svfloat32x2_t y1 = svld2_f32(pg1, y + lanes2); - svfloat32_t y00 = svget2_f32(y0, 0); - const svfloat32_t y01 = svget2_f32(y0, 1); - svfloat32_t y10 = svget2_f32(y1, 0); - const svfloat32_t y11 = svget2_f32(y1, 1); - y00 = ElementOp::op(pg0, x0, y00); - y10 = ElementOp::op(pg1, x0, y10); - y00 = ElementOp::merge(pg0, y00, x1, y01); - y10 = ElementOp::merge(pg1, y10, x1, y11); - svst1_f32(pg0, dis, y00); - svst1_f32(pg1, dis + lanes, y10); -} - -template -void fvec_op_ny_sve_d4(float* dis, const float* x, const float* y, size_t ny) { - const size_t lanes = svcntw(); - const size_t lanes4 = lanes * 4; - const svbool_t pg = svptrue_b32(); - const svfloat32_t x0 = svdup_n_f32(x[0]); - const svfloat32_t x1 = svdup_n_f32(x[1]); - const svfloat32_t x2 = svdup_n_f32(x[2]); - const svfloat32_t x3 = svdup_n_f32(x[3]); - size_t i = 0; - for (; i + lanes < ny; i += lanes) { - const svfloat32x4_t y0 = svld4_f32(pg, y); - svfloat32_t y00 = svget4_f32(y0, 0); - const svfloat32_t y01 = svget4_f32(y0, 1); - svfloat32_t y02 = svget4_f32(y0, 2); - const svfloat32_t y03 = svget4_f32(y0, 3); - y00 = ElementOp::op(pg, x0, y00); - y02 = ElementOp::op(pg, x2, y02); - y00 = ElementOp::merge(pg, y00, x1, y01); - y02 = ElementOp::merge(pg, y02, x3, y03); - y00 = svadd_f32_x(pg, y00, y02); - svst1_f32(pg, dis, y00); - y += lanes4; - dis += lanes; - } - const svbool_t pg0 = svwhilelt_b32_u64(i, ny); - const svfloat32x4_t y0 = svld4_f32(pg0, y); - svfloat32_t y00 = svget4_f32(y0, 0); - const svfloat32_t y01 = svget4_f32(y0, 1); - svfloat32_t y02 = svget4_f32(y0, 2); - const svfloat32_t y03 = svget4_f32(y0, 3); - y00 = ElementOp::op(pg0, x0, y00); - y02 = ElementOp::op(pg0, x2, y02); - y00 = ElementOp::merge(pg0, y00, x1, y01); - y02 = ElementOp::merge(pg0, y02, x3, y03); - y00 = svadd_f32_x(pg0, y00, y02); - svst1_f32(pg0, dis, y00); -} - -template -void fvec_op_ny_sve_d8(float* dis, const float* x, const float* y, size_t ny) { - const size_t lanes = svcntw(); - const size_t lanes4 = lanes * 4; - const size_t lanes8 = lanes * 8; - const svbool_t pg = svptrue_b32(); - const svfloat32_t x0 = svdup_n_f32(x[0]); - const svfloat32_t x1 = svdup_n_f32(x[1]); - const svfloat32_t x2 = svdup_n_f32(x[2]); - const svfloat32_t x3 = svdup_n_f32(x[3]); - const svfloat32_t x4 = svdup_n_f32(x[4]); - const svfloat32_t x5 = svdup_n_f32(x[5]); - const svfloat32_t x6 = svdup_n_f32(x[6]); - const svfloat32_t x7 = svdup_n_f32(x[7]); - size_t i = 0; - for (; i + lanes < ny; i += lanes) { - const svfloat32x4_t ya = svld4_f32(pg, y); - const svfloat32x4_t yb = svld4_f32(pg, y + lanes4); - const svfloat32_t ya0 = svget4_f32(ya, 0); - const svfloat32_t ya1 = svget4_f32(ya, 1); - const svfloat32_t ya2 = svget4_f32(ya, 2); - const svfloat32_t ya3 = svget4_f32(ya, 3); - const svfloat32_t yb0 = svget4_f32(yb, 0); - const svfloat32_t yb1 = svget4_f32(yb, 1); - const svfloat32_t yb2 = svget4_f32(yb, 2); - const svfloat32_t yb3 = svget4_f32(yb, 3); - svfloat32_t y0 = svuzp1(ya0, yb0); - const svfloat32_t y1 = svuzp1(ya1, yb1); - svfloat32_t y2 = svuzp1(ya2, yb2); - const svfloat32_t y3 = svuzp1(ya3, yb3); - svfloat32_t y4 = svuzp2(ya0, yb0); - const svfloat32_t y5 = svuzp2(ya1, yb1); - svfloat32_t y6 = svuzp2(ya2, yb2); - const svfloat32_t y7 = svuzp2(ya3, yb3); - y0 = ElementOp::op(pg, x0, y0); - y2 = ElementOp::op(pg, x2, y2); - y4 = ElementOp::op(pg, x4, y4); - y6 = ElementOp::op(pg, x6, y6); - y0 = ElementOp::merge(pg, y0, x1, y1); - y2 = ElementOp::merge(pg, y2, x3, y3); - y4 = ElementOp::merge(pg, y4, x5, y5); - y6 = ElementOp::merge(pg, y6, x7, y7); - y0 = svadd_f32_x(pg, y0, y2); - y4 = svadd_f32_x(pg, y4, y6); - y0 = svadd_f32_x(pg, y0, y4); - svst1_f32(pg, dis, y0); - y += lanes8; - dis += lanes; - } - const svbool_t pg0 = svwhilelt_b32_u64(i, ny); - const svbool_t pga = svwhilelt_b32_u64(i * 2, ny * 2); - const svbool_t pgb = svwhilelt_b32_u64(i * 2 + lanes, ny * 2); - const svfloat32x4_t ya = svld4_f32(pga, y); - const svfloat32x4_t yb = svld4_f32(pgb, y + lanes4); - const svfloat32_t ya0 = svget4_f32(ya, 0); - const svfloat32_t ya1 = svget4_f32(ya, 1); - const svfloat32_t ya2 = svget4_f32(ya, 2); - const svfloat32_t ya3 = svget4_f32(ya, 3); - const svfloat32_t yb0 = svget4_f32(yb, 0); - const svfloat32_t yb1 = svget4_f32(yb, 1); - const svfloat32_t yb2 = svget4_f32(yb, 2); - const svfloat32_t yb3 = svget4_f32(yb, 3); - svfloat32_t y0 = svuzp1(ya0, yb0); - const svfloat32_t y1 = svuzp1(ya1, yb1); - svfloat32_t y2 = svuzp1(ya2, yb2); - const svfloat32_t y3 = svuzp1(ya3, yb3); - svfloat32_t y4 = svuzp2(ya0, yb0); - const svfloat32_t y5 = svuzp2(ya1, yb1); - svfloat32_t y6 = svuzp2(ya2, yb2); - const svfloat32_t y7 = svuzp2(ya3, yb3); - y0 = ElementOp::op(pg0, x0, y0); - y2 = ElementOp::op(pg0, x2, y2); - y4 = ElementOp::op(pg0, x4, y4); - y6 = ElementOp::op(pg0, x6, y6); - y0 = ElementOp::merge(pg0, y0, x1, y1); - y2 = ElementOp::merge(pg0, y2, x3, y3); - y4 = ElementOp::merge(pg0, y4, x5, y5); - y6 = ElementOp::merge(pg0, y6, x7, y7); - y0 = svadd_f32_x(pg0, y0, y2); - y4 = svadd_f32_x(pg0, y4, y6); - y0 = svadd_f32_x(pg0, y0, y4); - svst1_f32(pg0, dis, y0); - y += lanes8; - dis += lanes; -} - -template -void fvec_op_ny_sve_lanes1( - float* dis, - const float* x, - const float* y, - size_t ny) { - const size_t lanes = svcntw(); - const size_t lanes2 = lanes * 2; - const size_t lanes3 = lanes * 3; - const size_t lanes4 = lanes * 4; - const svbool_t pg = svptrue_b32(); - const svfloat32_t x0 = svld1_f32(pg, x); - size_t i = 0; - for (; i + 3 < ny; i += 4) { - svfloat32_t y0 = svld1_f32(pg, y); - svfloat32_t y1 = svld1_f32(pg, y + lanes); - svfloat32_t y2 = svld1_f32(pg, y + lanes2); - svfloat32_t y3 = svld1_f32(pg, y + lanes3); - y += lanes4; - y0 = ElementOp::op(pg, x0, y0); - y1 = ElementOp::op(pg, x0, y1); - y2 = ElementOp::op(pg, x0, y2); - y3 = ElementOp::op(pg, x0, y3); - dis[i] = svaddv_f32(pg, y0); - dis[i + 1] = svaddv_f32(pg, y1); - dis[i + 2] = svaddv_f32(pg, y2); - dis[i + 3] = svaddv_f32(pg, y3); - } - for (; i < ny; ++i) { - svfloat32_t y0 = svld1_f32(pg, y); - y += lanes; - y0 = ElementOp::op(pg, x0, y0); - dis[i] = svaddv_f32(pg, y0); - } -} - -template -void fvec_op_ny_sve_lanes2( - float* dis, - const float* x, - const float* y, - size_t ny) { - const size_t lanes = svcntw(); - const size_t lanes2 = lanes * 2; - const size_t lanes3 = lanes * 3; - const size_t lanes4 = lanes * 4; - const svbool_t pg = svptrue_b32(); - const svfloat32_t x0 = svld1_f32(pg, x); - const svfloat32_t x1 = svld1_f32(pg, x + lanes); - size_t i = 0; - for (; i + 1 < ny; i += 2) { - svfloat32_t y00 = svld1_f32(pg, y); - const svfloat32_t y01 = svld1_f32(pg, y + lanes); - svfloat32_t y10 = svld1_f32(pg, y + lanes2); - const svfloat32_t y11 = svld1_f32(pg, y + lanes3); - y += lanes4; - y00 = ElementOp::op(pg, x0, y00); - y10 = ElementOp::op(pg, x0, y10); - y00 = ElementOp::merge(pg, y00, x1, y01); - y10 = ElementOp::merge(pg, y10, x1, y11); - dis[i] = svaddv_f32(pg, y00); - dis[i + 1] = svaddv_f32(pg, y10); - } - if (i < ny) { - svfloat32_t y0 = svld1_f32(pg, y); - const svfloat32_t y1 = svld1_f32(pg, y + lanes); - y0 = ElementOp::op(pg, x0, y0); - y0 = ElementOp::merge(pg, y0, x1, y1); - dis[i] = svaddv_f32(pg, y0); - } -} - -template -void fvec_op_ny_sve_lanes3( - float* dis, - const float* x, - const float* y, - size_t ny) { - const size_t lanes = svcntw(); - const size_t lanes2 = lanes * 2; - const size_t lanes3 = lanes * 3; - const svbool_t pg = svptrue_b32(); - const svfloat32_t x0 = svld1_f32(pg, x); - const svfloat32_t x1 = svld1_f32(pg, x + lanes); - const svfloat32_t x2 = svld1_f32(pg, x + lanes2); - for (size_t i = 0; i < ny; ++i) { - svfloat32_t y0 = svld1_f32(pg, y); - const svfloat32_t y1 = svld1_f32(pg, y + lanes); - svfloat32_t y2 = svld1_f32(pg, y + lanes2); - y += lanes3; - y0 = ElementOp::op(pg, x0, y0); - y0 = ElementOp::merge(pg, y0, x1, y1); - y0 = ElementOp::merge(pg, y0, x2, y2); - dis[i] = svaddv_f32(pg, y0); - } -} - -template -void fvec_op_ny_sve_lanes4( - float* dis, - const float* x, - const float* y, - size_t ny) { - const size_t lanes = svcntw(); - const size_t lanes2 = lanes * 2; - const size_t lanes3 = lanes * 3; - const size_t lanes4 = lanes * 4; - const svbool_t pg = svptrue_b32(); - const svfloat32_t x0 = svld1_f32(pg, x); - const svfloat32_t x1 = svld1_f32(pg, x + lanes); - const svfloat32_t x2 = svld1_f32(pg, x + lanes2); - const svfloat32_t x3 = svld1_f32(pg, x + lanes3); - for (size_t i = 0; i < ny; ++i) { - svfloat32_t y0 = svld1_f32(pg, y); - const svfloat32_t y1 = svld1_f32(pg, y + lanes); - svfloat32_t y2 = svld1_f32(pg, y + lanes2); - const svfloat32_t y3 = svld1_f32(pg, y + lanes3); - y += lanes4; - y0 = ElementOp::op(pg, x0, y0); - y2 = ElementOp::op(pg, x2, y2); - y0 = ElementOp::merge(pg, y0, x1, y1); - y2 = ElementOp::merge(pg, y2, x3, y3); - y0 = svadd_f32_x(pg, y0, y2); - dis[i] = svaddv_f32(pg, y0); - } -} - -void fvec_L2sqr_ny( - float* dis, - const float* x, - const float* y, - size_t d, - size_t ny) { - fvec_L2sqr_ny_ref(dis, x, y, d, ny); -} - -void fvec_L2sqr_ny_transposed( - float* dis, - const float* x, - const float* y, - const float* y_sqlen, - size_t d, - size_t d_offset, - size_t ny) { - return fvec_L2sqr_ny_y_transposed_ref(dis, x, y, y_sqlen, d, d_offset, ny); -} - -size_t fvec_L2sqr_ny_nearest( - float* distances_tmp_buffer, - const float* x, - const float* y, - size_t d, - size_t ny) { - return fvec_L2sqr_ny_nearest_ref(distances_tmp_buffer, x, y, d, ny); -} - -size_t fvec_L2sqr_ny_nearest_y_transposed( - float* distances_tmp_buffer, - const float* x, - const float* y, - const float* y_sqlen, - size_t d, - size_t d_offset, - size_t ny) { - return fvec_L2sqr_ny_nearest_y_transposed_ref( - distances_tmp_buffer, x, y, y_sqlen, d, d_offset, ny); -} - -float fvec_L1(const float* x, const float* y, size_t d) { - return fvec_L1_ref(x, y, d); -} - -float fvec_Linf(const float* x, const float* y, size_t d) { - return fvec_Linf_ref(x, y, d); -} - -void fvec_inner_products_ny( - float* dis, - const float* x, - const float* y, - size_t d, - size_t ny) { - const size_t lanes = svcntw(); - switch (d) { - case 1: - fvec_op_ny_sve_d1(dis, x, y, ny); - break; - case 2: - fvec_op_ny_sve_d2(dis, x, y, ny); - break; - case 4: - fvec_op_ny_sve_d4(dis, x, y, ny); - break; - case 8: - fvec_op_ny_sve_d8(dis, x, y, ny); - break; - default: - if (d == lanes) - fvec_op_ny_sve_lanes1(dis, x, y, ny); - else if (d == lanes * 2) - fvec_op_ny_sve_lanes2(dis, x, y, ny); - else if (d == lanes * 3) - fvec_op_ny_sve_lanes3(dis, x, y, ny); - else if (d == lanes * 4) - fvec_op_ny_sve_lanes4(dis, x, y, ny); - else - fvec_inner_products_ny_ref(dis, x, y, d, ny); - break; - } -} - -#elif defined(__aarch64__) - -// not optimized for ARM -void fvec_L2sqr_ny( - float* dis, - const float* x, - const float* y, - size_t d, - size_t ny) { - fvec_L2sqr_ny_ref(dis, x, y, d, ny); -} - -void fvec_L2sqr_ny_transposed( - float* dis, - const float* x, - const float* y, - const float* y_sqlen, - size_t d, - size_t d_offset, - size_t ny) { - return fvec_L2sqr_ny_y_transposed_ref(dis, x, y, y_sqlen, d, d_offset, ny); -} - -size_t fvec_L2sqr_ny_nearest( - float* distances_tmp_buffer, - const float* x, - const float* y, - size_t d, - size_t ny) { - return fvec_L2sqr_ny_nearest_ref(distances_tmp_buffer, x, y, d, ny); -} - -size_t fvec_L2sqr_ny_nearest_y_transposed( - float* distances_tmp_buffer, - const float* x, - const float* y, - const float* y_sqlen, - size_t d, - size_t d_offset, - size_t ny) { - return fvec_L2sqr_ny_nearest_y_transposed_ref( - distances_tmp_buffer, x, y, y_sqlen, d, d_offset, ny); -} - -float fvec_L1(const float* x, const float* y, size_t d) { - return fvec_L1_ref(x, y, d); -} - -float fvec_Linf(const float* x, const float* y, size_t d) { - return fvec_Linf_ref(x, y, d); -} - -void fvec_inner_products_ny( - float* dis, - const float* x, - const float* y, - size_t d, - size_t ny) { - fvec_inner_products_ny_ref(dis, x, y, d, ny); -} - -#else -// scalar implementation - -float fvec_L1(const float* x, const float* y, size_t d) { - return fvec_L1_ref(x, y, d); -} - -float fvec_Linf(const float* x, const float* y, size_t d) { - return fvec_Linf_ref(x, y, d); -} - -void fvec_L2sqr_ny( - float* dis, - const float* x, - const float* y, - size_t d, - size_t ny) { - fvec_L2sqr_ny_ref(dis, x, y, d, ny); -} - -void fvec_L2sqr_ny_transposed( - float* dis, - const float* x, - const float* y, - const float* y_sqlen, - size_t d, - size_t d_offset, - size_t ny) { - return fvec_L2sqr_ny_y_transposed_ref(dis, x, y, y_sqlen, d, d_offset, ny); -} - -size_t fvec_L2sqr_ny_nearest( - float* distances_tmp_buffer, - const float* x, - const float* y, - size_t d, - size_t ny) { - return fvec_L2sqr_ny_nearest_ref(distances_tmp_buffer, x, y, d, ny); -} - -size_t fvec_L2sqr_ny_nearest_y_transposed( - float* distances_tmp_buffer, - const float* x, - const float* y, - const float* y_sqlen, - size_t d, - size_t d_offset, - size_t ny) { - return fvec_L2sqr_ny_nearest_y_transposed_ref( - distances_tmp_buffer, x, y, y_sqlen, d, d_offset, ny); -} - -void fvec_inner_products_ny( - float* dis, - const float* x, - const float* y, - size_t d, - size_t ny) { - fvec_inner_products_ny_ref(dis, x, y, d, ny); -} - -#endif - -/*************************************************************************** - * heavily optimized table computations - ***************************************************************************/ - -[[maybe_unused]] static inline void fvec_madd_ref( - size_t n, - const float* a, - float bf, - const float* b, - float* c) { - for (size_t i = 0; i < n; i++) { - c[i] = a[i] + bf * b[i]; - } -} - -#if defined(__AVX512F__) - -static inline void fvec_madd_avx512( - const size_t n, - const float* __restrict a, - const float bf, - const float* __restrict b, - float* __restrict c) { - const size_t n16 = n / 16; - const size_t n_for_masking = n % 16; - - const __m512 bfmm = _mm512_set1_ps(bf); - - size_t idx = 0; - for (idx = 0; idx < n16 * 16; idx += 16) { - const __m512 ax = _mm512_loadu_ps(a + idx); - const __m512 bx = _mm512_loadu_ps(b + idx); - const __m512 abmul = _mm512_fmadd_ps(bfmm, bx, ax); - _mm512_storeu_ps(c + idx, abmul); - } - - if (n_for_masking > 0) { - const __mmask16 mask = (1 << n_for_masking) - 1; - - const __m512 ax = _mm512_maskz_loadu_ps(mask, a + idx); - const __m512 bx = _mm512_maskz_loadu_ps(mask, b + idx); - const __m512 abmul = _mm512_fmadd_ps(bfmm, bx, ax); - _mm512_mask_storeu_ps(c + idx, mask, abmul); - } -} - -#elif defined(__AVX2__) - -static inline void fvec_madd_avx2( - const size_t n, - const float* __restrict a, - const float bf, - const float* __restrict b, - float* __restrict c) { - // - const size_t n8 = n / 8; - const size_t n_for_masking = n % 8; - - const __m256 bfmm = _mm256_set1_ps(bf); - - size_t idx = 0; - for (idx = 0; idx < n8 * 8; idx += 8) { - const __m256 ax = _mm256_loadu_ps(a + idx); - const __m256 bx = _mm256_loadu_ps(b + idx); - const __m256 abmul = _mm256_fmadd_ps(bfmm, bx, ax); - _mm256_storeu_ps(c + idx, abmul); - } - - if (n_for_masking > 0) { - __m256i mask; - switch (n_for_masking) { - case 1: - mask = _mm256_set_epi32(0, 0, 0, 0, 0, 0, 0, -1); - break; - case 2: - mask = _mm256_set_epi32(0, 0, 0, 0, 0, 0, -1, -1); - break; - case 3: - mask = _mm256_set_epi32(0, 0, 0, 0, 0, -1, -1, -1); - break; - case 4: - mask = _mm256_set_epi32(0, 0, 0, 0, -1, -1, -1, -1); - break; - case 5: - mask = _mm256_set_epi32(0, 0, 0, -1, -1, -1, -1, -1); - break; - case 6: - mask = _mm256_set_epi32(0, 0, -1, -1, -1, -1, -1, -1); - break; - case 7: - mask = _mm256_set_epi32(0, -1, -1, -1, -1, -1, -1, -1); - break; - } - - const __m256 ax = _mm256_maskload_ps(a + idx, mask); - const __m256 bx = _mm256_maskload_ps(b + idx, mask); - const __m256 abmul = _mm256_fmadd_ps(bfmm, bx, ax); - _mm256_maskstore_ps(c + idx, mask, abmul); - } -} - -#endif - -#ifdef __SSE3__ - -[[maybe_unused]] static inline void fvec_madd_sse( - size_t n, - const float* a, - float bf, - const float* b, - float* c) { - n >>= 2; - __m128 bf4 = _mm_set_ps1(bf); - __m128* a4 = (__m128*)a; - __m128* b4 = (__m128*)b; - __m128* c4 = (__m128*)c; - - while (n--) { - *c4 = _mm_add_ps(*a4, _mm_mul_ps(bf4, *b4)); - b4++; - a4++; - c4++; - } -} - -void fvec_madd(size_t n, const float* a, float bf, const float* b, float* c) { -#ifdef __AVX512F__ - fvec_madd_avx512(n, a, bf, b, c); -#elif __AVX2__ - fvec_madd_avx2(n, a, bf, b, c); -#else - if ((n & 3) == 0 && ((((long)a) | ((long)b) | ((long)c)) & 15) == 0) - fvec_madd_sse(n, a, bf, b, c); - else - fvec_madd_ref(n, a, bf, b, c); -#endif -} - -#elif defined(__ARM_FEATURE_SVE) - -void fvec_madd( - const size_t n, - const float* __restrict a, - const float bf, - const float* __restrict b, - float* __restrict c) { - const size_t lanes = static_cast(svcntw()); - const size_t lanes2 = lanes * 2; - const size_t lanes3 = lanes * 3; - const size_t lanes4 = lanes * 4; - size_t i = 0; - for (; i + lanes4 < n; i += lanes4) { - const auto mask = svptrue_b32(); - const auto ai0 = svld1_f32(mask, a + i); - const auto ai1 = svld1_f32(mask, a + i + lanes); - const auto ai2 = svld1_f32(mask, a + i + lanes2); - const auto ai3 = svld1_f32(mask, a + i + lanes3); - const auto bi0 = svld1_f32(mask, b + i); - const auto bi1 = svld1_f32(mask, b + i + lanes); - const auto bi2 = svld1_f32(mask, b + i + lanes2); - const auto bi3 = svld1_f32(mask, b + i + lanes3); - const auto ci0 = svmla_n_f32_x(mask, ai0, bi0, bf); - const auto ci1 = svmla_n_f32_x(mask, ai1, bi1, bf); - const auto ci2 = svmla_n_f32_x(mask, ai2, bi2, bf); - const auto ci3 = svmla_n_f32_x(mask, ai3, bi3, bf); - svst1_f32(mask, c + i, ci0); - svst1_f32(mask, c + i + lanes, ci1); - svst1_f32(mask, c + i + lanes2, ci2); - svst1_f32(mask, c + i + lanes3, ci3); - } - const auto mask0 = svwhilelt_b32_u64(i, n); - const auto mask1 = svwhilelt_b32_u64(i + lanes, n); - const auto mask2 = svwhilelt_b32_u64(i + lanes2, n); - const auto mask3 = svwhilelt_b32_u64(i + lanes3, n); - const auto ai0 = svld1_f32(mask0, a + i); - const auto ai1 = svld1_f32(mask1, a + i + lanes); - const auto ai2 = svld1_f32(mask2, a + i + lanes2); - const auto ai3 = svld1_f32(mask3, a + i + lanes3); - const auto bi0 = svld1_f32(mask0, b + i); - const auto bi1 = svld1_f32(mask1, b + i + lanes); - const auto bi2 = svld1_f32(mask2, b + i + lanes2); - const auto bi3 = svld1_f32(mask3, b + i + lanes3); - const auto ci0 = svmla_n_f32_x(mask0, ai0, bi0, bf); - const auto ci1 = svmla_n_f32_x(mask1, ai1, bi1, bf); - const auto ci2 = svmla_n_f32_x(mask2, ai2, bi2, bf); - const auto ci3 = svmla_n_f32_x(mask3, ai3, bi3, bf); - svst1_f32(mask0, c + i, ci0); - svst1_f32(mask1, c + i + lanes, ci1); - svst1_f32(mask2, c + i + lanes2, ci2); - svst1_f32(mask3, c + i + lanes3, ci3); -} - -#elif defined(__aarch64__) - -void fvec_madd(size_t n, const float* a, float bf, const float* b, float* c) { - const size_t n_simd = n - (n & 3); - const float32x4_t bfv = vdupq_n_f32(bf); - size_t i; - for (i = 0; i < n_simd; i += 4) { - const float32x4_t ai = vld1q_f32(a + i); - const float32x4_t bi = vld1q_f32(b + i); - const float32x4_t ci = vfmaq_f32(ai, bfv, bi); - vst1q_f32(c + i, ci); - } - for (; i < n; ++i) - c[i] = a[i] + bf * b[i]; -} - -#else - -void fvec_madd(size_t n, const float* a, float bf, const float* b, float* c) { - fvec_madd_ref(n, a, bf, b, c); -} - -#endif - -static inline int fvec_madd_and_argmin_ref( - size_t n, - const float* a, - float bf, - const float* b, - float* c) { - float vmin = 1e20; - int imin = -1; - - for (size_t i = 0; i < n; i++) { - c[i] = a[i] + bf * b[i]; - if (c[i] < vmin) { - vmin = c[i]; - imin = i; - } - } - return imin; -} - -#ifdef __SSE3__ - -static inline int fvec_madd_and_argmin_sse( - size_t n, - const float* a, - float bf, - const float* b, - float* c) { - n >>= 2; - __m128 bf4 = _mm_set_ps1(bf); - __m128 vmin4 = _mm_set_ps1(1e20); - __m128i imin4 = _mm_set1_epi32(-1); - __m128i idx4 = _mm_set_epi32(3, 2, 1, 0); - __m128i inc4 = _mm_set1_epi32(4); - __m128* a4 = (__m128*)a; - __m128* b4 = (__m128*)b; - __m128* c4 = (__m128*)c; - - while (n--) { - __m128 vc4 = _mm_add_ps(*a4, _mm_mul_ps(bf4, *b4)); - *c4 = vc4; - __m128i mask = _mm_castps_si128(_mm_cmpgt_ps(vmin4, vc4)); - // imin4 = _mm_blendv_epi8 (imin4, idx4, mask); // slower! - - imin4 = _mm_or_si128( - _mm_and_si128(mask, idx4), _mm_andnot_si128(mask, imin4)); - vmin4 = _mm_min_ps(vmin4, vc4); - b4++; - a4++; - c4++; - idx4 = _mm_add_epi32(idx4, inc4); - } - - // 4 values -> 2 - { - idx4 = _mm_shuffle_epi32(imin4, 3 << 2 | 2); - __m128 vc4 = _mm_shuffle_ps(vmin4, vmin4, 3 << 2 | 2); - __m128i mask = _mm_castps_si128(_mm_cmpgt_ps(vmin4, vc4)); - imin4 = _mm_or_si128( - _mm_and_si128(mask, idx4), _mm_andnot_si128(mask, imin4)); - vmin4 = _mm_min_ps(vmin4, vc4); - } - // 2 values -> 1 - { - idx4 = _mm_shuffle_epi32(imin4, 1); - __m128 vc4 = _mm_shuffle_ps(vmin4, vmin4, 1); - __m128i mask = _mm_castps_si128(_mm_cmpgt_ps(vmin4, vc4)); - imin4 = _mm_or_si128( - _mm_and_si128(mask, idx4), _mm_andnot_si128(mask, imin4)); - // vmin4 = _mm_min_ps (vmin4, vc4); - } - return _mm_cvtsi128_si32(imin4); -} - -int fvec_madd_and_argmin( - size_t n, - const float* a, - float bf, - const float* b, - float* c) { - if ((n & 3) == 0 && ((((long)a) | ((long)b) | ((long)c)) & 15) == 0) { - return fvec_madd_and_argmin_sse(n, a, bf, b, c); - } else { - return fvec_madd_and_argmin_ref(n, a, bf, b, c); - } -} - -#elif defined(__aarch64__) - -int fvec_madd_and_argmin( - size_t n, - const float* a, - float bf, - const float* b, - float* c) { - float32x4_t vminv = vdupq_n_f32(1e20); - uint32x4_t iminv = vdupq_n_u32(static_cast(-1)); - size_t i; - { - const size_t n_simd = n - (n & 3); - const uint32_t iota[] = {0, 1, 2, 3}; - uint32x4_t iv = vld1q_u32(iota); - const uint32x4_t incv = vdupq_n_u32(4); - const float32x4_t bfv = vdupq_n_f32(bf); - for (i = 0; i < n_simd; i += 4) { - const float32x4_t ai = vld1q_f32(a + i); - const float32x4_t bi = vld1q_f32(b + i); - const float32x4_t ci = vfmaq_f32(ai, bfv, bi); - vst1q_f32(c + i, ci); - const uint32x4_t less_than = vcltq_f32(ci, vminv); - vminv = vminq_f32(ci, vminv); - iminv = vorrq_u32( - vandq_u32(less_than, iv), - vandq_u32(vmvnq_u32(less_than), iminv)); - iv = vaddq_u32(iv, incv); - } - } - float vmin = vminvq_f32(vminv); - uint32_t imin; - { - const float32x4_t vminy = vdupq_n_f32(vmin); - const uint32x4_t equals = vceqq_f32(vminv, vminy); - imin = vminvq_u32(vorrq_u32( - vandq_u32(equals, iminv), - vandq_u32( - vmvnq_u32(equals), - vdupq_n_u32(std::numeric_limits::max())))); - } - for (; i < n; ++i) { - c[i] = a[i] + bf * b[i]; - if (c[i] < vmin) { - vmin = c[i]; - imin = static_cast(i); - } - } - return static_cast(imin); -} - -#else +void fvec_madd(size_t n, const float* a, float bf, const float* b, float* c) { + DISPATCH_SIMDLevel(fvec_madd, n, a, bf, b, c); +} int fvec_madd_and_argmin( size_t n, @@ -3603,11 +304,9 @@ int fvec_madd_and_argmin( float bf, const float* b, float* c) { - return fvec_madd_and_argmin_ref(n, a, bf, b, c); + DISPATCH_SIMDLevel(fvec_madd_and_argmin, n, a, bf, b, c); } -#endif - /*************************************************************************** * PQ tables computations ***************************************************************************/ diff --git a/faiss/utils/extra_distances-inl.h b/faiss/utils/extra_distances-inl.h index 6a374ed518..066ba55590 100644 --- a/faiss/utils/extra_distances-inl.h +++ b/faiss/utils/extra_distances-inl.h @@ -59,13 +59,6 @@ inline float VectorDistance::operator()( const float* x, const float* y) const { return fvec_Linf(x, y, d); - /* - float vmax = 0; - for (size_t i = 0; i < d; i++) { - float diff = fabs (x[i] - y[i]); - if (diff > vmax) vmax = diff; - } - return vmax;*/ } template <> diff --git a/faiss/utils/simd_impl/distances_aarch64.cpp b/faiss/utils/simd_impl/distances_aarch64.cpp new file mode 100644 index 0000000000..33ad9bbc4f --- /dev/null +++ b/faiss/utils/simd_impl/distances_aarch64.cpp @@ -0,0 +1,137 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * + * This source code is licensed under the MIT license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include + +#define AUTOVEC_LEVEL SIMDLevel::ARM_NEON +#include + +namespace faiss { + +template <> +void fvec_madd( + size_t n, + const float* a, + float bf, + const float* b, + float* c) { + const size_t n_simd = n - (n & 3); + const float32x4_t bfv = vdupq_n_f32(bf); + size_t i; + for (i = 0; i < n_simd; i += 4) { + const float32x4_t ai = vld1q_f32(a + i); + const float32x4_t bi = vld1q_f32(b + i); + const float32x4_t ci = vfmaq_f32(ai, bfv, bi); + vst1q_f32(c + i, ci); + } + for (; i < n; ++i) + c[i] = a[i] + bf * b[i]; +} + +template <> +void fvec_L2sqr_ny_transposed( + float* dis, + const float* x, + const float* y, + const float* y_sqlen, + size_t d, + size_t d_offset, + size_t ny); + +template <> +void fvec_inner_products_ny( + float* dis, + const float* x, + const float* y, + size_t d, + size_t ny) { + fvec_inner_products_ny(dis, x, y, d, ny); +} + +template <> +void fvec_L2sqr_ny( + float* dis, + const float* x, + const float* y, + size_t d, + size_t ny) { + fvec_L2sqr_ny(dis, x, y, d, ny); +} + +template <> +size_t fvec_L2sqr_ny_nearest( + float* distances_tmp_buffer, + const float* x, + const float* y, + size_t d, + size_t ny) { + fvec_L2sqr_ny_nearest(distances_tmp_buffer, x, y, d, ny); +} + +size_t fvec_L2sqr_ny_nearest_y_transposed( + float* distances_tmp_buffer, + const float* x, + const float* y, + const float* y_sqlen, + size_t d, + size_t d_offset, + size_t ny) { + return fvec_L2sqr_ny_nearest_y_transposed_ref( + distances_tmp_buffer, x, y, y_sqlen, d, d_offset, ny); +} + +template <> +int fvec_madd_and_argmin( + size_t n, + const float* a, + float bf, + const float* b, + float* c) { + float32x4_t vminv = vdupq_n_f32(1e20); + uint32x4_t iminv = vdupq_n_u32(static_cast(-1)); + size_t i; + { + const size_t n_simd = n - (n & 3); + const uint32_t iota[] = {0, 1, 2, 3}; + uint32x4_t iv = vld1q_u32(iota); + const uint32x4_t incv = vdupq_n_u32(4); + const float32x4_t bfv = vdupq_n_f32(bf); + for (i = 0; i < n_simd; i += 4) { + const float32x4_t ai = vld1q_f32(a + i); + const float32x4_t bi = vld1q_f32(b + i); + const float32x4_t ci = vfmaq_f32(ai, bfv, bi); + vst1q_f32(c + i, ci); + const uint32x4_t less_than = vcltq_f32(ci, vminv); + vminv = vminq_f32(ci, vminv); + iminv = vorrq_u32( + vandq_u32(less_than, iv), + vandq_u32(vmvnq_u32(less_than), iminv)); + iv = vaddq_u32(iv, incv); + } + } + float vmin = vminvq_f32(vminv); + uint32_t imin; + { + const float32x4_t vminy = vdupq_n_f32(vmin); + const uint32x4_t equals = vceqq_f32(vminv, vminy); + imin = vminvq_u32(vorrq_u32( + vandq_u32(equals, iminv), + vandq_u32( + vmvnq_u32(equals), + vdupq_n_u32(std::numeric_limits::max())))); + } + for (; i < n; ++i) { + c[i] = a[i] + bf * b[i]; + if (c[i] < vmin) { + vmin = c[i]; + imin = static_cast(i); + } + } + return static_cast(imin); +} + +} // namespace faiss diff --git a/faiss/utils/simd_impl/distances_arm_sve.cpp b/faiss/utils/simd_impl/distances_arm_sve.cpp new file mode 100644 index 0000000000..3bd4227da0 --- /dev/null +++ b/faiss/utils/simd_impl/distances_arm_sve.cpp @@ -0,0 +1,496 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * + * This source code is licensed under the MIT license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include + +#define AUTOVEC_LEVEL SIMDLevel::ARM_SVE +#include + +namespace faiss { + +template <> +void fvec_madd( + const size_t n, + const float* __restrict a, + const float bf, + const float* __restrict b, + float* __restrict c) { + const size_t lanes = static_cast(svcntw()); + const size_t lanes2 = lanes * 2; + const size_t lanes3 = lanes * 3; + const size_t lanes4 = lanes * 4; + size_t i = 0; + for (; i + lanes4 < n; i += lanes4) { + const auto mask = svptrue_b32(); + const auto ai0 = svld1_f32(mask, a + i); + const auto ai1 = svld1_f32(mask, a + i + lanes); + const auto ai2 = svld1_f32(mask, a + i + lanes2); + const auto ai3 = svld1_f32(mask, a + i + lanes3); + const auto bi0 = svld1_f32(mask, b + i); + const auto bi1 = svld1_f32(mask, b + i + lanes); + const auto bi2 = svld1_f32(mask, b + i + lanes2); + const auto bi3 = svld1_f32(mask, b + i + lanes3); + const auto ci0 = svmla_n_f32_x(mask, ai0, bi0, bf); + const auto ci1 = svmla_n_f32_x(mask, ai1, bi1, bf); + const auto ci2 = svmla_n_f32_x(mask, ai2, bi2, bf); + const auto ci3 = svmla_n_f32_x(mask, ai3, bi3, bf); + svst1_f32(mask, c + i, ci0); + svst1_f32(mask, c + i + lanes, ci1); + svst1_f32(mask, c + i + lanes2, ci2); + svst1_f32(mask, c + i + lanes3, ci3); + } + const auto mask0 = svwhilelt_b32_u64(i, n); + const auto mask1 = svwhilelt_b32_u64(i + lanes, n); + const auto mask2 = svwhilelt_b32_u64(i + lanes2, n); + const auto mask3 = svwhilelt_b32_u64(i + lanes3, n); + const auto ai0 = svld1_f32(mask0, a + i); + const auto ai1 = svld1_f32(mask1, a + i + lanes); + const auto ai2 = svld1_f32(mask2, a + i + lanes2); + const auto ai3 = svld1_f32(mask3, a + i + lanes3); + const auto bi0 = svld1_f32(mask0, b + i); + const auto bi1 = svld1_f32(mask1, b + i + lanes); + const auto bi2 = svld1_f32(mask2, b + i + lanes2); + const auto bi3 = svld1_f32(mask3, b + i + lanes3); + const auto ci0 = svmla_n_f32_x(mask0, ai0, bi0, bf); + const auto ci1 = svmla_n_f32_x(mask1, ai1, bi1, bf); + const auto ci2 = svmla_n_f32_x(mask2, ai2, bi2, bf); + const auto ci3 = svmla_n_f32_x(mask3, ai3, bi3, bf); + svst1_f32(mask0, c + i, ci0); + svst1_f32(mask1, c + i + lanes, ci1); + svst1_f32(mask2, c + i + lanes2, ci2); + svst1_f32(mask3, c + i + lanes3, ci3); +} + +template <> +void fvec_L2sqr_ny_transposed( + float* dis, + const float* x, + const float* y, + const float* y_sqlen, + size_t d, + size_t d_offset, + size_t ny); + +struct ElementOpIP { + static svfloat32_t op(svbool_t pg, svfloat32_t x, svfloat32_t y) { + return svmul_f32_x(pg, x, y); + } + static svfloat32_t merge( + svbool_t pg, + svfloat32_t z, + svfloat32_t x, + svfloat32_t y) { + return svmla_f32_x(pg, z, x, y); + } +}; + +template +void fvec_op_ny_sve_d1(float* dis, const float* x, const float* y, size_t ny) { + const size_t lanes = svcntw(); + const size_t lanes2 = lanes * 2; + const size_t lanes3 = lanes * 3; + const size_t lanes4 = lanes * 4; + const svbool_t pg = svptrue_b32(); + const svfloat32_t x0 = svdup_n_f32(x[0]); + size_t i = 0; + for (; i + lanes4 < ny; i += lanes4) { + svfloat32_t y0 = svld1_f32(pg, y); + svfloat32_t y1 = svld1_f32(pg, y + lanes); + svfloat32_t y2 = svld1_f32(pg, y + lanes2); + svfloat32_t y3 = svld1_f32(pg, y + lanes3); + y0 = ElementOp::op(pg, x0, y0); + y1 = ElementOp::op(pg, x0, y1); + y2 = ElementOp::op(pg, x0, y2); + y3 = ElementOp::op(pg, x0, y3); + svst1_f32(pg, dis, y0); + svst1_f32(pg, dis + lanes, y1); + svst1_f32(pg, dis + lanes2, y2); + svst1_f32(pg, dis + lanes3, y3); + y += lanes4; + dis += lanes4; + } + const svbool_t pg0 = svwhilelt_b32_u64(i, ny); + const svbool_t pg1 = svwhilelt_b32_u64(i + lanes, ny); + const svbool_t pg2 = svwhilelt_b32_u64(i + lanes2, ny); + const svbool_t pg3 = svwhilelt_b32_u64(i + lanes3, ny); + svfloat32_t y0 = svld1_f32(pg0, y); + svfloat32_t y1 = svld1_f32(pg1, y + lanes); + svfloat32_t y2 = svld1_f32(pg2, y + lanes2); + svfloat32_t y3 = svld1_f32(pg3, y + lanes3); + y0 = ElementOp::op(pg0, x0, y0); + y1 = ElementOp::op(pg1, x0, y1); + y2 = ElementOp::op(pg2, x0, y2); + y3 = ElementOp::op(pg3, x0, y3); + svst1_f32(pg0, dis, y0); + svst1_f32(pg1, dis + lanes, y1); + svst1_f32(pg2, dis + lanes2, y2); + svst1_f32(pg3, dis + lanes3, y3); +} + +template +void fvec_op_ny_sve_d2(float* dis, const float* x, const float* y, size_t ny) { + const size_t lanes = svcntw(); + const size_t lanes2 = lanes * 2; + const size_t lanes4 = lanes * 4; + const svbool_t pg = svptrue_b32(); + const svfloat32_t x0 = svdup_n_f32(x[0]); + const svfloat32_t x1 = svdup_n_f32(x[1]); + size_t i = 0; + for (; i + lanes2 < ny; i += lanes2) { + const svfloat32x2_t y0 = svld2_f32(pg, y); + const svfloat32x2_t y1 = svld2_f32(pg, y + lanes2); + svfloat32_t y00 = svget2_f32(y0, 0); + const svfloat32_t y01 = svget2_f32(y0, 1); + svfloat32_t y10 = svget2_f32(y1, 0); + const svfloat32_t y11 = svget2_f32(y1, 1); + y00 = ElementOp::op(pg, x0, y00); + y10 = ElementOp::op(pg, x0, y10); + y00 = ElementOp::merge(pg, y00, x1, y01); + y10 = ElementOp::merge(pg, y10, x1, y11); + svst1_f32(pg, dis, y00); + svst1_f32(pg, dis + lanes, y10); + y += lanes4; + dis += lanes2; + } + const svbool_t pg0 = svwhilelt_b32_u64(i, ny); + const svbool_t pg1 = svwhilelt_b32_u64(i + lanes, ny); + const svfloat32x2_t y0 = svld2_f32(pg0, y); + const svfloat32x2_t y1 = svld2_f32(pg1, y + lanes2); + svfloat32_t y00 = svget2_f32(y0, 0); + const svfloat32_t y01 = svget2_f32(y0, 1); + svfloat32_t y10 = svget2_f32(y1, 0); + const svfloat32_t y11 = svget2_f32(y1, 1); + y00 = ElementOp::op(pg0, x0, y00); + y10 = ElementOp::op(pg1, x0, y10); + y00 = ElementOp::merge(pg0, y00, x1, y01); + y10 = ElementOp::merge(pg1, y10, x1, y11); + svst1_f32(pg0, dis, y00); + svst1_f32(pg1, dis + lanes, y10); +} + +template +void fvec_op_ny_sve_d4(float* dis, const float* x, const float* y, size_t ny) { + const size_t lanes = svcntw(); + const size_t lanes4 = lanes * 4; + const svbool_t pg = svptrue_b32(); + const svfloat32_t x0 = svdup_n_f32(x[0]); + const svfloat32_t x1 = svdup_n_f32(x[1]); + const svfloat32_t x2 = svdup_n_f32(x[2]); + const svfloat32_t x3 = svdup_n_f32(x[3]); + size_t i = 0; + for (; i + lanes < ny; i += lanes) { + const svfloat32x4_t y0 = svld4_f32(pg, y); + svfloat32_t y00 = svget4_f32(y0, 0); + const svfloat32_t y01 = svget4_f32(y0, 1); + svfloat32_t y02 = svget4_f32(y0, 2); + const svfloat32_t y03 = svget4_f32(y0, 3); + y00 = ElementOp::op(pg, x0, y00); + y02 = ElementOp::op(pg, x2, y02); + y00 = ElementOp::merge(pg, y00, x1, y01); + y02 = ElementOp::merge(pg, y02, x3, y03); + y00 = svadd_f32_x(pg, y00, y02); + svst1_f32(pg, dis, y00); + y += lanes4; + dis += lanes; + } + const svbool_t pg0 = svwhilelt_b32_u64(i, ny); + const svfloat32x4_t y0 = svld4_f32(pg0, y); + svfloat32_t y00 = svget4_f32(y0, 0); + const svfloat32_t y01 = svget4_f32(y0, 1); + svfloat32_t y02 = svget4_f32(y0, 2); + const svfloat32_t y03 = svget4_f32(y0, 3); + y00 = ElementOp::op(pg0, x0, y00); + y02 = ElementOp::op(pg0, x2, y02); + y00 = ElementOp::merge(pg0, y00, x1, y01); + y02 = ElementOp::merge(pg0, y02, x3, y03); + y00 = svadd_f32_x(pg0, y00, y02); + svst1_f32(pg0, dis, y00); +} + +template +void fvec_op_ny_sve_d8(float* dis, const float* x, const float* y, size_t ny) { + const size_t lanes = svcntw(); + const size_t lanes4 = lanes * 4; + const size_t lanes8 = lanes * 8; + const svbool_t pg = svptrue_b32(); + const svfloat32_t x0 = svdup_n_f32(x[0]); + const svfloat32_t x1 = svdup_n_f32(x[1]); + const svfloat32_t x2 = svdup_n_f32(x[2]); + const svfloat32_t x3 = svdup_n_f32(x[3]); + const svfloat32_t x4 = svdup_n_f32(x[4]); + const svfloat32_t x5 = svdup_n_f32(x[5]); + const svfloat32_t x6 = svdup_n_f32(x[6]); + const svfloat32_t x7 = svdup_n_f32(x[7]); + size_t i = 0; + for (; i + lanes < ny; i += lanes) { + const svfloat32x4_t ya = svld4_f32(pg, y); + const svfloat32x4_t yb = svld4_f32(pg, y + lanes4); + const svfloat32_t ya0 = svget4_f32(ya, 0); + const svfloat32_t ya1 = svget4_f32(ya, 1); + const svfloat32_t ya2 = svget4_f32(ya, 2); + const svfloat32_t ya3 = svget4_f32(ya, 3); + const svfloat32_t yb0 = svget4_f32(yb, 0); + const svfloat32_t yb1 = svget4_f32(yb, 1); + const svfloat32_t yb2 = svget4_f32(yb, 2); + const svfloat32_t yb3 = svget4_f32(yb, 3); + svfloat32_t y0 = svuzp1(ya0, yb0); + const svfloat32_t y1 = svuzp1(ya1, yb1); + svfloat32_t y2 = svuzp1(ya2, yb2); + const svfloat32_t y3 = svuzp1(ya3, yb3); + svfloat32_t y4 = svuzp2(ya0, yb0); + const svfloat32_t y5 = svuzp2(ya1, yb1); + svfloat32_t y6 = svuzp2(ya2, yb2); + const svfloat32_t y7 = svuzp2(ya3, yb3); + y0 = ElementOp::op(pg, x0, y0); + y2 = ElementOp::op(pg, x2, y2); + y4 = ElementOp::op(pg, x4, y4); + y6 = ElementOp::op(pg, x6, y6); + y0 = ElementOp::merge(pg, y0, x1, y1); + y2 = ElementOp::merge(pg, y2, x3, y3); + y4 = ElementOp::merge(pg, y4, x5, y5); + y6 = ElementOp::merge(pg, y6, x7, y7); + y0 = svadd_f32_x(pg, y0, y2); + y4 = svadd_f32_x(pg, y4, y6); + y0 = svadd_f32_x(pg, y0, y4); + svst1_f32(pg, dis, y0); + y += lanes8; + dis += lanes; + } + const svbool_t pg0 = svwhilelt_b32_u64(i, ny); + const svbool_t pga = svwhilelt_b32_u64(i * 2, ny * 2); + const svbool_t pgb = svwhilelt_b32_u64(i * 2 + lanes, ny * 2); + const svfloat32x4_t ya = svld4_f32(pga, y); + const svfloat32x4_t yb = svld4_f32(pgb, y + lanes4); + const svfloat32_t ya0 = svget4_f32(ya, 0); + const svfloat32_t ya1 = svget4_f32(ya, 1); + const svfloat32_t ya2 = svget4_f32(ya, 2); + const svfloat32_t ya3 = svget4_f32(ya, 3); + const svfloat32_t yb0 = svget4_f32(yb, 0); + const svfloat32_t yb1 = svget4_f32(yb, 1); + const svfloat32_t yb2 = svget4_f32(yb, 2); + const svfloat32_t yb3 = svget4_f32(yb, 3); + svfloat32_t y0 = svuzp1(ya0, yb0); + const svfloat32_t y1 = svuzp1(ya1, yb1); + svfloat32_t y2 = svuzp1(ya2, yb2); + const svfloat32_t y3 = svuzp1(ya3, yb3); + svfloat32_t y4 = svuzp2(ya0, yb0); + const svfloat32_t y5 = svuzp2(ya1, yb1); + svfloat32_t y6 = svuzp2(ya2, yb2); + const svfloat32_t y7 = svuzp2(ya3, yb3); + y0 = ElementOp::op(pg0, x0, y0); + y2 = ElementOp::op(pg0, x2, y2); + y4 = ElementOp::op(pg0, x4, y4); + y6 = ElementOp::op(pg0, x6, y6); + y0 = ElementOp::merge(pg0, y0, x1, y1); + y2 = ElementOp::merge(pg0, y2, x3, y3); + y4 = ElementOp::merge(pg0, y4, x5, y5); + y6 = ElementOp::merge(pg0, y6, x7, y7); + y0 = svadd_f32_x(pg0, y0, y2); + y4 = svadd_f32_x(pg0, y4, y6); + y0 = svadd_f32_x(pg0, y0, y4); + svst1_f32(pg0, dis, y0); + y += lanes8; + dis += lanes; +} + +template +void fvec_op_ny_sve_lanes1( + float* dis, + const float* x, + const float* y, + size_t ny) { + const size_t lanes = svcntw(); + const size_t lanes2 = lanes * 2; + const size_t lanes3 = lanes * 3; + const size_t lanes4 = lanes * 4; + const svbool_t pg = svptrue_b32(); + const svfloat32_t x0 = svld1_f32(pg, x); + size_t i = 0; + for (; i + 3 < ny; i += 4) { + svfloat32_t y0 = svld1_f32(pg, y); + svfloat32_t y1 = svld1_f32(pg, y + lanes); + svfloat32_t y2 = svld1_f32(pg, y + lanes2); + svfloat32_t y3 = svld1_f32(pg, y + lanes3); + y += lanes4; + y0 = ElementOp::op(pg, x0, y0); + y1 = ElementOp::op(pg, x0, y1); + y2 = ElementOp::op(pg, x0, y2); + y3 = ElementOp::op(pg, x0, y3); + dis[i] = svaddv_f32(pg, y0); + dis[i + 1] = svaddv_f32(pg, y1); + dis[i + 2] = svaddv_f32(pg, y2); + dis[i + 3] = svaddv_f32(pg, y3); + } + for (; i < ny; ++i) { + svfloat32_t y0 = svld1_f32(pg, y); + y += lanes; + y0 = ElementOp::op(pg, x0, y0); + dis[i] = svaddv_f32(pg, y0); + } +} + +template +void fvec_op_ny_sve_lanes2( + float* dis, + const float* x, + const float* y, + size_t ny) { + const size_t lanes = svcntw(); + const size_t lanes2 = lanes * 2; + const size_t lanes3 = lanes * 3; + const size_t lanes4 = lanes * 4; + const svbool_t pg = svptrue_b32(); + const svfloat32_t x0 = svld1_f32(pg, x); + const svfloat32_t x1 = svld1_f32(pg, x + lanes); + size_t i = 0; + for (; i + 1 < ny; i += 2) { + svfloat32_t y00 = svld1_f32(pg, y); + const svfloat32_t y01 = svld1_f32(pg, y + lanes); + svfloat32_t y10 = svld1_f32(pg, y + lanes2); + const svfloat32_t y11 = svld1_f32(pg, y + lanes3); + y += lanes4; + y00 = ElementOp::op(pg, x0, y00); + y10 = ElementOp::op(pg, x0, y10); + y00 = ElementOp::merge(pg, y00, x1, y01); + y10 = ElementOp::merge(pg, y10, x1, y11); + dis[i] = svaddv_f32(pg, y00); + dis[i + 1] = svaddv_f32(pg, y10); + } + if (i < ny) { + svfloat32_t y0 = svld1_f32(pg, y); + const svfloat32_t y1 = svld1_f32(pg, y + lanes); + y0 = ElementOp::op(pg, x0, y0); + y0 = ElementOp::merge(pg, y0, x1, y1); + dis[i] = svaddv_f32(pg, y0); + } +} + +template +void fvec_op_ny_sve_lanes3( + float* dis, + const float* x, + const float* y, + size_t ny) { + const size_t lanes = svcntw(); + const size_t lanes2 = lanes * 2; + const size_t lanes3 = lanes * 3; + const svbool_t pg = svptrue_b32(); + const svfloat32_t x0 = svld1_f32(pg, x); + const svfloat32_t x1 = svld1_f32(pg, x + lanes); + const svfloat32_t x2 = svld1_f32(pg, x + lanes2); + for (size_t i = 0; i < ny; ++i) { + svfloat32_t y0 = svld1_f32(pg, y); + const svfloat32_t y1 = svld1_f32(pg, y + lanes); + svfloat32_t y2 = svld1_f32(pg, y + lanes2); + y += lanes3; + y0 = ElementOp::op(pg, x0, y0); + y0 = ElementOp::merge(pg, y0, x1, y1); + y0 = ElementOp::merge(pg, y0, x2, y2); + dis[i] = svaddv_f32(pg, y0); + } +} + +template +void fvec_op_ny_sve_lanes4( + float* dis, + const float* x, + const float* y, + size_t ny) { + const size_t lanes = svcntw(); + const size_t lanes2 = lanes * 2; + const size_t lanes3 = lanes * 3; + const size_t lanes4 = lanes * 4; + const svbool_t pg = svptrue_b32(); + const svfloat32_t x0 = svld1_f32(pg, x); + const svfloat32_t x1 = svld1_f32(pg, x + lanes); + const svfloat32_t x2 = svld1_f32(pg, x + lanes2); + const svfloat32_t x3 = svld1_f32(pg, x + lanes3); + for (size_t i = 0; i < ny; ++i) { + svfloat32_t y0 = svld1_f32(pg, y); + const svfloat32_t y1 = svld1_f32(pg, y + lanes); + svfloat32_t y2 = svld1_f32(pg, y + lanes2); + const svfloat32_t y3 = svld1_f32(pg, y + lanes3); + y += lanes4; + y0 = ElementOp::op(pg, x0, y0); + y2 = ElementOp::op(pg, x2, y2); + y0 = ElementOp::merge(pg, y0, x1, y1); + y2 = ElementOp::merge(pg, y2, x3, y3); + y0 = svadd_f32_x(pg, y0, y2); + dis[i] = svaddv_f32(pg, y0); + } +} + +template <> +void fvec_inner_products_ny( + float* dis, + const float* x, + const float* y, + size_t d, + size_t ny) { + const size_t lanes = svcntw(); + switch (d) { + case 1: + fvec_op_ny_sve_d1(dis, x, y, ny); + break; + case 2: + fvec_op_ny_sve_d2(dis, x, y, ny); + break; + case 4: + fvec_op_ny_sve_d4(dis, x, y, ny); + break; + case 8: + fvec_op_ny_sve_d8(dis, x, y, ny); + break; + default: + if (d == lanes) + fvec_op_ny_sve_lanes1(dis, x, y, ny); + else if (d == lanes * 2) + fvec_op_ny_sve_lanes2(dis, x, y, ny); + else if (d == lanes * 3) + fvec_op_ny_sve_lanes3(dis, x, y, ny); + else if (d == lanes * 4) + fvec_op_ny_sve_lanes4(dis, x, y, ny); + else + fvec_inner_products_ny(dis, x, y, d, ny); + break; + } +} + +template <> +void fvec_L2sqr_ny( + float* dis, + const float* x, + const float* y, + size_t d, + size_t ny) { + fvec_L2sqr_ny(dis, x, y, d, ny); +} + +template <> +size_t fvec_L2sqr_ny_nearest( + float* distances_tmp_buffer, + const float* x, + const float* y, + size_t d, + size_t ny) { + fvec_L2sqr_ny_nearest( + distances_tmp_buffer, x, y, d, ny); +} + +size_t fvec_L2sqr_ny_nearest_y_transposed( + float* distances_tmp_buffer, + const float* x, + const float* y, + const float* y_sqlen, + size_t d, + size_t d_offset, + size_t ny) { + return fvec_L2sqr_ny_nearest_y_transposed( + distances_tmp_buffer, x, y, y_sqlen, d, d_offset, ny); +} + +} // namespace faiss diff --git a/faiss/utils/simd_impl/distances_autovec-inl.h b/faiss/utils/simd_impl/distances_autovec-inl.h new file mode 100644 index 0000000000..62d13eb38e --- /dev/null +++ b/faiss/utils/simd_impl/distances_autovec-inl.h @@ -0,0 +1,153 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * + * This source code is licensed under the MIT license found in the + * LICENSE file in the root directory of this source tree. + */ + +#pragma once + +#include +#include + +namespace faiss { + +FAISS_PRAGMA_IMPRECISE_FUNCTION_BEGIN +template <> +float fvec_norm_L2sqr(const float* x, size_t d) { + // the double in the _ref is suspected to be a typo. Some of the manual + // implementations this replaces used float. + float res = 0; + FAISS_PRAGMA_IMPRECISE_LOOP + for (size_t i = 0; i != d; ++i) { + res += x[i] * x[i]; + } + + return res; +} +FAISS_PRAGMA_IMPRECISE_FUNCTION_END + +FAISS_PRAGMA_IMPRECISE_FUNCTION_BEGIN +template <> +float fvec_L2sqr(const float* x, const float* y, size_t d) { + size_t i; + float res = 0; + FAISS_PRAGMA_IMPRECISE_LOOP + for (i = 0; i < d; i++) { + const float tmp = x[i] - y[i]; + res += tmp * tmp; + } + return res; +} +FAISS_PRAGMA_IMPRECISE_FUNCTION_END + +FAISS_PRAGMA_IMPRECISE_FUNCTION_BEGIN +template <> +float fvec_inner_product( + const float* x, + const float* y, + size_t d) { + float res = 0.F; + FAISS_PRAGMA_IMPRECISE_LOOP + for (size_t i = 0; i != d; ++i) { + res += x[i] * y[i]; + } + return res; +} +FAISS_PRAGMA_IMPRECISE_FUNCTION_END + +FAISS_PRAGMA_IMPRECISE_FUNCTION_BEGIN +template <> +float fvec_L1(const float* x, const float* y, size_t d) { + size_t i; + float res = 0; + FAISS_PRAGMA_IMPRECISE_LOOP + for (i = 0; i < d; i++) { + const float tmp = x[i] - y[i]; + res += fabs(tmp); + } + return res; +} +FAISS_PRAGMA_IMPRECISE_FUNCTION_END + +FAISS_PRAGMA_IMPRECISE_FUNCTION_BEGIN +template <> +float fvec_Linf(const float* x, const float* y, size_t d) { + float res = 0; + FAISS_PRAGMA_IMPRECISE_LOOP + for (size_t i = 0; i < d; i++) { + res = fmax(res, fabs(x[i] - y[i])); + } + return res; +} +FAISS_PRAGMA_IMPRECISE_FUNCTION_END + +FAISS_PRAGMA_IMPRECISE_FUNCTION_BEGIN +template <> +void fvec_inner_product_batch_4( + const float* __restrict x, + const float* __restrict y0, + const float* __restrict y1, + const float* __restrict y2, + const float* __restrict y3, + const size_t d, + float& dis0, + float& dis1, + float& dis2, + float& dis3) { + float d0 = 0; + float d1 = 0; + float d2 = 0; + float d3 = 0; + FAISS_PRAGMA_IMPRECISE_LOOP + for (size_t i = 0; i < d; ++i) { + d0 += x[i] * y0[i]; + d1 += x[i] * y1[i]; + d2 += x[i] * y2[i]; + d3 += x[i] * y3[i]; + } + + dis0 = d0; + dis1 = d1; + dis2 = d2; + dis3 = d3; +} +FAISS_PRAGMA_IMPRECISE_FUNCTION_END + +FAISS_PRAGMA_IMPRECISE_FUNCTION_BEGIN +template <> +void fvec_L2sqr_batch_4( + const float* x, + const float* y0, + const float* y1, + const float* y2, + const float* y3, + const size_t d, + float& dis0, + float& dis1, + float& dis2, + float& dis3) { + float d0 = 0; + float d1 = 0; + float d2 = 0; + float d3 = 0; + FAISS_PRAGMA_IMPRECISE_LOOP + for (size_t i = 0; i < d; ++i) { + const float q0 = x[i] - y0[i]; + const float q1 = x[i] - y1[i]; + const float q2 = x[i] - y2[i]; + const float q3 = x[i] - y3[i]; + d0 += q0 * q0; + d1 += q1 * q1; + d2 += q2 * q2; + d3 += q3 * q3; + } + + dis0 = d0; + dis1 = d1; + dis2 = d2; + dis3 = d3; +} +FAISS_PRAGMA_IMPRECISE_FUNCTION_END + +} // namespace faiss diff --git a/faiss/utils/simd_impl/distances_avx.cpp b/faiss/utils/simd_impl/distances_avx.cpp new file mode 100644 index 0000000000..c29e64c91f --- /dev/null +++ b/faiss/utils/simd_impl/distances_avx.cpp @@ -0,0 +1,99 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * + * This source code is licensed under the MIT license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include +#include + +#ifdef __AVX__ + +float fvec_L1(const float* x, const float* y, size_t d) { + __m256 msum1 = _mm256_setzero_ps(); + // signmask used for absolute value + __m256 signmask = _mm256_castsi256_ps(_mm256_set1_epi32(0x7fffffffUL)); + + while (d >= 8) { + __m256 mx = _mm256_loadu_ps(x); + x += 8; + __m256 my = _mm256_loadu_ps(y); + y += 8; + // subtract + const __m256 a_m_b = _mm256_sub_ps(mx, my); + // find sum of absolute value of distances (manhattan distance) + msum1 = _mm256_add_ps(msum1, _mm256_and_ps(signmask, a_m_b)); + d -= 8; + } + + __m128 msum2 = _mm256_extractf128_ps(msum1, 1); + msum2 = _mm_add_ps(msum2, _mm256_extractf128_ps(msum1, 0)); + __m128 signmask2 = _mm_castsi128_ps(_mm_set1_epi32(0x7fffffffUL)); + + if (d >= 4) { + __m128 mx = _mm_loadu_ps(x); + x += 4; + __m128 my = _mm_loadu_ps(y); + y += 4; + const __m128 a_m_b = _mm_sub_ps(mx, my); + msum2 = _mm_add_ps(msum2, _mm_and_ps(signmask2, a_m_b)); + d -= 4; + } + + if (d > 0) { + __m128 mx = masked_read(d, x); + __m128 my = masked_read(d, y); + __m128 a_m_b = _mm_sub_ps(mx, my); + msum2 = _mm_add_ps(msum2, _mm_and_ps(signmask2, a_m_b)); + } + + msum2 = _mm_hadd_ps(msum2, msum2); + msum2 = _mm_hadd_ps(msum2, msum2); + return _mm_cvtss_f32(msum2); +} + +float fvec_Linf(const float* x, const float* y, size_t d) { + __m256 msum1 = _mm256_setzero_ps(); + // signmask used for absolute value + __m256 signmask = _mm256_castsi256_ps(_mm256_set1_epi32(0x7fffffffUL)); + + while (d >= 8) { + __m256 mx = _mm256_loadu_ps(x); + x += 8; + __m256 my = _mm256_loadu_ps(y); + y += 8; + // subtract + const __m256 a_m_b = _mm256_sub_ps(mx, my); + // find max of absolute value of distances (chebyshev distance) + msum1 = _mm256_max_ps(msum1, _mm256_and_ps(signmask, a_m_b)); + d -= 8; + } + + __m128 msum2 = _mm256_extractf128_ps(msum1, 1); + msum2 = _mm_max_ps(msum2, _mm256_extractf128_ps(msum1, 0)); + __m128 signmask2 = _mm_castsi128_ps(_mm_set1_epi32(0x7fffffffUL)); + + if (d >= 4) { + __m128 mx = _mm_loadu_ps(x); + x += 4; + __m128 my = _mm_loadu_ps(y); + y += 4; + const __m128 a_m_b = _mm_sub_ps(mx, my); + msum2 = _mm_max_ps(msum2, _mm_and_ps(signmask2, a_m_b)); + d -= 4; + } + + if (d > 0) { + __m128 mx = masked_read(d, x); + __m128 my = masked_read(d, y); + __m128 a_m_b = _mm_sub_ps(mx, my); + msum2 = _mm_max_ps(msum2, _mm_and_ps(signmask2, a_m_b)); + } + + msum2 = _mm_max_ps(_mm_movehl_ps(msum2, msum2), msum2); + msum2 = _mm_max_ps(msum2, _mm_shuffle_ps(msum2, msum2, 1)); + return _mm_cvtss_f32(msum2); +} + +#endif diff --git a/faiss/utils/simd_impl/distances_avx2.cpp b/faiss/utils/simd_impl/distances_avx2.cpp new file mode 100644 index 0000000000..acfcbabe17 --- /dev/null +++ b/faiss/utils/simd_impl/distances_avx2.cpp @@ -0,0 +1,1178 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * + * This source code is licensed under the MIT license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include + +#include + +#define AUTOVEC_LEVEL SIMDLevel::AVX2 +#include + +#include +#include + +namespace faiss { + +template <> +void fvec_madd( + const size_t n, + const float* __restrict a, + const float bf, + const float* __restrict b, + float* __restrict c) { + // + const size_t n8 = n / 8; + const size_t n_for_masking = n % 8; + + const __m256 bfmm = _mm256_set1_ps(bf); + + size_t idx = 0; + for (idx = 0; idx < n8 * 8; idx += 8) { + const __m256 ax = _mm256_loadu_ps(a + idx); + const __m256 bx = _mm256_loadu_ps(b + idx); + const __m256 abmul = _mm256_fmadd_ps(bfmm, bx, ax); + _mm256_storeu_ps(c + idx, abmul); + } + + if (n_for_masking > 0) { + __m256i mask; + switch (n_for_masking) { + case 1: + mask = _mm256_set_epi32(0, 0, 0, 0, 0, 0, 0, -1); + break; + case 2: + mask = _mm256_set_epi32(0, 0, 0, 0, 0, 0, -1, -1); + break; + case 3: + mask = _mm256_set_epi32(0, 0, 0, 0, 0, -1, -1, -1); + break; + case 4: + mask = _mm256_set_epi32(0, 0, 0, 0, -1, -1, -1, -1); + break; + case 5: + mask = _mm256_set_epi32(0, 0, 0, -1, -1, -1, -1, -1); + break; + case 6: + mask = _mm256_set_epi32(0, 0, -1, -1, -1, -1, -1, -1); + break; + case 7: + mask = _mm256_set_epi32(0, -1, -1, -1, -1, -1, -1, -1); + break; + } + + const __m256 ax = _mm256_maskload_ps(a + idx, mask); + const __m256 bx = _mm256_maskload_ps(b + idx, mask); + const __m256 abmul = _mm256_fmadd_ps(bfmm, bx, ax); + _mm256_maskstore_ps(c + idx, mask, abmul); + } +} + +template +void fvec_L2sqr_ny_y_transposed_D( + float* distances, + const float* x, + const float* y, + const float* y_sqlen, + const size_t d_offset, + size_t ny) { + // current index being processed + size_t i = 0; + + // squared length of x + float x_sqlen = 0; + for (size_t j = 0; j < DIM; j++) { + x_sqlen += x[j] * x[j]; + } + + // process 8 vectors per loop. + const size_t ny8 = ny / 8; + + if (ny8 > 0) { + // m[i] = (2 * x[i], ... 2 * x[i]) + __m256 m[DIM]; + for (size_t j = 0; j < DIM; j++) { + m[j] = _mm256_set1_ps(x[j]); + m[j] = _mm256_add_ps(m[j], m[j]); + } + + __m256 x_sqlen_ymm = _mm256_set1_ps(x_sqlen); + + for (; i < ny8 * 8; i += 8) { + // collect dim 0 for 8 D4-vectors. + const __m256 v0 = _mm256_loadu_ps(y + 0 * d_offset); + + // compute dot products + // this is x^2 - 2x[0]*y[0] + __m256 dp = _mm256_fnmadd_ps(m[0], v0, x_sqlen_ymm); + + for (size_t j = 1; j < DIM; j++) { + // collect dim j for 8 D4-vectors. + const __m256 vj = _mm256_loadu_ps(y + j * d_offset); + dp = _mm256_fnmadd_ps(m[j], vj, dp); + } + + // we've got x^2 - (2x, y) at this point + + // y^2 - (2x, y) + x^2 + __m256 distances_v = _mm256_add_ps(_mm256_loadu_ps(y_sqlen), dp); + + _mm256_storeu_ps(distances + i, distances_v); + + // scroll y and y_sqlen forward. + y += 8; + y_sqlen += 8; + } + } + + if (i < ny) { + // process leftovers + for (; i < ny; i++) { + float dp = 0; + for (size_t j = 0; j < DIM; j++) { + dp += x[j] * y[j * d_offset]; + } + + // compute y^2 - 2 * (x, y), which is sufficient for looking for the + // lowest distance. + const float distance = y_sqlen[0] - 2 * dp + x_sqlen; + distances[i] = distance; + + y += 1; + y_sqlen += 1; + } + } +} + +template <> +void fvec_L2sqr_ny_transposed( + float* dis, + const float* x, + const float* y, + const float* y_sqlen, + size_t d, + size_t d_offset, + size_t ny) { + // optimized for a few special cases +#define DISPATCH(dval) \ + case dval: \ + return fvec_L2sqr_ny_y_transposed_D( \ + dis, x, y, y_sqlen, d_offset, ny); + + switch (d) { + DISPATCH(1) + DISPATCH(2) + DISPATCH(4) + DISPATCH(8) + default: + return fvec_L2sqr_ny_transposed( + dis, x, y, y_sqlen, d, d_offset, ny); + } +#undef DISPATCH +} + +struct AVX2ElementOpIP : public ElementOpIP { + using ElementOpIP::op; + static __m256 op(__m256 x, __m256 y) { + return _mm256_mul_ps(x, y); + } +}; + +struct AVX2ElementOpL2 : public ElementOpL2 { + using ElementOpL2::op; + + static __m256 op(__m256 x, __m256 y) { + __m256 tmp = _mm256_sub_ps(x, y); + return _mm256_mul_ps(tmp, tmp); + } +}; + +/// helper function for AVX2 +inline float horizontal_sum(const __m256 v) { + // add high and low parts + const __m128 v0 = + _mm_add_ps(_mm256_castps256_ps128(v), _mm256_extractf128_ps(v, 1)); + // perform horizontal sum on v0 + return horizontal_sum(v0); +} + +template <> +void fvec_op_ny_D2( + float* dis, + const float* x, + const float* y, + size_t ny) { + const size_t ny8 = ny / 8; + size_t i = 0; + + if (ny8 > 0) { + // process 8 D2-vectors per loop. + _mm_prefetch((const char*)y, _MM_HINT_T0); + _mm_prefetch((const char*)(y + 16), _MM_HINT_T0); + + const __m256 m0 = _mm256_set1_ps(x[0]); + const __m256 m1 = _mm256_set1_ps(x[1]); + + for (i = 0; i < ny8 * 8; i += 8) { + _mm_prefetch((const char*)(y + 32), _MM_HINT_T0); + + // load 8x2 matrix and transpose it in registers. + // the typical bottleneck is memory access, so + // let's trade instructions for the bandwidth. + + __m256 v0; + __m256 v1; + + transpose_8x2( + _mm256_loadu_ps(y + 0 * 8), + _mm256_loadu_ps(y + 1 * 8), + v0, + v1); + + // compute distances + __m256 distances = _mm256_mul_ps(m0, v0); + distances = _mm256_fmadd_ps(m1, v1, distances); + + // store + _mm256_storeu_ps(dis + i, distances); + + y += 16; + } + } + + if (i < ny) { + // process leftovers + float x0 = x[0]; + float x1 = x[1]; + + for (; i < ny; i++) { + float distance = x0 * y[0] + x1 * y[1]; + y += 2; + dis[i] = distance; + } + } +} + +template <> +void fvec_op_ny_D2( + float* dis, + const float* x, + const float* y, + size_t ny) { + const size_t ny8 = ny / 8; + size_t i = 0; + + if (ny8 > 0) { + // process 8 D2-vectors per loop. + _mm_prefetch((const char*)y, _MM_HINT_T0); + _mm_prefetch((const char*)(y + 16), _MM_HINT_T0); + + const __m256 m0 = _mm256_set1_ps(x[0]); + const __m256 m1 = _mm256_set1_ps(x[1]); + + for (i = 0; i < ny8 * 8; i += 8) { + _mm_prefetch((const char*)(y + 32), _MM_HINT_T0); + + // load 8x2 matrix and transpose it in registers. + // the typical bottleneck is memory access, so + // let's trade instructions for the bandwidth. + + __m256 v0; + __m256 v1; + + transpose_8x2( + _mm256_loadu_ps(y + 0 * 8), + _mm256_loadu_ps(y + 1 * 8), + v0, + v1); + + // compute differences + const __m256 d0 = _mm256_sub_ps(m0, v0); + const __m256 d1 = _mm256_sub_ps(m1, v1); + + // compute squares of differences + __m256 distances = _mm256_mul_ps(d0, d0); + distances = _mm256_fmadd_ps(d1, d1, distances); + + // store + _mm256_storeu_ps(dis + i, distances); + + y += 16; + } + } + + if (i < ny) { + // process leftovers + float x0 = x[0]; + float x1 = x[1]; + + for (; i < ny; i++) { + float sub0 = x0 - y[0]; + float sub1 = x1 - y[1]; + float distance = sub0 * sub0 + sub1 * sub1; + + y += 2; + dis[i] = distance; + } + } +} + +template <> +void fvec_op_ny_D4( + float* dis, + const float* x, + const float* y, + size_t ny) { + const size_t ny8 = ny / 8; + size_t i = 0; + + if (ny8 > 0) { + // process 8 D4-vectors per loop. + const __m256 m0 = _mm256_set1_ps(x[0]); + const __m256 m1 = _mm256_set1_ps(x[1]); + const __m256 m2 = _mm256_set1_ps(x[2]); + const __m256 m3 = _mm256_set1_ps(x[3]); + + for (i = 0; i < ny8 * 8; i += 8) { + // load 8x4 matrix and transpose it in registers. + // the typical bottleneck is memory access, so + // let's trade instructions for the bandwidth. + + __m256 v0; + __m256 v1; + __m256 v2; + __m256 v3; + + transpose_8x4( + _mm256_loadu_ps(y + 0 * 8), + _mm256_loadu_ps(y + 1 * 8), + _mm256_loadu_ps(y + 2 * 8), + _mm256_loadu_ps(y + 3 * 8), + v0, + v1, + v2, + v3); + + // compute distances + __m256 distances = _mm256_mul_ps(m0, v0); + distances = _mm256_fmadd_ps(m1, v1, distances); + distances = _mm256_fmadd_ps(m2, v2, distances); + distances = _mm256_fmadd_ps(m3, v3, distances); + + // store + _mm256_storeu_ps(dis + i, distances); + + y += 32; + } + } + + if (i < ny) { + // process leftovers + __m128 x0 = _mm_loadu_ps(x); + + for (; i < ny; i++) { + __m128 accu = AVX2ElementOpIP::op(x0, _mm_loadu_ps(y)); + y += 4; + dis[i] = horizontal_sum(accu); + } + } +} + +template <> +void fvec_op_ny_D4( + float* dis, + const float* x, + const float* y, + size_t ny) { + const size_t ny8 = ny / 8; + size_t i = 0; + + if (ny8 > 0) { + // process 8 D4-vectors per loop. + const __m256 m0 = _mm256_set1_ps(x[0]); + const __m256 m1 = _mm256_set1_ps(x[1]); + const __m256 m2 = _mm256_set1_ps(x[2]); + const __m256 m3 = _mm256_set1_ps(x[3]); + + for (i = 0; i < ny8 * 8; i += 8) { + // load 8x4 matrix and transpose it in registers. + // the typical bottleneck is memory access, so + // let's trade instructions for the bandwidth. + + __m256 v0; + __m256 v1; + __m256 v2; + __m256 v3; + + transpose_8x4( + _mm256_loadu_ps(y + 0 * 8), + _mm256_loadu_ps(y + 1 * 8), + _mm256_loadu_ps(y + 2 * 8), + _mm256_loadu_ps(y + 3 * 8), + v0, + v1, + v2, + v3); + + // compute differences + const __m256 d0 = _mm256_sub_ps(m0, v0); + const __m256 d1 = _mm256_sub_ps(m1, v1); + const __m256 d2 = _mm256_sub_ps(m2, v2); + const __m256 d3 = _mm256_sub_ps(m3, v3); + + // compute squares of differences + __m256 distances = _mm256_mul_ps(d0, d0); + distances = _mm256_fmadd_ps(d1, d1, distances); + distances = _mm256_fmadd_ps(d2, d2, distances); + distances = _mm256_fmadd_ps(d3, d3, distances); + + // store + _mm256_storeu_ps(dis + i, distances); + + y += 32; + } + } + + if (i < ny) { + // process leftovers + __m128 x0 = _mm_loadu_ps(x); + + for (; i < ny; i++) { + __m128 accu = AVX2ElementOpL2::op(x0, _mm_loadu_ps(y)); + y += 4; + dis[i] = horizontal_sum(accu); + } + } +} + +template <> +void fvec_op_ny_D8( + float* dis, + const float* x, + const float* y, + size_t ny) { + const size_t ny8 = ny / 8; + size_t i = 0; + + if (ny8 > 0) { + // process 8 D8-vectors per loop. + const __m256 m0 = _mm256_set1_ps(x[0]); + const __m256 m1 = _mm256_set1_ps(x[1]); + const __m256 m2 = _mm256_set1_ps(x[2]); + const __m256 m3 = _mm256_set1_ps(x[3]); + const __m256 m4 = _mm256_set1_ps(x[4]); + const __m256 m5 = _mm256_set1_ps(x[5]); + const __m256 m6 = _mm256_set1_ps(x[6]); + const __m256 m7 = _mm256_set1_ps(x[7]); + + for (i = 0; i < ny8 * 8; i += 8) { + // load 8x8 matrix and transpose it in registers. + // the typical bottleneck is memory access, so + // let's trade instructions for the bandwidth. + + __m256 v0; + __m256 v1; + __m256 v2; + __m256 v3; + __m256 v4; + __m256 v5; + __m256 v6; + __m256 v7; + + transpose_8x8( + _mm256_loadu_ps(y + 0 * 8), + _mm256_loadu_ps(y + 1 * 8), + _mm256_loadu_ps(y + 2 * 8), + _mm256_loadu_ps(y + 3 * 8), + _mm256_loadu_ps(y + 4 * 8), + _mm256_loadu_ps(y + 5 * 8), + _mm256_loadu_ps(y + 6 * 8), + _mm256_loadu_ps(y + 7 * 8), + v0, + v1, + v2, + v3, + v4, + v5, + v6, + v7); + + // compute distances + __m256 distances = _mm256_mul_ps(m0, v0); + distances = _mm256_fmadd_ps(m1, v1, distances); + distances = _mm256_fmadd_ps(m2, v2, distances); + distances = _mm256_fmadd_ps(m3, v3, distances); + distances = _mm256_fmadd_ps(m4, v4, distances); + distances = _mm256_fmadd_ps(m5, v5, distances); + distances = _mm256_fmadd_ps(m6, v6, distances); + distances = _mm256_fmadd_ps(m7, v7, distances); + + // store + _mm256_storeu_ps(dis + i, distances); + + y += 64; + } + } + + if (i < ny) { + // process leftovers + __m256 x0 = _mm256_loadu_ps(x); + + for (; i < ny; i++) { + __m256 accu = AVX2ElementOpIP::op(x0, _mm256_loadu_ps(y)); + y += 8; + dis[i] = horizontal_sum(accu); + } + } +} + +template <> +void fvec_op_ny_D8( + float* dis, + const float* x, + const float* y, + size_t ny) { + const size_t ny8 = ny / 8; + size_t i = 0; + + if (ny8 > 0) { + // process 8 D8-vectors per loop. + const __m256 m0 = _mm256_set1_ps(x[0]); + const __m256 m1 = _mm256_set1_ps(x[1]); + const __m256 m2 = _mm256_set1_ps(x[2]); + const __m256 m3 = _mm256_set1_ps(x[3]); + const __m256 m4 = _mm256_set1_ps(x[4]); + const __m256 m5 = _mm256_set1_ps(x[5]); + const __m256 m6 = _mm256_set1_ps(x[6]); + const __m256 m7 = _mm256_set1_ps(x[7]); + + for (i = 0; i < ny8 * 8; i += 8) { + // load 8x8 matrix and transpose it in registers. + // the typical bottleneck is memory access, so + // let's trade instructions for the bandwidth. + + __m256 v0; + __m256 v1; + __m256 v2; + __m256 v3; + __m256 v4; + __m256 v5; + __m256 v6; + __m256 v7; + + transpose_8x8( + _mm256_loadu_ps(y + 0 * 8), + _mm256_loadu_ps(y + 1 * 8), + _mm256_loadu_ps(y + 2 * 8), + _mm256_loadu_ps(y + 3 * 8), + _mm256_loadu_ps(y + 4 * 8), + _mm256_loadu_ps(y + 5 * 8), + _mm256_loadu_ps(y + 6 * 8), + _mm256_loadu_ps(y + 7 * 8), + v0, + v1, + v2, + v3, + v4, + v5, + v6, + v7); + + // compute differences + const __m256 d0 = _mm256_sub_ps(m0, v0); + const __m256 d1 = _mm256_sub_ps(m1, v1); + const __m256 d2 = _mm256_sub_ps(m2, v2); + const __m256 d3 = _mm256_sub_ps(m3, v3); + const __m256 d4 = _mm256_sub_ps(m4, v4); + const __m256 d5 = _mm256_sub_ps(m5, v5); + const __m256 d6 = _mm256_sub_ps(m6, v6); + const __m256 d7 = _mm256_sub_ps(m7, v7); + + // compute squares of differences + __m256 distances = _mm256_mul_ps(d0, d0); + distances = _mm256_fmadd_ps(d1, d1, distances); + distances = _mm256_fmadd_ps(d2, d2, distances); + distances = _mm256_fmadd_ps(d3, d3, distances); + distances = _mm256_fmadd_ps(d4, d4, distances); + distances = _mm256_fmadd_ps(d5, d5, distances); + distances = _mm256_fmadd_ps(d6, d6, distances); + distances = _mm256_fmadd_ps(d7, d7, distances); + + // store + _mm256_storeu_ps(dis + i, distances); + + y += 64; + } + } + + if (i < ny) { + // process leftovers + __m256 x0 = _mm256_loadu_ps(x); + + for (; i < ny; i++) { + __m256 accu = AVX2ElementOpL2::op(x0, _mm256_loadu_ps(y)); + y += 8; + dis[i] = horizontal_sum(accu); + } + } +} + +template <> +void fvec_inner_products_ny( + float* ip, /* output inner product */ + const float* x, + const float* y, + size_t d, + size_t ny) { + fvec_inner_products_ny_ref(ip, x, y, d, ny); +} + +template <> +void fvec_L2sqr_ny( + float* dis, + const float* x, + const float* y, + size_t d, + size_t ny) { + fvec_L2sqr_ny_ref(dis, x, y, d, ny); +} + +template <> +size_t fvec_L2sqr_ny_nearest_D2( + float* distances_tmp_buffer, + const float* x, + const float* y, + size_t ny) { + // this implementation does not use distances_tmp_buffer. + // current index being processed + size_t i = 0; + + // min distance and the index of the closest vector so far + float current_min_distance = HUGE_VALF; + size_t current_min_index = 0; + + // process 8 D2-vectors per loop. + const size_t ny8 = ny / 8; + if (ny8 > 0) { + _mm_prefetch((const char*)y, _MM_HINT_T0); + _mm_prefetch((const char*)(y + 16), _MM_HINT_T0); + + // track min distance and the closest vector independently + // for each of 8 AVX2 components. + __m256 min_distances = _mm256_set1_ps(HUGE_VALF); + __m256i min_indices = _mm256_set1_epi32(0); + + __m256i current_indices = _mm256_setr_epi32(0, 1, 2, 3, 4, 5, 6, 7); + const __m256i indices_increment = _mm256_set1_epi32(8); + + // 1 value per register + const __m256 m0 = _mm256_set1_ps(x[0]); + const __m256 m1 = _mm256_set1_ps(x[1]); + + for (; i < ny8 * 8; i += 8) { + _mm_prefetch((const char*)(y + 32), _MM_HINT_T0); + + __m256 v0; + __m256 v1; + + transpose_8x2( + _mm256_loadu_ps(y + 0 * 8), + _mm256_loadu_ps(y + 1 * 8), + v0, + v1); + + // compute differences + const __m256 d0 = _mm256_sub_ps(m0, v0); + const __m256 d1 = _mm256_sub_ps(m1, v1); + + // compute squares of differences + __m256 distances = _mm256_mul_ps(d0, d0); + distances = _mm256_fmadd_ps(d1, d1, distances); + + // compare the new distances to the min distances + // for each of 8 AVX2 components. + __m256 comparison = + _mm256_cmp_ps(min_distances, distances, _CMP_LT_OS); + + // update min distances and indices with closest vectors if needed. + min_distances = _mm256_min_ps(distances, min_distances); + min_indices = _mm256_castps_si256(_mm256_blendv_ps( + _mm256_castsi256_ps(current_indices), + _mm256_castsi256_ps(min_indices), + comparison)); + + // update current indices values. Basically, +8 to each of the + // 8 AVX2 components. + current_indices = + _mm256_add_epi32(current_indices, indices_increment); + + // scroll y forward (8 vectors 2 DIM each). + y += 16; + } + + // dump values and find the minimum distance / minimum index + float min_distances_scalar[8]; + uint32_t min_indices_scalar[8]; + _mm256_storeu_ps(min_distances_scalar, min_distances); + _mm256_storeu_si256((__m256i*)(min_indices_scalar), min_indices); + + for (size_t j = 0; j < 8; j++) { + if (current_min_distance > min_distances_scalar[j]) { + current_min_distance = min_distances_scalar[j]; + current_min_index = min_indices_scalar[j]; + } + } + } + + if (i < ny) { + // process leftovers. + // the following code is not optimal, but it is rarely invoked. + float x0 = x[0]; + float x1 = x[1]; + + for (; i < ny; i++) { + float sub0 = x0 - y[0]; + float sub1 = x1 - y[1]; + float distance = sub0 * sub0 + sub1 * sub1; + + y += 2; + + if (current_min_distance > distance) { + current_min_distance = distance; + current_min_index = i; + } + } + } + + return current_min_index; +} + +template <> +size_t fvec_L2sqr_ny_nearest_D4( + float* distances_tmp_buffer, + const float* x, + const float* y, + size_t ny) { + // this implementation does not use distances_tmp_buffer. + + // current index being processed + size_t i = 0; + + // min distance and the index of the closest vector so far + float current_min_distance = HUGE_VALF; + size_t current_min_index = 0; + + // process 8 D4-vectors per loop. + const size_t ny8 = ny / 8; + + if (ny8 > 0) { + // track min distance and the closest vector independently + // for each of 8 AVX2 components. + __m256 min_distances = _mm256_set1_ps(HUGE_VALF); + __m256i min_indices = _mm256_set1_epi32(0); + + __m256i current_indices = _mm256_setr_epi32(0, 1, 2, 3, 4, 5, 6, 7); + const __m256i indices_increment = _mm256_set1_epi32(8); + + // 1 value per register + const __m256 m0 = _mm256_set1_ps(x[0]); + const __m256 m1 = _mm256_set1_ps(x[1]); + const __m256 m2 = _mm256_set1_ps(x[2]); + const __m256 m3 = _mm256_set1_ps(x[3]); + + for (; i < ny8 * 8; i += 8) { + __m256 v0; + __m256 v1; + __m256 v2; + __m256 v3; + + transpose_8x4( + _mm256_loadu_ps(y + 0 * 8), + _mm256_loadu_ps(y + 1 * 8), + _mm256_loadu_ps(y + 2 * 8), + _mm256_loadu_ps(y + 3 * 8), + v0, + v1, + v2, + v3); + + // compute differences + const __m256 d0 = _mm256_sub_ps(m0, v0); + const __m256 d1 = _mm256_sub_ps(m1, v1); + const __m256 d2 = _mm256_sub_ps(m2, v2); + const __m256 d3 = _mm256_sub_ps(m3, v3); + + // compute squares of differences + __m256 distances = _mm256_mul_ps(d0, d0); + distances = _mm256_fmadd_ps(d1, d1, distances); + distances = _mm256_fmadd_ps(d2, d2, distances); + distances = _mm256_fmadd_ps(d3, d3, distances); + + // compare the new distances to the min distances + // for each of 8 AVX2 components. + __m256 comparison = + _mm256_cmp_ps(min_distances, distances, _CMP_LT_OS); + + // update min distances and indices with closest vectors if needed. + min_distances = _mm256_min_ps(distances, min_distances); + min_indices = _mm256_castps_si256(_mm256_blendv_ps( + _mm256_castsi256_ps(current_indices), + _mm256_castsi256_ps(min_indices), + comparison)); + + // update current indices values. Basically, +8 to each of the + // 8 AVX2 components. + current_indices = + _mm256_add_epi32(current_indices, indices_increment); + + // scroll y forward (8 vectors 4 DIM each). + y += 32; + } + + // dump values and find the minimum distance / minimum index + float min_distances_scalar[8]; + uint32_t min_indices_scalar[8]; + _mm256_storeu_ps(min_distances_scalar, min_distances); + _mm256_storeu_si256((__m256i*)(min_indices_scalar), min_indices); + + for (size_t j = 0; j < 8; j++) { + if (current_min_distance > min_distances_scalar[j]) { + current_min_distance = min_distances_scalar[j]; + current_min_index = min_indices_scalar[j]; + } + } + } + + if (i < ny) { + // process leftovers + __m128 x0 = _mm_loadu_ps(x); + + for (; i < ny; i++) { + __m128 accu = ElementOpL2::op(x0, _mm_loadu_ps(y)); + y += 4; + const float distance = horizontal_sum(accu); + + if (current_min_distance > distance) { + current_min_distance = distance; + current_min_index = i; + } + } + } + + return current_min_index; +} + +template <> +size_t fvec_L2sqr_ny_nearest_D8( + float* distances_tmp_buffer, + const float* x, + const float* y, + size_t ny) { + // this implementation does not use distances_tmp_buffer. + + // current index being processed + size_t i = 0; + + // min distance and the index of the closest vector so far + float current_min_distance = HUGE_VALF; + size_t current_min_index = 0; + + // process 8 D8-vectors per loop. + const size_t ny8 = ny / 8; + if (ny8 > 0) { + // track min distance and the closest vector independently + // for each of 8 AVX2 components. + __m256 min_distances = _mm256_set1_ps(HUGE_VALF); + __m256i min_indices = _mm256_set1_epi32(0); + + __m256i current_indices = _mm256_setr_epi32(0, 1, 2, 3, 4, 5, 6, 7); + const __m256i indices_increment = _mm256_set1_epi32(8); + + // 1 value per register + const __m256 m0 = _mm256_set1_ps(x[0]); + const __m256 m1 = _mm256_set1_ps(x[1]); + const __m256 m2 = _mm256_set1_ps(x[2]); + const __m256 m3 = _mm256_set1_ps(x[3]); + + const __m256 m4 = _mm256_set1_ps(x[4]); + const __m256 m5 = _mm256_set1_ps(x[5]); + const __m256 m6 = _mm256_set1_ps(x[6]); + const __m256 m7 = _mm256_set1_ps(x[7]); + + for (; i < ny8 * 8; i += 8) { + __m256 v0; + __m256 v1; + __m256 v2; + __m256 v3; + __m256 v4; + __m256 v5; + __m256 v6; + __m256 v7; + + transpose_8x8( + _mm256_loadu_ps(y + 0 * 8), + _mm256_loadu_ps(y + 1 * 8), + _mm256_loadu_ps(y + 2 * 8), + _mm256_loadu_ps(y + 3 * 8), + _mm256_loadu_ps(y + 4 * 8), + _mm256_loadu_ps(y + 5 * 8), + _mm256_loadu_ps(y + 6 * 8), + _mm256_loadu_ps(y + 7 * 8), + v0, + v1, + v2, + v3, + v4, + v5, + v6, + v7); + + // compute differences + const __m256 d0 = _mm256_sub_ps(m0, v0); + const __m256 d1 = _mm256_sub_ps(m1, v1); + const __m256 d2 = _mm256_sub_ps(m2, v2); + const __m256 d3 = _mm256_sub_ps(m3, v3); + const __m256 d4 = _mm256_sub_ps(m4, v4); + const __m256 d5 = _mm256_sub_ps(m5, v5); + const __m256 d6 = _mm256_sub_ps(m6, v6); + const __m256 d7 = _mm256_sub_ps(m7, v7); + + // compute squares of differences + __m256 distances = _mm256_mul_ps(d0, d0); + distances = _mm256_fmadd_ps(d1, d1, distances); + distances = _mm256_fmadd_ps(d2, d2, distances); + distances = _mm256_fmadd_ps(d3, d3, distances); + distances = _mm256_fmadd_ps(d4, d4, distances); + distances = _mm256_fmadd_ps(d5, d5, distances); + distances = _mm256_fmadd_ps(d6, d6, distances); + distances = _mm256_fmadd_ps(d7, d7, distances); + + // compare the new distances to the min distances + // for each of 8 AVX2 components. + __m256 comparison = + _mm256_cmp_ps(min_distances, distances, _CMP_LT_OS); + + // update min distances and indices with closest vectors if needed. + min_distances = _mm256_min_ps(distances, min_distances); + min_indices = _mm256_castps_si256(_mm256_blendv_ps( + _mm256_castsi256_ps(current_indices), + _mm256_castsi256_ps(min_indices), + comparison)); + + // update current indices values. Basically, +8 to each of the + // 8 AVX2 components. + current_indices = + _mm256_add_epi32(current_indices, indices_increment); + + // scroll y forward (8 vectors 8 DIM each). + y += 64; + } + + // dump values and find the minimum distance / minimum index + float min_distances_scalar[8]; + uint32_t min_indices_scalar[8]; + _mm256_storeu_ps(min_distances_scalar, min_distances); + _mm256_storeu_si256((__m256i*)(min_indices_scalar), min_indices); + + for (size_t j = 0; j < 8; j++) { + if (current_min_distance > min_distances_scalar[j]) { + current_min_distance = min_distances_scalar[j]; + current_min_index = min_indices_scalar[j]; + } + } + } + + if (i < ny) { + // process leftovers + __m256 x0 = _mm256_loadu_ps(x); + + for (; i < ny; i++) { + __m256 accu = AVX2ElementOpL2::op(x0, _mm256_loadu_ps(y)); + y += 8; + const float distance = horizontal_sum(accu); + + if (current_min_distance > distance) { + current_min_distance = distance; + current_min_index = i; + } + } + } + + return current_min_index; +} + +template <> +size_t fvec_L2sqr_ny_nearest( + float* distances_tmp_buffer, + const float* x, + const float* y, + size_t d, + size_t ny) { + return fvec_L2sqr_ny_nearest_x86( + distances_tmp_buffer, + x, + y, + d, + ny, + &fvec_L2sqr_ny_nearest_D2, + &fvec_L2sqr_ny_nearest_D4, + &fvec_L2sqr_ny_nearest_D8); +} + +template +size_t fvec_L2sqr_ny_nearest_y_transposed_D( + float* distances_tmp_buffer, + const float* x, + const float* y, + const float* y_sqlen, + const size_t d_offset, + size_t ny) { + // this implementation does not use distances_tmp_buffer. + + // current index being processed + size_t i = 0; + + // min distance and the index of the closest vector so far + float current_min_distance = HUGE_VALF; + size_t current_min_index = 0; + + // process 8 vectors per loop. + const size_t ny8 = ny / 8; + + if (ny8 > 0) { + // track min distance and the closest vector independently + // for each of 8 AVX2 components. + __m256 min_distances = _mm256_set1_ps(HUGE_VALF); + __m256i min_indices = _mm256_set1_epi32(0); + + __m256i current_indices = _mm256_setr_epi32(0, 1, 2, 3, 4, 5, 6, 7); + const __m256i indices_increment = _mm256_set1_epi32(8); + + // m[i] = (2 * x[i], ... 2 * x[i]) + __m256 m[DIM]; + for (size_t j = 0; j < DIM; j++) { + m[j] = _mm256_set1_ps(x[j]); + m[j] = _mm256_add_ps(m[j], m[j]); + } + + for (; i < ny8 * 8; i += 8) { + // collect dim 0 for 8 D4-vectors. + const __m256 v0 = _mm256_loadu_ps(y + 0 * d_offset); + // compute dot products + __m256 dp = _mm256_mul_ps(m[0], v0); + + for (size_t j = 1; j < DIM; j++) { + // collect dim j for 8 D4-vectors. + const __m256 vj = _mm256_loadu_ps(y + j * d_offset); + dp = _mm256_fmadd_ps(m[j], vj, dp); + } + + // compute y^2 - (2 * x, y), which is sufficient for looking for the + // lowest distance. + // x^2 is the constant that can be avoided. + const __m256 distances = + _mm256_sub_ps(_mm256_loadu_ps(y_sqlen), dp); + + // compare the new distances to the min distances + // for each of 8 AVX2 components. + const __m256 comparison = + _mm256_cmp_ps(min_distances, distances, _CMP_LT_OS); + + // update min distances and indices with closest vectors if needed. + min_distances = + _mm256_blendv_ps(distances, min_distances, comparison); + min_indices = _mm256_castps_si256(_mm256_blendv_ps( + _mm256_castsi256_ps(current_indices), + _mm256_castsi256_ps(min_indices), + comparison)); + + // update current indices values. Basically, +8 to each of the + // 8 AVX2 components. + current_indices = + _mm256_add_epi32(current_indices, indices_increment); + + // scroll y and y_sqlen forward. + y += 8; + y_sqlen += 8; + } + + // dump values and find the minimum distance / minimum index + float min_distances_scalar[8]; + uint32_t min_indices_scalar[8]; + _mm256_storeu_ps(min_distances_scalar, min_distances); + _mm256_storeu_si256((__m256i*)(min_indices_scalar), min_indices); + + for (size_t j = 0; j < 8; j++) { + if (current_min_distance > min_distances_scalar[j]) { + current_min_distance = min_distances_scalar[j]; + current_min_index = min_indices_scalar[j]; + } + } + } + + if (i < ny) { + // process leftovers + for (; i < ny; i++) { + float dp = 0; + for (size_t j = 0; j < DIM; j++) { + dp += x[j] * y[j * d_offset]; + } + + // compute y^2 - 2 * (x, y), which is sufficient for looking for the + // lowest distance. + const float distance = y_sqlen[0] - 2 * dp; + + if (current_min_distance > distance) { + current_min_distance = distance; + current_min_index = i; + } + + y += 1; + y_sqlen += 1; + } + } + + return current_min_index; +} + +template <> +size_t fvec_L2sqr_ny_nearest_y_transposed( + float* distances_tmp_buffer, + const float* x, + const float* y, + const float* y_sqlen, + size_t d, + size_t d_offset, + size_t ny) { +// optimized for a few special cases +#define DISPATCH(dval) \ + case dval: \ + return fvec_L2sqr_ny_nearest_y_transposed_D( \ + distances_tmp_buffer, x, y, y_sqlen, d_offset, ny); + + switch (d) { + DISPATCH(1) + DISPATCH(2) + DISPATCH(4) + DISPATCH(8) + default: + return fvec_L2sqr_ny_nearest_y_transposed( + distances_tmp_buffer, x, y, y_sqlen, d, d_offset, ny); + } +#undef DISPATCH +} + +template <> +int fvec_madd_and_argmin( + size_t n, + const float* a, + float bf, + const float* b, + float* c) { + return fvec_madd_and_argmin_sse(n, a, bf, b, c); +} + +} // namespace faiss diff --git a/faiss/utils/simd_impl/distances_avx512.cpp b/faiss/utils/simd_impl/distances_avx512.cpp new file mode 100644 index 0000000000..06d5b399f4 --- /dev/null +++ b/faiss/utils/simd_impl/distances_avx512.cpp @@ -0,0 +1,1092 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * + * This source code is licensed under the MIT license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include + +#include + +#define AUTOVEC_LEVEL SIMDLevel::AVX512 +#include +#include +#include + +namespace faiss { + +template <> +void fvec_madd( + const size_t n, + const float* __restrict a, + const float bf, + const float* __restrict b, + float* __restrict c) { + const size_t n16 = n / 16; + const size_t n_for_masking = n % 16; + + const __m512 bfmm = _mm512_set1_ps(bf); + + size_t idx = 0; + for (idx = 0; idx < n16 * 16; idx += 16) { + const __m512 ax = _mm512_loadu_ps(a + idx); + const __m512 bx = _mm512_loadu_ps(b + idx); + const __m512 abmul = _mm512_fmadd_ps(bfmm, bx, ax); + _mm512_storeu_ps(c + idx, abmul); + } + + if (n_for_masking > 0) { + const __mmask16 mask = (1 << n_for_masking) - 1; + + const __m512 ax = _mm512_maskz_loadu_ps(mask, a + idx); + const __m512 bx = _mm512_maskz_loadu_ps(mask, b + idx); + const __m512 abmul = _mm512_fmadd_ps(bfmm, bx, ax); + _mm512_mask_storeu_ps(c + idx, mask, abmul); + } +} + +template +void fvec_L2sqr_ny_y_transposed_D( + float* distances, + const float* x, + const float* y, + const float* y_sqlen, + const size_t d_offset, + size_t ny) { + // current index being processed + size_t i = 0; + + // squared length of x + float x_sqlen = 0; + for (size_t j = 0; j < DIM; j++) { + x_sqlen += x[j] * x[j]; + } + + // process 16 vectors per loop + const size_t ny16 = ny / 16; + + if (ny16 > 0) { + // m[i] = (2 * x[i], ... 2 * x[i]) + __m512 m[DIM]; + for (size_t j = 0; j < DIM; j++) { + m[j] = _mm512_set1_ps(x[j]); + m[j] = _mm512_add_ps(m[j], m[j]); // m[j] = 2 * x[j] + } + + __m512 x_sqlen_ymm = _mm512_set1_ps(x_sqlen); + + for (; i < ny16 * 16; i += 16) { + // Load vectors for 16 dimensions + __m512 v[DIM]; + for (size_t j = 0; j < DIM; j++) { + v[j] = _mm512_loadu_ps(y + j * d_offset); + } + + // Compute dot products + __m512 dp = _mm512_fnmadd_ps(m[0], v[0], x_sqlen_ymm); + for (size_t j = 1; j < DIM; j++) { + dp = _mm512_fnmadd_ps(m[j], v[j], dp); + } + + // Compute y^2 - (2 * x, y) + x^2 + __m512 distances_v = _mm512_add_ps(_mm512_loadu_ps(y_sqlen), dp); + + _mm512_storeu_ps(distances + i, distances_v); + + // Scroll y and y_sqlen forward + y += 16; + y_sqlen += 16; + } + } + + if (i < ny) { + // Process leftovers + for (; i < ny; i++) { + float dp = 0; + for (size_t j = 0; j < DIM; j++) { + dp += x[j] * y[j * d_offset]; + } + + // Compute y^2 - 2 * (x, y), which is sufficient for looking for the + // lowest distance. + const float distance = y_sqlen[0] - 2 * dp + x_sqlen; + distances[i] = distance; + + y += 1; + y_sqlen += 1; + } + } +} + +template <> +void fvec_L2sqr_ny_transposed( + float* dis, + const float* x, + const float* y, + const float* y_sqlen, + size_t d, + size_t d_offset, + size_t ny) { + // optimized for a few special cases +#define DISPATCH(dval) \ + case dval: \ + return fvec_L2sqr_ny_y_transposed_D( \ + dis, x, y, y_sqlen, d_offset, ny); + + switch (d) { + DISPATCH(1) + DISPATCH(2) + DISPATCH(4) + DISPATCH(8) + default: + return fvec_L2sqr_ny_transposed( + dis, x, y, y_sqlen, d, d_offset, ny); + } +#undef DISPATCH +} + +struct AVX512ElementOpIP : public ElementOpIP { + using ElementOpIP::op; + static __m512 op(__m512 x, __m512 y) { + return _mm512_mul_ps(x, y); + } + static __m256 op(__m256 x, __m256 y) { + return _mm256_mul_ps(x, y); + } +}; + +struct AVX512ElementOpL2 : public ElementOpL2 { + using ElementOpL2::op; + static __m512 op(__m512 x, __m512 y) { + __m512 tmp = _mm512_sub_ps(x, y); + return _mm512_mul_ps(tmp, tmp); + } + static __m256 op(__m256 x, __m256 y) { + __m256 tmp = _mm256_sub_ps(x, y); + return _mm256_mul_ps(tmp, tmp); + } +}; + +/// helper function for AVX512 +inline float horizontal_sum(const __m512 v) { + // performs better than adding the high and low parts + return _mm512_reduce_add_ps(v); +} + +inline float horizontal_sum(const __m256 v) { + // add high and low parts + const __m128 v0 = + _mm_add_ps(_mm256_castps256_ps128(v), _mm256_extractf128_ps(v, 1)); + // perform horizontal sum on v0 + return horizontal_sum(v0); +} + +template <> +void fvec_op_ny_D2( + float* dis, + const float* x, + const float* y, + size_t ny) { + const size_t ny16 = ny / 16; + size_t i = 0; + + if (ny16 > 0) { + // process 16 D2-vectors per loop. + _mm_prefetch((const char*)y, _MM_HINT_T0); + _mm_prefetch((const char*)(y + 32), _MM_HINT_T0); + + const __m512 m0 = _mm512_set1_ps(x[0]); + const __m512 m1 = _mm512_set1_ps(x[1]); + + for (i = 0; i < ny16 * 16; i += 16) { + _mm_prefetch((const char*)(y + 64), _MM_HINT_T0); + + // load 16x2 matrix and transpose it in registers. + // the typical bottleneck is memory access, so + // let's trade instructions for the bandwidth. + + __m512 v0; + __m512 v1; + + transpose_16x2( + _mm512_loadu_ps(y + 0 * 16), + _mm512_loadu_ps(y + 1 * 16), + v0, + v1); + + // compute distances (dot product) + __m512 distances = _mm512_mul_ps(m0, v0); + distances = _mm512_fmadd_ps(m1, v1, distances); + + // store + _mm512_storeu_ps(dis + i, distances); + + y += 32; // move to the next set of 16x2 elements + } + } + + if (i < ny) { + // process leftovers + float x0 = x[0]; + float x1 = x[1]; + + for (; i < ny; i++) { + float distance = x0 * y[0] + x1 * y[1]; + y += 2; + dis[i] = distance; + } + } +} + +template <> +void fvec_op_ny_D2( + float* dis, + const float* x, + const float* y, + size_t ny) { + const size_t ny16 = ny / 16; + size_t i = 0; + + if (ny16 > 0) { + // process 16 D2-vectors per loop. + _mm_prefetch((const char*)y, _MM_HINT_T0); + _mm_prefetch((const char*)(y + 32), _MM_HINT_T0); + + const __m512 m0 = _mm512_set1_ps(x[0]); + const __m512 m1 = _mm512_set1_ps(x[1]); + + for (i = 0; i < ny16 * 16; i += 16) { + _mm_prefetch((const char*)(y + 64), _MM_HINT_T0); + + // load 16x2 matrix and transpose it in registers. + // the typical bottleneck is memory access, so + // let's trade instructions for the bandwidth. + + __m512 v0; + __m512 v1; + + transpose_16x2( + _mm512_loadu_ps(y + 0 * 16), + _mm512_loadu_ps(y + 1 * 16), + v0, + v1); + + // compute differences + const __m512 d0 = _mm512_sub_ps(m0, v0); + const __m512 d1 = _mm512_sub_ps(m1, v1); + + // compute squares of differences + __m512 distances = _mm512_mul_ps(d0, d0); + distances = _mm512_fmadd_ps(d1, d1, distances); + + // store + _mm512_storeu_ps(dis + i, distances); + + y += 32; // move to the next set of 16x2 elements + } + } + + if (i < ny) { + // process leftovers + float x0 = x[0]; + float x1 = x[1]; + + for (; i < ny; i++) { + float sub0 = x0 - y[0]; + float sub1 = x1 - y[1]; + float distance = sub0 * sub0 + sub1 * sub1; + + y += 2; + dis[i] = distance; + } + } +} + +template <> +void fvec_op_ny_D4( + float* dis, + const float* x, + const float* y, + size_t ny) { + const size_t ny16 = ny / 16; + size_t i = 0; + + if (ny16 > 0) { + // process 16 D4-vectors per loop. + const __m512 m0 = _mm512_set1_ps(x[0]); + const __m512 m1 = _mm512_set1_ps(x[1]); + const __m512 m2 = _mm512_set1_ps(x[2]); + const __m512 m3 = _mm512_set1_ps(x[3]); + + for (i = 0; i < ny16 * 16; i += 16) { + // load 16x4 matrix and transpose it in registers. + // the typical bottleneck is memory access, so + // let's trade instructions for the bandwidth. + + __m512 v0; + __m512 v1; + __m512 v2; + __m512 v3; + + transpose_16x4( + _mm512_loadu_ps(y + 0 * 16), + _mm512_loadu_ps(y + 1 * 16), + _mm512_loadu_ps(y + 2 * 16), + _mm512_loadu_ps(y + 3 * 16), + v0, + v1, + v2, + v3); + + // compute distances + __m512 distances = _mm512_mul_ps(m0, v0); + distances = _mm512_fmadd_ps(m1, v1, distances); + distances = _mm512_fmadd_ps(m2, v2, distances); + distances = _mm512_fmadd_ps(m3, v3, distances); + + // store + _mm512_storeu_ps(dis + i, distances); + + y += 64; // move to the next set of 16x4 elements + } + } + + if (i < ny) { + // process leftovers + __m128 x0 = _mm_loadu_ps(x); + + for (; i < ny; i++) { + __m128 accu = AVX512ElementOpIP::op(x0, _mm_loadu_ps(y)); + y += 4; + dis[i] = horizontal_sum(accu); + } + } +} + +template <> +void fvec_op_ny_D4( + float* dis, + const float* x, + const float* y, + size_t ny) { + const size_t ny16 = ny / 16; + size_t i = 0; + + if (ny16 > 0) { + // process 16 D4-vectors per loop. + const __m512 m0 = _mm512_set1_ps(x[0]); + const __m512 m1 = _mm512_set1_ps(x[1]); + const __m512 m2 = _mm512_set1_ps(x[2]); + const __m512 m3 = _mm512_set1_ps(x[3]); + + for (i = 0; i < ny16 * 16; i += 16) { + // load 16x4 matrix and transpose it in registers. + // the typical bottleneck is memory access, so + // let's trade instructions for the bandwidth. + + __m512 v0; + __m512 v1; + __m512 v2; + __m512 v3; + + transpose_16x4( + _mm512_loadu_ps(y + 0 * 16), + _mm512_loadu_ps(y + 1 * 16), + _mm512_loadu_ps(y + 2 * 16), + _mm512_loadu_ps(y + 3 * 16), + v0, + v1, + v2, + v3); + + // compute differences + const __m512 d0 = _mm512_sub_ps(m0, v0); + const __m512 d1 = _mm512_sub_ps(m1, v1); + const __m512 d2 = _mm512_sub_ps(m2, v2); + const __m512 d3 = _mm512_sub_ps(m3, v3); + + // compute squares of differences + __m512 distances = _mm512_mul_ps(d0, d0); + distances = _mm512_fmadd_ps(d1, d1, distances); + distances = _mm512_fmadd_ps(d2, d2, distances); + distances = _mm512_fmadd_ps(d3, d3, distances); + + // store + _mm512_storeu_ps(dis + i, distances); + + y += 64; // move to the next set of 16x4 elements + } + } + + if (i < ny) { + // process leftovers + __m128 x0 = _mm_loadu_ps(x); + + for (; i < ny; i++) { + __m128 accu = AVX512ElementOpL2::op(x0, _mm_loadu_ps(y)); + y += 4; + dis[i] = horizontal_sum(accu); + } + } +} + +template <> +void fvec_op_ny_D8( + float* dis, + const float* x, + const float* y, + size_t ny) { + const size_t ny16 = ny / 16; + size_t i = 0; + + if (ny16 > 0) { + // process 16 D16-vectors per loop. + const __m512 m0 = _mm512_set1_ps(x[0]); + const __m512 m1 = _mm512_set1_ps(x[1]); + const __m512 m2 = _mm512_set1_ps(x[2]); + const __m512 m3 = _mm512_set1_ps(x[3]); + const __m512 m4 = _mm512_set1_ps(x[4]); + const __m512 m5 = _mm512_set1_ps(x[5]); + const __m512 m6 = _mm512_set1_ps(x[6]); + const __m512 m7 = _mm512_set1_ps(x[7]); + + for (i = 0; i < ny16 * 16; i += 16) { + // load 16x8 matrix and transpose it in registers. + // the typical bottleneck is memory access, so + // let's trade instructions for the bandwidth. + + __m512 v0; + __m512 v1; + __m512 v2; + __m512 v3; + __m512 v4; + __m512 v5; + __m512 v6; + __m512 v7; + + transpose_16x8( + _mm512_loadu_ps(y + 0 * 16), + _mm512_loadu_ps(y + 1 * 16), + _mm512_loadu_ps(y + 2 * 16), + _mm512_loadu_ps(y + 3 * 16), + _mm512_loadu_ps(y + 4 * 16), + _mm512_loadu_ps(y + 5 * 16), + _mm512_loadu_ps(y + 6 * 16), + _mm512_loadu_ps(y + 7 * 16), + v0, + v1, + v2, + v3, + v4, + v5, + v6, + v7); + + // compute distances + __m512 distances = _mm512_mul_ps(m0, v0); + distances = _mm512_fmadd_ps(m1, v1, distances); + distances = _mm512_fmadd_ps(m2, v2, distances); + distances = _mm512_fmadd_ps(m3, v3, distances); + distances = _mm512_fmadd_ps(m4, v4, distances); + distances = _mm512_fmadd_ps(m5, v5, distances); + distances = _mm512_fmadd_ps(m6, v6, distances); + distances = _mm512_fmadd_ps(m7, v7, distances); + + // store + _mm512_storeu_ps(dis + i, distances); + + y += 128; // 16 floats * 8 rows + } + } + + if (i < ny) { + // process leftovers + __m256 x0 = _mm256_loadu_ps(x); + + for (; i < ny; i++) { + __m256 accu = AVX512ElementOpIP::op(x0, _mm256_loadu_ps(y)); + y += 8; + dis[i] = horizontal_sum(accu); + } + } +} + +template <> +void fvec_op_ny_D8( + float* dis, + const float* x, + const float* y, + size_t ny) { + const size_t ny16 = ny / 16; + size_t i = 0; + + if (ny16 > 0) { + // process 16 D16-vectors per loop. + const __m512 m0 = _mm512_set1_ps(x[0]); + const __m512 m1 = _mm512_set1_ps(x[1]); + const __m512 m2 = _mm512_set1_ps(x[2]); + const __m512 m3 = _mm512_set1_ps(x[3]); + const __m512 m4 = _mm512_set1_ps(x[4]); + const __m512 m5 = _mm512_set1_ps(x[5]); + const __m512 m6 = _mm512_set1_ps(x[6]); + const __m512 m7 = _mm512_set1_ps(x[7]); + + for (i = 0; i < ny16 * 16; i += 16) { + // load 16x8 matrix and transpose it in registers. + // the typical bottleneck is memory access, so + // let's trade instructions for the bandwidth. + + __m512 v0; + __m512 v1; + __m512 v2; + __m512 v3; + __m512 v4; + __m512 v5; + __m512 v6; + __m512 v7; + + transpose_16x8( + _mm512_loadu_ps(y + 0 * 16), + _mm512_loadu_ps(y + 1 * 16), + _mm512_loadu_ps(y + 2 * 16), + _mm512_loadu_ps(y + 3 * 16), + _mm512_loadu_ps(y + 4 * 16), + _mm512_loadu_ps(y + 5 * 16), + _mm512_loadu_ps(y + 6 * 16), + _mm512_loadu_ps(y + 7 * 16), + v0, + v1, + v2, + v3, + v4, + v5, + v6, + v7); + + // compute differences + const __m512 d0 = _mm512_sub_ps(m0, v0); + const __m512 d1 = _mm512_sub_ps(m1, v1); + const __m512 d2 = _mm512_sub_ps(m2, v2); + const __m512 d3 = _mm512_sub_ps(m3, v3); + const __m512 d4 = _mm512_sub_ps(m4, v4); + const __m512 d5 = _mm512_sub_ps(m5, v5); + const __m512 d6 = _mm512_sub_ps(m6, v6); + const __m512 d7 = _mm512_sub_ps(m7, v7); + + // compute squares of differences + __m512 distances = _mm512_mul_ps(d0, d0); + distances = _mm512_fmadd_ps(d1, d1, distances); + distances = _mm512_fmadd_ps(d2, d2, distances); + distances = _mm512_fmadd_ps(d3, d3, distances); + distances = _mm512_fmadd_ps(d4, d4, distances); + distances = _mm512_fmadd_ps(d5, d5, distances); + distances = _mm512_fmadd_ps(d6, d6, distances); + distances = _mm512_fmadd_ps(d7, d7, distances); + + // store + _mm512_storeu_ps(dis + i, distances); + + y += 128; // 16 floats * 8 rows + } + } + + if (i < ny) { + // process leftovers + __m256 x0 = _mm256_loadu_ps(x); + + for (; i < ny; i++) { + __m256 accu = AVX512ElementOpL2::op(x0, _mm256_loadu_ps(y)); + y += 8; + dis[i] = horizontal_sum(accu); + } + } +} + +template <> +void fvec_inner_products_ny( + float* ip, /* output inner product */ + const float* x, + const float* y, + size_t d, + size_t ny) { + fvec_inner_products_ny_ref(ip, x, y, d, ny); +} + +template <> +void fvec_L2sqr_ny( + float* dis, + const float* x, + const float* y, + size_t d, + size_t ny) { + fvec_L2sqr_ny_ref(dis, x, y, d, ny); +} + +template <> +size_t fvec_L2sqr_ny_nearest_D2( + float* distances_tmp_buffer, + const float* x, + const float* y, + size_t ny) { + // this implementation does not use distances_tmp_buffer. + + size_t i = 0; + float current_min_distance = HUGE_VALF; + size_t current_min_index = 0; + + const size_t ny16 = ny / 16; + if (ny16 > 0) { + _mm_prefetch((const char*)y, _MM_HINT_T0); + _mm_prefetch((const char*)(y + 32), _MM_HINT_T0); + + __m512 min_distances = _mm512_set1_ps(HUGE_VALF); + __m512i min_indices = _mm512_set1_epi32(0); + + __m512i current_indices = _mm512_setr_epi32( + 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15); + const __m512i indices_increment = _mm512_set1_epi32(16); + + const __m512 m0 = _mm512_set1_ps(x[0]); + const __m512 m1 = _mm512_set1_ps(x[1]); + + for (; i < ny16 * 16; i += 16) { + _mm_prefetch((const char*)(y + 64), _MM_HINT_T0); + + __m512 v0; + __m512 v1; + + transpose_16x2( + _mm512_loadu_ps(y + 0 * 16), + _mm512_loadu_ps(y + 1 * 16), + v0, + v1); + + const __m512 d0 = _mm512_sub_ps(m0, v0); + const __m512 d1 = _mm512_sub_ps(m1, v1); + + __m512 distances = _mm512_mul_ps(d0, d0); + distances = _mm512_fmadd_ps(d1, d1, distances); + + __mmask16 comparison = + _mm512_cmp_ps_mask(distances, min_distances, _CMP_LT_OS); + + min_distances = _mm512_min_ps(distances, min_distances); + min_indices = _mm512_mask_blend_epi32( + comparison, min_indices, current_indices); + + current_indices = + _mm512_add_epi32(current_indices, indices_increment); + + y += 32; + } + + alignas(64) float min_distances_scalar[16]; + alignas(64) uint32_t min_indices_scalar[16]; + _mm512_store_ps(min_distances_scalar, min_distances); + _mm512_store_epi32(min_indices_scalar, min_indices); + + for (size_t j = 0; j < 16; j++) { + if (current_min_distance > min_distances_scalar[j]) { + current_min_distance = min_distances_scalar[j]; + current_min_index = min_indices_scalar[j]; + } + } + } + + if (i < ny) { + float x0 = x[0]; + float x1 = x[1]; + + for (; i < ny; i++) { + float sub0 = x0 - y[0]; + float sub1 = x1 - y[1]; + float distance = sub0 * sub0 + sub1 * sub1; + + y += 2; + + if (current_min_distance > distance) { + current_min_distance = distance; + current_min_index = i; + } + } + } + + return current_min_index; +} + +template <> +size_t fvec_L2sqr_ny_nearest_D4( + float* distances_tmp_buffer, + const float* x, + const float* y, + size_t ny) { + // this implementation does not use distances_tmp_buffer. + + size_t i = 0; + float current_min_distance = HUGE_VALF; + size_t current_min_index = 0; + + const size_t ny16 = ny / 16; + + if (ny16 > 0) { + __m512 min_distances = _mm512_set1_ps(HUGE_VALF); + __m512i min_indices = _mm512_set1_epi32(0); + + __m512i current_indices = _mm512_setr_epi32( + 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15); + const __m512i indices_increment = _mm512_set1_epi32(16); + + const __m512 m0 = _mm512_set1_ps(x[0]); + const __m512 m1 = _mm512_set1_ps(x[1]); + const __m512 m2 = _mm512_set1_ps(x[2]); + const __m512 m3 = _mm512_set1_ps(x[3]); + + for (; i < ny16 * 16; i += 16) { + __m512 v0; + __m512 v1; + __m512 v2; + __m512 v3; + + transpose_16x4( + _mm512_loadu_ps(y + 0 * 16), + _mm512_loadu_ps(y + 1 * 16), + _mm512_loadu_ps(y + 2 * 16), + _mm512_loadu_ps(y + 3 * 16), + v0, + v1, + v2, + v3); + + const __m512 d0 = _mm512_sub_ps(m0, v0); + const __m512 d1 = _mm512_sub_ps(m1, v1); + const __m512 d2 = _mm512_sub_ps(m2, v2); + const __m512 d3 = _mm512_sub_ps(m3, v3); + + __m512 distances = _mm512_mul_ps(d0, d0); + distances = _mm512_fmadd_ps(d1, d1, distances); + distances = _mm512_fmadd_ps(d2, d2, distances); + distances = _mm512_fmadd_ps(d3, d3, distances); + + __mmask16 comparison = + _mm512_cmp_ps_mask(distances, min_distances, _CMP_LT_OS); + + min_distances = _mm512_min_ps(distances, min_distances); + min_indices = _mm512_mask_blend_epi32( + comparison, min_indices, current_indices); + + current_indices = + _mm512_add_epi32(current_indices, indices_increment); + + y += 64; + } + + alignas(64) float min_distances_scalar[16]; + alignas(64) uint32_t min_indices_scalar[16]; + _mm512_store_ps(min_distances_scalar, min_distances); + _mm512_store_epi32(min_indices_scalar, min_indices); + + for (size_t j = 0; j < 16; j++) { + if (current_min_distance > min_distances_scalar[j]) { + current_min_distance = min_distances_scalar[j]; + current_min_index = min_indices_scalar[j]; + } + } + } + + if (i < ny) { + __m128 x0 = _mm_loadu_ps(x); + + for (; i < ny; i++) { + __m128 accu = ElementOpL2::op(x0, _mm_loadu_ps(y)); + y += 4; + const float distance = horizontal_sum(accu); + + if (current_min_distance > distance) { + current_min_distance = distance; + current_min_index = i; + } + } + } + + return current_min_index; +} + +template <> +size_t fvec_L2sqr_ny_nearest_D8( + float* distances_tmp_buffer, + const float* x, + const float* y, + size_t ny) { + // this implementation does not use distances_tmp_buffer. + + size_t i = 0; + float current_min_distance = HUGE_VALF; + size_t current_min_index = 0; + + const size_t ny16 = ny / 16; + if (ny16 > 0) { + __m512 min_distances = _mm512_set1_ps(HUGE_VALF); + __m512i min_indices = _mm512_set1_epi32(0); + + __m512i current_indices = _mm512_setr_epi32( + 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15); + const __m512i indices_increment = _mm512_set1_epi32(16); + + const __m512 m0 = _mm512_set1_ps(x[0]); + const __m512 m1 = _mm512_set1_ps(x[1]); + const __m512 m2 = _mm512_set1_ps(x[2]); + const __m512 m3 = _mm512_set1_ps(x[3]); + + const __m512 m4 = _mm512_set1_ps(x[4]); + const __m512 m5 = _mm512_set1_ps(x[5]); + const __m512 m6 = _mm512_set1_ps(x[6]); + const __m512 m7 = _mm512_set1_ps(x[7]); + + for (; i < ny16 * 16; i += 16) { + __m512 v0; + __m512 v1; + __m512 v2; + __m512 v3; + __m512 v4; + __m512 v5; + __m512 v6; + __m512 v7; + + transpose_16x8( + _mm512_loadu_ps(y + 0 * 16), + _mm512_loadu_ps(y + 1 * 16), + _mm512_loadu_ps(y + 2 * 16), + _mm512_loadu_ps(y + 3 * 16), + _mm512_loadu_ps(y + 4 * 16), + _mm512_loadu_ps(y + 5 * 16), + _mm512_loadu_ps(y + 6 * 16), + _mm512_loadu_ps(y + 7 * 16), + v0, + v1, + v2, + v3, + v4, + v5, + v6, + v7); + + const __m512 d0 = _mm512_sub_ps(m0, v0); + const __m512 d1 = _mm512_sub_ps(m1, v1); + const __m512 d2 = _mm512_sub_ps(m2, v2); + const __m512 d3 = _mm512_sub_ps(m3, v3); + const __m512 d4 = _mm512_sub_ps(m4, v4); + const __m512 d5 = _mm512_sub_ps(m5, v5); + const __m512 d6 = _mm512_sub_ps(m6, v6); + const __m512 d7 = _mm512_sub_ps(m7, v7); + + __m512 distances = _mm512_mul_ps(d0, d0); + distances = _mm512_fmadd_ps(d1, d1, distances); + distances = _mm512_fmadd_ps(d2, d2, distances); + distances = _mm512_fmadd_ps(d3, d3, distances); + distances = _mm512_fmadd_ps(d4, d4, distances); + distances = _mm512_fmadd_ps(d5, d5, distances); + distances = _mm512_fmadd_ps(d6, d6, distances); + distances = _mm512_fmadd_ps(d7, d7, distances); + + __mmask16 comparison = + _mm512_cmp_ps_mask(distances, min_distances, _CMP_LT_OS); + + min_distances = _mm512_min_ps(distances, min_distances); + min_indices = _mm512_mask_blend_epi32( + comparison, min_indices, current_indices); + + current_indices = + _mm512_add_epi32(current_indices, indices_increment); + + y += 128; + } + + alignas(64) float min_distances_scalar[16]; + alignas(64) uint32_t min_indices_scalar[16]; + _mm512_store_ps(min_distances_scalar, min_distances); + _mm512_store_epi32(min_indices_scalar, min_indices); + + for (size_t j = 0; j < 16; j++) { + if (current_min_distance > min_distances_scalar[j]) { + current_min_distance = min_distances_scalar[j]; + current_min_index = min_indices_scalar[j]; + } + } + } + + if (i < ny) { + __m256 x0 = _mm256_loadu_ps(x); + + for (; i < ny; i++) { + __m256 accu = AVX512ElementOpL2::op(x0, _mm256_loadu_ps(y)); + y += 8; + const float distance = horizontal_sum(accu); + + if (current_min_distance > distance) { + current_min_distance = distance; + current_min_index = i; + } + } + } + + return current_min_index; +} + +template <> +size_t fvec_L2sqr_ny_nearest( + float* distances_tmp_buffer, + const float* x, + const float* y, + size_t d, + size_t ny) { + return fvec_L2sqr_ny_nearest_x86( + distances_tmp_buffer, + x, + y, + d, + ny, + &fvec_L2sqr_ny_nearest_D2, + &fvec_L2sqr_ny_nearest_D4, + &fvec_L2sqr_ny_nearest_D8); +} + +template <> +size_t fvec_L2sqr_ny_nearest_y_transposed( + float* distances_tmp_buffer, + const float* x, + const float* y, + const float* y_sqlen, + size_t d, + size_t d_offset, + size_t ny) { + return fvec_L2sqr_ny_nearest_y_transposed( + distances_tmp_buffer, x, y, y_sqlen, d, d_offset, ny); +} + +// TODO: Following functions are not used in the current codebase. Check AVX2 , +// respective implementation has been used +template +size_t fvec_L2sqr_ny_nearest_y_transposed_D( + float* distances_tmp_buffer, + const float* x, + const float* y, + const float* y_sqlen, + const size_t d_offset, + size_t ny) { + // This implementation does not use distances_tmp_buffer. + + // Current index being processed + size_t i = 0; + + // Min distance and the index of the closest vector so far + float current_min_distance = HUGE_VALF; + size_t current_min_index = 0; + + // Process 16 vectors per loop + const size_t ny16 = ny / 16; + + if (ny16 > 0) { + // Track min distance and the closest vector independently + // for each of 16 AVX-512 components. + __m512 min_distances = _mm512_set1_ps(HUGE_VALF); + __m512i min_indices = _mm512_set1_epi32(0); + + __m512i current_indices = _mm512_setr_epi32( + 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15); + const __m512i indices_increment = _mm512_set1_epi32(16); + + // m[i] = (2 * x[i], ... 2 * x[i]) + __m512 m[DIM]; + for (size_t j = 0; j < DIM; j++) { + m[j] = _mm512_set1_ps(x[j]); + m[j] = _mm512_add_ps(m[j], m[j]); + } + + for (; i < ny16 * 16; i += 16) { + // Compute dot products + const __m512 v0 = _mm512_loadu_ps(y + 0 * d_offset); + __m512 dp = _mm512_mul_ps(m[0], v0); + for (size_t j = 1; j < DIM; j++) { + const __m512 vj = _mm512_loadu_ps(y + j * d_offset); + dp = _mm512_fmadd_ps(m[j], vj, dp); + } + + // Compute y^2 - (2 * x, y), which is sufficient for looking for the + // lowest distance. + // x^2 is the constant that can be avoided. + const __m512 distances = + _mm512_sub_ps(_mm512_loadu_ps(y_sqlen), dp); + + // Compare the new distances to the min distances + __mmask16 comparison = + _mm512_cmp_ps_mask(min_distances, distances, _CMP_LT_OS); + + // Update min distances and indices with closest vectors if needed + min_distances = + _mm512_mask_blend_ps(comparison, distances, min_distances); + min_indices = _mm512_castps_si512(_mm512_mask_blend_ps( + comparison, + _mm512_castsi512_ps(current_indices), + _mm512_castsi512_ps(min_indices))); + + // Update current indices values. Basically, +16 to each of the 16 + // AVX-512 components. + current_indices = + _mm512_add_epi32(current_indices, indices_increment); + + // Scroll y and y_sqlen forward. + y += 16; + y_sqlen += 16; + } + + // Dump values and find the minimum distance / minimum index + float min_distances_scalar[16]; + uint32_t min_indices_scalar[16]; + _mm512_storeu_ps(min_distances_scalar, min_distances); + _mm512_storeu_si512((__m512i*)(min_indices_scalar), min_indices); + + for (size_t j = 0; j < 16; j++) { + if (current_min_distance > min_distances_scalar[j]) { + current_min_distance = min_distances_scalar[j]; + current_min_index = min_indices_scalar[j]; + } + } + } + + if (i < ny) { + // Process leftovers + for (; i < ny; i++) { + float dp = 0; + for (size_t j = 0; j < DIM; j++) { + dp += x[j] * y[j * d_offset]; + } + + // Compute y^2 - 2 * (x, y), which is sufficient for looking for the + // lowest distance. + const float distance = y_sqlen[0] - 2 * dp; + + if (current_min_distance > distance) { + current_min_distance = distance; + current_min_index = i; + } + + y += 1; + y_sqlen += 1; + } + } + + return current_min_index; +} + +template <> +int fvec_madd_and_argmin( + size_t n, + const float* a, + float bf, + const float* b, + float* c) { + return fvec_madd_and_argmin_sse(n, a, bf, b, c); +} + +} // namespace faiss diff --git a/faiss/utils/simd_impl/distances_sse-inl.h b/faiss/utils/simd_impl/distances_sse-inl.h new file mode 100644 index 0000000000..a5151750cb --- /dev/null +++ b/faiss/utils/simd_impl/distances_sse-inl.h @@ -0,0 +1,385 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * + * This source code is licensed under the MIT license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include + +#include + +namespace faiss { + +[[maybe_unused]] static inline void fvec_madd_sse( + size_t n, + const float* a, + float bf, + const float* b, + float* c) { + n >>= 2; + __m128 bf4 = _mm_set_ps1(bf); + __m128* a4 = (__m128*)a; + __m128* b4 = (__m128*)b; + __m128* c4 = (__m128*)c; + + while (n--) { + *c4 = _mm_add_ps(*a4, _mm_mul_ps(bf4, *b4)); + b4++; + a4++; + c4++; + } +} + +/// helper function +inline float horizontal_sum(const __m128 v) { + // say, v is [x0, x1, x2, x3] + + // v0 is [x2, x3, ..., ...] + const __m128 v0 = _mm_shuffle_ps(v, v, _MM_SHUFFLE(0, 0, 3, 2)); + // v1 is [x0 + x2, x1 + x3, ..., ...] + const __m128 v1 = _mm_add_ps(v, v0); + // v2 is [x1 + x3, ..., .... ,...] + __m128 v2 = _mm_shuffle_ps(v1, v1, _MM_SHUFFLE(0, 0, 0, 1)); + // v3 is [x0 + x1 + x2 + x3, ..., ..., ...] + const __m128 v3 = _mm_add_ps(v1, v2); + // return v3[0] + return _mm_cvtss_f32(v3); +} + +/// Function that does a component-wise operation between x and y +/// to compute inner products +struct ElementOpIP { + static float op(float x, float y) { + return x * y; + } + + static __m128 op(__m128 x, __m128 y) { + return _mm_mul_ps(x, y); + } +}; + +/// Function that does a component-wise operation between x and y +/// to compute L2 distances. ElementOp can then be used in the fvec_op_ny +/// functions below +struct ElementOpL2 { + static float op(float x, float y) { + float tmp = x - y; + return tmp * tmp; + } + + static __m128 op(__m128 x, __m128 y) { + __m128 tmp = _mm_sub_ps(x, y); + return _mm_mul_ps(tmp, tmp); + } +}; + +template +void fvec_op_ny_D1(float* dis, const float* x, const float* y, size_t ny) { + float x0s = x[0]; + __m128 x0 = _mm_set_ps(x0s, x0s, x0s, x0s); + + size_t i; + for (i = 0; i + 3 < ny; i += 4) { + __m128 accu = ElementOp::op(x0, _mm_loadu_ps(y)); + y += 4; + dis[i] = _mm_cvtss_f32(accu); + __m128 tmp = _mm_shuffle_ps(accu, accu, 1); + dis[i + 1] = _mm_cvtss_f32(tmp); + tmp = _mm_shuffle_ps(accu, accu, 2); + dis[i + 2] = _mm_cvtss_f32(tmp); + tmp = _mm_shuffle_ps(accu, accu, 3); + dis[i + 3] = _mm_cvtss_f32(tmp); + } + while (i < ny) { // handle non-multiple-of-4 case + dis[i++] = ElementOp::op(x0s, *y++); + } +} + +template +void fvec_op_ny_D2(float* dis, const float* x, const float* y, size_t ny) { + __m128 x0 = _mm_set_ps(x[1], x[0], x[1], x[0]); + + size_t i; + for (i = 0; i + 1 < ny; i += 2) { + __m128 accu = ElementOp::op(x0, _mm_loadu_ps(y)); + y += 4; + accu = _mm_hadd_ps(accu, accu); + dis[i] = _mm_cvtss_f32(accu); + accu = _mm_shuffle_ps(accu, accu, 3); + dis[i + 1] = _mm_cvtss_f32(accu); + } + if (i < ny) { // handle odd case + dis[i] = ElementOp::op(x[0], y[0]) + ElementOp::op(x[1], y[1]); + } +} + +template +void fvec_op_ny_D4(float* dis, const float* x, const float* y, size_t ny) { + __m128 x0 = _mm_loadu_ps(x); + + for (size_t i = 0; i < ny; i++) { + __m128 accu = ElementOp::op(x0, _mm_loadu_ps(y)); + y += 4; + dis[i] = horizontal_sum(accu); + } +} + +template +void fvec_op_ny_D8(float* dis, const float* x, const float* y, size_t ny) { + __m128 x0 = _mm_loadu_ps(x); + __m128 x1 = _mm_loadu_ps(x + 4); + + for (size_t i = 0; i < ny; i++) { + __m128 accu = ElementOp::op(x0, _mm_loadu_ps(y)); + y += 4; + accu = _mm_add_ps(accu, ElementOp::op(x1, _mm_loadu_ps(y))); + y += 4; + accu = _mm_hadd_ps(accu, accu); + accu = _mm_hadd_ps(accu, accu); + dis[i] = _mm_cvtss_f32(accu); + } +} + +template +void fvec_op_ny_D12(float* dis, const float* x, const float* y, size_t ny) { + __m128 x0 = _mm_loadu_ps(x); + __m128 x1 = _mm_loadu_ps(x + 4); + __m128 x2 = _mm_loadu_ps(x + 8); + + for (size_t i = 0; i < ny; i++) { + __m128 accu = ElementOp::op(x0, _mm_loadu_ps(y)); + y += 4; + accu = _mm_add_ps(accu, ElementOp::op(x1, _mm_loadu_ps(y))); + y += 4; + accu = _mm_add_ps(accu, ElementOp::op(x2, _mm_loadu_ps(y))); + y += 4; + dis[i] = horizontal_sum(accu); + } +} + +template +void fvec_inner_products_ny_ref( + float* dis, + const float* x, + const float* y, + size_t d, + size_t ny) { +#define DISPATCH(dval) \ + case dval: \ + fvec_op_ny_D##dval(dis, x, y, ny); \ + return; + + switch (d) { + DISPATCH(1) + DISPATCH(2) + DISPATCH(4) + DISPATCH(8) + DISPATCH(12) + default: + fvec_inner_products_ny(dis, x, y, d, ny); + return; + } +#undef DISPATCH +} + +template +void fvec_L2sqr_ny_ref( + float* dis, + const float* x, + const float* y, + size_t d, + size_t ny) { + // optimized for a few special cases + +#define DISPATCH(dval) \ + case dval: \ + fvec_op_ny_D##dval(dis, x, y, ny); \ + return; + + switch (d) { + DISPATCH(1) + DISPATCH(2) + DISPATCH(4) + DISPATCH(8) + DISPATCH(12) + default: + fvec_L2sqr_ny(dis, x, y, d, ny); + return; + } +#undef DISPATCH +} + +template +size_t fvec_L2sqr_ny_nearest_D2( + float* distances_tmp_buffer, + const float* x, + const float* y, + size_t ny); + +template +size_t fvec_L2sqr_ny_nearest_D4( + float* distances_tmp_buffer, + const float* x, + const float* y, + size_t ny); + +template +size_t fvec_L2sqr_ny_nearest_D8( + float* distances_tmp_buffer, + const float* x, + const float* y, + size_t ny); + +template +size_t fvec_L2sqr_ny_nearest_x86( + float* distances_tmp_buffer, + const float* x, + const float* y, + size_t d, + size_t ny, + size_t (*fvec_L2sqr_ny_nearest_D2_func)( + float*, + const float*, + const float*, + size_t) = &fvec_L2sqr_ny_nearest_D2, + size_t (*fvec_L2sqr_ny_nearest_D4_func)( + float*, + const float*, + const float*, + size_t) = &fvec_L2sqr_ny_nearest_D4, + size_t (*fvec_L2sqr_ny_nearest_D8_func)( + float*, + const float*, + const float*, + size_t) = &fvec_L2sqr_ny_nearest_D8); + +template +size_t fvec_L2sqr_ny_nearest_x86( + float* distances_tmp_buffer, + const float* x, + const float* y, + size_t d, + size_t ny, + size_t (*fvec_L2sqr_ny_nearest_D2_func)( + float*, + const float*, + const float*, + size_t), + size_t (*fvec_L2sqr_ny_nearest_D4_func)( + float*, + const float*, + const float*, + size_t), + size_t (*fvec_L2sqr_ny_nearest_D8_func)( + float*, + const float*, + const float*, + size_t)) { + switch (d) { + case 2: + return fvec_L2sqr_ny_nearest_D2_func( + distances_tmp_buffer, x, y, ny); + case 4: + return fvec_L2sqr_ny_nearest_D4_func( + distances_tmp_buffer, x, y, ny); + case 8: + return fvec_L2sqr_ny_nearest_D8_func( + distances_tmp_buffer, x, y, ny); + } + + return fvec_L2sqr_ny_nearest( + distances_tmp_buffer, x, y, d, ny); +} + +template +inline size_t fvec_L2sqr_ny_nearest( + float* distances_tmp_buffer, + const float* x, + const float* y, + size_t d, + size_t ny); + +static inline int fvec_madd_and_argmin_sse_ref( + size_t n, + const float* a, + float bf, + const float* b, + float* c) { + n >>= 2; + __m128 bf4 = _mm_set_ps1(bf); + __m128 vmin4 = _mm_set_ps1(1e20); + __m128i imin4 = _mm_set1_epi32(-1); + __m128i idx4 = _mm_set_epi32(3, 2, 1, 0); + __m128i inc4 = _mm_set1_epi32(4); + __m128* a4 = (__m128*)a; + __m128* b4 = (__m128*)b; + __m128* c4 = (__m128*)c; + + while (n--) { + __m128 vc4 = _mm_add_ps(*a4, _mm_mul_ps(bf4, *b4)); + *c4 = vc4; + __m128i mask = _mm_castps_si128(_mm_cmpgt_ps(vmin4, vc4)); + // imin4 = _mm_blendv_epi8 (imin4, idx4, mask); // slower! + + imin4 = _mm_or_si128( + _mm_and_si128(mask, idx4), _mm_andnot_si128(mask, imin4)); + vmin4 = _mm_min_ps(vmin4, vc4); + b4++; + a4++; + c4++; + idx4 = _mm_add_epi32(idx4, inc4); + } + + // 4 values -> 2 + { + idx4 = _mm_shuffle_epi32(imin4, 3 << 2 | 2); + __m128 vc4 = _mm_shuffle_ps(vmin4, vmin4, 3 << 2 | 2); + __m128i mask = _mm_castps_si128(_mm_cmpgt_ps(vmin4, vc4)); + imin4 = _mm_or_si128( + _mm_and_si128(mask, idx4), _mm_andnot_si128(mask, imin4)); + vmin4 = _mm_min_ps(vmin4, vc4); + } + // 2 values -> 1 + { + idx4 = _mm_shuffle_epi32(imin4, 1); + __m128 vc4 = _mm_shuffle_ps(vmin4, vmin4, 1); + __m128i mask = _mm_castps_si128(_mm_cmpgt_ps(vmin4, vc4)); + imin4 = _mm_or_si128( + _mm_and_si128(mask, idx4), _mm_andnot_si128(mask, imin4)); + // vmin4 = _mm_min_ps (vmin4, vc4); + } + return _mm_cvtsi128_si32(imin4); +} + +static inline int fvec_madd_and_argmin_sse( + size_t n, + const float* a, + float bf, + const float* b, + float* c) { + if ((n & 3) == 0 && ((((long)a) | ((long)b) | ((long)c)) & 15) == 0) + return fvec_madd_and_argmin_sse_ref(n, a, bf, b, c); + + return fvec_madd_and_argmin(n, a, bf, b, c); +} + +// reads 0 <= d < 4 floats as __m128 +static inline __m128 masked_read(int d, const float* x) { + assert(0 <= d && d < 4); + ALIGNED(16) float buf[4] = {0, 0, 0, 0}; + switch (d) { + case 3: + buf[2] = x[2]; + [[fallthrough]]; + case 2: + buf[1] = x[1]; + [[fallthrough]]; + case 1: + buf[0] = x[0]; + } + return _mm_load_ps(buf); + // cannot use AVX2 _mm_mask_set1_epi32 +} + +} // namespace faiss diff --git a/faiss/utils/simd_levels.cpp b/faiss/utils/simd_levels.cpp index 887225ee3b..3f1769b289 100644 --- a/faiss/utils/simd_levels.cpp +++ b/faiss/utils/simd_levels.cpp @@ -125,10 +125,9 @@ SIMDLevel SIMDConfig::auto_detect_simd_level() { #if defined(__aarch64__) && defined(__ARM_NEON) && \ defined(COMPILE_SIMD_ARM_NEON) - // ARM NEON is standard on aarch64, so we can assume it's available + // ARM NEON is standard on aarch64 supported_simd_levels().insert(SIMDLevel::ARM_NEON); level = SIMDLevel::ARM_NEON; - // TODO: Add ARM SVE detection when needed // For now, we default to ARM_NEON as it's universally supported on aarch64 #endif diff --git a/faiss/utils/simd_levels.h b/faiss/utils/simd_levels.h index ad3d0b289d..95b2decc0b 100644 --- a/faiss/utils/simd_levels.h +++ b/faiss/utils/simd_levels.h @@ -61,7 +61,7 @@ struct SIMDConfig { #ifdef COMPILE_SIMD_AVX512 #define DISPATCH_SIMDLevel_AVX512(f, ...) \ - case SIMDLevel::AVX512F: \ + case SIMDLevel::AVX512: \ return f(__VA_ARGS__) #else #define DISPATCH_SIMDLevel_AVX512(f, ...) diff --git a/tests/test_distances_simd.cpp b/tests/test_distances_simd.cpp index 539fe2a419..dda33c3e72 100644 --- a/tests/test_distances_simd.cpp +++ b/tests/test_distances_simd.cpp @@ -39,104 +39,352 @@ void fvec_L2sqr_ny_ref( } } -// test templated versions of fvec_L2sqr_ny -TEST(TestFvecL2sqrNy, D2) { - // we're using int values in order to get 100% accurate - // results with floats. - std::default_random_engine rng(123); - std::uniform_int_distribution u(0, 32); +void remove_simd_level_if_exists( + std::unordered_set& levels, + faiss::SIMDLevel level) { + std::erase_if( + levels, [level](faiss::SIMDLevel elem) { return elem == level; }); +} - for (const auto dim : {2, 4, 8, 12}) { - std::vector x(dim, 0); - for (size_t i = 0; i < x.size(); i++) { - x[i] = u(rng); +class DistancesSIMDTest : public ::testing::TestWithParam { + protected: + void SetUp() override { + original_simd_level = faiss::SIMDConfig::get_level(); + std::iota(dims.begin(), dims.end(), 1); + + ntests = 4; + + simd_level = GetParam(); + faiss::SIMDConfig::set_level(simd_level); + + EXPECT_EQ(faiss::SIMDConfig::get_level(), simd_level); + + rng = std::default_random_engine(123); + uniform = std::uniform_int_distribution(0, 32); + } + + void TearDown() override { + faiss::SIMDConfig::set_level(original_simd_level); + } + + std::tuple, std::vector>> + SetupTestData(int dims, int ny) { + std::vector x(dims); + std::vector> y(ny, std::vector(dims)); + + for (size_t i = 0; i < dims; i++) { + x[i] = uniform(rng); + for (size_t j = 0; j < ny; j++) { + y[j][i] = uniform(rng); + } + } + return std::make_tuple(x, y); + } + + std::vector flatten_2d_vector( + const std::vector>& v) { + std::vector flat_v; + for (const auto& vec : v) { + flat_v.insert(flat_v.end(), vec.begin(), vec.end()); + } + return flat_v; + } + + faiss::SIMDLevel simd_level = faiss::SIMDLevel::NONE; + faiss::SIMDLevel original_simd_level = faiss::SIMDLevel::NONE; + std::default_random_engine rng; + std::uniform_int_distribution uniform; + + std::vector dims = {128}; + int ntests = 1; +}; + +TEST_P(DistancesSIMDTest, LinfDistance_chebyshev_distance) { + for (int i = 0; i < ntests; ++i) { // repeat tests + for (const auto dim : dims) { // test different dimensions + int ny = 1; + auto [x, y] = SetupTestData(dim, ny); + for (int k = 0; k < ny; ++k) { // test different vectors + float distance = faiss::fvec_Linf(x.data(), y[k].data(), dim); + float ref_distance = 0; + + for (int j = 0; j < dim; ++j) { + ref_distance = + std::max(ref_distance, std::abs(x[j] - y[k][j])); + } + ASSERT_EQ(distance, ref_distance); + } } + } +} - for (const auto nrows : {1, 2, 5, 10, 15, 20, 25}) { - std::vector y(nrows * dim); - for (size_t i = 0; i < y.size(); i++) { - y[i] = u(rng); +TEST_P(DistancesSIMDTest, inner_product_batch_4) { + for (int i = 0; i < ntests; ++i) { + int dim = 128; + int ny = 4; + auto [x, y] = SetupTestData(dim, ny); + + std::vector true_distances(ny, 0.F); + for (int j = 0; j < ny; ++j) { + for (int k = 0; k < dim; ++k) { + true_distances[j] += x[k] * y[j][k]; } + } - std::vector distances(nrows, 0); - faiss::fvec_L2sqr_ny( - distances.data(), x.data(), y.data(), dim, nrows); + std::vector actual_distances(ny, 0.F); + faiss::fvec_inner_product_batch_4( + x.data(), + y[0].data(), + y[1].data(), + y[2].data(), + y[3].data(), + dim, + actual_distances[0], + actual_distances[1], + actual_distances[2], + actual_distances[3]); + + ASSERT_EQ(actual_distances, true_distances) + << "Mismatching fvec_inner_product_batch4 results for test = " + << i; + } +} + +TEST_P(DistancesSIMDTest, fvec_L2sqr) { + for (int i = 0; i < ntests; ++i) { + int ny = 1; + for (const auto dim : dims) { + auto [x, y] = SetupTestData(dim, ny); + float true_distance = 0.F; + for (int k = 0; k < dim; ++k) { + const float tmp = x[k] - y[0][k]; + true_distance += tmp * tmp; + } - std::vector distances_ref(nrows, 0); - fvec_L2sqr_ny_ref( - distances_ref.data(), x.data(), y.data(), dim, nrows); + float actual_distance = + faiss::fvec_L2sqr(x.data(), y[0].data(), dim); - ASSERT_EQ(distances, distances_ref) - << "Mismatching results for dim = " << dim - << ", nrows = " << nrows; + ASSERT_EQ(actual_distance, true_distance) + << "Mismatching fvec_L2sqr results for test = " << i; } } } -// fvec_inner_products_ny -TEST(TestFvecInnerProductsNy, D2) { - // we're using int values in order to get 100% accurate - // results with floats. - std::default_random_engine rng(123); - std::uniform_int_distribution u(0, 32); +TEST_P(DistancesSIMDTest, L2sqr_batch_4) { + for (int i = 0; i < ntests; ++i) { + int dim = 128; + int ny = 4; + auto [x, y] = SetupTestData(dim, ny); + + std::vector true_distances(ny, 0.F); + for (int j = 0; j < ny; ++j) { + for (int k = 0; k < dim; ++k) { + const float tmp = x[k] - y[j][k]; + true_distances[j] += tmp * tmp; + } + } + + std::vector actual_distances(ny, 0.F); + faiss::fvec_L2sqr_batch_4( + x.data(), + y[0].data(), + y[1].data(), + y[2].data(), + y[3].data(), + dim, + actual_distances[0], + actual_distances[1], + actual_distances[2], + actual_distances[3]); + + ASSERT_EQ(actual_distances, true_distances) + << "Mismatching fvec_L2sqr_batch_4 results for test = " << i; + } +} +TEST_P(DistancesSIMDTest, fvec_L2sqr_ny) { for (const auto dim : {2, 4, 8, 12}) { - std::vector x(dim, 0); - for (size_t i = 0; i < x.size(); i++) { - x[i] = u(rng); - } + for (const auto ny : {1, 2, 5, 10, 15, 20, 25}) { + auto [x, y] = SetupTestData(dim, ny); + + std::vector actual_distances(ny, 0.F); - for (const auto nrows : {1, 2, 5, 10, 15, 20, 25}) { - std::vector y(nrows * dim); - for (size_t i = 0; i < y.size(); i++) { - y[i] = u(rng); + std::vector flat_y; + for (auto y_ : y) { + flat_y.insert(flat_y.end(), y_.begin(), y_.end()); } - std::vector distances(nrows, 0); + std::vector true_distances(ny, 0.F); + for (int i = 0; i < ny; ++i) { + for (int k = 0; k < dim; ++k) { + const float tmp = x[k] - y[i][k]; + true_distances[i] += tmp * tmp; + } + } + + faiss::fvec_L2sqr_ny( + actual_distances.data(), x.data(), flat_y.data(), dim, ny); + + ASSERT_EQ(actual_distances, true_distances) + << "Mismatching fvec_L2sqr_ny results for dim = " << dim + << ", ny = " << ny; + } + } +} + +TEST_P(DistancesSIMDTest, fvec_inner_products_ny) { + for (const auto dim : {2, 4, 8, 12}) { + for (const auto ny : {1, 2, 5, 10, 15, 20, 25}) { + auto [x, y] = SetupTestData(dim, ny); + auto flat_y = flatten_2d_vector(y); + + std::vector actual_distances(ny, 0.F); faiss::fvec_inner_products_ny( - distances.data(), x.data(), y.data(), dim, nrows); + actual_distances.data(), x.data(), flat_y.data(), dim, ny); - std::vector distances_ref(nrows, 0); - fvec_inner_products_ny_ref( - distances_ref.data(), x.data(), y.data(), dim, nrows); + std::vector true_distances(ny, 0.F); + for (int i = 0; i < ny; ++i) { + for (int k = 0; k < dim; ++k) { + true_distances[i] += x[k] * y[i][k]; + } + } - ASSERT_EQ(distances, distances_ref) - << "Mismatching results for dim = " << dim - << ", nrows = " << nrows; + ASSERT_EQ(actual_distances, true_distances) + << "Mismatching fvec_inner_products_ny results for dim = " + << dim << ", ny = " << ny; } } } -TEST(TestFvecL2sqr, distances_L2_squared_y_transposed) { - // ints instead of floats for 100% accuracy +TEST_P(DistancesSIMDTest, L2SqrNYNearest) { std::default_random_engine rng(123); std::uniform_int_distribution uniform(0, 32); + int dim = 128; + int ny = 11; + + auto [x, y] = SetupTestData(dim, ny); + auto flat_y = flatten_2d_vector(y); + + std::vector true_tmp_buffer_distances(ny, 0.F); + for (int i = 0; i < ny; ++i) { + for (int k = 0; k < dim; ++k) { + const float tmp = x[k] - y[i][k]; + true_tmp_buffer_distances[i] += tmp * tmp; + } + } + + size_t true_nearest_idx = 0; + float min_dis = HUGE_VALF; + + for (size_t i = 0; i < ny; i++) { + if (true_tmp_buffer_distances[i] < min_dis) { + min_dis = true_tmp_buffer_distances[i]; + true_nearest_idx = i; + } + } + + std::vector actual_distances(ny); + auto actual_nearest_index = faiss::fvec_L2sqr_ny_nearest( + actual_distances.data(), x.data(), flat_y.data(), dim, ny); + + EXPECT_EQ(actual_nearest_index, true_nearest_idx); +} + +TEST_P(DistancesSIMDTest, multiple_add) { + // modulo 8 results - 16 is to repeat the while loop in the function + for (const auto dim : {8, 9, 10, 11, 12, 13, 14, 15, 16}) { + auto [x, y] = SetupTestData(dim, 1); + const float bf = uniform(rng); + std::vector true_distances(dim); + for (size_t i = 0; i < x.size(); i++) { + true_distances[i] = x[i] + bf * y[0][i]; + } + + std::vector actual_distances(dim); + faiss::fvec_madd( + x.size(), x.data(), bf, y[0].data(), actual_distances.data()); + + ASSERT_EQ(actual_distances, true_distances) + << "Mismatching fvec_madd results for nrows = " << dim; + } +} + +TEST_P(DistancesSIMDTest, manhattan_distance) { + // modulo 8 results - 16 is to repeat the while loop in the function + for (const auto dim : {8, 9, 10, 11, 12, 13, 14, 15, 16}) { + auto [x, y] = SetupTestData(dim, 1); + float true_distance = 0; + for (size_t i = 0; i < x.size(); i++) { + true_distance += std::abs(x[i] - y[0][i]); + } + + auto actual_distances = faiss::fvec_L1(x.data(), y[0].data(), x.size()); + + ASSERT_EQ(actual_distances, true_distance) + << "Mismatching fvec_Linf results for nrows = " << dim; + } +} + +TEST_P(DistancesSIMDTest, add_value) { + for (const auto dim : {1, 2, 5, 10, 15, 20, 25}) { + auto [x, y] = SetupTestData(dim, 1); + const float b = uniform(rng); // value to add + std::vector true_distances(dim); + for (size_t i = 0; i < x.size(); i++) { + true_distances[i] = x[i] + b; + } + + std::vector actual_distances(dim); + faiss::fvec_add(x.size(), x.data(), b, actual_distances.data()); + + ASSERT_EQ(actual_distances, true_distances) + << "Mismatching array-value fvec_add results for nrows = " + << dim; + } +} + +TEST_P(DistancesSIMDTest, add_array) { + for (const auto dim : {1, 2, 5, 10, 15, 20, 25}) { + auto [x, y] = SetupTestData(dim, 1); + std::vector true_distances(dim); + for (size_t i = 0; i < x.size(); i++) { + true_distances[i] = x[i] + y[0][i]; + } + + std::vector actual_distances(dim); + faiss::fvec_add( + x.size(), x.data(), y[0].data(), actual_distances.data()); + + ASSERT_EQ(actual_distances, true_distances) + << "Mismatching array-array fvec_add results for nrows = " + << dim; + } +} + +TEST_P(DistancesSIMDTest, distances_L2_squared_y_transposed) { // modulo 8 results - 16 is to repeat the loop in the function int ny = 11; // this value will hit all the codepaths for (const auto d : {1, 2, 3, 4, 5, 6, 7, 8, 16}) { - // initialize inputs - std::vector x(d); + auto [x, y] = SetupTestData(d, ny); float x_sqlen = 0; - for (size_t i = 0; i < x.size(); i++) { - x[i] = uniform(rng); + for (size_t i = 0; i < d; ++i) { x_sqlen += x[i] * x[i]; } - std::vector y(d * ny); + auto flat_y = flatten_2d_vector(y); std::vector y_sqlens(ny, 0); - for (size_t i = 0; i < ny; i++) { - for (size_t j = 0; j < y.size(); j++) { - y[j] = uniform(rng); - y_sqlens[i] += y[j] * y[j]; + for (size_t i = 0; i < ny; ++i) { + for (size_t j = 0; j < d; ++j) { + y_sqlens[i] += flat_y[j] * flat_y[j]; } } // perform function std::vector true_distances(ny, 0); - for (size_t i = 0; i < ny; i++) { + for (size_t i = 0; i < ny; ++i) { float dp = 0; - for (size_t j = 0; j < d; j++) { - dp += x[j] * y[i + j * ny]; + for (size_t j = 0; j < d; ++j) { + dp += x[j] * flat_y[i + j * ny]; } true_distances[i] = x_sqlen + y_sqlens[i] - 2 * dp; } @@ -145,7 +393,7 @@ TEST(TestFvecL2sqr, distances_L2_squared_y_transposed) { faiss::fvec_L2sqr_ny_transposed( distances.data(), x.data(), - y.data(), + flat_y.data(), y_sqlens.data(), d, ny, // no need for special offset to test all lines of code @@ -156,39 +404,34 @@ TEST(TestFvecL2sqr, distances_L2_squared_y_transposed) { } } -TEST(TestFvecL2sqr, nearest_L2_squared_y_transposed) { - // ints instead of floats for 100% accuracy - std::default_random_engine rng(123); - std::uniform_int_distribution uniform(0, 32); - +TEST_P(DistancesSIMDTest, nearest_L2_squared_y_transposed) { // modulo 8 results - 16 is to repeat the loop in the function int ny = 11; // this value will hit all the codepaths - for (const auto d : {1, 2, 3, 4, 5, 6, 7, 8, 16}) { - // initialize inputs - std::vector x(d); - float x_sqlen = 0; - for (size_t i = 0; i < x.size(); i++) { - x[i] = uniform(rng); + for (const auto dim : {1, 2, 3, 4, 5, 6, 7, 8, 16}) { + auto [x, y] = SetupTestData(dim, ny); + float x_sqlen = 0.F; + for (size_t i = 0; i < dim; i++) { x_sqlen += x[i] * x[i]; } - std::vector y(d * ny); + + auto flat_y = flatten_2d_vector(y); std::vector y_sqlens(ny, 0); + for (size_t i = 0; i < ny; i++) { - for (size_t j = 0; j < y.size(); j++) { - y[j] = uniform(rng); - y_sqlens[i] += y[j] * y[j]; + for (size_t j = 0; j < dim; j++) { + y_sqlens[i] += y[i][j] * y[i][j]; } } - // get distances std::vector distances(ny, 0); for (size_t i = 0; i < ny; i++) { float dp = 0; - for (size_t j = 0; j < d; j++) { - dp += x[j] * y[i + j * ny]; + for (size_t j = 0; j < dim; j++) { + dp += x[j] * flat_y[i + j * ny]; } distances[i] = x_sqlen + y_sqlens[i] - 2 * dp; } + // find nearest size_t true_nearest_idx = 0; float min_dis = HUGE_VALF; @@ -200,135 +443,42 @@ TEST(TestFvecL2sqr, nearest_L2_squared_y_transposed) { } std::vector buffer(ny); - size_t nearest_idx = faiss::fvec_L2sqr_ny_nearest_y_transposed( + size_t actual_nearest_idx = faiss::fvec_L2sqr_ny_nearest_y_transposed( buffer.data(), x.data(), - y.data(), + flat_y.data(), y_sqlens.data(), - d, + dim, ny, // no need for special offset to test all lines of code ny); - ASSERT_EQ(nearest_idx, true_nearest_idx) + ASSERT_EQ(actual_nearest_idx, true_nearest_idx) << "Mismatching fvec_L2sqr_ny_nearest_y_transposed results for d = " - << d; + << dim; } } -TEST(TestFvecL1, manhattan_distance) { - // ints instead of floats for 100% accuracy - std::default_random_engine rng(123); - std::uniform_int_distribution uniform(0, 32); +std::vector GetSupportedSIMDLevels() { + std::vector supported_levels = {faiss::SIMDLevel::NONE}; - // modulo 8 results - 16 is to repeat the while loop in the function - for (const auto nrows : {8, 9, 10, 11, 12, 13, 14, 15, 16}) { - std::vector x(nrows); - std::vector y(nrows); - float true_distance = 0; - for (size_t i = 0; i < x.size(); i++) { - x[i] = uniform(rng); - y[i] = uniform(rng); - true_distance += std::abs(x[i] - y[i]); + for (int level = static_cast(faiss::SIMDLevel::NONE) + 1; + level < static_cast(faiss::SIMDLevel::COUNT); + level++) { + faiss::SIMDLevel simd_level = static_cast(level); + if (faiss::SIMDConfig::is_simd_level_available(simd_level)) { + supported_levels.push_back(simd_level); } - - auto distance = faiss::fvec_L1(x.data(), y.data(), x.size()); - - ASSERT_EQ(distance, true_distance) - << "Mismatching fvec_Linf results for nrows = " << nrows; } -} -TEST(TestFvecLinf, chebyshev_distance) { - // ints instead of floats for 100% accuracy - std::default_random_engine rng(123); - std::uniform_int_distribution uniform(0, 32); + EXPECT_TRUE(supported_levels.size() > 0); - // modulo 8 results - 16 is to repeat the while loop in the function - for (const auto nrows : {8, 9, 10, 11, 12, 13, 14, 15, 16}) { - std::vector x(nrows); - std::vector y(nrows); - float true_distance = 0; - for (size_t i = 0; i < x.size(); i++) { - x[i] = uniform(rng); - y[i] = uniform(rng); - true_distance = std::max(true_distance, std::abs(x[i] - y[i])); - } - - auto distance = faiss::fvec_Linf(x.data(), y.data(), x.size()); - - ASSERT_EQ(distance, true_distance) - << "Mismatching fvec_Linf results for nrows = " << nrows; - } + return std::vector( + supported_levels.begin(), supported_levels.end()); } -TEST(TestFvecMadd, multiple_add) { - // ints instead of floats for 100% accuracy - std::default_random_engine rng(123); - std::uniform_int_distribution uniform(0, 32); - - // modulo 8 results - 16 is to repeat the while loop in the function - for (const auto nrows : {8, 9, 10, 11, 12, 13, 14, 15, 16}) { - std::vector a(nrows); - std::vector b(nrows); - const float bf = uniform(rng); - std::vector true_distances(nrows); - for (size_t i = 0; i < a.size(); i++) { - a[i] = uniform(rng); - b[i] = uniform(rng); - true_distances[i] = a[i] + bf * b[i]; - } - - std::vector distances(nrows); - faiss::fvec_madd(a.size(), a.data(), bf, b.data(), distances.data()); - - ASSERT_EQ(distances, true_distances) - << "Mismatching fvec_madd results for nrows = " << nrows; - } +::testing::internal::ParamGenerator SupportedSIMDLevels() { + std::vector levels = GetSupportedSIMDLevels(); + return ::testing::ValuesIn(levels); } -TEST(TestFvecAdd, add_array) { - // ints instead of floats for 100% accuracy - std::default_random_engine rng(123); - std::uniform_int_distribution uniform(0, 32); - - for (const auto nrows : {1, 2, 5, 10, 15, 20, 25}) { - std::vector a(nrows); - std::vector b(nrows); - std::vector true_distances(nrows); - for (size_t i = 0; i < a.size(); i++) { - a[i] = uniform(rng); - b[i] = uniform(rng); - true_distances[i] = a[i] + b[i]; - } - - std::vector distances(nrows); - faiss::fvec_add(a.size(), a.data(), b.data(), distances.data()); - - ASSERT_EQ(distances, true_distances) - << "Mismatching array-array fvec_add results for nrows = " - << nrows; - } -} - -TEST(TestFvecAdd, add_value) { - // ints instead of floats for 100% accuracy - std::default_random_engine rng(123); - std::uniform_int_distribution uniform(0, 32); - - for (const auto nrows : {1, 2, 5, 10, 15, 20, 25}) { - std::vector a(nrows); - const float b = uniform(rng); // value to add - std::vector true_distances(nrows); - for (size_t i = 0; i < a.size(); i++) { - a[i] = uniform(rng); - true_distances[i] = a[i] + b; - } - - std::vector distances(nrows); - faiss::fvec_add(a.size(), a.data(), b, distances.data()); - - ASSERT_EQ(distances, true_distances) - << "Mismatching array-value fvec_add results for nrows = " - << nrows; - } -} +INSTANTIATE_TEST_SUITE_P(SIMDLevels, DistancesSIMDTest, SupportedSIMDLevels()); diff --git a/tests/test_simd_levels.cpp b/tests/test_simd_levels.cpp index 4dac2e9877..64da6e77b9 100644 --- a/tests/test_simd_levels.cpp +++ b/tests/test_simd_levels.cpp @@ -6,8 +6,6 @@ */ #include -#include -#include #ifdef __x86_64__ #include @@ -15,25 +13,9 @@ #include -static jmp_buf jmpbuf; -static void sigill_handler(int sig) { - longjmp(jmpbuf, 1); -} - -bool try_execute(void (*func)()) { - signal(SIGILL, sigill_handler); - if (setjmp(jmpbuf) == 0) { - func(); - signal(SIGILL, SIG_DFL); - return true; - } else { - signal(SIGILL, SIG_DFL); - return false; - } -} - #ifdef __x86_64__ -std::vector run_avx2_computation() { +bool run_avx2_computation() { +#if defined(__AVX2__) alignas(32) int result[8]; alignas(32) int input1[8] = {1, 2, 3, 4, 5, 6, 7, 8}; alignas(32) int input2[8] = {8, 7, 6, 5, 4, 3, 2, 1}; @@ -43,10 +25,14 @@ std::vector run_avx2_computation() { __m256i vec_result = _mm256_add_epi32(vec1, vec2); _mm256_store_si256(reinterpret_cast<__m256i*>(result), vec_result); - return {result, result + 8}; + return true; +#else + return false; +#endif // __AVX2__ } -std::vector run_avx512f_computation() { +bool run_avx512f_computation() { +#ifdef __AVX512F__ alignas(64) long long result[8]; alignas(64) long long input1[8] = {1, 2, 3, 4, 5, 6, 7, 8}; alignas(64) long long input2[8] = {8, 7, 6, 5, 4, 3, 2, 1}; @@ -56,11 +42,15 @@ std::vector run_avx512f_computation() { __m512i vec_result = _mm512_add_epi64(vec1, vec2); _mm512_store_si512(reinterpret_cast<__m512i*>(result), vec_result); - return {result, result + 8}; + return true; +#else + return false; +#endif // __AVX512F__ } -std::vector run_avx512cd_computation() { - run_avx512f_computation(); +bool run_avx512cd_computation() { + EXPECT_TRUE(run_avx512f_computation()); +#ifdef __AVX512CD__ __m512i indices = _mm512_set_epi32( 15, 14, 13, 12, 11, 10, 9, 8, 7, 6, 5, 4, 3, 2, 1, 0); @@ -68,38 +58,47 @@ std::vector run_avx512cd_computation() { alignas(64) int mask_array[16]; _mm512_store_epi32(mask_array, conflict_mask); - - return std::vector(); + return true; +#else + return false; +#endif // __AVX512CD__ } -std::vector run_avx512vl_computation() { - run_avx512f_computation(); +bool run_avx512vl_computation() { + EXPECT_TRUE(run_avx512f_computation()); +#ifdef __AVX512VL__ __m256i vec1 = _mm256_set_epi32(7, 6, 5, 4, 3, 2, 1, 0); __m256i vec2 = _mm256_set_epi32(0, 1, 2, 3, 4, 5, 6, 7); __m256i result = _mm256_add_epi32(vec1, vec2); alignas(32) int result_array[8]; _mm256_store_si256(reinterpret_cast<__m256i*>(result_array), result); - - return std::vector(result_array, result_array + 8); + return true; +#else + return false; +#endif // __AVX512VL__ } -std::vector run_avx512dq_computation() { - run_avx512f_computation(); +bool run_avx512dq_computation() { + EXPECT_TRUE(run_avx512f_computation()); +#ifdef __AVX512DQ__ __m512i vec1 = _mm512_set_epi64(7, 6, 5, 4, 3, 2, 1, 0); __m512i vec2 = _mm512_set_epi64(0, 1, 2, 3, 4, 5, 6, 7); __m512i result = _mm512_add_epi64(vec1, vec2); alignas(64) long long result_array[8]; _mm512_store_si512(result_array, result); - - return std::vector(result_array, result_array + 8); + return true; +#else + return false; +#endif // __AVX512DQ__ } -std::vector run_avx512bw_computation() { - run_avx512f_computation(); +bool run_avx512bw_computation() { + EXPECT_TRUE(run_avx512f_computation()); +#ifdef __AVX512BW__ std::vector input1(64, 0); __m512i vec1 = _mm512_loadu_si512(reinterpret_cast(input1.data())); @@ -111,22 +110,13 @@ std::vector run_avx512bw_computation() { alignas(64) int8_t result_array[64]; _mm512_storeu_si512(reinterpret_cast<__m512i*>(result_array), result); - return std::vector(result_array, result_array + 64); + return true; +#else + return false; +#endif // __AVX512BW__ } #endif // __x86_64__ -std::pair> try_execute(std::vector (*func)()) { - signal(SIGILL, sigill_handler); - if (setjmp(jmpbuf) == 0) { - auto result = func(); - signal(SIGILL, SIG_DFL); - return std::make_pair(true, result); - } else { - signal(SIGILL, SIG_DFL); - return std::make_pair(false, std::vector()); - } -} - TEST(SIMDConfig, simd_level_auto_detect_architecture_only) { faiss::SIMDLevel detected_level = faiss::SIMDConfig::auto_detect_simd_level(); @@ -140,10 +130,12 @@ TEST(SIMDConfig, simd_level_auto_detect_architecture_only) { detected_level == faiss::SIMDLevel::AVX2 || detected_level == faiss::SIMDLevel::AVX512); #elif defined(__aarch64__) && defined(__ARM_NEON) - EXPECT_TRUE(detected_level == faiss::SIMDLevel::ARM_NEON); + // Uncomment following line when dynamic dispatch is enabled for ARM_NEON + // EXPECT_TRUE(detected_level == faiss::SIMDLevel::ARM_NEON); #else EXPECT_EQ(detected_level, faiss::SIMDLevel::NONE); #endif + EXPECT_TRUE(detected_level != faiss::SIMDLevel::COUNT); } #ifdef __x86_64__ @@ -151,10 +143,8 @@ TEST(SIMDConfig, successful_avx2_execution_on_x86arch) { faiss::SIMDConfig simd_config(nullptr); if (simd_config.is_simd_level_available(faiss::SIMDLevel::AVX2)) { - auto actual_result = try_execute(run_avx2_computation); - EXPECT_TRUE(actual_result.first); - auto expected_result_vector = std::vector(8, 9); - EXPECT_EQ(actual_result.second, expected_result_vector); + auto actual_result = run_avx2_computation(); + EXPECT_TRUE(actual_result); } } @@ -171,10 +161,8 @@ TEST(SIMDConfig, successful_avx512f_execution_on_x86arch) { faiss::SIMDConfig simd_config(nullptr); if (simd_config.is_simd_level_available(faiss::SIMDLevel::AVX512)) { - auto actual_result = try_execute(run_avx512f_computation); - EXPECT_TRUE(actual_result.first); - auto expected_result_vector = std::vector(8, 9); - EXPECT_EQ(actual_result.second, expected_result_vector); + auto actual_result = run_avx512f_computation(); + EXPECT_TRUE(actual_result); } } @@ -182,8 +170,8 @@ TEST(SIMDConfig, successful_avx512cd_execution_on_x86arch) { faiss::SIMDConfig simd_config(nullptr); if (simd_config.is_simd_level_available(faiss::SIMDLevel::AVX512)) { - auto actual = try_execute(run_avx512cd_computation); - EXPECT_TRUE(actual.first); + auto actual = run_avx512cd_computation(); + EXPECT_TRUE(actual); } } @@ -191,9 +179,8 @@ TEST(SIMDConfig, successful_avx512vl_execution_on_x86arch) { faiss::SIMDConfig simd_config(nullptr); if (simd_config.is_simd_level_available(faiss::SIMDLevel::AVX512)) { - auto actual = try_execute(run_avx512vl_computation); - EXPECT_TRUE(actual.first); - EXPECT_EQ(actual.second, std::vector(8, 7)); + auto actual = run_avx512vl_computation(); + EXPECT_TRUE(actual); } } @@ -203,9 +190,8 @@ TEST(SIMDConfig, successful_avx512dq_execution_on_x86arch) { if (simd_config.is_simd_level_available(faiss::SIMDLevel::AVX512)) { EXPECT_TRUE( simd_config.is_simd_level_available(faiss::SIMDLevel::AVX512)); - auto actual = try_execute(run_avx512dq_computation); - EXPECT_TRUE(actual.first); - EXPECT_EQ(actual.second, std::vector(8, 7)); + auto actual = run_avx512dq_computation(); + EXPECT_TRUE(actual); } } @@ -215,21 +201,22 @@ TEST(SIMDConfig, successful_avx512bw_execution_on_x86arch) { if (simd_config.is_simd_level_available(faiss::SIMDLevel::AVX512)) { EXPECT_TRUE( simd_config.is_simd_level_available(faiss::SIMDLevel::AVX512)); - auto actual = try_execute(run_avx512bw_computation); - EXPECT_TRUE(actual.first); - EXPECT_EQ(actual.second, std::vector(64, 7)); + auto actual = run_avx512bw_computation(); + EXPECT_TRUE(actual); + // EXPECT_TRUE(actual.first); + // EXPECT_EQ(actual.second, std::vector(64, 7)); } } #endif // __x86_64__ TEST(SIMDConfig, override_simd_level) { - const char* faiss_env_var_neon = "ARM_NEON"; - faiss::SIMDConfig simd_neon_config(&faiss_env_var_neon); - EXPECT_EQ(simd_neon_config.level, faiss::SIMDLevel::ARM_NEON); + // const char* faiss_env_var_neon = "ARM_NEON"; + // faiss::SIMDConfig simd_neon_config(&faiss_env_var_neon); + // EXPECT_EQ(simd_neon_config.level, faiss::SIMDLevel::ARM_NEON); - EXPECT_EQ(simd_neon_config.supported_simd_levels().size(), 2); - EXPECT_TRUE(simd_neon_config.is_simd_level_available( - faiss::SIMDLevel::ARM_NEON)); + // EXPECT_EQ(simd_neon_config.supported_simd_levels().size(), 2); + // EXPECT_TRUE(simd_neon_config.is_simd_level_available( + // faiss::SIMDLevel::ARM_NEON)); const char* faiss_env_var_avx512 = "AVX512"; faiss::SIMDConfig simd_avx512_config(&faiss_env_var_avx512); @@ -240,12 +227,12 @@ TEST(SIMDConfig, override_simd_level) { } TEST(SIMDConfig, simd_config_get_level_name) { - const char* faiss_env_var_neon = "ARM_NEON"; - faiss::SIMDConfig simd_neon_config(&faiss_env_var_neon); - EXPECT_EQ(simd_neon_config.level, faiss::SIMDLevel::ARM_NEON); - EXPECT_TRUE(simd_neon_config.is_simd_level_available( - faiss::SIMDLevel::ARM_NEON)); - EXPECT_EQ(faiss_env_var_neon, simd_neon_config.get_level_name()); + // const char* faiss_env_var_neon = "ARM_NEON"; + // faiss::SIMDConfig simd_neon_config(&faiss_env_var_neon); + // EXPECT_EQ(simd_neon_config.level, faiss::SIMDLevel::ARM_NEON); + // EXPECT_TRUE(simd_neon_config.is_simd_level_available( + // faiss::SIMDLevel::ARM_NEON)); + // EXPECT_EQ(faiss_env_var_neon, simd_neon_config.get_level_name()); const char* faiss_env_var_avx512 = "AVX512"; faiss::SIMDConfig simd_avx512_config(&faiss_env_var_avx512); @@ -259,7 +246,8 @@ TEST(SIMDLevel, get_level_name_from_enum) { EXPECT_EQ("NONE", to_string(faiss::SIMDLevel::NONE).value_or("")); EXPECT_EQ("AVX2", to_string(faiss::SIMDLevel::AVX2).value_or("")); EXPECT_EQ("AVX512", to_string(faiss::SIMDLevel::AVX512).value_or("")); - EXPECT_EQ("ARM_NEON", to_string(faiss::SIMDLevel::ARM_NEON).value_or("")); + // EXPECT_EQ("ARM_NEON", + // to_string(faiss::SIMDLevel::ARM_NEON).value_or("")); int actual_num_simd_levels = static_cast(faiss::SIMDLevel::COUNT); EXPECT_EQ(4, actual_num_simd_levels); @@ -275,6 +263,6 @@ TEST(SIMDLevel, to_simd_level_from_string) { EXPECT_EQ(faiss::SIMDLevel::NONE, faiss::to_simd_level("NONE")); EXPECT_EQ(faiss::SIMDLevel::AVX2, faiss::to_simd_level("AVX2")); EXPECT_EQ(faiss::SIMDLevel::AVX512, faiss::to_simd_level("AVX512")); - EXPECT_EQ(faiss::SIMDLevel::ARM_NEON, faiss::to_simd_level("ARM_NEON")); + // EXPECT_EQ(faiss::SIMDLevel::ARM_NEON, faiss::to_simd_level("ARM_NEON")); EXPECT_FALSE(faiss::to_simd_level("INVALID").has_value()); } From 18ffaa780eb545d6512c504fcb565f996f0a604a Mon Sep 17 00:00:00 2001 From: matthijs Date: Thu, 28 Aug 2025 00:59:30 -0700 Subject: [PATCH 3/5] moved IndexIVFPQ and IndexPQ to dynamic dispatch (#4291) Summary: Pull Request resolved: https://github.com/facebookresearch/faiss/pull/4291 moved IndexIVFPQ and IndexPQ to dynamic dispatch. Since the code was already quite modular (thanks Alex!), this boils down to make independent cpp files for the different SIMD versions. Differential Revision: D72937709 --- faiss/IndexIVFPQ.cpp | 68 ++- faiss/IndexPQ.cpp | 32 +- .../impl/code_distance/code_distance-avx2.cpp | 490 ++++++++++++++++ faiss/impl/code_distance/code_distance-avx2.h | 534 ------------------ .../code_distance/code_distance-avx512.cpp | 203 +++++++ .../impl/code_distance/code_distance-avx512.h | 248 -------- .../code_distance/code_distance-generic.cpp | 20 + .../code_distance/code_distance-generic.h | 81 --- ...e_distance-sve.h => code_distance-sve.cpp} | 4 +- faiss/impl/code_distance/code_distance.h | 229 +++----- faiss/utils/{ => simd_impl}/simdlib_avx2.h | 0 faiss/utils/{ => simd_impl}/simdlib_avx512.h | 0 .../utils/{ => simd_impl}/simdlib_emulated.h | 0 faiss/utils/{ => simd_impl}/simdlib_neon.h | 0 faiss/utils/{ => simd_impl}/simdlib_ppc64.h | 0 faiss/utils/simdlib.h | 8 +- tests/test_code_distance.cpp | 73 ++- 17 files changed, 896 insertions(+), 1094 deletions(-) create mode 100644 faiss/impl/code_distance/code_distance-avx2.cpp delete mode 100644 faiss/impl/code_distance/code_distance-avx2.h create mode 100644 faiss/impl/code_distance/code_distance-avx512.cpp delete mode 100644 faiss/impl/code_distance/code_distance-avx512.h create mode 100644 faiss/impl/code_distance/code_distance-generic.cpp delete mode 100644 faiss/impl/code_distance/code_distance-generic.h rename faiss/impl/code_distance/{code_distance-sve.h => code_distance-sve.cpp} (99%) rename faiss/utils/{ => simd_impl}/simdlib_avx2.h (100%) rename faiss/utils/{ => simd_impl}/simdlib_avx512.h (100%) rename faiss/utils/{ => simd_impl}/simdlib_emulated.h (100%) rename faiss/utils/{ => simd_impl}/simdlib_neon.h (100%) rename faiss/utils/{ => simd_impl}/simdlib_ppc64.h (100%) diff --git a/faiss/IndexIVFPQ.cpp b/faiss/IndexIVFPQ.cpp index 41e0192ff7..bcdbf913cf 100644 --- a/faiss/IndexIVFPQ.cpp +++ b/faiss/IndexIVFPQ.cpp @@ -817,8 +817,9 @@ struct RangeSearchResults { * The scanning functions call their favorite precompute_* * function to precompute the tables they need. *****************************************************/ -template +template struct IVFPQScannerT : QueryTables { + using PQDecoder = typename PQCodeDistance::PQDecoder; const uint8_t* list_codes; const IDType* list_ids; size_t list_size; @@ -894,7 +895,7 @@ struct IVFPQScannerT : QueryTables { float distance_1 = 0; float distance_2 = 0; float distance_3 = 0; - distance_four_codes( + PQCodeDistance::distance_four_codes( pq.M, pq.nbits, sim_table, @@ -917,7 +918,7 @@ struct IVFPQScannerT : QueryTables { if (counter >= 1) { float dis = dis0 + - distance_single_code( + PQCodeDistance::distance_single_code( pq.M, pq.nbits, sim_table, @@ -926,7 +927,7 @@ struct IVFPQScannerT : QueryTables { } if (counter >= 2) { float dis = dis0 + - distance_single_code( + PQCodeDistance::distance_single_code( pq.M, pq.nbits, sim_table, @@ -935,7 +936,7 @@ struct IVFPQScannerT : QueryTables { } if (counter >= 3) { float dis = dis0 + - distance_single_code( + PQCodeDistance::distance_single_code( pq.M, pq.nbits, sim_table, @@ -1101,7 +1102,7 @@ struct IVFPQScannerT : QueryTables { float distance_1 = dis0; float distance_2 = dis0; float distance_3 = dis0; - distance_four_codes( + PQCodeDistance::distance_four_codes( pq.M, pq.nbits, sim_table, @@ -1132,7 +1133,7 @@ struct IVFPQScannerT : QueryTables { n_hamming_pass++; float dis = dis0 + - distance_single_code( + PQCodeDistance::distance_single_code( pq.M, pq.nbits, sim_table, @@ -1152,7 +1153,7 @@ struct IVFPQScannerT : QueryTables { n_hamming_pass++; float dis = dis0 + - distance_single_code( + PQCodeDistance::distance_single_code( pq.M, pq.nbits, sim_table, @@ -1197,8 +1198,8 @@ struct IVFPQScannerT : QueryTables { * * use_sel: store or ignore the IDSelector */ -template -struct IVFPQScanner : IVFPQScannerT, +template +struct IVFPQScanner : IVFPQScannerT, InvertedListScanner { int precompute_mode; const IDSelector* sel; @@ -1208,7 +1209,7 @@ struct IVFPQScanner : IVFPQScannerT, bool store_pairs, int precompute_mode, const IDSelector* sel) - : IVFPQScannerT(ivfpq, nullptr), + : IVFPQScannerT(ivfpq, nullptr), precompute_mode(precompute_mode), sel(sel) { this->store_pairs = store_pairs; @@ -1228,7 +1229,7 @@ struct IVFPQScanner : IVFPQScannerT, float distance_to_code(const uint8_t* code) const override { assert(precompute_mode == 2); float dis = this->dis0 + - distance_single_code( + PQCodeDistance::distance_single_code( this->pq.M, this->pq.nbits, this->sim_table, code); return dis; } @@ -1292,7 +1293,9 @@ struct IVFPQScanner : IVFPQScannerT, } }; -template +/** follow 3 stages of template dispatching */ + +template InvertedListScanner* get_InvertedListScanner1( const IndexIVFPQ& index, bool store_pairs, @@ -1301,32 +1304,47 @@ InvertedListScanner* get_InvertedListScanner1( return new IVFPQScanner< METRIC_INNER_PRODUCT, CMin, - PQDecoder, + PQCodeDistance, use_sel>(index, store_pairs, 2, sel); } else if (index.metric_type == METRIC_L2) { return new IVFPQScanner< METRIC_L2, CMax, - PQDecoder, + PQCodeDistance, use_sel>(index, store_pairs, 2, sel); } return nullptr; } -template +template InvertedListScanner* get_InvertedListScanner2( const IndexIVFPQ& index, bool store_pairs, const IDSelector* sel) { if (index.pq.nbits == 8) { - return get_InvertedListScanner1( - index, store_pairs, sel); + return get_InvertedListScanner1< + PQCodeDistance, + use_sel>(index, store_pairs, sel); } else if (index.pq.nbits == 16) { - return get_InvertedListScanner1( - index, store_pairs, sel); + return get_InvertedListScanner1< + PQCodeDistance, + use_sel>(index, store_pairs, sel); + } else { + return get_InvertedListScanner1< + PQCodeDistance, + use_sel>(index, store_pairs, sel); + } +} + +template +InvertedListScanner* get_InvertedListScanner3( + const IndexIVFPQ& index, + bool store_pairs, + const IDSelector* sel) { + if (sel) { + return get_InvertedListScanner2(index, store_pairs, sel); } else { - return get_InvertedListScanner1( - index, store_pairs, sel); + return get_InvertedListScanner2(index, store_pairs, sel); } } @@ -1336,11 +1354,7 @@ InvertedListScanner* IndexIVFPQ::get_InvertedListScanner( bool store_pairs, const IDSelector* sel, const IVFSearchParameters*) const { - if (sel) { - return get_InvertedListScanner2(*this, store_pairs, sel); - } else { - return get_InvertedListScanner2(*this, store_pairs, sel); - } + DISPATCH_SIMDLevel(get_InvertedListScanner3, *this, store_pairs, sel); return nullptr; } diff --git a/faiss/IndexPQ.cpp b/faiss/IndexPQ.cpp index 4f7a2d0f62..2f1b220d03 100644 --- a/faiss/IndexPQ.cpp +++ b/faiss/IndexPQ.cpp @@ -72,7 +72,7 @@ void IndexPQ::train(idx_t n, const float* x) { namespace { -template +template struct PQDistanceComputer : FlatCodesDistanceComputer { size_t d; MetricType metric; @@ -85,7 +85,7 @@ struct PQDistanceComputer : FlatCodesDistanceComputer { float distance_to_code(const uint8_t* code) final { ndis++; - float dis = distance_single_code( + float dis = PQCodeDistance::distance_single_code( pq.M, pq.nbits, precomputed_table.data(), code); return dis; } @@ -94,8 +94,10 @@ struct PQDistanceComputer : FlatCodesDistanceComputer { FAISS_THROW_IF_NOT(sdc); const float* sdci = sdc; float accu = 0; - PQDecoder codei(codes + i * code_size, pq.nbits); - PQDecoder codej(codes + j * code_size, pq.nbits); + typename PQCodeDistance::PQDecoder codei( + codes + i * code_size, pq.nbits); + typename PQCodeDistance::PQDecoder codej( + codes + j * code_size, pq.nbits); for (int l = 0; l < pq.M; l++) { accu += sdci[codei.decode() + (codej.decode() << codei.nbits)]; @@ -131,16 +133,24 @@ struct PQDistanceComputer : FlatCodesDistanceComputer { } }; +template +FlatCodesDistanceComputer* get_FlatCodesDistanceComputer1( + const IndexPQ& index) { + int nbits = index.pq.nbits; + if (nbits == 8) { + return new PQDistanceComputer>(index); + } else if (nbits == 16) { + return new PQDistanceComputer>(index); + } else { + return new PQDistanceComputer>( + index); + } +} + } // namespace FlatCodesDistanceComputer* IndexPQ::get_FlatCodesDistanceComputer() const { - if (pq.nbits == 8) { - return new PQDistanceComputer(*this); - } else if (pq.nbits == 16) { - return new PQDistanceComputer(*this); - } else { - return new PQDistanceComputer(*this); - } + DISPATCH_SIMDLevel(get_FlatCodesDistanceComputer1, *this); } /***************************************** diff --git a/faiss/impl/code_distance/code_distance-avx2.cpp b/faiss/impl/code_distance/code_distance-avx2.cpp new file mode 100644 index 0000000000..e1e12daca2 --- /dev/null +++ b/faiss/impl/code_distance/code_distance-avx2.cpp @@ -0,0 +1,490 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * + * This source code is licensed under the MIT license found in the + * LICENSE file in the root directory of this source tree. + */ + +#ifdef COMPILE_SIMD_AVX2 + +#include + +#include +#include + +// https://gcc.gnu.org/bugzilla/show_bug.cgi?id=78782 +#if defined(__GNUC__) && __GNUC__ < 9 +#define _mm_loadu_si64(x) (_mm_loadl_epi64((__m128i_u*)x)) +#endif + +namespace { + +inline float horizontal_sum(const __m128 v) { + const __m128 v0 = _mm_shuffle_ps(v, v, _MM_SHUFFLE(0, 0, 3, 2)); + const __m128 v1 = _mm_add_ps(v, v0); + __m128 v2 = _mm_shuffle_ps(v1, v1, _MM_SHUFFLE(0, 0, 0, 1)); + const __m128 v3 = _mm_add_ps(v1, v2); + return _mm_cvtss_f32(v3); +} + +// Computes a horizontal sum over an __m256 register +inline float horizontal_sum(const __m256 v) { + const __m128 v0 = + _mm_add_ps(_mm256_castps256_ps128(v), _mm256_extractf128_ps(v, 1)); + return horizontal_sum(v0); +} + +// processes a single code for M=4, ksub=256, nbits=8 +float inline distance_single_code_avx2_pqdecoder8_m4( + // precomputed distances, layout (4, 256) + const float* sim_table, + const uint8_t* code) { + float result = 0; + + const float* tab = sim_table; + constexpr size_t ksub = 1 << 8; + + const __m128i vksub = _mm_set1_epi32(ksub); + __m128i offsets_0 = _mm_setr_epi32(0, 1, 2, 3); + offsets_0 = _mm_mullo_epi32(offsets_0, vksub); + + // accumulators of partial sums + __m128 partialSum; + + // load 4 uint8 values + const __m128i mm1 = _mm_cvtsi32_si128(*((const int32_t*)code)); + { + // convert uint8 values (low part of __m128i) to int32 + // values + const __m128i idx1 = _mm_cvtepu8_epi32(mm1); + + // add offsets + const __m128i indices_to_read_from = _mm_add_epi32(idx1, offsets_0); + + // gather 8 values, similar to 8 operations of tab[idx] + __m128 collected = + _mm_i32gather_ps(tab, indices_to_read_from, sizeof(float)); + + // collect partial sums + partialSum = collected; + } + + // horizontal sum for partialSum + result = horizontal_sum(partialSum); + return result; +} + +// processes a single code for M=8, ksub=256, nbits=8 +float inline distance_single_code_avx2_pqdecoder8_m8( + // precomputed distances, layout (8, 256) + const float* sim_table, + const uint8_t* code) { + float result = 0; + + const float* tab = sim_table; + constexpr size_t ksub = 1 << 8; + + const __m256i vksub = _mm256_set1_epi32(ksub); + __m256i offsets_0 = _mm256_setr_epi32(0, 1, 2, 3, 4, 5, 6, 7); + offsets_0 = _mm256_mullo_epi32(offsets_0, vksub); + + // accumulators of partial sums + __m256 partialSum; + + // load 8 uint8 values + const __m128i mm1 = _mm_loadu_si64((const __m128i_u*)code); + { + // convert uint8 values (low part of __m128i) to int32 + // values + const __m256i idx1 = _mm256_cvtepu8_epi32(mm1); + + // add offsets + const __m256i indices_to_read_from = _mm256_add_epi32(idx1, offsets_0); + + // gather 8 values, similar to 8 operations of tab[idx] + __m256 collected = + _mm256_i32gather_ps(tab, indices_to_read_from, sizeof(float)); + + // collect partial sums + partialSum = collected; + } + + // horizontal sum for partialSum + result = horizontal_sum(partialSum); + return result; +} + +// processes four codes for M=4, ksub=256, nbits=8 +inline void distance_four_codes_avx2_pqdecoder8_m4( + // precomputed distances, layout (4, 256) + const float* sim_table, + // codes + const uint8_t* __restrict code0, + const uint8_t* __restrict code1, + const uint8_t* __restrict code2, + const uint8_t* __restrict code3, + // computed distances + float& result0, + float& result1, + float& result2, + float& result3) { + constexpr intptr_t N = 4; + + const float* tab = sim_table; + constexpr size_t ksub = 1 << 8; + + // process 8 values + const __m128i vksub = _mm_set1_epi32(ksub); + __m128i offsets_0 = _mm_setr_epi32(0, 1, 2, 3); + offsets_0 = _mm_mullo_epi32(offsets_0, vksub); + + // accumulators of partial sums + __m128 partialSums[N]; + + // load 4 uint8 values + __m128i mm1[N]; + mm1[0] = _mm_cvtsi32_si128(*((const int32_t*)code0)); + mm1[1] = _mm_cvtsi32_si128(*((const int32_t*)code1)); + mm1[2] = _mm_cvtsi32_si128(*((const int32_t*)code2)); + mm1[3] = _mm_cvtsi32_si128(*((const int32_t*)code3)); + + for (intptr_t j = 0; j < N; j++) { + // convert uint8 values (low part of __m128i) to int32 + // values + const __m128i idx1 = _mm_cvtepu8_epi32(mm1[j]); + + // add offsets + const __m128i indices_to_read_from = _mm_add_epi32(idx1, offsets_0); + + // gather 4 values, similar to 4 operations of tab[idx] + __m128 collected = + _mm_i32gather_ps(tab, indices_to_read_from, sizeof(float)); + + // collect partial sums + partialSums[j] = collected; + } + + // horizontal sum for partialSum + result0 = horizontal_sum(partialSums[0]); + result1 = horizontal_sum(partialSums[1]); + result2 = horizontal_sum(partialSums[2]); + result3 = horizontal_sum(partialSums[3]); +} + +// processes four codes for M=8, ksub=256, nbits=8 +inline void distance_four_codes_avx2_pqdecoder8_m8( + // precomputed distances, layout (8, 256) + const float* sim_table, + // codes + const uint8_t* __restrict code0, + const uint8_t* __restrict code1, + const uint8_t* __restrict code2, + const uint8_t* __restrict code3, + // computed distances + float& result0, + float& result1, + float& result2, + float& result3) { + constexpr intptr_t N = 4; + + const float* tab = sim_table; + constexpr size_t ksub = 1 << 8; + + // process 8 values + const __m256i vksub = _mm256_set1_epi32(ksub); + __m256i offsets_0 = _mm256_setr_epi32(0, 1, 2, 3, 4, 5, 6, 7); + offsets_0 = _mm256_mullo_epi32(offsets_0, vksub); + + // accumulators of partial sums + __m256 partialSums[N]; + + // load 8 uint8 values + __m128i mm1[N]; + mm1[0] = _mm_loadu_si64((const __m128i_u*)code0); + mm1[1] = _mm_loadu_si64((const __m128i_u*)code1); + mm1[2] = _mm_loadu_si64((const __m128i_u*)code2); + mm1[3] = _mm_loadu_si64((const __m128i_u*)code3); + + for (intptr_t j = 0; j < N; j++) { + // convert uint8 values (low part of __m128i) to int32 + // values + const __m256i idx1 = _mm256_cvtepu8_epi32(mm1[j]); + + // add offsets + const __m256i indices_to_read_from = _mm256_add_epi32(idx1, offsets_0); + + // gather 8 values, similar to 8 operations of tab[idx] + __m256 collected = + _mm256_i32gather_ps(tab, indices_to_read_from, sizeof(float)); + + // collect partial sums + partialSums[j] = collected; + } + + // horizontal sum for partialSum + result0 = horizontal_sum(partialSums[0]); + result1 = horizontal_sum(partialSums[1]); + result2 = horizontal_sum(partialSums[2]); + result3 = horizontal_sum(partialSums[3]); +} + +} // namespace + +namespace faiss { + +template <> +struct PQCodeDistance { + float distance_single_code( + // number of subquantizers + const size_t M, + // number of bits per quantization index + const size_t nbits, + // precomputed distances, layout (M, ksub) + const float* sim_table, + const uint8_t* code) { + if (M == 4) { + return distance_single_code_avx2_pqdecoder8_m4(sim_table, code); + } + if (M == 8) { + return distance_single_code_avx2_pqdecoder8_m8(sim_table, code); + } + + float result = 0; + constexpr size_t ksub = 1 << 8; + + size_t m = 0; + const size_t pqM16 = M / 16; + + const float* tab = sim_table; + + if (pqM16 > 0) { + // process 16 values per loop + + const __m256i vksub = _mm256_set1_epi32(ksub); + __m256i offsets_0 = _mm256_setr_epi32(0, 1, 2, 3, 4, 5, 6, 7); + offsets_0 = _mm256_mullo_epi32(offsets_0, vksub); + + // accumulators of partial sums + __m256 partialSum = _mm256_setzero_ps(); + + // loop + for (m = 0; m < pqM16 * 16; m += 16) { + // load 16 uint8 values + const __m128i mm1 = + _mm_loadu_si128((const __m128i_u*)(code + m)); + { + // convert uint8 values (low part of __m128i) to int32 + // values + const __m256i idx1 = _mm256_cvtepu8_epi32(mm1); + + // add offsets + const __m256i indices_to_read_from = + _mm256_add_epi32(idx1, offsets_0); + + // gather 8 values, similar to 8 operations of tab[idx] + __m256 collected = _mm256_i32gather_ps( + tab, indices_to_read_from, sizeof(float)); + tab += ksub * 8; + + // collect partial sums + partialSum = _mm256_add_ps(partialSum, collected); + } + + // move high 8 uint8 to low ones + const __m128i mm2 = + _mm_unpackhi_epi64(mm1, _mm_setzero_si128()); + { + // convert uint8 values (low part of __m128i) to int32 + // values + const __m256i idx1 = _mm256_cvtepu8_epi32(mm2); + + // add offsets + const __m256i indices_to_read_from = + _mm256_add_epi32(idx1, offsets_0); + + // gather 8 values, similar to 8 operations of tab[idx] + __m256 collected = _mm256_i32gather_ps( + tab, indices_to_read_from, sizeof(float)); + tab += ksub * 8; + + // collect partial sums + partialSum = _mm256_add_ps(partialSum, collected); + } + } + + // horizontal sum for partialSum + result += horizontal_sum(partialSum); + } + + // + if (m < M) { + // process leftovers + PQDecoder8 decoder(code + m, nbits); + + for (; m < M; m++) { + result += tab[decoder.decode()]; + tab += ksub; + } + } + + return result; + } + + // Combines 4 operations of distance_single_code() + void distance_four_codes( + // number of subquantizers + const size_t M, + // number of bits per quantization index + const size_t nbits, + // precomputed distances, layout (M, ksub) + const float* sim_table, + // codes + const uint8_t* __restrict code0, + const uint8_t* __restrict code1, + const uint8_t* __restrict code2, + const uint8_t* __restrict code3, + // computed distances + float& result0, + float& result1, + float& result2, + float& result3) { + if (M == 4) { + distance_four_codes_avx2_pqdecoder8_m4( + sim_table, + code0, + code1, + code2, + code3, + result0, + result1, + result2, + result3); + return; + } + if (M == 8) { + distance_four_codes_avx2_pqdecoder8_m8( + sim_table, + code0, + code1, + code2, + code3, + result0, + result1, + result2, + result3); + return; + } + + result0 = 0; + result1 = 0; + result2 = 0; + result3 = 0; + constexpr size_t ksub = 1 << 8; + + size_t m = 0; + const size_t pqM16 = M / 16; + + constexpr intptr_t N = 4; + + const float* tab = sim_table; + + if (pqM16 > 0) { + // process 16 values per loop + const __m256i vksub = _mm256_set1_epi32(ksub); + __m256i offsets_0 = _mm256_setr_epi32(0, 1, 2, 3, 4, 5, 6, 7); + offsets_0 = _mm256_mullo_epi32(offsets_0, vksub); + + // accumulators of partial sums + __m256 partialSums[N]; + for (intptr_t j = 0; j < N; j++) { + partialSums[j] = _mm256_setzero_ps(); + } + + // loop + for (m = 0; m < pqM16 * 16; m += 16) { + // load 16 uint8 values + __m128i mm1[N]; + mm1[0] = _mm_loadu_si128((const __m128i_u*)(code0 + m)); + mm1[1] = _mm_loadu_si128((const __m128i_u*)(code1 + m)); + mm1[2] = _mm_loadu_si128((const __m128i_u*)(code2 + m)); + mm1[3] = _mm_loadu_si128((const __m128i_u*)(code3 + m)); + + // process first 8 codes + for (intptr_t j = 0; j < N; j++) { + // convert uint8 values (low part of __m128i) to int32 + // values + const __m256i idx1 = _mm256_cvtepu8_epi32(mm1[j]); + + // add offsets + const __m256i indices_to_read_from = + _mm256_add_epi32(idx1, offsets_0); + + // gather 8 values, similar to 8 operations of tab[idx] + __m256 collected = _mm256_i32gather_ps( + tab, indices_to_read_from, sizeof(float)); + + // collect partial sums + partialSums[j] = _mm256_add_ps(partialSums[j], collected); + } + tab += ksub * 8; + + // process next 8 codes + for (intptr_t j = 0; j < N; j++) { + // move high 8 uint8 to low ones + const __m128i mm2 = + _mm_unpackhi_epi64(mm1[j], _mm_setzero_si128()); + + // convert uint8 values (low part of __m128i) to int32 + // values + const __m256i idx1 = _mm256_cvtepu8_epi32(mm2); + + // add offsets + const __m256i indices_to_read_from = + _mm256_add_epi32(idx1, offsets_0); + + // gather 8 values, similar to 8 operations of tab[idx] + __m256 collected = _mm256_i32gather_ps( + tab, indices_to_read_from, sizeof(float)); + + // collect partial sums + partialSums[j] = _mm256_add_ps(partialSums[j], collected); + } + + tab += ksub * 8; + } + + // horizontal sum for partialSum + result0 += horizontal_sum(partialSums[0]); + result1 += horizontal_sum(partialSums[1]); + result2 += horizontal_sum(partialSums[2]); + result3 += horizontal_sum(partialSums[3]); + } + + // + if (m < M) { + // process leftovers + PQDecoder8 decoder0(code0 + m, nbits); + PQDecoder8 decoder1(code1 + m, nbits); + PQDecoder8 decoder2(code2 + m, nbits); + PQDecoder8 decoder3(code3 + m, nbits); + for (; m < M; m++) { + result0 += tab[decoder0.decode()]; + result1 += tab[decoder1.decode()]; + result2 += tab[decoder2.decode()]; + result3 += tab[decoder3.decode()]; + tab += ksub; + } + } + } +}; + +// explicit template instanciations +// template struct PQCodeDistance; + +// these two will automatically use the generic implementation +template struct PQCodeDistance; +template struct PQCodeDistance; + +} // namespace faiss + +#endif // COMPILE_SIMD_AVX2 diff --git a/faiss/impl/code_distance/code_distance-avx2.h b/faiss/impl/code_distance/code_distance-avx2.h deleted file mode 100644 index 53380b6e46..0000000000 --- a/faiss/impl/code_distance/code_distance-avx2.h +++ /dev/null @@ -1,534 +0,0 @@ -/* - * Copyright (c) Meta Platforms, Inc. and affiliates. - * - * This source code is licensed under the MIT license found in the - * LICENSE file in the root directory of this source tree. - */ - -#pragma once - -#ifdef __AVX2__ - -#include - -#include - -#include -#include - -// https://gcc.gnu.org/bugzilla/show_bug.cgi?id=78782 -#if defined(__GNUC__) && __GNUC__ < 9 -#define _mm_loadu_si64(x) (_mm_loadl_epi64((__m128i_u*)x)) -#endif - -namespace { - -inline float horizontal_sum(const __m128 v) { - const __m128 v0 = _mm_shuffle_ps(v, v, _MM_SHUFFLE(0, 0, 3, 2)); - const __m128 v1 = _mm_add_ps(v, v0); - __m128 v2 = _mm_shuffle_ps(v1, v1, _MM_SHUFFLE(0, 0, 0, 1)); - const __m128 v3 = _mm_add_ps(v1, v2); - return _mm_cvtss_f32(v3); -} - -// Computes a horizontal sum over an __m256 register -inline float horizontal_sum(const __m256 v) { - const __m128 v0 = - _mm_add_ps(_mm256_castps256_ps128(v), _mm256_extractf128_ps(v, 1)); - return horizontal_sum(v0); -} - -// processes a single code for M=4, ksub=256, nbits=8 -float inline distance_single_code_avx2_pqdecoder8_m4( - // precomputed distances, layout (4, 256) - const float* sim_table, - const uint8_t* code) { - float result = 0; - - const float* tab = sim_table; - constexpr size_t ksub = 1 << 8; - - const __m128i vksub = _mm_set1_epi32(ksub); - __m128i offsets_0 = _mm_setr_epi32(0, 1, 2, 3); - offsets_0 = _mm_mullo_epi32(offsets_0, vksub); - - // accumulators of partial sums - __m128 partialSum; - - // load 4 uint8 values - const __m128i mm1 = _mm_cvtsi32_si128(*((const int32_t*)code)); - { - // convert uint8 values (low part of __m128i) to int32 - // values - const __m128i idx1 = _mm_cvtepu8_epi32(mm1); - - // add offsets - const __m128i indices_to_read_from = _mm_add_epi32(idx1, offsets_0); - - // gather 8 values, similar to 8 operations of tab[idx] - __m128 collected = - _mm_i32gather_ps(tab, indices_to_read_from, sizeof(float)); - - // collect partial sums - partialSum = collected; - } - - // horizontal sum for partialSum - result = horizontal_sum(partialSum); - return result; -} - -// processes a single code for M=8, ksub=256, nbits=8 -float inline distance_single_code_avx2_pqdecoder8_m8( - // precomputed distances, layout (8, 256) - const float* sim_table, - const uint8_t* code) { - float result = 0; - - const float* tab = sim_table; - constexpr size_t ksub = 1 << 8; - - const __m256i vksub = _mm256_set1_epi32(ksub); - __m256i offsets_0 = _mm256_setr_epi32(0, 1, 2, 3, 4, 5, 6, 7); - offsets_0 = _mm256_mullo_epi32(offsets_0, vksub); - - // accumulators of partial sums - __m256 partialSum; - - // load 8 uint8 values - const __m128i mm1 = _mm_loadu_si64((const __m128i_u*)code); - { - // convert uint8 values (low part of __m128i) to int32 - // values - const __m256i idx1 = _mm256_cvtepu8_epi32(mm1); - - // add offsets - const __m256i indices_to_read_from = _mm256_add_epi32(idx1, offsets_0); - - // gather 8 values, similar to 8 operations of tab[idx] - __m256 collected = - _mm256_i32gather_ps(tab, indices_to_read_from, sizeof(float)); - - // collect partial sums - partialSum = collected; - } - - // horizontal sum for partialSum - result = horizontal_sum(partialSum); - return result; -} - -// processes four codes for M=4, ksub=256, nbits=8 -inline void distance_four_codes_avx2_pqdecoder8_m4( - // precomputed distances, layout (4, 256) - const float* sim_table, - // codes - const uint8_t* __restrict code0, - const uint8_t* __restrict code1, - const uint8_t* __restrict code2, - const uint8_t* __restrict code3, - // computed distances - float& result0, - float& result1, - float& result2, - float& result3) { - constexpr intptr_t N = 4; - - const float* tab = sim_table; - constexpr size_t ksub = 1 << 8; - - // process 8 values - const __m128i vksub = _mm_set1_epi32(ksub); - __m128i offsets_0 = _mm_setr_epi32(0, 1, 2, 3); - offsets_0 = _mm_mullo_epi32(offsets_0, vksub); - - // accumulators of partial sums - __m128 partialSums[N]; - - // load 4 uint8 values - __m128i mm1[N]; - mm1[0] = _mm_cvtsi32_si128(*((const int32_t*)code0)); - mm1[1] = _mm_cvtsi32_si128(*((const int32_t*)code1)); - mm1[2] = _mm_cvtsi32_si128(*((const int32_t*)code2)); - mm1[3] = _mm_cvtsi32_si128(*((const int32_t*)code3)); - - for (intptr_t j = 0; j < N; j++) { - // convert uint8 values (low part of __m128i) to int32 - // values - const __m128i idx1 = _mm_cvtepu8_epi32(mm1[j]); - - // add offsets - const __m128i indices_to_read_from = _mm_add_epi32(idx1, offsets_0); - - // gather 4 values, similar to 4 operations of tab[idx] - __m128 collected = - _mm_i32gather_ps(tab, indices_to_read_from, sizeof(float)); - - // collect partial sums - partialSums[j] = collected; - } - - // horizontal sum for partialSum - result0 = horizontal_sum(partialSums[0]); - result1 = horizontal_sum(partialSums[1]); - result2 = horizontal_sum(partialSums[2]); - result3 = horizontal_sum(partialSums[3]); -} - -// processes four codes for M=8, ksub=256, nbits=8 -inline void distance_four_codes_avx2_pqdecoder8_m8( - // precomputed distances, layout (8, 256) - const float* sim_table, - // codes - const uint8_t* __restrict code0, - const uint8_t* __restrict code1, - const uint8_t* __restrict code2, - const uint8_t* __restrict code3, - // computed distances - float& result0, - float& result1, - float& result2, - float& result3) { - constexpr intptr_t N = 4; - - const float* tab = sim_table; - constexpr size_t ksub = 1 << 8; - - // process 8 values - const __m256i vksub = _mm256_set1_epi32(ksub); - __m256i offsets_0 = _mm256_setr_epi32(0, 1, 2, 3, 4, 5, 6, 7); - offsets_0 = _mm256_mullo_epi32(offsets_0, vksub); - - // accumulators of partial sums - __m256 partialSums[N]; - - // load 8 uint8 values - __m128i mm1[N]; - mm1[0] = _mm_loadu_si64((const __m128i_u*)code0); - mm1[1] = _mm_loadu_si64((const __m128i_u*)code1); - mm1[2] = _mm_loadu_si64((const __m128i_u*)code2); - mm1[3] = _mm_loadu_si64((const __m128i_u*)code3); - - for (intptr_t j = 0; j < N; j++) { - // convert uint8 values (low part of __m128i) to int32 - // values - const __m256i idx1 = _mm256_cvtepu8_epi32(mm1[j]); - - // add offsets - const __m256i indices_to_read_from = _mm256_add_epi32(idx1, offsets_0); - - // gather 8 values, similar to 8 operations of tab[idx] - __m256 collected = - _mm256_i32gather_ps(tab, indices_to_read_from, sizeof(float)); - - // collect partial sums - partialSums[j] = collected; - } - - // horizontal sum for partialSum - result0 = horizontal_sum(partialSums[0]); - result1 = horizontal_sum(partialSums[1]); - result2 = horizontal_sum(partialSums[2]); - result3 = horizontal_sum(partialSums[3]); -} - -} // namespace - -namespace faiss { - -template -typename std::enable_if::value, float>:: - type inline distance_single_code_avx2( - // number of subquantizers - const size_t M, - // number of bits per quantization index - const size_t nbits, - // precomputed distances, layout (M, ksub) - const float* sim_table, - const uint8_t* code) { - // default implementation - return distance_single_code_generic(M, nbits, sim_table, code); -} - -template -typename std::enable_if::value, float>:: - type inline distance_single_code_avx2( - // number of subquantizers - const size_t M, - // number of bits per quantization index - const size_t nbits, - // precomputed distances, layout (M, ksub) - const float* sim_table, - const uint8_t* code) { - if (M == 4) { - return distance_single_code_avx2_pqdecoder8_m4(sim_table, code); - } - if (M == 8) { - return distance_single_code_avx2_pqdecoder8_m8(sim_table, code); - } - - float result = 0; - constexpr size_t ksub = 1 << 8; - - size_t m = 0; - const size_t pqM16 = M / 16; - - const float* tab = sim_table; - - if (pqM16 > 0) { - // process 16 values per loop - - const __m256i vksub = _mm256_set1_epi32(ksub); - __m256i offsets_0 = _mm256_setr_epi32(0, 1, 2, 3, 4, 5, 6, 7); - offsets_0 = _mm256_mullo_epi32(offsets_0, vksub); - - // accumulators of partial sums - __m256 partialSum = _mm256_setzero_ps(); - - // loop - for (m = 0; m < pqM16 * 16; m += 16) { - // load 16 uint8 values - const __m128i mm1 = _mm_loadu_si128((const __m128i_u*)(code + m)); - { - // convert uint8 values (low part of __m128i) to int32 - // values - const __m256i idx1 = _mm256_cvtepu8_epi32(mm1); - - // add offsets - const __m256i indices_to_read_from = - _mm256_add_epi32(idx1, offsets_0); - - // gather 8 values, similar to 8 operations of tab[idx] - __m256 collected = _mm256_i32gather_ps( - tab, indices_to_read_from, sizeof(float)); - tab += ksub * 8; - - // collect partial sums - partialSum = _mm256_add_ps(partialSum, collected); - } - - // move high 8 uint8 to low ones - const __m128i mm2 = _mm_unpackhi_epi64(mm1, _mm_setzero_si128()); - { - // convert uint8 values (low part of __m128i) to int32 - // values - const __m256i idx1 = _mm256_cvtepu8_epi32(mm2); - - // add offsets - const __m256i indices_to_read_from = - _mm256_add_epi32(idx1, offsets_0); - - // gather 8 values, similar to 8 operations of tab[idx] - __m256 collected = _mm256_i32gather_ps( - tab, indices_to_read_from, sizeof(float)); - tab += ksub * 8; - - // collect partial sums - partialSum = _mm256_add_ps(partialSum, collected); - } - } - - // horizontal sum for partialSum - result += horizontal_sum(partialSum); - } - - // - if (m < M) { - // process leftovers - PQDecoder8 decoder(code + m, nbits); - - for (; m < M; m++) { - result += tab[decoder.decode()]; - tab += ksub; - } - } - - return result; -} - -template -typename std::enable_if::value, void>:: - type - distance_four_codes_avx2( - // number of subquantizers - const size_t M, - // number of bits per quantization index - const size_t nbits, - // precomputed distances, layout (M, ksub) - const float* sim_table, - // codes - const uint8_t* __restrict code0, - const uint8_t* __restrict code1, - const uint8_t* __restrict code2, - const uint8_t* __restrict code3, - // computed distances - float& result0, - float& result1, - float& result2, - float& result3) { - distance_four_codes_generic( - M, - nbits, - sim_table, - code0, - code1, - code2, - code3, - result0, - result1, - result2, - result3); -} - -// Combines 4 operations of distance_single_code() -template -typename std::enable_if::value, void>::type -distance_four_codes_avx2( - // number of subquantizers - const size_t M, - // number of bits per quantization index - const size_t nbits, - // precomputed distances, layout (M, ksub) - const float* sim_table, - // codes - const uint8_t* __restrict code0, - const uint8_t* __restrict code1, - const uint8_t* __restrict code2, - const uint8_t* __restrict code3, - // computed distances - float& result0, - float& result1, - float& result2, - float& result3) { - if (M == 4) { - distance_four_codes_avx2_pqdecoder8_m4( - sim_table, - code0, - code1, - code2, - code3, - result0, - result1, - result2, - result3); - return; - } - if (M == 8) { - distance_four_codes_avx2_pqdecoder8_m8( - sim_table, - code0, - code1, - code2, - code3, - result0, - result1, - result2, - result3); - return; - } - - result0 = 0; - result1 = 0; - result2 = 0; - result3 = 0; - constexpr size_t ksub = 1 << 8; - - size_t m = 0; - const size_t pqM16 = M / 16; - - constexpr intptr_t N = 4; - - const float* tab = sim_table; - - if (pqM16 > 0) { - // process 16 values per loop - const __m256i vksub = _mm256_set1_epi32(ksub); - __m256i offsets_0 = _mm256_setr_epi32(0, 1, 2, 3, 4, 5, 6, 7); - offsets_0 = _mm256_mullo_epi32(offsets_0, vksub); - - // accumulators of partial sums - __m256 partialSums[N]; - for (intptr_t j = 0; j < N; j++) { - partialSums[j] = _mm256_setzero_ps(); - } - - // loop - for (m = 0; m < pqM16 * 16; m += 16) { - // load 16 uint8 values - __m128i mm1[N]; - mm1[0] = _mm_loadu_si128((const __m128i_u*)(code0 + m)); - mm1[1] = _mm_loadu_si128((const __m128i_u*)(code1 + m)); - mm1[2] = _mm_loadu_si128((const __m128i_u*)(code2 + m)); - mm1[3] = _mm_loadu_si128((const __m128i_u*)(code3 + m)); - - // process first 8 codes - for (intptr_t j = 0; j < N; j++) { - // convert uint8 values (low part of __m128i) to int32 - // values - const __m256i idx1 = _mm256_cvtepu8_epi32(mm1[j]); - - // add offsets - const __m256i indices_to_read_from = - _mm256_add_epi32(idx1, offsets_0); - - // gather 8 values, similar to 8 operations of tab[idx] - __m256 collected = _mm256_i32gather_ps( - tab, indices_to_read_from, sizeof(float)); - - // collect partial sums - partialSums[j] = _mm256_add_ps(partialSums[j], collected); - } - tab += ksub * 8; - - // process next 8 codes - for (intptr_t j = 0; j < N; j++) { - // move high 8 uint8 to low ones - const __m128i mm2 = - _mm_unpackhi_epi64(mm1[j], _mm_setzero_si128()); - - // convert uint8 values (low part of __m128i) to int32 - // values - const __m256i idx1 = _mm256_cvtepu8_epi32(mm2); - - // add offsets - const __m256i indices_to_read_from = - _mm256_add_epi32(idx1, offsets_0); - - // gather 8 values, similar to 8 operations of tab[idx] - __m256 collected = _mm256_i32gather_ps( - tab, indices_to_read_from, sizeof(float)); - - // collect partial sums - partialSums[j] = _mm256_add_ps(partialSums[j], collected); - } - - tab += ksub * 8; - } - - // horizontal sum for partialSum - result0 += horizontal_sum(partialSums[0]); - result1 += horizontal_sum(partialSums[1]); - result2 += horizontal_sum(partialSums[2]); - result3 += horizontal_sum(partialSums[3]); - } - - // - if (m < M) { - // process leftovers - PQDecoder8 decoder0(code0 + m, nbits); - PQDecoder8 decoder1(code1 + m, nbits); - PQDecoder8 decoder2(code2 + m, nbits); - PQDecoder8 decoder3(code3 + m, nbits); - for (; m < M; m++) { - result0 += tab[decoder0.decode()]; - result1 += tab[decoder1.decode()]; - result2 += tab[decoder2.decode()]; - result3 += tab[decoder3.decode()]; - tab += ksub; - } - } -} - -} // namespace faiss - -#endif diff --git a/faiss/impl/code_distance/code_distance-avx512.cpp b/faiss/impl/code_distance/code_distance-avx512.cpp new file mode 100644 index 0000000000..aa16b1c4b8 --- /dev/null +++ b/faiss/impl/code_distance/code_distance-avx512.cpp @@ -0,0 +1,203 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * + * This source code is licensed under the MIT license found in the + * LICENSE file in the root directory of this source tree. + */ + +#ifdef COMPILE_SIMD_AVX512 + +#include + +#include + +#include +#include + +// According to experiments, the AVX-512 version may be SLOWER than +// the AVX2 version, which is somewhat unexpected. +// This version is not used for now, but it may be used later. +// +// TODO: test for AMD CPUs. + +namespace faiss { + +template <> +struct PQCodeDistance { + float distance_single_code( + // number of subquantizers + const size_t M, + // number of bits per quantization index + const size_t nbits, + // precomputed distances, layout (M, ksub) + const float* sim_table, + const uint8_t* code0) { + float result0 = 0; + constexpr size_t ksub = 1 << 8; + + size_t m = 0; + const size_t pqM16 = M / 16; + + constexpr intptr_t N = 1; + + const float* tab = sim_table; + + if (pqM16 > 0) { + // process 16 values per loop + const __m512i vksub = _mm512_set1_epi32(ksub); + __m512i offsets_0 = _mm512_setr_epi32( + 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15); + offsets_0 = _mm512_mullo_epi32(offsets_0, vksub); + + // accumulators of partial sums + __m512 partialSums[N]; + for (intptr_t j = 0; j < N; j++) { + partialSums[j] = _mm512_setzero_ps(); + } + + // loop + for (m = 0; m < pqM16 * 16; m += 16) { + // load 16 uint8 values + __m128i mm1[N]; + mm1[0] = _mm_loadu_si128((const __m128i_u*)(code0 + m)); + + // process first 8 codes + for (intptr_t j = 0; j < N; j++) { + const __m512i idx1 = _mm512_cvtepu8_epi32(mm1[j]); + + // add offsets + const __m512i indices_to_read_from = + _mm512_add_epi32(idx1, offsets_0); + + // gather 16 values, similar to 16 operations of tab[idx] + __m512 collected = _mm512_i32gather_ps( + indices_to_read_from, tab, sizeof(float)); + + // collect partial sums + partialSums[j] = _mm512_add_ps(partialSums[j], collected); + } + tab += ksub * 16; + } + + // horizontal sum for partialSum + result0 += _mm512_reduce_add_ps(partialSums[0]); + } + + // + if (m < M) { + // process leftovers + PQDecoder8 decoder0(code0 + m, nbits); + for (; m < M; m++) { + result0 += tab[decoder0.decode()]; + tab += ksub; + } + } + + return result0; + } + + void distance_four_codes_avx512( + // number of subquantizers + const size_t M, + // number of bits per quantization index + const size_t nbits, + // precomputed distances, layout (M, ksub) + const float* sim_table, + // codes + const uint8_t* __restrict code0, + const uint8_t* __restrict code1, + const uint8_t* __restrict code2, + const uint8_t* __restrict code3, + // computed distances + float& result0, + float& result1, + float& result2, + float& result3) { + result0 = 0; + result1 = 0; + result2 = 0; + result3 = 0; + constexpr size_t ksub = 1 << 8; + + size_t m = 0; + const size_t pqM16 = M / 16; + + constexpr intptr_t N = 4; + + const float* tab = sim_table; + + if (pqM16 > 0) { + // process 16 values per loop + const __m512i vksub = _mm512_set1_epi32(ksub); + __m512i offsets_0 = _mm512_setr_epi32( + 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15); + offsets_0 = _mm512_mullo_epi32(offsets_0, vksub); + + // accumulators of partial sums + __m512 partialSums[N]; + for (intptr_t j = 0; j < N; j++) { + partialSums[j] = _mm512_setzero_ps(); + } + + // loop + for (m = 0; m < pqM16 * 16; m += 16) { + // load 16 uint8 values + __m128i mm1[N]; + mm1[0] = _mm_loadu_si128((const __m128i_u*)(code0 + m)); + mm1[1] = _mm_loadu_si128((const __m128i_u*)(code1 + m)); + mm1[2] = _mm_loadu_si128((const __m128i_u*)(code2 + m)); + mm1[3] = _mm_loadu_si128((const __m128i_u*)(code3 + m)); + + // process first 8 codes + for (intptr_t j = 0; j < N; j++) { + const __m512i idx1 = _mm512_cvtepu8_epi32(mm1[j]); + + // add offsets + const __m512i indices_to_read_from = + _mm512_add_epi32(idx1, offsets_0); + + // gather 16 values, similar to 16 operations of tab[idx] + __m512 collected = _mm512_i32gather_ps( + indices_to_read_from, tab, sizeof(float)); + + // collect partial sums + partialSums[j] = _mm512_add_ps(partialSums[j], collected); + } + tab += ksub * 16; + } + + // horizontal sum for partialSum + result0 += _mm512_reduce_add_ps(partialSums[0]); + result1 += _mm512_reduce_add_ps(partialSums[1]); + result2 += _mm512_reduce_add_ps(partialSums[2]); + result3 += _mm512_reduce_add_ps(partialSums[3]); + } + + // + if (m < M) { + // process leftovers + PQDecoder8 decoder0(code0 + m, nbits); + PQDecoder8 decoder1(code1 + m, nbits); + PQDecoder8 decoder2(code2 + m, nbits); + PQDecoder8 decoder3(code3 + m, nbits); + for (; m < M; m++) { + result0 += tab[decoder0.decode()]; + result1 += tab[decoder1.decode()]; + result2 += tab[decoder2.decode()]; + result3 += tab[decoder3.decode()]; + tab += ksub; + } + } + } +}; + +// explicit template instanciations +// template struct PQCodeDistance; + +// these two will automatically use the generic implementation +template struct PQCodeDistance; +template struct PQCodeDistance; + +} // namespace faiss + +#endif // COMPILE_SIMD_AVX512F diff --git a/faiss/impl/code_distance/code_distance-avx512.h b/faiss/impl/code_distance/code_distance-avx512.h deleted file mode 100644 index d05c41c19c..0000000000 --- a/faiss/impl/code_distance/code_distance-avx512.h +++ /dev/null @@ -1,248 +0,0 @@ -/* - * Copyright (c) Meta Platforms, Inc. and affiliates. - * - * This source code is licensed under the MIT license found in the - * LICENSE file in the root directory of this source tree. - */ - -#pragma once - -#ifdef __AVX512F__ - -#include - -#include - -#include -#include - -namespace faiss { - -// According to experiments, the AVX-512 version may be SLOWER than -// the AVX2 version, which is somewhat unexpected. -// This version is not used for now, but it may be used later. -// -// TODO: test for AMD CPUs. - -template -typename std::enable_if::value, float>:: - type inline distance_single_code_avx512( - // number of subquantizers - const size_t M, - // number of bits per quantization index - const size_t nbits, - // precomputed distances, layout (M, ksub) - const float* sim_table, - const uint8_t* code) { - // default implementation - return distance_single_code_generic(M, nbits, sim_table, code); -} - -template -typename std::enable_if::value, float>:: - type inline distance_single_code_avx512( - // number of subquantizers - const size_t M, - // number of bits per quantization index - const size_t nbits, - // precomputed distances, layout (M, ksub) - const float* sim_table, - const uint8_t* code0) { - float result0 = 0; - constexpr size_t ksub = 1 << 8; - - size_t m = 0; - const size_t pqM16 = M / 16; - - constexpr intptr_t N = 1; - - const float* tab = sim_table; - - if (pqM16 > 0) { - // process 16 values per loop - const __m512i vksub = _mm512_set1_epi32(ksub); - __m512i offsets_0 = _mm512_setr_epi32( - 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15); - offsets_0 = _mm512_mullo_epi32(offsets_0, vksub); - - // accumulators of partial sums - __m512 partialSums[N]; - for (intptr_t j = 0; j < N; j++) { - partialSums[j] = _mm512_setzero_ps(); - } - - // loop - for (m = 0; m < pqM16 * 16; m += 16) { - // load 16 uint8 values - __m128i mm1[N]; - mm1[0] = _mm_loadu_si128((const __m128i_u*)(code0 + m)); - - // process first 8 codes - for (intptr_t j = 0; j < N; j++) { - const __m512i idx1 = _mm512_cvtepu8_epi32(mm1[j]); - - // add offsets - const __m512i indices_to_read_from = - _mm512_add_epi32(idx1, offsets_0); - - // gather 16 values, similar to 16 operations of tab[idx] - __m512 collected = _mm512_i32gather_ps( - indices_to_read_from, tab, sizeof(float)); - - // collect partial sums - partialSums[j] = _mm512_add_ps(partialSums[j], collected); - } - tab += ksub * 16; - } - - // horizontal sum for partialSum - result0 += _mm512_reduce_add_ps(partialSums[0]); - } - - // - if (m < M) { - // process leftovers - PQDecoder8 decoder0(code0 + m, nbits); - for (; m < M; m++) { - result0 += tab[decoder0.decode()]; - tab += ksub; - } - } - - return result0; -} - -template -typename std::enable_if::value, void>:: - type - distance_four_codes_avx512( - // number of subquantizers - const size_t M, - // number of bits per quantization index - const size_t nbits, - // precomputed distances, layout (M, ksub) - const float* sim_table, - // codes - const uint8_t* __restrict code0, - const uint8_t* __restrict code1, - const uint8_t* __restrict code2, - const uint8_t* __restrict code3, - // computed distances - float& result0, - float& result1, - float& result2, - float& result3) { - distance_four_codes_generic( - M, - nbits, - sim_table, - code0, - code1, - code2, - code3, - result0, - result1, - result2, - result3); -} - -// Combines 4 operations of distance_single_code() -template -typename std::enable_if::value, void>::type -distance_four_codes_avx512( - // number of subquantizers - const size_t M, - // number of bits per quantization index - const size_t nbits, - // precomputed distances, layout (M, ksub) - const float* sim_table, - // codes - const uint8_t* __restrict code0, - const uint8_t* __restrict code1, - const uint8_t* __restrict code2, - const uint8_t* __restrict code3, - // computed distances - float& result0, - float& result1, - float& result2, - float& result3) { - result0 = 0; - result1 = 0; - result2 = 0; - result3 = 0; - constexpr size_t ksub = 1 << 8; - - size_t m = 0; - const size_t pqM16 = M / 16; - - constexpr intptr_t N = 4; - - const float* tab = sim_table; - - if (pqM16 > 0) { - // process 16 values per loop - const __m512i vksub = _mm512_set1_epi32(ksub); - __m512i offsets_0 = _mm512_setr_epi32( - 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15); - offsets_0 = _mm512_mullo_epi32(offsets_0, vksub); - - // accumulators of partial sums - __m512 partialSums[N]; - for (intptr_t j = 0; j < N; j++) { - partialSums[j] = _mm512_setzero_ps(); - } - - // loop - for (m = 0; m < pqM16 * 16; m += 16) { - // load 16 uint8 values - __m128i mm1[N]; - mm1[0] = _mm_loadu_si128((const __m128i_u*)(code0 + m)); - mm1[1] = _mm_loadu_si128((const __m128i_u*)(code1 + m)); - mm1[2] = _mm_loadu_si128((const __m128i_u*)(code2 + m)); - mm1[3] = _mm_loadu_si128((const __m128i_u*)(code3 + m)); - - // process first 8 codes - for (intptr_t j = 0; j < N; j++) { - const __m512i idx1 = _mm512_cvtepu8_epi32(mm1[j]); - - // add offsets - const __m512i indices_to_read_from = - _mm512_add_epi32(idx1, offsets_0); - - // gather 16 values, similar to 16 operations of tab[idx] - __m512 collected = _mm512_i32gather_ps( - indices_to_read_from, tab, sizeof(float)); - - // collect partial sums - partialSums[j] = _mm512_add_ps(partialSums[j], collected); - } - tab += ksub * 16; - } - - // horizontal sum for partialSum - result0 += _mm512_reduce_add_ps(partialSums[0]); - result1 += _mm512_reduce_add_ps(partialSums[1]); - result2 += _mm512_reduce_add_ps(partialSums[2]); - result3 += _mm512_reduce_add_ps(partialSums[3]); - } - - // - if (m < M) { - // process leftovers - PQDecoder8 decoder0(code0 + m, nbits); - PQDecoder8 decoder1(code1 + m, nbits); - PQDecoder8 decoder2(code2 + m, nbits); - PQDecoder8 decoder3(code3 + m, nbits); - for (; m < M; m++) { - result0 += tab[decoder0.decode()]; - result1 += tab[decoder1.decode()]; - result2 += tab[decoder2.decode()]; - result3 += tab[decoder3.decode()]; - tab += ksub; - } - } -} - -} // namespace faiss - -#endif diff --git a/faiss/impl/code_distance/code_distance-generic.cpp b/faiss/impl/code_distance/code_distance-generic.cpp new file mode 100644 index 0000000000..ac9561ed93 --- /dev/null +++ b/faiss/impl/code_distance/code_distance-generic.cpp @@ -0,0 +1,20 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * + * This source code is licensed under the MIT license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include + +#include +#include + +namespace faiss { + +// explicit template instanciations +template struct PQCodeDistance; +template struct PQCodeDistance; +template struct PQCodeDistance; + +} // namespace faiss diff --git a/faiss/impl/code_distance/code_distance-generic.h b/faiss/impl/code_distance/code_distance-generic.h deleted file mode 100644 index c02551c415..0000000000 --- a/faiss/impl/code_distance/code_distance-generic.h +++ /dev/null @@ -1,81 +0,0 @@ -/* - * Copyright (c) Meta Platforms, Inc. and affiliates. - * - * This source code is licensed under the MIT license found in the - * LICENSE file in the root directory of this source tree. - */ - -#pragma once - -#include -#include - -namespace faiss { - -/// Returns the distance to a single code. -template -inline float distance_single_code_generic( - // number of subquantizers - const size_t M, - // number of bits per quantization index - const size_t nbits, - // precomputed distances, layout (M, ksub) - const float* sim_table, - // the code - const uint8_t* code) { - PQDecoderT decoder(code, nbits); - const size_t ksub = 1 << nbits; - - const float* tab = sim_table; - float result = 0; - - for (size_t m = 0; m < M; m++) { - result += tab[decoder.decode()]; - tab += ksub; - } - - return result; -} - -/// Combines 4 operations of distance_single_code() -/// General-purpose version. -template -inline void distance_four_codes_generic( - // number of subquantizers - const size_t M, - // number of bits per quantization index - const size_t nbits, - // precomputed distances, layout (M, ksub) - const float* sim_table, - // codes - const uint8_t* __restrict code0, - const uint8_t* __restrict code1, - const uint8_t* __restrict code2, - const uint8_t* __restrict code3, - // computed distances - float& result0, - float& result1, - float& result2, - float& result3) { - PQDecoderT decoder0(code0, nbits); - PQDecoderT decoder1(code1, nbits); - PQDecoderT decoder2(code2, nbits); - PQDecoderT decoder3(code3, nbits); - const size_t ksub = 1 << nbits; - - const float* tab = sim_table; - result0 = 0; - result1 = 0; - result2 = 0; - result3 = 0; - - for (size_t m = 0; m < M; m++) { - result0 += tab[decoder0.decode()]; - result1 += tab[decoder1.decode()]; - result2 += tab[decoder2.decode()]; - result3 += tab[decoder3.decode()]; - tab += ksub; - } -} - -} // namespace faiss diff --git a/faiss/impl/code_distance/code_distance-sve.h b/faiss/impl/code_distance/code_distance-sve.cpp similarity index 99% rename from faiss/impl/code_distance/code_distance-sve.h rename to faiss/impl/code_distance/code_distance-sve.cpp index 82f7746be6..9a941798ff 100644 --- a/faiss/impl/code_distance/code_distance-sve.h +++ b/faiss/impl/code_distance/code_distance-sve.cpp @@ -5,8 +5,6 @@ * LICENSE file in the root directory of this source tree. */ -#pragma once - #ifdef __ARM_FEATURE_SVE #include @@ -15,7 +13,7 @@ #include #include -#include +#include namespace faiss { diff --git a/faiss/impl/code_distance/code_distance.h b/faiss/impl/code_distance/code_distance.h index 8f29abda97..585890cb40 100644 --- a/faiss/impl/code_distance/code_distance.h +++ b/faiss/impl/code_distance/code_distance.h @@ -9,6 +9,10 @@ #include +#include + +#include + // This directory contains functions to compute a distance // from a given PQ code to a query vector, given that the // distances to a query vector for pq.M codebooks are precomputed. @@ -24,163 +28,76 @@ // why the names of the functions for custom implementations // have this _generic or _avx2 suffix. -#ifdef __AVX2__ - -#include - namespace faiss { -template -inline float distance_single_code( - // number of subquantizers - const size_t M, - // number of bits per quantization index - const size_t nbits, - // precomputed distances, layout (M, ksub) - const float* sim_table, - // the code - const uint8_t* code) { - return distance_single_code_avx2(M, nbits, sim_table, code); -} - -template -inline void distance_four_codes( - // number of subquantizers - const size_t M, - // number of bits per quantization index - const size_t nbits, - // precomputed distances, layout (M, ksub) - const float* sim_table, - // codes - const uint8_t* __restrict code0, - const uint8_t* __restrict code1, - const uint8_t* __restrict code2, - const uint8_t* __restrict code3, - // computed distances - float& result0, - float& result1, - float& result2, - float& result3) { - distance_four_codes_avx2( - M, - nbits, - sim_table, - code0, - code1, - code2, - code3, - result0, - result1, - result2, - result3); -} +// definiton and default implementation +template +struct PQCodeDistance { + using PQDecoder = PQDecoderT; + + /// Returns the distance to a single code. + static float distance_single_code( + // number of subquantizers + const size_t M, + // number of bits per quantization index + const size_t nbits, + // precomputed distances, layout (M, ksub) + const float* sim_table, + // the code + const uint8_t* code) { + PQDecoderT decoder(code, nbits); + const size_t ksub = 1 << nbits; + + const float* tab = sim_table; + float result = 0; + + for (size_t m = 0; m < M; m++) { + result += tab[decoder.decode()]; + tab += ksub; + } + + return result; + } + + /// Combines 4 operations of distance_single_code() + /// General-purpose version. + static void distance_four_codes( + // number of subquantizers + const size_t M, + // number of bits per quantization index + const size_t nbits, + // precomputed distances, layout (M, ksub) + const float* sim_table, + // codes + const uint8_t* __restrict code0, + const uint8_t* __restrict code1, + const uint8_t* __restrict code2, + const uint8_t* __restrict code3, + // computed distances + float& result0, + float& result1, + float& result2, + float& result3) { + PQDecoderT decoder0(code0, nbits); + PQDecoderT decoder1(code1, nbits); + PQDecoderT decoder2(code2, nbits); + PQDecoderT decoder3(code3, nbits); + const size_t ksub = 1 << nbits; + + const float* tab = sim_table; + result0 = 0; + result1 = 0; + result2 = 0; + result3 = 0; + + for (size_t m = 0; m < M; m++) { + result0 += tab[decoder0.decode()]; + result1 += tab[decoder1.decode()]; + result2 += tab[decoder2.decode()]; + result3 += tab[decoder3.decode()]; + tab += ksub; + } + } +}; } // namespace faiss - -#elif defined(__ARM_FEATURE_SVE) - -#include - -namespace faiss { - -template -inline float distance_single_code( - // the product quantizer - const size_t M, - // number of bits per quantization index - const size_t nbits, - // precomputed distances, layout (M, ksub) - const float* sim_table, - // the code - const uint8_t* code) { - return distance_single_code_sve(M, nbits, sim_table, code); -} - -template -inline void distance_four_codes( - // the product quantizer - const size_t M, - // number of bits per quantization index - const size_t nbits, - // precomputed distances, layout (M, ksub) - const float* sim_table, - // codes - const uint8_t* __restrict code0, - const uint8_t* __restrict code1, - const uint8_t* __restrict code2, - const uint8_t* __restrict code3, - // computed distances - float& result0, - float& result1, - float& result2, - float& result3) { - distance_four_codes_sve( - M, - nbits, - sim_table, - code0, - code1, - code2, - code3, - result0, - result1, - result2, - result3); -} - -} // namespace faiss - -#else - -#include - -namespace faiss { - -template -inline float distance_single_code( - // number of subquantizers - const size_t M, - // number of bits per quantization index - const size_t nbits, - // precomputed distances, layout (M, ksub) - const float* sim_table, - // the code - const uint8_t* code) { - return distance_single_code_generic(M, nbits, sim_table, code); -} - -template -inline void distance_four_codes( - // number of subquantizers - const size_t M, - // number of bits per quantization index - const size_t nbits, - // precomputed distances, layout (M, ksub) - const float* sim_table, - // codes - const uint8_t* __restrict code0, - const uint8_t* __restrict code1, - const uint8_t* __restrict code2, - const uint8_t* __restrict code3, - // computed distances - float& result0, - float& result1, - float& result2, - float& result3) { - distance_four_codes_generic( - M, - nbits, - sim_table, - code0, - code1, - code2, - code3, - result0, - result1, - result2, - result3); -} - -} // namespace faiss - -#endif diff --git a/faiss/utils/simdlib_avx2.h b/faiss/utils/simd_impl/simdlib_avx2.h similarity index 100% rename from faiss/utils/simdlib_avx2.h rename to faiss/utils/simd_impl/simdlib_avx2.h diff --git a/faiss/utils/simdlib_avx512.h b/faiss/utils/simd_impl/simdlib_avx512.h similarity index 100% rename from faiss/utils/simdlib_avx512.h rename to faiss/utils/simd_impl/simdlib_avx512.h diff --git a/faiss/utils/simdlib_emulated.h b/faiss/utils/simd_impl/simdlib_emulated.h similarity index 100% rename from faiss/utils/simdlib_emulated.h rename to faiss/utils/simd_impl/simdlib_emulated.h diff --git a/faiss/utils/simdlib_neon.h b/faiss/utils/simd_impl/simdlib_neon.h similarity index 100% rename from faiss/utils/simdlib_neon.h rename to faiss/utils/simd_impl/simdlib_neon.h diff --git a/faiss/utils/simdlib_ppc64.h b/faiss/utils/simd_impl/simdlib_ppc64.h similarity index 100% rename from faiss/utils/simdlib_ppc64.h rename to faiss/utils/simd_impl/simdlib_ppc64.h diff --git a/faiss/utils/simdlib.h b/faiss/utils/simdlib.h index eadfb78ae3..98c38f7a0d 100644 --- a/faiss/utils/simdlib.h +++ b/faiss/utils/simdlib.h @@ -21,20 +21,20 @@ #elif defined(__AVX2__) -#include +#include #elif defined(__aarch64__) -#include +#include #elif defined(__PPC64__) -#include +#include #else // emulated = all operations are implemented as scalars -#include +#include // FIXME: make a SSE version // is this ever going to happen? We will probably rather implement AVX512 diff --git a/tests/test_code_distance.cpp b/tests/test_code_distance.cpp index f1a3939388..e4b61baf0f 100644 --- a/tests/test_code_distance.cpp +++ b/tests/test_code_distance.cpp @@ -22,6 +22,7 @@ #include #include #include +#include size_t nMismatches( const std::vector& ref, @@ -80,8 +81,10 @@ void test( for (size_t k = 0; k < 10; k++) { #pragma omp parallel for schedule(guided) for (size_t i = 0; i < n; i++) { - resultsRef[i] = - faiss::distance_single_code_generic( + resultsRef[i] = faiss::PQCodeDistance< + faiss::PQDecoder8, + faiss::SIMDLevel::NONE>:: + distance_single_code( subq, 8, lookup.data(), codes.data() + subq * i); } } @@ -94,8 +97,10 @@ void test( for (size_t k = 0; k < 1000; k++) { #pragma omp parallel for schedule(guided) for (size_t i = 0; i < n; i++) { - resultsNewGeneric1x[i] = - faiss::distance_single_code_generic( + resultsNewGeneric1x[i] = faiss::PQCodeDistance< + faiss::PQDecoder8, + faiss::SIMDLevel::NONE>:: + distance_single_code( subq, 8, lookup.data(), @@ -117,18 +122,21 @@ void test( for (size_t k = 0; k < 1000; k++) { #pragma omp parallel for schedule(guided) for (size_t i = 0; i < n; i += 4) { - faiss::distance_four_codes_generic( - subq, - 8, - lookup.data(), - codes.data() + subq * (i + 0), - codes.data() + subq * (i + 1), - codes.data() + subq * (i + 2), - codes.data() + subq * (i + 3), - resultsNewGeneric4x[i + 0], - resultsNewGeneric4x[i + 1], - resultsNewGeneric4x[i + 2], - resultsNewGeneric4x[i + 3]); + faiss::PQCodeDistance< + faiss::PQDecoder8, + faiss::SIMDLevel::NONE>:: + distance_four_codes( + subq, + 8, + lookup.data(), + codes.data() + subq * (i + 0), + codes.data() + subq * (i + 1), + codes.data() + subq * (i + 2), + codes.data() + subq * (i + 3), + resultsNewGeneric4x[i + 0], + resultsNewGeneric4x[i + 1], + resultsNewGeneric4x[i + 2], + resultsNewGeneric4x[i + 3]); } } @@ -147,8 +155,10 @@ void test( for (size_t k = 0; k < 1000; k++) { #pragma omp parallel for schedule(guided) for (size_t i = 0; i < n; i++) { - resultsNewCustom1x[i] = - faiss::distance_single_code( + resultsNewCustom1x[i] = faiss::PQCodeDistance< + faiss::PQDecoder8, + faiss::SIMDLevel::NONE>:: + distance_single_code( subq, 8, lookup.data(), @@ -170,18 +180,21 @@ void test( for (size_t k = 0; k < 1000; k++) { #pragma omp parallel for schedule(guided) for (size_t i = 0; i < n; i += 4) { - faiss::distance_four_codes( - subq, - 8, - lookup.data(), - codes.data() + subq * (i + 0), - codes.data() + subq * (i + 1), - codes.data() + subq * (i + 2), - codes.data() + subq * (i + 3), - resultsNewCustom4x[i + 0], - resultsNewCustom4x[i + 1], - resultsNewCustom4x[i + 2], - resultsNewCustom4x[i + 3]); + faiss::PQCodeDistance< + faiss::PQDecoder8, + faiss::SIMDLevel::NONE>:: + distance_four_codes( + subq, + 8, + lookup.data(), + codes.data() + subq * (i + 0), + codes.data() + subq * (i + 1), + codes.data() + subq * (i + 2), + codes.data() + subq * (i + 3), + resultsNewCustom4x[i + 0], + resultsNewCustom4x[i + 1], + resultsNewCustom4x[i + 2], + resultsNewCustom4x[i + 3]); } } From 97352161d15df9168efdb11d04b6c52461dc47c3 Mon Sep 17 00:00:00 2001 From: matthijs Date: Thu, 28 Aug 2025 00:59:30 -0700 Subject: [PATCH 4/5] Use simdlib abstraction in ScalarQuantizer implementation, split off training code, split quantizer code into headers, Make headers more independent Summary: Move the interface of SIMD functions to use the simdXfloat32 API to mutualize code. Begin splitting the ScalarQuantizer.cpp Continue splitting. Purely in header files for now. Differential Revision: D72945865 --- faiss/impl/ScalarQuantizer.cpp | 1965 +---------------- faiss/impl/ScalarQuantizer.h | 2 - faiss/impl/scalar_quantizer/codecs.h | 305 +++ .../scalar_quantizer/distance_computers.h | 381 ++++ faiss/impl/scalar_quantizer/quantizers.h | 586 +++++ faiss/impl/scalar_quantizer/similarities.h | 345 +++ faiss/impl/scalar_quantizer/training.cpp | 188 ++ faiss/impl/scalar_quantizer/training.h | 40 + faiss/utils/simd_impl/simdlib_avx512.h | 2 +- faiss/utils/simdlib.h | 4 +- 10 files changed, 1962 insertions(+), 1856 deletions(-) create mode 100644 faiss/impl/scalar_quantizer/codecs.h create mode 100644 faiss/impl/scalar_quantizer/distance_computers.h create mode 100644 faiss/impl/scalar_quantizer/quantizers.h create mode 100644 faiss/impl/scalar_quantizer/similarities.h create mode 100644 faiss/impl/scalar_quantizer/training.cpp create mode 100644 faiss/impl/scalar_quantizer/training.h diff --git a/faiss/impl/ScalarQuantizer.cpp b/faiss/impl/ScalarQuantizer.cpp index af90b7e130..ada60c116d 100644 --- a/faiss/impl/ScalarQuantizer.cpp +++ b/faiss/impl/ScalarQuantizer.cpp @@ -7,1881 +7,142 @@ // -*- c++ -*- -#include - -#include -#include - -#include - -#ifdef __SSE__ -#include -#endif - -#include -#include -#include -#include -#include -#include -#include - -namespace faiss { - -/******************************************************************* - * ScalarQuantizer implementation - * - * The main source of complexity is to support combinations of 4 - * variants without incurring runtime tests or virtual function calls: - * - * - 4 / 8 bits per code component - * - uniform / non-uniform - * - IP / L2 distance search - * - scalar / AVX distance computation - * - * The appropriate Quantizer object is returned via select_quantizer - * that hides the template mess. - ********************************************************************/ - -#if defined(__AVX512F__) && defined(__F16C__) -#define USE_AVX512_F16C -#elif defined(__AVX2__) -#ifdef __F16C__ -#define USE_F16C -#else -#warning \ - "Cannot enable AVX optimizations in scalar quantizer if -mf16c is not set as well" -#endif -#endif - -#if defined(__aarch64__) -#if defined(__GNUC__) && __GNUC__ < 8 -#warning \ - "Cannot enable NEON optimizations in scalar quantizer if the compiler is GCC<8" -#else -#define USE_NEON -#endif -#endif - -namespace { - -typedef ScalarQuantizer::QuantizerType QuantizerType; -typedef ScalarQuantizer::RangeStat RangeStat; -using SQDistanceComputer = ScalarQuantizer::SQDistanceComputer; - -/******************************************************************* - * Codec: converts between values in [0, 1] and an index in a code - * array. The "i" parameter is the vector component index (not byte - * index). - */ - -struct Codec8bit { - static FAISS_ALWAYS_INLINE void encode_component( - float x, - uint8_t* code, - int i) { - code[i] = (int)(255 * x); - } - - static FAISS_ALWAYS_INLINE float decode_component( - const uint8_t* code, - int i) { - return (code[i] + 0.5f) / 255.0f; - } - -#if defined(__AVX512F__) - static FAISS_ALWAYS_INLINE __m512 - decode_16_components(const uint8_t* code, int i) { - const __m128i c16 = _mm_loadu_si128((__m128i*)(code + i)); - const __m512i i32 = _mm512_cvtepu8_epi32(c16); - const __m512 f16 = _mm512_cvtepi32_ps(i32); - const __m512 half_one_255 = _mm512_set1_ps(0.5f / 255.f); - const __m512 one_255 = _mm512_set1_ps(1.f / 255.f); - return _mm512_fmadd_ps(f16, one_255, half_one_255); - } -#elif defined(__AVX2__) - static FAISS_ALWAYS_INLINE __m256 - decode_8_components(const uint8_t* code, int i) { - const uint64_t c8 = *(uint64_t*)(code + i); - - const __m128i i8 = _mm_set1_epi64x(c8); - const __m256i i32 = _mm256_cvtepu8_epi32(i8); - const __m256 f8 = _mm256_cvtepi32_ps(i32); - const __m256 half_one_255 = _mm256_set1_ps(0.5f / 255.f); - const __m256 one_255 = _mm256_set1_ps(1.f / 255.f); - return _mm256_fmadd_ps(f8, one_255, half_one_255); - } -#endif - -#ifdef USE_NEON - static FAISS_ALWAYS_INLINE float32x4x2_t - decode_8_components(const uint8_t* code, int i) { - float32_t result[8] = {}; - for (size_t j = 0; j < 8; j++) { - result[j] = decode_component(code, i + j); - } - float32x4_t res1 = vld1q_f32(result); - float32x4_t res2 = vld1q_f32(result + 4); - return {res1, res2}; - } -#endif -}; - -struct Codec4bit { - static FAISS_ALWAYS_INLINE void encode_component( - float x, - uint8_t* code, - int i) { - code[i / 2] |= (int)(x * 15.0) << ((i & 1) << 2); - } - - static FAISS_ALWAYS_INLINE float decode_component( - const uint8_t* code, - int i) { - return (((code[i / 2] >> ((i & 1) << 2)) & 0xf) + 0.5f) / 15.0f; - } - -#if defined(__AVX512F__) - static FAISS_ALWAYS_INLINE __m512 - decode_16_components(const uint8_t* code, int i) { - uint64_t c8 = *(uint64_t*)(code + (i >> 1)); - uint64_t mask = 0x0f0f0f0f0f0f0f0f; - uint64_t c8ev = c8 & mask; - uint64_t c8od = (c8 >> 4) & mask; - - __m128i c16 = - _mm_unpacklo_epi8(_mm_set1_epi64x(c8ev), _mm_set1_epi64x(c8od)); - __m256i c8lo = _mm256_cvtepu8_epi32(c16); - __m256i c8hi = _mm256_cvtepu8_epi32(_mm_srli_si128(c16, 8)); - __m512i i16 = _mm512_castsi256_si512(c8lo); - i16 = _mm512_inserti32x8(i16, c8hi, 1); - __m512 f16 = _mm512_cvtepi32_ps(i16); - const __m512 half_one_255 = _mm512_set1_ps(0.5f / 15.f); - const __m512 one_255 = _mm512_set1_ps(1.f / 15.f); - return _mm512_fmadd_ps(f16, one_255, half_one_255); - } -#elif defined(__AVX2__) - static FAISS_ALWAYS_INLINE __m256 - decode_8_components(const uint8_t* code, int i) { - uint32_t c4 = *(uint32_t*)(code + (i >> 1)); - uint32_t mask = 0x0f0f0f0f; - uint32_t c4ev = c4 & mask; - uint32_t c4od = (c4 >> 4) & mask; - - // the 8 lower bytes of c8 contain the values - __m128i c8 = - _mm_unpacklo_epi8(_mm_set1_epi32(c4ev), _mm_set1_epi32(c4od)); - __m128i c4lo = _mm_cvtepu8_epi32(c8); - __m128i c4hi = _mm_cvtepu8_epi32(_mm_srli_si128(c8, 4)); - __m256i i8 = _mm256_castsi128_si256(c4lo); - i8 = _mm256_insertf128_si256(i8, c4hi, 1); - __m256 f8 = _mm256_cvtepi32_ps(i8); - __m256 half = _mm256_set1_ps(0.5f); - f8 = _mm256_add_ps(f8, half); - __m256 one_255 = _mm256_set1_ps(1.f / 15.f); - return _mm256_mul_ps(f8, one_255); - } -#endif - -#ifdef USE_NEON - static FAISS_ALWAYS_INLINE float32x4x2_t - decode_8_components(const uint8_t* code, int i) { - float32_t result[8] = {}; - for (size_t j = 0; j < 8; j++) { - result[j] = decode_component(code, i + j); - } - float32x4_t res1 = vld1q_f32(result); - float32x4_t res2 = vld1q_f32(result + 4); - return {res1, res2}; - } -#endif -}; - -struct Codec6bit { - static FAISS_ALWAYS_INLINE void encode_component( - float x, - uint8_t* code, - int i) { - int bits = (int)(x * 63.0); - code += (i >> 2) * 3; - switch (i & 3) { - case 0: - code[0] |= bits; - break; - case 1: - code[0] |= bits << 6; - code[1] |= bits >> 2; - break; - case 2: - code[1] |= bits << 4; - code[2] |= bits >> 4; - break; - case 3: - code[2] |= bits << 2; - break; - } - } - - static FAISS_ALWAYS_INLINE float decode_component( - const uint8_t* code, - int i) { - uint8_t bits; - code += (i >> 2) * 3; - switch (i & 3) { - case 0: - bits = code[0] & 0x3f; - break; - case 1: - bits = code[0] >> 6; - bits |= (code[1] & 0xf) << 2; - break; - case 2: - bits = code[1] >> 4; - bits |= (code[2] & 3) << 4; - break; - case 3: - bits = code[2] >> 2; - break; - } - return (bits + 0.5f) / 63.0f; - } - -#if defined(__AVX512F__) - - static FAISS_ALWAYS_INLINE __m512 - decode_16_components(const uint8_t* code, int i) { - // pure AVX512 implementation (not necessarily the fastest). - // see: - // https://github.com/zilliztech/knowhere/blob/main/thirdparty/faiss/faiss/impl/ScalarQuantizerCodec_avx512.h - - // clang-format off - - // 16 components, 16x6 bit=12 bytes - const __m128i bit_6v = - _mm_maskz_loadu_epi8(0b0000111111111111, code + (i >> 2) * 3); - const __m256i bit_6v_256 = _mm256_broadcast_i32x4(bit_6v); - - // 00 01 02 03 04 05 06 07 08 09 0A 0B 0C 0D 0E 0F - // 00 01 02 03 - const __m256i shuffle_mask = _mm256_setr_epi16( - 0xFF00, 0x0100, 0x0201, 0xFF02, - 0xFF03, 0x0403, 0x0504, 0xFF05, - 0xFF06, 0x0706, 0x0807, 0xFF08, - 0xFF09, 0x0A09, 0x0B0A, 0xFF0B); - const __m256i shuffled = _mm256_shuffle_epi8(bit_6v_256, shuffle_mask); - - // 0: xxxxxxxx xx543210 - // 1: xxxx5432 10xxxxxx - // 2: xxxxxx54 3210xxxx - // 3: xxxxxxxx 543210xx - const __m256i shift_right_v = _mm256_setr_epi16( - 0x0U, 0x6U, 0x4U, 0x2U, - 0x0U, 0x6U, 0x4U, 0x2U, - 0x0U, 0x6U, 0x4U, 0x2U, - 0x0U, 0x6U, 0x4U, 0x2U); - __m256i shuffled_shifted = _mm256_srlv_epi16(shuffled, shift_right_v); - - // remove unneeded bits - shuffled_shifted = - _mm256_and_si256(shuffled_shifted, _mm256_set1_epi16(0x003F)); - - // scale - const __m512 f8 = - _mm512_cvtepi32_ps(_mm512_cvtepi16_epi32(shuffled_shifted)); - const __m512 half_one_255 = _mm512_set1_ps(0.5f / 63.f); - const __m512 one_255 = _mm512_set1_ps(1.f / 63.f); - return _mm512_fmadd_ps(f8, one_255, half_one_255); - - // clang-format on - } - -#elif defined(__AVX2__) - - /* Load 6 bytes that represent 8 6-bit values, return them as a - * 8*32 bit vector register */ - static FAISS_ALWAYS_INLINE __m256i load6(const uint16_t* code16) { - const __m128i perm = _mm_set_epi8( - -1, 5, 5, 4, 4, 3, -1, 3, -1, 2, 2, 1, 1, 0, -1, 0); - const __m256i shifts = _mm256_set_epi32(2, 4, 6, 0, 2, 4, 6, 0); - - // load 6 bytes - __m128i c1 = - _mm_set_epi16(0, 0, 0, 0, 0, code16[2], code16[1], code16[0]); - - // put in 8 * 32 bits - __m128i c2 = _mm_shuffle_epi8(c1, perm); - __m256i c3 = _mm256_cvtepi16_epi32(c2); - - // shift and mask out useless bits - __m256i c4 = _mm256_srlv_epi32(c3, shifts); - __m256i c5 = _mm256_and_si256(_mm256_set1_epi32(63), c4); - return c5; - } - - static FAISS_ALWAYS_INLINE __m256 - decode_8_components(const uint8_t* code, int i) { - // // Faster code for Intel CPUs or AMD Zen3+, just keeping it here - // // for the reference, maybe, it becomes used oned day. - // const uint16_t* data16 = (const uint16_t*)(code + (i >> 2) * 3); - // const uint32_t* data32 = (const uint32_t*)data16; - // const uint64_t val = *data32 + ((uint64_t)data16[2] << 32); - // const uint64_t vext = _pdep_u64(val, 0x3F3F3F3F3F3F3F3FULL); - // const __m128i i8 = _mm_set1_epi64x(vext); - // const __m256i i32 = _mm256_cvtepi8_epi32(i8); - // const __m256 f8 = _mm256_cvtepi32_ps(i32); - // const __m256 half_one_255 = _mm256_set1_ps(0.5f / 63.f); - // const __m256 one_255 = _mm256_set1_ps(1.f / 63.f); - // return _mm256_fmadd_ps(f8, one_255, half_one_255); - - __m256i i8 = load6((const uint16_t*)(code + (i >> 2) * 3)); - __m256 f8 = _mm256_cvtepi32_ps(i8); - // this could also be done with bit manipulations but it is - // not obviously faster - const __m256 half_one_255 = _mm256_set1_ps(0.5f / 63.f); - const __m256 one_255 = _mm256_set1_ps(1.f / 63.f); - return _mm256_fmadd_ps(f8, one_255, half_one_255); - } - -#endif - -#ifdef USE_NEON - static FAISS_ALWAYS_INLINE float32x4x2_t - decode_8_components(const uint8_t* code, int i) { - float32_t result[8] = {}; - for (size_t j = 0; j < 8; j++) { - result[j] = decode_component(code, i + j); - } - float32x4_t res1 = vld1q_f32(result); - float32x4_t res2 = vld1q_f32(result + 4); - return {res1, res2}; - } -#endif -}; - -/******************************************************************* - * Quantizer: normalizes scalar vector components, then passes them - * through a codec - *******************************************************************/ - -enum class QuantizerTemplateScaling { UNIFORM = 0, NON_UNIFORM = 1 }; - -template -struct QuantizerTemplate {}; - -template -struct QuantizerTemplate - : ScalarQuantizer::SQuantizer { - const size_t d; - const float vmin, vdiff; - - QuantizerTemplate(size_t d, const std::vector& trained) - : d(d), vmin(trained[0]), vdiff(trained[1]) {} - - void encode_vector(const float* x, uint8_t* code) const final { - for (size_t i = 0; i < d; i++) { - float xi = 0; - if (vdiff != 0) { - xi = (x[i] - vmin) / vdiff; - if (xi < 0) { - xi = 0; - } - if (xi > 1.0) { - xi = 1.0; - } - } - Codec::encode_component(xi, code, i); - } - } - - void decode_vector(const uint8_t* code, float* x) const final { - for (size_t i = 0; i < d; i++) { - float xi = Codec::decode_component(code, i); - x[i] = vmin + xi * vdiff; - } - } - - FAISS_ALWAYS_INLINE float reconstruct_component(const uint8_t* code, int i) - const { - float xi = Codec::decode_component(code, i); - return vmin + xi * vdiff; - } -}; - -#if defined(__AVX512F__) - -template -struct QuantizerTemplate - : QuantizerTemplate { - QuantizerTemplate(size_t d, const std::vector& trained) - : QuantizerTemplate( - d, - trained) {} - - FAISS_ALWAYS_INLINE __m512 - reconstruct_16_components(const uint8_t* code, int i) const { - __m512 xi = Codec::decode_16_components(code, i); - return _mm512_fmadd_ps( - xi, _mm512_set1_ps(this->vdiff), _mm512_set1_ps(this->vmin)); - } -}; - -#elif defined(__AVX2__) - -template -struct QuantizerTemplate - : QuantizerTemplate { - QuantizerTemplate(size_t d, const std::vector& trained) - : QuantizerTemplate( - d, - trained) {} - - FAISS_ALWAYS_INLINE __m256 - reconstruct_8_components(const uint8_t* code, int i) const { - __m256 xi = Codec::decode_8_components(code, i); - return _mm256_fmadd_ps( - xi, _mm256_set1_ps(this->vdiff), _mm256_set1_ps(this->vmin)); - } -}; - -#endif - -#ifdef USE_NEON - -template -struct QuantizerTemplate - : QuantizerTemplate { - QuantizerTemplate(size_t d, const std::vector& trained) - : QuantizerTemplate( - d, - trained) {} - - FAISS_ALWAYS_INLINE float32x4x2_t - reconstruct_8_components(const uint8_t* code, int i) const { - float32x4x2_t xi = Codec::decode_8_components(code, i); - return {vfmaq_f32( - vdupq_n_f32(this->vmin), - xi.val[0], - vdupq_n_f32(this->vdiff)), - vfmaq_f32( - vdupq_n_f32(this->vmin), - xi.val[1], - vdupq_n_f32(this->vdiff))}; - } -}; - -#endif - -template -struct QuantizerTemplate - : ScalarQuantizer::SQuantizer { - const size_t d; - const float *vmin, *vdiff; - - QuantizerTemplate(size_t d, const std::vector& trained) - : d(d), vmin(trained.data()), vdiff(trained.data() + d) {} - - void encode_vector(const float* x, uint8_t* code) const final { - for (size_t i = 0; i < d; i++) { - float xi = 0; - if (vdiff[i] != 0) { - xi = (x[i] - vmin[i]) / vdiff[i]; - if (xi < 0) { - xi = 0; - } - if (xi > 1.0) { - xi = 1.0; - } - } - Codec::encode_component(xi, code, i); - } - } - - void decode_vector(const uint8_t* code, float* x) const final { - for (size_t i = 0; i < d; i++) { - float xi = Codec::decode_component(code, i); - x[i] = vmin[i] + xi * vdiff[i]; - } - } - - FAISS_ALWAYS_INLINE float reconstruct_component(const uint8_t* code, int i) - const { - float xi = Codec::decode_component(code, i); - return vmin[i] + xi * vdiff[i]; - } -}; - -#if defined(__AVX512F__) - -template -struct QuantizerTemplate - : QuantizerTemplate { - QuantizerTemplate(size_t d, const std::vector& trained) - : QuantizerTemplate< - Codec, - QuantizerTemplateScaling::NON_UNIFORM, - 1>(d, trained) {} - - FAISS_ALWAYS_INLINE __m512 - reconstruct_16_components(const uint8_t* code, int i) const { - __m512 xi = Codec::decode_16_components(code, i); - return _mm512_fmadd_ps( - xi, - _mm512_loadu_ps(this->vdiff + i), - _mm512_loadu_ps(this->vmin + i)); - } -}; - -#elif defined(__AVX2__) - -template -struct QuantizerTemplate - : QuantizerTemplate { - QuantizerTemplate(size_t d, const std::vector& trained) - : QuantizerTemplate< - Codec, - QuantizerTemplateScaling::NON_UNIFORM, - 1>(d, trained) {} - - FAISS_ALWAYS_INLINE __m256 - reconstruct_8_components(const uint8_t* code, int i) const { - __m256 xi = Codec::decode_8_components(code, i); - return _mm256_fmadd_ps( - xi, - _mm256_loadu_ps(this->vdiff + i), - _mm256_loadu_ps(this->vmin + i)); - } -}; - -#endif - -#ifdef USE_NEON - -template -struct QuantizerTemplate - : QuantizerTemplate { - QuantizerTemplate(size_t d, const std::vector& trained) - : QuantizerTemplate< - Codec, - QuantizerTemplateScaling::NON_UNIFORM, - 1>(d, trained) {} - - FAISS_ALWAYS_INLINE float32x4x2_t - reconstruct_8_components(const uint8_t* code, int i) const { - float32x4x2_t xi = Codec::decode_8_components(code, i); - - float32x4x2_t vmin_8 = vld1q_f32_x2(this->vmin + i); - float32x4x2_t vdiff_8 = vld1q_f32_x2(this->vdiff + i); - - return {vfmaq_f32(vmin_8.val[0], xi.val[0], vdiff_8.val[0]), - vfmaq_f32(vmin_8.val[1], xi.val[1], vdiff_8.val[1])}; - } -}; - -#endif - -/******************************************************************* - * FP16 quantizer - *******************************************************************/ - -template -struct QuantizerFP16 {}; - -template <> -struct QuantizerFP16<1> : ScalarQuantizer::SQuantizer { - const size_t d; - - QuantizerFP16(size_t d, const std::vector& /* unused */) : d(d) {} - - void encode_vector(const float* x, uint8_t* code) const final { - for (size_t i = 0; i < d; i++) { - ((uint16_t*)code)[i] = encode_fp16(x[i]); - } - } - - void decode_vector(const uint8_t* code, float* x) const final { - for (size_t i = 0; i < d; i++) { - x[i] = decode_fp16(((uint16_t*)code)[i]); - } - } - - FAISS_ALWAYS_INLINE float reconstruct_component(const uint8_t* code, int i) - const { - return decode_fp16(((uint16_t*)code)[i]); - } -}; - -#if defined(USE_AVX512_F16C) - -template <> -struct QuantizerFP16<16> : QuantizerFP16<1> { - QuantizerFP16(size_t d, const std::vector& trained) - : QuantizerFP16<1>(d, trained) {} - - FAISS_ALWAYS_INLINE __m512 - reconstruct_16_components(const uint8_t* code, int i) const { - __m256i codei = _mm256_loadu_si256((const __m256i*)(code + 2 * i)); - return _mm512_cvtph_ps(codei); - } -}; - -#endif - -#if defined(USE_F16C) - -template <> -struct QuantizerFP16<8> : QuantizerFP16<1> { - QuantizerFP16(size_t d, const std::vector& trained) - : QuantizerFP16<1>(d, trained) {} - - FAISS_ALWAYS_INLINE __m256 - reconstruct_8_components(const uint8_t* code, int i) const { - __m128i codei = _mm_loadu_si128((const __m128i*)(code + 2 * i)); - return _mm256_cvtph_ps(codei); - } -}; - -#endif - -#ifdef USE_NEON - -template <> -struct QuantizerFP16<8> : QuantizerFP16<1> { - QuantizerFP16(size_t d, const std::vector& trained) - : QuantizerFP16<1>(d, trained) {} - - FAISS_ALWAYS_INLINE float32x4x2_t - reconstruct_8_components(const uint8_t* code, int i) const { - uint16x4x2_t codei = vld1_u16_x2((const uint16_t*)(code + 2 * i)); - return {vcvt_f32_f16(vreinterpret_f16_u16(codei.val[0])), - vcvt_f32_f16(vreinterpret_f16_u16(codei.val[1]))}; - } -}; -#endif - -/******************************************************************* - * BF16 quantizer - *******************************************************************/ - -template -struct QuantizerBF16 {}; - -template <> -struct QuantizerBF16<1> : ScalarQuantizer::SQuantizer { - const size_t d; - - QuantizerBF16(size_t d, const std::vector& /* unused */) : d(d) {} - - void encode_vector(const float* x, uint8_t* code) const final { - for (size_t i = 0; i < d; i++) { - ((uint16_t*)code)[i] = encode_bf16(x[i]); - } - } - - void decode_vector(const uint8_t* code, float* x) const final { - for (size_t i = 0; i < d; i++) { - x[i] = decode_bf16(((uint16_t*)code)[i]); - } - } - - FAISS_ALWAYS_INLINE float reconstruct_component(const uint8_t* code, int i) - const { - return decode_bf16(((uint16_t*)code)[i]); - } -}; - -#if defined(__AVX512F__) - -template <> -struct QuantizerBF16<16> : QuantizerBF16<1> { - QuantizerBF16(size_t d, const std::vector& trained) - : QuantizerBF16<1>(d, trained) {} - FAISS_ALWAYS_INLINE __m512 - reconstruct_16_components(const uint8_t* code, int i) const { - __m256i code_256i = _mm256_loadu_si256((const __m256i*)(code + 2 * i)); - __m512i code_512i = _mm512_cvtepu16_epi32(code_256i); - code_512i = _mm512_slli_epi32(code_512i, 16); - return _mm512_castsi512_ps(code_512i); - } -}; - -#elif defined(__AVX2__) - -template <> -struct QuantizerBF16<8> : QuantizerBF16<1> { - QuantizerBF16(size_t d, const std::vector& trained) - : QuantizerBF16<1>(d, trained) {} - - FAISS_ALWAYS_INLINE __m256 - reconstruct_8_components(const uint8_t* code, int i) const { - __m128i code_128i = _mm_loadu_si128((const __m128i*)(code + 2 * i)); - __m256i code_256i = _mm256_cvtepu16_epi32(code_128i); - code_256i = _mm256_slli_epi32(code_256i, 16); - return _mm256_castsi256_ps(code_256i); - } -}; - -#endif - -#ifdef USE_NEON - -template <> -struct QuantizerBF16<8> : QuantizerBF16<1> { - QuantizerBF16(size_t d, const std::vector& trained) - : QuantizerBF16<1>(d, trained) {} - - FAISS_ALWAYS_INLINE float32x4x2_t - reconstruct_8_components(const uint8_t* code, int i) const { - uint16x4x2_t codei = vld1_u16_x2((const uint16_t*)(code + 2 * i)); - return {vreinterpretq_f32_u32(vshlq_n_u32(vmovl_u16(codei.val[0]), 16)), - vreinterpretq_f32_u32( - vshlq_n_u32(vmovl_u16(codei.val[1]), 16))}; - } -}; -#endif - -/******************************************************************* - * 8bit_direct quantizer - *******************************************************************/ - -template -struct Quantizer8bitDirect {}; - -template <> -struct Quantizer8bitDirect<1> : ScalarQuantizer::SQuantizer { - const size_t d; - - Quantizer8bitDirect(size_t d, const std::vector& /* unused */) - : d(d) {} - - void encode_vector(const float* x, uint8_t* code) const final { - for (size_t i = 0; i < d; i++) { - code[i] = (uint8_t)x[i]; - } - } - - void decode_vector(const uint8_t* code, float* x) const final { - for (size_t i = 0; i < d; i++) { - x[i] = code[i]; - } - } - - FAISS_ALWAYS_INLINE float reconstruct_component(const uint8_t* code, int i) - const { - return code[i]; - } -}; - -#if defined(__AVX512F__) - -template <> -struct Quantizer8bitDirect<16> : Quantizer8bitDirect<1> { - Quantizer8bitDirect(size_t d, const std::vector& trained) - : Quantizer8bitDirect<1>(d, trained) {} - - FAISS_ALWAYS_INLINE __m512 - reconstruct_16_components(const uint8_t* code, int i) const { - __m128i x16 = _mm_loadu_si128((__m128i*)(code + i)); // 16 * int8 - __m512i y16 = _mm512_cvtepu8_epi32(x16); // 16 * int32 - return _mm512_cvtepi32_ps(y16); // 16 * float32 - } -}; - -#elif defined(__AVX2__) - -template <> -struct Quantizer8bitDirect<8> : Quantizer8bitDirect<1> { - Quantizer8bitDirect(size_t d, const std::vector& trained) - : Quantizer8bitDirect<1>(d, trained) {} - - FAISS_ALWAYS_INLINE __m256 - reconstruct_8_components(const uint8_t* code, int i) const { - __m128i x8 = _mm_loadl_epi64((__m128i*)(code + i)); // 8 * int8 - __m256i y8 = _mm256_cvtepu8_epi32(x8); // 8 * int32 - return _mm256_cvtepi32_ps(y8); // 8 * float32 - } -}; - -#endif - -#ifdef USE_NEON - -template <> -struct Quantizer8bitDirect<8> : Quantizer8bitDirect<1> { - Quantizer8bitDirect(size_t d, const std::vector& trained) - : Quantizer8bitDirect<1>(d, trained) {} - - FAISS_ALWAYS_INLINE float32x4x2_t - reconstruct_8_components(const uint8_t* code, int i) const { - uint8x8_t x8 = vld1_u8((const uint8_t*)(code + i)); - uint16x8_t y8 = vmovl_u8(x8); - uint16x4_t y8_0 = vget_low_u16(y8); - uint16x4_t y8_1 = vget_high_u16(y8); - - // convert uint16 -> uint32 -> fp32 - return {vcvtq_f32_u32(vmovl_u16(y8_0)), vcvtq_f32_u32(vmovl_u16(y8_1))}; - } -}; - -#endif - -/******************************************************************* - * 8bit_direct_signed quantizer - *******************************************************************/ - -template -struct Quantizer8bitDirectSigned {}; - -template <> -struct Quantizer8bitDirectSigned<1> : ScalarQuantizer::SQuantizer { - const size_t d; - - Quantizer8bitDirectSigned(size_t d, const std::vector& /* unused */) - : d(d) {} - - void encode_vector(const float* x, uint8_t* code) const final { - for (size_t i = 0; i < d; i++) { - code[i] = (uint8_t)(x[i] + 128); - } - } - - void decode_vector(const uint8_t* code, float* x) const final { - for (size_t i = 0; i < d; i++) { - x[i] = code[i] - 128; - } - } - - FAISS_ALWAYS_INLINE float reconstruct_component(const uint8_t* code, int i) - const { - return code[i] - 128; - } -}; - -#if defined(__AVX512F__) - -template <> -struct Quantizer8bitDirectSigned<16> : Quantizer8bitDirectSigned<1> { - Quantizer8bitDirectSigned(size_t d, const std::vector& trained) - : Quantizer8bitDirectSigned<1>(d, trained) {} - - FAISS_ALWAYS_INLINE __m512 - reconstruct_16_components(const uint8_t* code, int i) const { - __m128i x16 = _mm_loadu_si128((__m128i*)(code + i)); // 16 * int8 - __m512i y16 = _mm512_cvtepu8_epi32(x16); // 16 * int32 - __m512i c16 = _mm512_set1_epi32(128); - __m512i z16 = _mm512_sub_epi32(y16, c16); // subtract 128 from all lanes - return _mm512_cvtepi32_ps(z16); // 16 * float32 - } -}; - -#elif defined(__AVX2__) - -template <> -struct Quantizer8bitDirectSigned<8> : Quantizer8bitDirectSigned<1> { - Quantizer8bitDirectSigned(size_t d, const std::vector& trained) - : Quantizer8bitDirectSigned<1>(d, trained) {} - - FAISS_ALWAYS_INLINE __m256 - reconstruct_8_components(const uint8_t* code, int i) const { - __m128i x8 = _mm_loadl_epi64((__m128i*)(code + i)); // 8 * int8 - __m256i y8 = _mm256_cvtepu8_epi32(x8); // 8 * int32 - __m256i c8 = _mm256_set1_epi32(128); - __m256i z8 = _mm256_sub_epi32(y8, c8); // subtract 128 from all lanes - return _mm256_cvtepi32_ps(z8); // 8 * float32 - } -}; - -#endif - -#ifdef USE_NEON - -template <> -struct Quantizer8bitDirectSigned<8> : Quantizer8bitDirectSigned<1> { - Quantizer8bitDirectSigned(size_t d, const std::vector& trained) - : Quantizer8bitDirectSigned<1>(d, trained) {} - - FAISS_ALWAYS_INLINE float32x4x2_t - reconstruct_8_components(const uint8_t* code, int i) const { - uint8x8_t x8 = vld1_u8((const uint8_t*)(code + i)); - uint16x8_t y8 = vmovl_u8(x8); // convert uint8 -> uint16 - uint16x4_t y8_0 = vget_low_u16(y8); - uint16x4_t y8_1 = vget_high_u16(y8); - - float32x4_t z8_0 = vcvtq_f32_u32( - vmovl_u16(y8_0)); // convert uint16 -> uint32 -> fp32 - float32x4_t z8_1 = vcvtq_f32_u32(vmovl_u16(y8_1)); - - // subtract 128 to convert into signed numbers - return {vsubq_f32(z8_0, vmovq_n_f32(128.0)), - vsubq_f32(z8_1, vmovq_n_f32(128.0))}; - } -}; - -#endif - -template -ScalarQuantizer::SQuantizer* select_quantizer_1( - QuantizerType qtype, - size_t d, - const std::vector& trained) { - switch (qtype) { - case ScalarQuantizer::QT_8bit: - return new QuantizerTemplate< - Codec8bit, - QuantizerTemplateScaling::NON_UNIFORM, - SIMDWIDTH>(d, trained); - case ScalarQuantizer::QT_6bit: - return new QuantizerTemplate< - Codec6bit, - QuantizerTemplateScaling::NON_UNIFORM, - SIMDWIDTH>(d, trained); - case ScalarQuantizer::QT_4bit: - return new QuantizerTemplate< - Codec4bit, - QuantizerTemplateScaling::NON_UNIFORM, - SIMDWIDTH>(d, trained); - case ScalarQuantizer::QT_8bit_uniform: - return new QuantizerTemplate< - Codec8bit, - QuantizerTemplateScaling::UNIFORM, - SIMDWIDTH>(d, trained); - case ScalarQuantizer::QT_4bit_uniform: - return new QuantizerTemplate< - Codec4bit, - QuantizerTemplateScaling::UNIFORM, - SIMDWIDTH>(d, trained); - case ScalarQuantizer::QT_fp16: - return new QuantizerFP16(d, trained); - case ScalarQuantizer::QT_bf16: - return new QuantizerBF16(d, trained); - case ScalarQuantizer::QT_8bit_direct: - return new Quantizer8bitDirect(d, trained); - case ScalarQuantizer::QT_8bit_direct_signed: - return new Quantizer8bitDirectSigned(d, trained); - } - FAISS_THROW_MSG("unknown qtype"); -} - -/******************************************************************* - * Quantizer range training - */ - -static float sqr(float x) { - return x * x; -} - -void train_Uniform( - RangeStat rs, - float rs_arg, - idx_t n, - int k, - const float* x, - std::vector& trained) { - trained.resize(2); - float& vmin = trained[0]; - float& vmax = trained[1]; - - if (rs == ScalarQuantizer::RS_minmax) { - vmin = HUGE_VAL; - vmax = -HUGE_VAL; - for (size_t i = 0; i < n; i++) { - if (x[i] < vmin) { - vmin = x[i]; - } - if (x[i] > vmax) { - vmax = x[i]; - } - } - float vexp = (vmax - vmin) * rs_arg; - vmin -= vexp; - vmax += vexp; - } else if (rs == ScalarQuantizer::RS_meanstd) { - double sum = 0, sum2 = 0; - for (size_t i = 0; i < n; i++) { - sum += x[i]; - sum2 += x[i] * x[i]; - } - float mean = sum / n; - float var = sum2 / n - mean * mean; - float std = var <= 0 ? 1.0 : sqrt(var); - - vmin = mean - std * rs_arg; - vmax = mean + std * rs_arg; - } else if (rs == ScalarQuantizer::RS_quantiles) { - std::vector x_copy(n); - memcpy(x_copy.data(), x, n * sizeof(*x)); - // TODO just do a quickselect - std::sort(x_copy.begin(), x_copy.end()); - int o = int(rs_arg * n); - if (o < 0) { - o = 0; - } - if (o > n - o) { - o = n / 2; - } - vmin = x_copy[o]; - vmax = x_copy[n - 1 - o]; - - } else if (rs == ScalarQuantizer::RS_optim) { - float a, b; - float sx = 0; - { - vmin = HUGE_VAL, vmax = -HUGE_VAL; - for (size_t i = 0; i < n; i++) { - if (x[i] < vmin) { - vmin = x[i]; - } - if (x[i] > vmax) { - vmax = x[i]; - } - sx += x[i]; - } - b = vmin; - a = (vmax - vmin) / (k - 1); - } - int verbose = false; - int niter = 2000; - float last_err = -1; - int iter_last_err = 0; - for (int it = 0; it < niter; it++) { - float sn = 0, sn2 = 0, sxn = 0, err1 = 0; - - for (idx_t i = 0; i < n; i++) { - float xi = x[i]; - float ni = floor((xi - b) / a + 0.5); - if (ni < 0) { - ni = 0; - } - if (ni >= k) { - ni = k - 1; - } - err1 += sqr(xi - (ni * a + b)); - sn += ni; - sn2 += ni * ni; - sxn += ni * xi; - } - - if (err1 == last_err) { - iter_last_err++; - if (iter_last_err == 16) { - break; - } - } else { - last_err = err1; - iter_last_err = 0; - } - - float det = sqr(sn) - sn2 * n; - - b = (sn * sxn - sn2 * sx) / det; - a = (sn * sx - n * sxn) / det; - if (verbose) { - printf("it %d, err1=%g \r", it, err1); - fflush(stdout); - } - } - if (verbose) { - printf("\n"); - } - - vmin = b; - vmax = b + a * (k - 1); - - } else { - FAISS_THROW_MSG("Invalid qtype"); - } - vmax -= vmin; -} - -void train_NonUniform( - RangeStat rs, - float rs_arg, - idx_t n, - int d, - int k, - const float* x, - std::vector& trained) { - trained.resize(2 * d); - float* vmin = trained.data(); - float* vmax = trained.data() + d; - if (rs == ScalarQuantizer::RS_minmax) { - memcpy(vmin, x, sizeof(*x) * d); - memcpy(vmax, x, sizeof(*x) * d); - for (size_t i = 1; i < n; i++) { - const float* xi = x + i * d; - for (size_t j = 0; j < d; j++) { - if (xi[j] < vmin[j]) { - vmin[j] = xi[j]; - } - if (xi[j] > vmax[j]) { - vmax[j] = xi[j]; - } - } - } - float* vdiff = vmax; - for (size_t j = 0; j < d; j++) { - float vexp = (vmax[j] - vmin[j]) * rs_arg; - vmin[j] -= vexp; - vmax[j] += vexp; - vdiff[j] = vmax[j] - vmin[j]; - } - } else { - // transpose - std::vector xt(n * d); - for (size_t i = 1; i < n; i++) { - const float* xi = x + i * d; - for (size_t j = 0; j < d; j++) { - xt[j * n + i] = xi[j]; - } - } - std::vector trained_d(2); -#pragma omp parallel for - for (int j = 0; j < d; j++) { - train_Uniform(rs, rs_arg, n, k, xt.data() + j * n, trained_d); - vmin[j] = trained_d[0]; - vmax[j] = trained_d[1]; - } - } -} - -/******************************************************************* - * Similarity: gets vector components and computes a similarity wrt. a - * query vector stored in the object. The data fields just encapsulate - * an accumulator. - */ - -template -struct SimilarityL2 {}; - -template <> -struct SimilarityL2<1> { - static constexpr int simdwidth = 1; - static constexpr MetricType metric_type = METRIC_L2; - - const float *y, *yi; - - explicit SimilarityL2(const float* y) : y(y) {} - - /******* scalar accumulator *******/ - - float accu; - - FAISS_ALWAYS_INLINE void begin() { - accu = 0; - yi = y; - } - - FAISS_ALWAYS_INLINE void add_component(float x) { - float tmp = *yi++ - x; - accu += tmp * tmp; - } - - FAISS_ALWAYS_INLINE void add_component_2(float x1, float x2) { - float tmp = x1 - x2; - accu += tmp * tmp; - } - - FAISS_ALWAYS_INLINE float result() { - return accu; - } -}; - -#if defined(__AVX512F__) - -template <> -struct SimilarityL2<16> { - static constexpr int simdwidth = 16; - static constexpr MetricType metric_type = METRIC_L2; - - const float *y, *yi; - - explicit SimilarityL2(const float* y) : y(y) {} - __m512 accu16; - - FAISS_ALWAYS_INLINE void begin_16() { - accu16 = _mm512_setzero_ps(); - yi = y; - } - - FAISS_ALWAYS_INLINE void add_16_components(__m512 x) { - __m512 yiv = _mm512_loadu_ps(yi); - yi += 16; - __m512 tmp = _mm512_sub_ps(yiv, x); - accu16 = _mm512_fmadd_ps(tmp, tmp, accu16); - } - - FAISS_ALWAYS_INLINE void add_16_components_2(__m512 x, __m512 y_2) { - __m512 tmp = _mm512_sub_ps(y_2, x); - accu16 = _mm512_fmadd_ps(tmp, tmp, accu16); - } - - FAISS_ALWAYS_INLINE float result_16() { - // performs better than dividing into _mm256 and adding - return _mm512_reduce_add_ps(accu16); - } -}; - -#elif defined(__AVX2__) - -template <> -struct SimilarityL2<8> { - static constexpr int simdwidth = 8; - static constexpr MetricType metric_type = METRIC_L2; - - const float *y, *yi; - - explicit SimilarityL2(const float* y) : y(y) {} - __m256 accu8; - - FAISS_ALWAYS_INLINE void begin_8() { - accu8 = _mm256_setzero_ps(); - yi = y; - } - - FAISS_ALWAYS_INLINE void add_8_components(__m256 x) { - __m256 yiv = _mm256_loadu_ps(yi); - yi += 8; - __m256 tmp = _mm256_sub_ps(yiv, x); - accu8 = _mm256_fmadd_ps(tmp, tmp, accu8); - } - - FAISS_ALWAYS_INLINE void add_8_components_2(__m256 x, __m256 y_2) { - __m256 tmp = _mm256_sub_ps(y_2, x); - accu8 = _mm256_fmadd_ps(tmp, tmp, accu8); - } - - FAISS_ALWAYS_INLINE float result_8() { - const __m128 sum = _mm_add_ps( - _mm256_castps256_ps128(accu8), _mm256_extractf128_ps(accu8, 1)); - const __m128 v0 = _mm_shuffle_ps(sum, sum, _MM_SHUFFLE(0, 0, 3, 2)); - const __m128 v1 = _mm_add_ps(sum, v0); - __m128 v2 = _mm_shuffle_ps(v1, v1, _MM_SHUFFLE(0, 0, 0, 1)); - const __m128 v3 = _mm_add_ps(v1, v2); - return _mm_cvtss_f32(v3); - } -}; - -#endif - -#ifdef USE_NEON -template <> -struct SimilarityL2<8> { - static constexpr int simdwidth = 8; - static constexpr MetricType metric_type = METRIC_L2; - - const float *y, *yi; - explicit SimilarityL2(const float* y) : y(y) {} - float32x4x2_t accu8; - - FAISS_ALWAYS_INLINE void begin_8() { - accu8 = {vdupq_n_f32(0.0f), vdupq_n_f32(0.0f)}; - yi = y; - } - - FAISS_ALWAYS_INLINE void add_8_components(float32x4x2_t x) { - float32x4x2_t yiv = vld1q_f32_x2(yi); - yi += 8; - - float32x4_t sub0 = vsubq_f32(yiv.val[0], x.val[0]); - float32x4_t sub1 = vsubq_f32(yiv.val[1], x.val[1]); - - float32x4_t accu8_0 = vfmaq_f32(accu8.val[0], sub0, sub0); - float32x4_t accu8_1 = vfmaq_f32(accu8.val[1], sub1, sub1); - - accu8 = {accu8_0, accu8_1}; - } - - FAISS_ALWAYS_INLINE void add_8_components_2( - float32x4x2_t x, - float32x4x2_t y) { - float32x4_t sub0 = vsubq_f32(y.val[0], x.val[0]); - float32x4_t sub1 = vsubq_f32(y.val[1], x.val[1]); - - float32x4_t accu8_0 = vfmaq_f32(accu8.val[0], sub0, sub0); - float32x4_t accu8_1 = vfmaq_f32(accu8.val[1], sub1, sub1); - - accu8 = {accu8_0, accu8_1}; - } - - FAISS_ALWAYS_INLINE float result_8() { - float32x4_t sum_0 = vpaddq_f32(accu8.val[0], accu8.val[0]); - float32x4_t sum_1 = vpaddq_f32(accu8.val[1], accu8.val[1]); - - float32x4_t sum2_0 = vpaddq_f32(sum_0, sum_0); - float32x4_t sum2_1 = vpaddq_f32(sum_1, sum_1); - return vgetq_lane_f32(sum2_0, 0) + vgetq_lane_f32(sum2_1, 0); - } -}; -#endif - -template -struct SimilarityIP {}; - -template <> -struct SimilarityIP<1> { - static constexpr int simdwidth = 1; - static constexpr MetricType metric_type = METRIC_INNER_PRODUCT; - const float *y, *yi; - - float accu; - - explicit SimilarityIP(const float* y) : y(y) {} - - FAISS_ALWAYS_INLINE void begin() { - accu = 0; - yi = y; - } - - FAISS_ALWAYS_INLINE void add_component(float x) { - accu += *yi++ * x; - } - - FAISS_ALWAYS_INLINE void add_component_2(float x1, float x2) { - accu += x1 * x2; - } - - FAISS_ALWAYS_INLINE float result() { - return accu; - } -}; - -#if defined(__AVX512F__) - -template <> -struct SimilarityIP<16> { - static constexpr int simdwidth = 16; - static constexpr MetricType metric_type = METRIC_INNER_PRODUCT; - - const float *y, *yi; - - float accu; - - explicit SimilarityIP(const float* y) : y(y) {} - - __m512 accu16; - - FAISS_ALWAYS_INLINE void begin_16() { - accu16 = _mm512_setzero_ps(); - yi = y; - } - - FAISS_ALWAYS_INLINE void add_16_components(__m512 x) { - __m512 yiv = _mm512_loadu_ps(yi); - yi += 16; - accu16 = _mm512_fmadd_ps(yiv, x, accu16); - } - - FAISS_ALWAYS_INLINE void add_16_components_2(__m512 x1, __m512 x2) { - accu16 = _mm512_fmadd_ps(x1, x2, accu16); - } - - FAISS_ALWAYS_INLINE float result_16() { - // performs better than dividing into _mm256 and adding - return _mm512_reduce_add_ps(accu16); - } -}; - -#elif defined(__AVX2__) - -template <> -struct SimilarityIP<8> { - static constexpr int simdwidth = 8; - static constexpr MetricType metric_type = METRIC_INNER_PRODUCT; - - const float *y, *yi; - - float accu; - - explicit SimilarityIP(const float* y) : y(y) {} - - __m256 accu8; - - FAISS_ALWAYS_INLINE void begin_8() { - accu8 = _mm256_setzero_ps(); - yi = y; - } +#include +#include - FAISS_ALWAYS_INLINE void add_8_components(__m256 x) { - __m256 yiv = _mm256_loadu_ps(yi); - yi += 8; - accu8 = _mm256_fmadd_ps(yiv, x, accu8); - } +#include +#include +#include - FAISS_ALWAYS_INLINE void add_8_components_2(__m256 x1, __m256 x2) { - accu8 = _mm256_fmadd_ps(x1, x2, accu8); - } +#include - FAISS_ALWAYS_INLINE float result_8() { - const __m128 sum = _mm_add_ps( - _mm256_castps256_ps128(accu8), _mm256_extractf128_ps(accu8, 1)); - const __m128 v0 = _mm_shuffle_ps(sum, sum, _MM_SHUFFLE(0, 0, 3, 2)); - const __m128 v1 = _mm_add_ps(sum, v0); - __m128 v2 = _mm_shuffle_ps(v1, v1, _MM_SHUFFLE(0, 0, 0, 1)); - const __m128 v3 = _mm_add_ps(v1, v2); - return _mm_cvtss_f32(v3); - } -}; +#ifdef __SSE__ +#include #endif -#ifdef USE_NEON - -template <> -struct SimilarityIP<8> { - static constexpr int simdwidth = 8; - static constexpr MetricType metric_type = METRIC_INNER_PRODUCT; - - const float *y, *yi; - - explicit SimilarityIP(const float* y) : y(y) {} - float32x4x2_t accu8; - - FAISS_ALWAYS_INLINE void begin_8() { - accu8 = {vdupq_n_f32(0.0f), vdupq_n_f32(0.0f)}; - yi = y; - } - - FAISS_ALWAYS_INLINE void add_8_components(float32x4x2_t x) { - float32x4x2_t yiv = vld1q_f32_x2(yi); - yi += 8; - - float32x4_t accu8_0 = vfmaq_f32(accu8.val[0], yiv.val[0], x.val[0]); - float32x4_t accu8_1 = vfmaq_f32(accu8.val[1], yiv.val[1], x.val[1]); - accu8 = {accu8_0, accu8_1}; - } - - FAISS_ALWAYS_INLINE void add_8_components_2( - float32x4x2_t x1, - float32x4x2_t x2) { - float32x4_t accu8_0 = vfmaq_f32(accu8.val[0], x1.val[0], x2.val[0]); - float32x4_t accu8_1 = vfmaq_f32(accu8.val[1], x1.val[1], x2.val[1]); - accu8 = {accu8_0, accu8_1}; - } - - FAISS_ALWAYS_INLINE float result_8() { - float32x4x2_t sum = { - vpaddq_f32(accu8.val[0], accu8.val[0]), - vpaddq_f32(accu8.val[1], accu8.val[1])}; - - float32x4x2_t sum2 = { - vpaddq_f32(sum.val[0], sum.val[0]), - vpaddq_f32(sum.val[1], sum.val[1])}; - return vgetq_lane_f32(sum2.val[0], 0) + vgetq_lane_f32(sum2.val[1], 0); - } -}; -#endif +#include +#include +#include +#include +#include +#include +#include /******************************************************************* - * DistanceComputer: combines a similarity and a quantizer to do - * code-to-vector or code-to-code comparisons - *******************************************************************/ - -template -struct DCTemplate : SQDistanceComputer {}; - -template -struct DCTemplate : SQDistanceComputer { - using Sim = Similarity; - - Quantizer quant; - - DCTemplate(size_t d, const std::vector& trained) - : quant(d, trained) {} - - float compute_distance(const float* x, const uint8_t* code) const { - Similarity sim(x); - sim.begin(); - for (size_t i = 0; i < quant.d; i++) { - float xi = quant.reconstruct_component(code, i); - sim.add_component(xi); - } - return sim.result(); - } - - float compute_code_distance(const uint8_t* code1, const uint8_t* code2) - const { - Similarity sim(nullptr); - sim.begin(); - for (size_t i = 0; i < quant.d; i++) { - float x1 = quant.reconstruct_component(code1, i); - float x2 = quant.reconstruct_component(code2, i); - sim.add_component_2(x1, x2); - } - return sim.result(); - } - - void set_query(const float* x) final { - q = x; - } - - float symmetric_dis(idx_t i, idx_t j) override { - return compute_code_distance( - codes + i * code_size, codes + j * code_size); - } - - float query_to_code(const uint8_t* code) const final { - return compute_distance(q, code); - } -}; - -#if defined(USE_AVX512_F16C) - -template -struct DCTemplate - : SQDistanceComputer { // Update to handle 16 lanes - using Sim = Similarity; - - Quantizer quant; - - DCTemplate(size_t d, const std::vector& trained) - : quant(d, trained) {} - - float compute_distance(const float* x, const uint8_t* code) const { - Similarity sim(x); - sim.begin_16(); - for (size_t i = 0; i < quant.d; i += 16) { - __m512 xi = quant.reconstruct_16_components(code, i); - sim.add_16_components(xi); - } - return sim.result_16(); - } - - float compute_code_distance(const uint8_t* code1, const uint8_t* code2) - const { - Similarity sim(nullptr); - sim.begin_16(); - for (size_t i = 0; i < quant.d; i += 16) { - __m512 x1 = quant.reconstruct_16_components(code1, i); - __m512 x2 = quant.reconstruct_16_components(code2, i); - sim.add_16_components_2(x1, x2); - } - return sim.result_16(); - } - - void set_query(const float* x) final { - q = x; - } - - float symmetric_dis(idx_t i, idx_t j) override { - return compute_code_distance( - codes + i * code_size, codes + j * code_size); - } - - float query_to_code(const uint8_t* code) const final { - return compute_distance(q, code); - } -}; - -#elif defined(USE_F16C) - -template -struct DCTemplate : SQDistanceComputer { - using Sim = Similarity; - - Quantizer quant; - - DCTemplate(size_t d, const std::vector& trained) - : quant(d, trained) {} - - float compute_distance(const float* x, const uint8_t* code) const { - Similarity sim(x); - sim.begin_8(); - for (size_t i = 0; i < quant.d; i += 8) { - __m256 xi = quant.reconstruct_8_components(code, i); - sim.add_8_components(xi); - } - return sim.result_8(); - } - - float compute_code_distance(const uint8_t* code1, const uint8_t* code2) - const { - Similarity sim(nullptr); - sim.begin_8(); - for (size_t i = 0; i < quant.d; i += 8) { - __m256 x1 = quant.reconstruct_8_components(code1, i); - __m256 x2 = quant.reconstruct_8_components(code2, i); - sim.add_8_components_2(x1, x2); - } - return sim.result_8(); - } - - void set_query(const float* x) final { - q = x; - } - - float symmetric_dis(idx_t i, idx_t j) override { - return compute_code_distance( - codes + i * code_size, codes + j * code_size); - } - - float query_to_code(const uint8_t* code) const final { - return compute_distance(q, code); - } -}; + * ScalarQuantizer implementation + * + * The main source of complexity is to support combinations of 4 + * variants without incurring runtime tests or virtual function calls: + * + * - 4 / 6 / 8 bits per code component + * - uniform / non-uniform + * - IP / L2 distance search + * - scalar / SIMD distance computation + * + * The appropriate Quantizer object is returned via select_quantizer + * that hides the template mess. + ********************************************************************/ +#if defined(__AVX512F__) && defined(__F16C__) +#define USE_AVX512_F16C +#elif defined(__AVX2__) +#ifdef __F16C__ +#define USE_F16C +#else +#warning \ + "Cannot enable AVX optimizations in scalar quantizer if -mf16c is not set as well" +#endif #endif -#ifdef USE_NEON - -template -struct DCTemplate : SQDistanceComputer { - using Sim = Similarity; - - Quantizer quant; - - DCTemplate(size_t d, const std::vector& trained) - : quant(d, trained) {} - float compute_distance(const float* x, const uint8_t* code) const { - Similarity sim(x); - sim.begin_8(); - for (size_t i = 0; i < quant.d; i += 8) { - float32x4x2_t xi = quant.reconstruct_8_components(code, i); - sim.add_8_components(xi); - } - return sim.result_8(); - } - - float compute_code_distance(const uint8_t* code1, const uint8_t* code2) - const { - Similarity sim(nullptr); - sim.begin_8(); - for (size_t i = 0; i < quant.d; i += 8) { - float32x4x2_t x1 = quant.reconstruct_8_components(code1, i); - float32x4x2_t x2 = quant.reconstruct_8_components(code2, i); - sim.add_8_components_2(x1, x2); - } - return sim.result_8(); - } - - void set_query(const float* x) final { - q = x; - } - - float symmetric_dis(idx_t i, idx_t j) override { - return compute_code_distance( - codes + i * code_size, codes + j * code_size); - } - - float query_to_code(const uint8_t* code) const final { - return compute_distance(q, code); - } -}; +#if defined(__aarch64__) +#if defined(__GNUC__) && __GNUC__ < 8 +#warning \ + "Cannot enable NEON optimizations in scalar quantizer if the compiler is GCC<8" +#else +#define USE_NEON +#endif #endif /******************************************************************* - * DistanceComputerByte: computes distances in the integer domain - *******************************************************************/ - -template -struct DistanceComputerByte : SQDistanceComputer {}; - -template -struct DistanceComputerByte : SQDistanceComputer { - using Sim = Similarity; - - int d; - std::vector tmp; - - DistanceComputerByte(int d, const std::vector&) : d(d), tmp(d) {} - - int compute_code_distance(const uint8_t* code1, const uint8_t* code2) - const { - int accu = 0; - for (int i = 0; i < d; i++) { - if (Sim::metric_type == METRIC_INNER_PRODUCT) { - accu += int(code1[i]) * code2[i]; - } else { - int diff = int(code1[i]) - code2[i]; - accu += diff * diff; - } - } - return accu; - } - - void set_query(const float* x) final { - for (int i = 0; i < d; i++) { - tmp[i] = int(x[i]); - } - } - - int compute_distance(const float* x, const uint8_t* code) { - set_query(x); - return compute_code_distance(tmp.data(), code); - } - - float symmetric_dis(idx_t i, idx_t j) override { - return compute_code_distance( - codes + i * code_size, codes + j * code_size); - } - - float query_to_code(const uint8_t* code) const final { - return compute_code_distance(tmp.data(), code); - } -}; - -#if defined(__AVX512F__) - -template -struct DistanceComputerByte : SQDistanceComputer { - using Sim = Similarity; - - int d; - std::vector tmp; - - DistanceComputerByte(int d, const std::vector&) : d(d), tmp(d) {} - - int compute_code_distance(const uint8_t* code1, const uint8_t* code2) - const { - __m512i accu = _mm512_setzero_si512(); - for (int i = 0; i < d; i += 32) { // Process 32 bytes at a time - __m512i c1 = _mm512_cvtepu8_epi16( - _mm256_loadu_si256((__m256i*)(code1 + i))); - __m512i c2 = _mm512_cvtepu8_epi16( - _mm256_loadu_si256((__m256i*)(code2 + i))); - __m512i prod32; - if (Sim::metric_type == METRIC_INNER_PRODUCT) { - prod32 = _mm512_madd_epi16(c1, c2); - } else { - __m512i diff = _mm512_sub_epi16(c1, c2); - prod32 = _mm512_madd_epi16(diff, diff); - } - accu = _mm512_add_epi32(accu, prod32); - } - // Horizontally add elements of accu - return _mm512_reduce_add_epi32(accu); - } - - void set_query(const float* x) final { - for (int i = 0; i < d; i++) { - tmp[i] = int(x[i]); - } - } - - int compute_distance(const float* x, const uint8_t* code) { - set_query(x); - return compute_code_distance(tmp.data(), code); - } - - float symmetric_dis(idx_t i, idx_t j) override { - return compute_code_distance( - codes + i * code_size, codes + j * code_size); - } - - float query_to_code(const uint8_t* code) const final { - return compute_code_distance(tmp.data(), code); - } -}; - -#elif defined(__AVX2__) - -template -struct DistanceComputerByte : SQDistanceComputer { - using Sim = Similarity; - - int d; - std::vector tmp; - - DistanceComputerByte(int d, const std::vector&) : d(d), tmp(d) {} - - int compute_code_distance(const uint8_t* code1, const uint8_t* code2) - const { - // __m256i accu = _mm256_setzero_ps (); - __m256i accu = _mm256_setzero_si256(); - for (int i = 0; i < d; i += 16) { - // load 16 bytes, convert to 16 uint16_t - __m256i c1 = _mm256_cvtepu8_epi16( - _mm_loadu_si128((__m128i*)(code1 + i))); - __m256i c2 = _mm256_cvtepu8_epi16( - _mm_loadu_si128((__m128i*)(code2 + i))); - __m256i prod32; - if (Sim::metric_type == METRIC_INNER_PRODUCT) { - prod32 = _mm256_madd_epi16(c1, c2); - } else { - __m256i diff = _mm256_sub_epi16(c1, c2); - prod32 = _mm256_madd_epi16(diff, diff); - } - accu = _mm256_add_epi32(accu, prod32); - } - __m128i sum = _mm256_extractf128_si256(accu, 0); - sum = _mm_add_epi32(sum, _mm256_extractf128_si256(accu, 1)); - sum = _mm_hadd_epi32(sum, sum); - sum = _mm_hadd_epi32(sum, sum); - return _mm_cvtsi128_si32(sum); - } + * Codec: converts between values in [0, 1] and an index in a code + * array. The "i" parameter is the vector component index (not byte + * index). + */ - void set_query(const float* x) final { - /* - for (int i = 0; i < d; i += 8) { - __m256 xi = _mm256_loadu_ps (x + i); - __m256i ci = _mm256_cvtps_epi32(xi); - */ - for (int i = 0; i < d; i++) { - tmp[i] = int(x[i]); - } - } +#include - int compute_distance(const float* x, const uint8_t* code) { - set_query(x); - return compute_code_distance(tmp.data(), code); - } +/******************************************************************* + * Quantizer: normalizes scalar vector components, then passes them + * through a codec + *******************************************************************/ - float symmetric_dis(idx_t i, idx_t j) override { - return compute_code_distance( - codes + i * code_size, codes + j * code_size); - } +#include - float query_to_code(const uint8_t* code) const final { - return compute_code_distance(tmp.data(), code); - } -}; +/******************************************************************* + * Similarity: gets vector components and computes a similarity wrt. a + * query vector stored in the object. The data fields just encapsulate + * an accumulator. + */ -#endif +#include -#ifdef USE_NEON +/******************************************************************* + * DistanceComputer: combines a similarity and a quantizer to do + * code-to-vector or code-to-code comparisons + *******************************************************************/ -template -struct DistanceComputerByte : SQDistanceComputer { - using Sim = Similarity; - - int d; - std::vector tmp; - - DistanceComputerByte(int d, const std::vector&) : d(d), tmp(d) {} - - int compute_code_distance(const uint8_t* code1, const uint8_t* code2) - const { - int accu = 0; - for (int i = 0; i < d; i++) { - if (Sim::metric_type == METRIC_INNER_PRODUCT) { - accu += int(code1[i]) * code2[i]; - } else { - int diff = int(code1[i]) - code2[i]; - accu += diff * diff; - } - } - return accu; - } +#include - void set_query(const float* x) final { - for (int i = 0; i < d; i++) { - tmp[i] = int(x[i]); - } - } +namespace faiss { - int compute_distance(const float* x, const uint8_t* code) { - set_query(x); - return compute_code_distance(tmp.data(), code); - } +namespace scalar_quantizer { - float symmetric_dis(idx_t i, idx_t j) override { - return compute_code_distance( - codes + i * code_size, codes + j * code_size); - } +typedef ScalarQuantizer::QuantizerType QuantizerType; +typedef ScalarQuantizer::RangeStat RangeStat; +using SQDistanceComputer = ScalarQuantizer::SQDistanceComputer; - float query_to_code(const uint8_t* code) const final { - return compute_code_distance(tmp.data(), code); +template +ScalarQuantizer::SQuantizer* select_quantizer_1( + QuantizerType qtype, + size_t d, + const std::vector& trained) { + switch (qtype) { + case ScalarQuantizer::QT_8bit: + return new QuantizerTemplate< + Codec8bit, + QuantizerTemplateScaling::NON_UNIFORM, + SIMDWIDTH>(d, trained); + case ScalarQuantizer::QT_6bit: + return new QuantizerTemplate< + Codec6bit, + QuantizerTemplateScaling::NON_UNIFORM, + SIMDWIDTH>(d, trained); + case ScalarQuantizer::QT_4bit: + return new QuantizerTemplate< + Codec4bit, + QuantizerTemplateScaling::NON_UNIFORM, + SIMDWIDTH>(d, trained); + case ScalarQuantizer::QT_8bit_uniform: + return new QuantizerTemplate< + Codec8bit, + QuantizerTemplateScaling::UNIFORM, + SIMDWIDTH>(d, trained); + case ScalarQuantizer::QT_4bit_uniform: + return new QuantizerTemplate< + Codec4bit, + QuantizerTemplateScaling::UNIFORM, + SIMDWIDTH>(d, trained); + case ScalarQuantizer::QT_fp16: + return new QuantizerFP16(d, trained); + case ScalarQuantizer::QT_bf16: + return new QuantizerBF16(d, trained); + case ScalarQuantizer::QT_8bit_direct: + return new Quantizer8bitDirect(d, trained); + case ScalarQuantizer::QT_8bit_direct_signed: + return new Quantizer8bitDirectSigned(d, trained); } -}; - -#endif + FAISS_THROW_MSG("unknown qtype"); +} /******************************************************************* * select_distance_computer: runtime selection of template @@ -1974,7 +235,9 @@ SQDistanceComputer* select_distance_computer( return nullptr; } -} // anonymous namespace +} // namespace scalar_quantizer + +using namespace scalar_quantizer; /******************************************************************* * ScalarQuantizer implementation diff --git a/faiss/impl/ScalarQuantizer.h b/faiss/impl/ScalarQuantizer.h index c1f4f98f63..279938443a 100644 --- a/faiss/impl/ScalarQuantizer.h +++ b/faiss/impl/ScalarQuantizer.h @@ -5,8 +5,6 @@ * LICENSE file in the root directory of this source tree. */ -// -*- c++ -*- - #pragma once #include diff --git a/faiss/impl/scalar_quantizer/codecs.h b/faiss/impl/scalar_quantizer/codecs.h new file mode 100644 index 0000000000..31c75bc632 --- /dev/null +++ b/faiss/impl/scalar_quantizer/codecs.h @@ -0,0 +1,305 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * + * This source code is licensed under the MIT license found in the + * LICENSE file in the root directory of this source tree. + */ + +#pragma once + +#include + +namespace faiss { + +namespace scalar_quantizer { + +/******************************************************************* + * Codec: converts between values in [0, 1] and an index in a code + * array. The "i" parameter is the vector component index (not byte + * index). + */ + +struct Codec8bit { + static FAISS_ALWAYS_INLINE void encode_component( + float x, + uint8_t* code, + int i) { + code[i] = (int)(255 * x); + } + + static FAISS_ALWAYS_INLINE float decode_component( + const uint8_t* code, + int i) { + return (code[i] + 0.5f) / 255.0f; + } + +#if defined(__AVX512F__) + static FAISS_ALWAYS_INLINE simd16float32 + decode_16_components(const uint8_t* code, int i) { + const __m128i c16 = _mm_loadu_si128((__m128i*)(code + i)); + const __m512i i32 = _mm512_cvtepu8_epi32(c16); + const __m512 f16 = _mm512_cvtepi32_ps(i32); + const __m512 half_one_255 = _mm512_set1_ps(0.5f / 255.f); + const __m512 one_255 = _mm512_set1_ps(1.f / 255.f); + return simd16float32(_mm512_fmadd_ps(f16, one_255, half_one_255)); + } +#elif defined(__AVX2__) + static FAISS_ALWAYS_INLINE simd8float32 + decode_8_components(const uint8_t* code, int i) { + const uint64_t c8 = *(uint64_t*)(code + i); + + const __m128i i8 = _mm_set1_epi64x(c8); + const __m256i i32 = _mm256_cvtepu8_epi32(i8); + const __m256 f8 = _mm256_cvtepi32_ps(i32); + const __m256 half_one_255 = _mm256_set1_ps(0.5f / 255.f); + const __m256 one_255 = _mm256_set1_ps(1.f / 255.f); + return simd8float32(_mm256_fmadd_ps(f8, one_255, half_one_255)); + } +#endif + +#ifdef USE_NEON + static FAISS_ALWAYS_INLINE decode_8_components(const uint8_t* code, int i) { + float32_t result[8] = {}; + for (size_t j = 0; j < 8; j++) { + result[j] = decode_component(code, i + j); + } + float32x4_t res1 = vld1q_f32(result); + float32x4_t res2 = vld1q_f32(result + 4); + return simd8float32(float32x4x2_t{res1, res2}); + } +#endif +}; + +struct Codec4bit { + static FAISS_ALWAYS_INLINE void encode_component( + float x, + uint8_t* code, + int i) { + code[i / 2] |= (int)(x * 15.0) << ((i & 1) << 2); + } + + static FAISS_ALWAYS_INLINE float decode_component( + const uint8_t* code, + int i) { + return (((code[i / 2] >> ((i & 1) << 2)) & 0xf) + 0.5f) / 15.0f; + } + +#if defined(__AVX512F__) + static FAISS_ALWAYS_INLINE simd16float32 + decode_16_components(const uint8_t* code, int i) { + uint64_t c8 = *(uint64_t*)(code + (i >> 1)); + uint64_t mask = 0x0f0f0f0f0f0f0f0f; + uint64_t c8ev = c8 & mask; + uint64_t c8od = (c8 >> 4) & mask; + + __m128i c16 = + _mm_unpacklo_epi8(_mm_set1_epi64x(c8ev), _mm_set1_epi64x(c8od)); + __m256i c8lo = _mm256_cvtepu8_epi32(c16); + __m256i c8hi = _mm256_cvtepu8_epi32(_mm_srli_si128(c16, 8)); + __m512i i16 = _mm512_castsi256_si512(c8lo); + i16 = _mm512_inserti32x8(i16, c8hi, 1); + __m512 f16 = _mm512_cvtepi32_ps(i16); + const __m512 half_one_255 = _mm512_set1_ps(0.5f / 15.f); + const __m512 one_255 = _mm512_set1_ps(1.f / 15.f); + return simd16float32(_mm512_fmadd_ps(f16, one_255, half_one_255)); + } +#elif defined(__AVX2__) + static FAISS_ALWAYS_INLINE simd8float32 + decode_8_components(const uint8_t* code, int i) { + uint32_t c4 = *(uint32_t*)(code + (i >> 1)); + uint32_t mask = 0x0f0f0f0f; + uint32_t c4ev = c4 & mask; + uint32_t c4od = (c4 >> 4) & mask; + + // the 8 lower bytes of c8 contain the values + __m128i c8 = + _mm_unpacklo_epi8(_mm_set1_epi32(c4ev), _mm_set1_epi32(c4od)); + __m128i c4lo = _mm_cvtepu8_epi32(c8); + __m128i c4hi = _mm_cvtepu8_epi32(_mm_srli_si128(c8, 4)); + __m256i i8 = _mm256_castsi128_si256(c4lo); + i8 = _mm256_insertf128_si256(i8, c4hi, 1); + __m256 f8 = _mm256_cvtepi32_ps(i8); + __m256 half = _mm256_set1_ps(0.5f); + f8 = _mm256_add_ps(f8, half); + __m256 one_255 = _mm256_set1_ps(1.f / 15.f); + return simd8float32(_mm256_mul_ps(f8, one_255)); + } +#endif + +#ifdef USE_NEON + static FAISS_ALWAYS_INLINE simd8float32 + decode_8_components(const uint8_t* code, int i) { + float32_t result[8] = {}; + for (size_t j = 0; j < 8; j++) { + result[j] = decode_component(code, i + j); + } + float32x4_t res1 = vld1q_f32(result); + float32x4_t res2 = vld1q_f32(result + 4); + return simd8float32({res1, res2}); + } +#endif +}; + +struct Codec6bit { + static FAISS_ALWAYS_INLINE void encode_component( + float x, + uint8_t* code, + int i) { + int bits = (int)(x * 63.0); + code += (i >> 2) * 3; + switch (i & 3) { + case 0: + code[0] |= bits; + break; + case 1: + code[0] |= bits << 6; + code[1] |= bits >> 2; + break; + case 2: + code[1] |= bits << 4; + code[2] |= bits >> 4; + break; + case 3: + code[2] |= bits << 2; + break; + } + } + + static FAISS_ALWAYS_INLINE float decode_component( + const uint8_t* code, + int i) { + uint8_t bits; + code += (i >> 2) * 3; + switch (i & 3) { + case 0: + bits = code[0] & 0x3f; + break; + case 1: + bits = code[0] >> 6; + bits |= (code[1] & 0xf) << 2; + break; + case 2: + bits = code[1] >> 4; + bits |= (code[2] & 3) << 4; + break; + case 3: + bits = code[2] >> 2; + break; + } + return (bits + 0.5f) / 63.0f; + } + +#if defined(__AVX512F__) + + static FAISS_ALWAYS_INLINE simd16float32 + decode_16_components(const uint8_t* code, int i) { + // pure AVX512 implementation (not necessarily the fastest). + // see: + // https://github.com/zilliztech/knowhere/blob/main/thirdparty/faiss/faiss/impl/ScalarQuantizerCodec_avx512.h + + // clang-format off + + // 16 components, 16x6 bit=12 bytes + const __m128i bit_6v = + _mm_maskz_loadu_epi8(0b0000111111111111, code + (i >> 2) * 3); + const __m256i bit_6v_256 = _mm256_broadcast_i32x4(bit_6v); + + // 00 01 02 03 04 05 06 07 08 09 0A 0B 0C 0D 0E 0F + // 00 01 02 03 + const __m256i shuffle_mask = _mm256_setr_epi16( + 0xFF00, 0x0100, 0x0201, 0xFF02, + 0xFF03, 0x0403, 0x0504, 0xFF05, + 0xFF06, 0x0706, 0x0807, 0xFF08, + 0xFF09, 0x0A09, 0x0B0A, 0xFF0B); + const __m256i shuffled = _mm256_shuffle_epi8(bit_6v_256, shuffle_mask); + + // 0: xxxxxxxx xx543210 + // 1: xxxx5432 10xxxxxx + // 2: xxxxxx54 3210xxxx + // 3: xxxxxxxx 543210xx + const __m256i shift_right_v = _mm256_setr_epi16( + 0x0U, 0x6U, 0x4U, 0x2U, + 0x0U, 0x6U, 0x4U, 0x2U, + 0x0U, 0x6U, 0x4U, 0x2U, + 0x0U, 0x6U, 0x4U, 0x2U); + __m256i shuffled_shifted = _mm256_srlv_epi16(shuffled, shift_right_v); + + // remove unneeded bits + shuffled_shifted = + _mm256_and_si256(shuffled_shifted, _mm256_set1_epi16(0x003F)); + + // scale + const __m512 f8 = + _mm512_cvtepi32_ps(_mm512_cvtepi16_epi32(shuffled_shifted)); + const __m512 half_one_255 = _mm512_set1_ps(0.5f / 63.f); + const __m512 one_255 = _mm512_set1_ps(1.f / 63.f); + return simd16float32(_mm512_fmadd_ps(f8, one_255, half_one_255)); + + // clang-format on + } + +#elif defined(__AVX2__) + + /* Load 6 bytes that represent 8 6-bit values, return them as a + * 8*32 bit vector register */ + static FAISS_ALWAYS_INLINE __m256i load6(const uint16_t* code16) { + const __m128i perm = _mm_set_epi8( + -1, 5, 5, 4, 4, 3, -1, 3, -1, 2, 2, 1, 1, 0, -1, 0); + const __m256i shifts = _mm256_set_epi32(2, 4, 6, 0, 2, 4, 6, 0); + + // load 6 bytes + __m128i c1 = + _mm_set_epi16(0, 0, 0, 0, 0, code16[2], code16[1], code16[0]); + + // put in 8 * 32 bits + __m128i c2 = _mm_shuffle_epi8(c1, perm); + __m256i c3 = _mm256_cvtepi16_epi32(c2); + + // shift and mask out useless bits + __m256i c4 = _mm256_srlv_epi32(c3, shifts); + __m256i c5 = _mm256_and_si256(_mm256_set1_epi32(63), c4); + return c5; + } + + static FAISS_ALWAYS_INLINE simd8float32 + decode_8_components(const uint8_t* code, int i) { + // // Faster code for Intel CPUs or AMD Zen3+, just keeping it here + // // for the reference, maybe, it becomes used oned day. + // const uint16_t* data16 = (const uint16_t*)(code + (i >> 2) * 3); + // const uint32_t* data32 = (const uint32_t*)data16; + // const uint64_t val = *data32 + ((uint64_t)data16[2] << 32); + // const uint64_t vext = _pdep_u64(val, 0x3F3F3F3F3F3F3F3FULL); + // const __m128i i8 = _mm_set1_epi64x(vext); + // const __m256i i32 = _mm256_cvtepi8_epi32(i8); + // const __m256 f8 = _mm256_cvtepi32_ps(i32); + // const __m256 half_one_255 = _mm256_set1_ps(0.5f / 63.f); + // const __m256 one_255 = _mm256_set1_ps(1.f / 63.f); + // return _mm256_fmadd_ps(f8, one_255, half_one_255); + + __m256i i8 = load6((const uint16_t*)(code + (i >> 2) * 3)); + __m256 f8 = _mm256_cvtepi32_ps(i8); + // this could also be done with bit manipulations but it is + // not obviously faster + const __m256 half_one_255 = _mm256_set1_ps(0.5f / 63.f); + const __m256 one_255 = _mm256_set1_ps(1.f / 63.f); + return simd8float32(_mm256_fmadd_ps(f8, one_255, half_one_255)); + } + +#endif + +#ifdef USE_NEON + static FAISS_ALWAYS_INLINE simd8float32 + decode_8_components(const uint8_t* code, int i) { + float32_t result[8] = {}; + for (size_t j = 0; j < 8; j++) { + result[j] = decode_component(code, i + j); + } + float32x4_t res1 = vld1q_f32(result); + float32x4_t res2 = vld1q_f32(result + 4); + return simd8float32(float32x4x2_t({res1, res2})); + } +#endif +}; + +} // namespace scalar_quantizer +} // namespace faiss diff --git a/faiss/impl/scalar_quantizer/distance_computers.h b/faiss/impl/scalar_quantizer/distance_computers.h new file mode 100644 index 0000000000..96de493204 --- /dev/null +++ b/faiss/impl/scalar_quantizer/distance_computers.h @@ -0,0 +1,381 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * + * This source code is licensed under the MIT license found in the + * LICENSE file in the root directory of this source tree. + */ + +#pragma once + +#include + +namespace faiss { + +namespace scalar_quantizer { + +using SQDistanceComputer = ScalarQuantizer::SQDistanceComputer; + +template +struct DCTemplate : SQDistanceComputer {}; + +template +struct DCTemplate : SQDistanceComputer { + using Sim = Similarity; + + Quantizer quant; + + DCTemplate(size_t d, const std::vector& trained) + : quant(d, trained) {} + + float compute_distance(const float* x, const uint8_t* code) const { + Similarity sim(x); + sim.begin(); + for (size_t i = 0; i < quant.d; i++) { + float xi = quant.reconstruct_component(code, i); + sim.add_component(xi); + } + return sim.result(); + } + + float compute_code_distance(const uint8_t* code1, const uint8_t* code2) + const { + Similarity sim(nullptr); + sim.begin(); + for (size_t i = 0; i < quant.d; i++) { + float x1 = quant.reconstruct_component(code1, i); + float x2 = quant.reconstruct_component(code2, i); + sim.add_component_2(x1, x2); + } + return sim.result(); + } + + void set_query(const float* x) final { + q = x; + } + + float symmetric_dis(idx_t i, idx_t j) override { + return compute_code_distance( + codes + i * code_size, codes + j * code_size); + } + + float query_to_code(const uint8_t* code) const final { + return compute_distance(q, code); + } +}; + +#if defined(USE_AVX512_F16C) + +template +struct DCTemplate + : SQDistanceComputer { // Update to handle 16 lanes + using Sim = Similarity; + + Quantizer quant; + + DCTemplate(size_t d, const std::vector& trained) + : quant(d, trained) {} + + float compute_distance(const float* x, const uint8_t* code) const { + Similarity sim(x); + sim.begin_16(); + for (size_t i = 0; i < quant.d; i += 16) { + simd16float32 xi = quant.reconstruct_16_components(code, i); + sim.add_16_components(xi); + } + return sim.result_16(); + } + + float compute_code_distance(const uint8_t* code1, const uint8_t* code2) + const { + Similarity sim(nullptr); + sim.begin_16(); + for (size_t i = 0; i < quant.d; i += 16) { + simd16float32 x1 = quant.reconstruct_16_components(code1, i); + simd16float32 x2 = quant.reconstruct_16_components(code2, i); + sim.add_16_components_2(x1, x2); + } + return sim.result_16(); + } + + void set_query(const float* x) final { + q = x; + } + + float symmetric_dis(idx_t i, idx_t j) override { + return compute_code_distance( + codes + i * code_size, codes + j * code_size); + } + + float query_to_code(const uint8_t* code) const final { + return compute_distance(q, code); + } +}; + +#elif defined(USE_F16C) || defined(USE_NEON) + +template +struct DCTemplate : SQDistanceComputer { + using Sim = Similarity; + + Quantizer quant; + + DCTemplate(size_t d, const std::vector& trained) + : quant(d, trained) {} + + float compute_distance(const float* x, const uint8_t* code) const { + Similarity sim(x); + sim.begin_8(); + for (size_t i = 0; i < quant.d; i += 8) { + simd8float32 xi = quant.reconstruct_8_components(code, i); + sim.add_8_components(xi); + } + return sim.result_8(); + } + + float compute_code_distance(const uint8_t* code1, const uint8_t* code2) + const { + Similarity sim(nullptr); + sim.begin_8(); + for (size_t i = 0; i < quant.d; i += 8) { + simd8float32 x1 = quant.reconstruct_8_components(code1, i); + simd8float32 x2 = quant.reconstruct_8_components(code2, i); + sim.add_8_components_2(x1, x2); + } + return sim.result_8(); + } + + void set_query(const float* x) final { + q = x; + } + + float symmetric_dis(idx_t i, idx_t j) override { + return compute_code_distance( + codes + i * code_size, codes + j * code_size); + } + + float query_to_code(const uint8_t* code) const final { + return compute_distance(q, code); + } +}; + +#endif + +/******************************************************************* + * DistanceComputerByte: computes distances in the integer domain + *******************************************************************/ + +template +struct DistanceComputerByte : SQDistanceComputer {}; + +template +struct DistanceComputerByte : SQDistanceComputer { + using Sim = Similarity; + + int d; + std::vector tmp; + + DistanceComputerByte(int d, const std::vector&) : d(d), tmp(d) {} + + int compute_code_distance(const uint8_t* code1, const uint8_t* code2) + const { + int accu = 0; + for (int i = 0; i < d; i++) { + if (Sim::metric_type == METRIC_INNER_PRODUCT) { + accu += int(code1[i]) * code2[i]; + } else { + int diff = int(code1[i]) - code2[i]; + accu += diff * diff; + } + } + return accu; + } + + void set_query(const float* x) final { + for (int i = 0; i < d; i++) { + tmp[i] = int(x[i]); + } + } + + int compute_distance(const float* x, const uint8_t* code) { + set_query(x); + return compute_code_distance(tmp.data(), code); + } + + float symmetric_dis(idx_t i, idx_t j) override { + return compute_code_distance( + codes + i * code_size, codes + j * code_size); + } + + float query_to_code(const uint8_t* code) const final { + return compute_code_distance(tmp.data(), code); + } +}; + +#if defined(__AVX512F__) + +template +struct DistanceComputerByte : SQDistanceComputer { + using Sim = Similarity; + + int d; + std::vector tmp; + + DistanceComputerByte(int d, const std::vector&) : d(d), tmp(d) {} + + int compute_code_distance(const uint8_t* code1, const uint8_t* code2) + const { + __m512i accu = _mm512_setzero_si512(); + for (int i = 0; i < d; i += 32) { // Process 32 bytes at a time + __m512i c1 = _mm512_cvtepu8_epi16( + _mm256_loadu_si256((__m256i*)(code1 + i))); + __m512i c2 = _mm512_cvtepu8_epi16( + _mm256_loadu_si256((__m256i*)(code2 + i))); + __m512i prod32; + if (Sim::metric_type == METRIC_INNER_PRODUCT) { + prod32 = _mm512_madd_epi16(c1, c2); + } else { + __m512i diff = _mm512_sub_epi16(c1, c2); + prod32 = _mm512_madd_epi16(diff, diff); + } + accu = _mm512_add_epi32(accu, prod32); + } + // Horizontally add elements of accu + return _mm512_reduce_add_epi32(accu); + } + + void set_query(const float* x) final { + for (int i = 0; i < d; i++) { + tmp[i] = int(x[i]); + } + } + + int compute_distance(const float* x, const uint8_t* code) { + set_query(x); + return compute_code_distance(tmp.data(), code); + } + + float symmetric_dis(idx_t i, idx_t j) override { + return compute_code_distance( + codes + i * code_size, codes + j * code_size); + } + + float query_to_code(const uint8_t* code) const final { + return compute_code_distance(tmp.data(), code); + } +}; + +#elif defined(__AVX2__) + +template +struct DistanceComputerByte : SQDistanceComputer { + using Sim = Similarity; + + int d; + std::vector tmp; + + DistanceComputerByte(int d, const std::vector&) : d(d), tmp(d) {} + + int compute_code_distance(const uint8_t* code1, const uint8_t* code2) + const { + // __m256i accu = _mm256_setzero_ps (); + __m256i accu = _mm256_setzero_si256(); + for (int i = 0; i < d; i += 16) { + // load 16 bytes, convert to 16 uint16_t + __m256i c1 = _mm256_cvtepu8_epi16( + _mm_loadu_si128((__m128i*)(code1 + i))); + __m256i c2 = _mm256_cvtepu8_epi16( + _mm_loadu_si128((__m128i*)(code2 + i))); + __m256i prod32; + if (Sim::metric_type == METRIC_INNER_PRODUCT) { + prod32 = _mm256_madd_epi16(c1, c2); + } else { + __m256i diff = _mm256_sub_epi16(c1, c2); + prod32 = _mm256_madd_epi16(diff, diff); + } + accu = _mm256_add_epi32(accu, prod32); + } + __m128i sum = _mm256_extractf128_si256(accu, 0); + sum = _mm_add_epi32(sum, _mm256_extractf128_si256(accu, 1)); + sum = _mm_hadd_epi32(sum, sum); + sum = _mm_hadd_epi32(sum, sum); + return _mm_cvtsi128_si32(sum); + } + + void set_query(const float* x) final { + /* + for (int i = 0; i < d; i += 8) { + __m256 xi = _mm256_loadu_ps (x + i); + __m256i ci = _mm256_cvtps_epi32(xi); + */ + for (int i = 0; i < d; i++) { + tmp[i] = int(x[i]); + } + } + + int compute_distance(const float* x, const uint8_t* code) { + set_query(x); + return compute_code_distance(tmp.data(), code); + } + + float symmetric_dis(idx_t i, idx_t j) override { + return compute_code_distance( + codes + i * code_size, codes + j * code_size); + } + + float query_to_code(const uint8_t* code) const final { + return compute_code_distance(tmp.data(), code); + } +}; + +#endif + +#ifdef USE_NEON + +template +struct DistanceComputerByte : SQDistanceComputer { + using Sim = Similarity; + + int d; + std::vector tmp; + + DistanceComputerByte(int d, const std::vector&) : d(d), tmp(d) {} + + int compute_code_distance(const uint8_t* code1, const uint8_t* code2) + const { + int accu = 0; + for (int i = 0; i < d; i++) { + if (Sim::metric_type == METRIC_INNER_PRODUCT) { + accu += int(code1[i]) * code2[i]; + } else { + int diff = int(code1[i]) - code2[i]; + accu += diff * diff; + } + } + return accu; + } + + void set_query(const float* x) final { + for (int i = 0; i < d; i++) { + tmp[i] = int(x[i]); + } + } + + int compute_distance(const float* x, const uint8_t* code) { + set_query(x); + return compute_code_distance(tmp.data(), code); + } + + float symmetric_dis(idx_t i, idx_t j) override { + return compute_code_distance( + codes + i * code_size, codes + j * code_size); + } + + float query_to_code(const uint8_t* code) const final { + return compute_code_distance(tmp.data(), code); + } +}; + +#endif + +} // namespace scalar_quantizer +} // namespace faiss diff --git a/faiss/impl/scalar_quantizer/quantizers.h b/faiss/impl/scalar_quantizer/quantizers.h new file mode 100644 index 0000000000..a4abf058c6 --- /dev/null +++ b/faiss/impl/scalar_quantizer/quantizers.h @@ -0,0 +1,586 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * + * This source code is licensed under the MIT license found in the + * LICENSE file in the root directory of this source tree. + */ + +#pragma once + +#include + +namespace faiss { + +namespace scalar_quantizer { + +/******************************************************************* + * Quantizer: normalizes scalar vector components, then passes them + * through a codec + *******************************************************************/ + +enum class QuantizerTemplateScaling { UNIFORM = 0, NON_UNIFORM = 1 }; + +template +struct QuantizerTemplate {}; + +template +struct QuantizerTemplate + : ScalarQuantizer::SQuantizer { + const size_t d; + const float vmin, vdiff; + + QuantizerTemplate(size_t d, const std::vector& trained) + : d(d), vmin(trained[0]), vdiff(trained[1]) {} + + void encode_vector(const float* x, uint8_t* code) const final { + for (size_t i = 0; i < d; i++) { + float xi = 0; + if (vdiff != 0) { + xi = (x[i] - vmin) / vdiff; + if (xi < 0) { + xi = 0; + } + if (xi > 1.0) { + xi = 1.0; + } + } + Codec::encode_component(xi, code, i); + } + } + + void decode_vector(const uint8_t* code, float* x) const final { + for (size_t i = 0; i < d; i++) { + float xi = Codec::decode_component(code, i); + x[i] = vmin + xi * vdiff; + } + } + + FAISS_ALWAYS_INLINE float reconstruct_component(const uint8_t* code, int i) + const { + float xi = Codec::decode_component(code, i); + return vmin + xi * vdiff; + } +}; + +#if defined(__AVX512F__) + +template +struct QuantizerTemplate + : QuantizerTemplate { + QuantizerTemplate(size_t d, const std::vector& trained) + : QuantizerTemplate( + d, + trained) {} + + FAISS_ALWAYS_INLINE simd16float32 + reconstruct_16_components(const uint8_t* code, int i) const { + __m512 xi = Codec::decode_16_components(code, i); + return simd16float32(_mm512_fmadd_ps( + xi, _mm512_set1_ps(this->vdiff), _mm512_set1_ps(this->vmin))); + } +}; + +#elif defined(__AVX2__) + +template +struct QuantizerTemplate + : QuantizerTemplate { + QuantizerTemplate(size_t d, const std::vector& trained) + : QuantizerTemplate( + d, + trained) {} + + FAISS_ALWAYS_INLINE simd8float32 + reconstruct_8_components(const uint8_t* code, int i) const { + __m256 xi = Codec::decode_8_components(code, i).f; + return simd8float32(_mm256_fmadd_ps( + xi, _mm256_set1_ps(this->vdiff), _mm256_set1_ps(this->vmin))); + } +}; + +#endif + +#ifdef USE_NEON + +template +struct QuantizerTemplate + : QuantizerTemplate { + QuantizerTemplate(size_t d, const std::vector& trained) + : QuantizerTemplate( + d, + trained) {} + + FAISS_ALWAYS_INLINE simd8float32 + reconstruct_8_components(const uint8_t* code, int i) const { + float32x4x2_t xi = Codec::decode_8_components(code, i); + return simd8float32(float32x4x2_t( + {vfmaq_f32( + vdupq_n_f32(this->vmin), + xi.val[0], + vdupq_n_f32(this->vdiff)), + vfmaq_f32( + vdupq_n_f32(this->vmin), + xi.val[1], + vdupq_n_f32(this->vdiff))})); + } +}; + +#endif + +template +struct QuantizerTemplate + : ScalarQuantizer::SQuantizer { + const size_t d; + const float *vmin, *vdiff; + + QuantizerTemplate(size_t d, const std::vector& trained) + : d(d), vmin(trained.data()), vdiff(trained.data() + d) {} + + void encode_vector(const float* x, uint8_t* code) const final { + for (size_t i = 0; i < d; i++) { + float xi = 0; + if (vdiff[i] != 0) { + xi = (x[i] - vmin[i]) / vdiff[i]; + if (xi < 0) { + xi = 0; + } + if (xi > 1.0) { + xi = 1.0; + } + } + Codec::encode_component(xi, code, i); + } + } + + void decode_vector(const uint8_t* code, float* x) const final { + for (size_t i = 0; i < d; i++) { + float xi = Codec::decode_component(code, i); + x[i] = vmin[i] + xi * vdiff[i]; + } + } + + FAISS_ALWAYS_INLINE float reconstruct_component(const uint8_t* code, int i) + const { + float xi = Codec::decode_component(code, i); + return vmin[i] + xi * vdiff[i]; + } +}; + +#if defined(__AVX512F__) + +template +struct QuantizerTemplate + : QuantizerTemplate { + QuantizerTemplate(size_t d, const std::vector& trained) + : QuantizerTemplate< + Codec, + QuantizerTemplateScaling::NON_UNIFORM, + 1>(d, trained) {} + + FAISS_ALWAYS_INLINE simd16float32 + reconstruct_16_components(const uint8_t* code, int i) const { + __m512 xi = Codec::decode_16_components(code, i).f; + return simd16float32(_mm512_fmadd_ps( + xi, + _mm512_loadu_ps(this->vdiff + i), + _mm512_loadu_ps(this->vmin + i))); + } +}; + +#elif defined(__AVX2__) + +template +struct QuantizerTemplate + : QuantizerTemplate { + QuantizerTemplate(size_t d, const std::vector& trained) + : QuantizerTemplate< + Codec, + QuantizerTemplateScaling::NON_UNIFORM, + 1>(d, trained) {} + + FAISS_ALWAYS_INLINE simd8float32 + reconstruct_8_components(const uint8_t* code, int i) const { + __m256 xi = Codec::decode_8_components(code, i).f; + return simd8float32(_mm256_fmadd_ps( + xi, + _mm256_loadu_ps(this->vdiff + i), + _mm256_loadu_ps(this->vmin + i))); + } +}; + +#endif + +#ifdef USE_NEON + +template +struct QuantizerTemplate + : QuantizerTemplate { + QuantizerTemplate(size_t d, const std::vector& trained) + : QuantizerTemplate< + Codec, + QuantizerTemplateScaling::NON_UNIFORM, + 1>(d, trained) {} + + FAISS_ALWAYS_INLINE simd8float32 + reconstruct_8_components(const uint8_t* code, int i) const { + float32x4x2_t xi = Codec::decode_8_components(code, i).data; + + float32x4x2_t vmin_8 = vld1q_f32_x2(this->vmin + i); + float32x4x2_t vdiff_8 = vld1q_f32_x2(this->vdiff + i); + + return simd8float32( + {vfmaq_f32(vmin_8.val[0], xi.val[0], vdiff_8.val[0]), + vfmaq_f32(vmin_8.val[1], xi.val[1], vdiff_8.val[1])}); + } +}; + +#endif + +/******************************************************************* + * FP16 quantizer + *******************************************************************/ + +template +struct QuantizerFP16 {}; + +template <> +struct QuantizerFP16<1> : ScalarQuantizer::SQuantizer { + const size_t d; + + QuantizerFP16(size_t d, const std::vector& /* unused */) : d(d) {} + + void encode_vector(const float* x, uint8_t* code) const final { + for (size_t i = 0; i < d; i++) { + ((uint16_t*)code)[i] = encode_fp16(x[i]); + } + } + + void decode_vector(const uint8_t* code, float* x) const final { + for (size_t i = 0; i < d; i++) { + x[i] = decode_fp16(((uint16_t*)code)[i]); + } + } + + FAISS_ALWAYS_INLINE float reconstruct_component(const uint8_t* code, int i) + const { + return decode_fp16(((uint16_t*)code)[i]); + } +}; + +#if defined(USE_AVX512_F16C) + +template <> +struct QuantizerFP16<16> : QuantizerFP16<1> { + QuantizerFP16(size_t d, const std::vector& trained) + : QuantizerFP16<1>(d, trained) {} + + FAISS_ALWAYS_INLINE simd16float32 + reconstruct_16_components(const uint8_t* code, int i) const { + __m256i codei = _mm256_loadu_si256((const __m256i*)(code + 2 * i)); + return simd16float32(_mm512_cvtph_ps(codei)); + } +}; + +#endif + +#if defined(USE_F16C) + +template <> +struct QuantizerFP16<8> : QuantizerFP16<1> { + QuantizerFP16(size_t d, const std::vector& trained) + : QuantizerFP16<1>(d, trained) {} + + FAISS_ALWAYS_INLINE simd8float32 + reconstruct_8_components(const uint8_t* code, int i) const { + __m128i codei = _mm_loadu_si128((const __m128i*)(code + 2 * i)); + return simd8float32(_mm256_cvtph_ps(codei)); + } +}; + +#endif + +#ifdef USE_NEON + +template <> +struct QuantizerFP16<8> : QuantizerFP16<1> { + QuantizerFP16(size_t d, const std::vector& trained) + : QuantizerFP16<1>(d, trained) {} + + FAISS_ALWAYS_INLINE simd8float32 + reconstruct_8_components(const uint8_t* code, int i) const { + uint16x4x2_t codei = vld1_u16_x2((const uint16_t*)(code + 2 * i)); + return simd8float32( + {vcvt_f32_f16(vreinterpret_f16_u16(codei.val[0])), + vcvt_f32_f16(vreinterpret_f16_u16(codei.val[1]))}); + } +}; +#endif + +/******************************************************************* + * BF16 quantizer + *******************************************************************/ + +template +struct QuantizerBF16 {}; + +template <> +struct QuantizerBF16<1> : ScalarQuantizer::SQuantizer { + const size_t d; + + QuantizerBF16(size_t d, const std::vector& /* unused */) : d(d) {} + + void encode_vector(const float* x, uint8_t* code) const final { + for (size_t i = 0; i < d; i++) { + ((uint16_t*)code)[i] = encode_bf16(x[i]); + } + } + + void decode_vector(const uint8_t* code, float* x) const final { + for (size_t i = 0; i < d; i++) { + x[i] = decode_bf16(((uint16_t*)code)[i]); + } + } + + FAISS_ALWAYS_INLINE float reconstruct_component(const uint8_t* code, int i) + const { + return decode_bf16(((uint16_t*)code)[i]); + } +}; + +#if defined(__AVX512F__) + +template <> +struct QuantizerBF16<16> : QuantizerBF16<1> { + QuantizerBF16(size_t d, const std::vector& trained) + : QuantizerBF16<1>(d, trained) {} + FAISS_ALWAYS_INLINE simd16float32 + reconstruct_16_components(const uint8_t* code, int i) const { + __m256i code_256i = _mm256_loadu_si256((const __m256i*)(code + 2 * i)); + __m512i code_512i = _mm512_cvtepu16_epi32(code_256i); + code_512i = _mm512_slli_epi32(code_512i, 16); + return simd16float32(_mm512_castsi512_ps(code_512i)); + } +}; + +#elif defined(__AVX2__) + +template <> +struct QuantizerBF16<8> : QuantizerBF16<1> { + QuantizerBF16(size_t d, const std::vector& trained) + : QuantizerBF16<1>(d, trained) {} + + FAISS_ALWAYS_INLINE simd8float32 + reconstruct_8_components(const uint8_t* code, int i) const { + __m128i code_128i = _mm_loadu_si128((const __m128i*)(code + 2 * i)); + __m256i code_256i = _mm256_cvtepu16_epi32(code_128i); + code_256i = _mm256_slli_epi32(code_256i, 16); + return simd8float32(_mm256_castsi256_ps(code_256i)); + } +}; + +#endif + +#ifdef USE_NEON + +template <> +struct QuantizerBF16<8> : QuantizerBF16<1> { + QuantizerBF16(size_t d, const std::vector& trained) + : QuantizerBF16<1>(d, trained) {} + + FAISS_ALWAYS_INLINE simd8float32 + reconstruct_8_components(const uint8_t* code, int i) const { + uint16x4x2_t codei = vld1_u16_x2((const uint16_t*)(code + 2 * i)); + return simd8float32( + {vreinterpretq_f32_u32( + vshlq_n_u32(vmovl_u16(codei.val[0]), 16)), + vreinterpretq_f32_u32( + vshlq_n_u32(vmovl_u16(codei.val[1]), 16))}); + } +}; +#endif + +/******************************************************************* + * 8bit_direct quantizer + *******************************************************************/ + +template +struct Quantizer8bitDirect {}; + +template <> +struct Quantizer8bitDirect<1> : ScalarQuantizer::SQuantizer { + const size_t d; + + Quantizer8bitDirect(size_t d, const std::vector& /* unused */) + : d(d) {} + + void encode_vector(const float* x, uint8_t* code) const final { + for (size_t i = 0; i < d; i++) { + code[i] = (uint8_t)x[i]; + } + } + + void decode_vector(const uint8_t* code, float* x) const final { + for (size_t i = 0; i < d; i++) { + x[i] = code[i]; + } + } + + FAISS_ALWAYS_INLINE float reconstruct_component(const uint8_t* code, int i) + const { + return code[i]; + } +}; + +#if defined(__AVX512F__) + +template <> +struct Quantizer8bitDirect<16> : Quantizer8bitDirect<1> { + Quantizer8bitDirect(size_t d, const std::vector& trained) + : Quantizer8bitDirect<1>(d, trained) {} + + FAISS_ALWAYS_INLINE simd16float32 + reconstruct_16_components(const uint8_t* code, int i) const { + __m128i x16 = _mm_loadu_si128((__m128i*)(code + i)); // 16 * int8 + __m512i y16 = _mm512_cvtepu8_epi32(x16); // 16 * int32 + return simd16float32(_mm512_cvtepi32_ps(y16)); // 16 * float32 + } +}; + +#elif defined(__AVX2__) + +template <> +struct Quantizer8bitDirect<8> : Quantizer8bitDirect<1> { + Quantizer8bitDirect(size_t d, const std::vector& trained) + : Quantizer8bitDirect<1>(d, trained) {} + + FAISS_ALWAYS_INLINE simd8float32 + reconstruct_8_components(const uint8_t* code, int i) const { + __m128i x8 = _mm_loadl_epi64((__m128i*)(code + i)); // 8 * int8 + __m256i y8 = _mm256_cvtepu8_epi32(x8); // 8 * int32 + return simd8float32(_mm256_cvtepi32_ps(y8)); // 8 * float32 + } +}; + +#endif + +#ifdef USE_NEON + +template <> +struct Quantizer8bitDirect<8> : Quantizer8bitDirect<1> { + Quantizer8bitDirect(size_t d, const std::vector& trained) + : Quantizer8bitDirect<1>(d, trained) {} + + FAISS_ALWAYS_INLINE simd8float32 + reconstruct_8_components(const uint8_t* code, int i) const { + uint8x8_t x8 = vld1_u8((const uint8_t*)(code + i)); + uint16x8_t y8 = vmovl_u8(x8); + uint16x4_t y8_0 = vget_low_u16(y8); + uint16x4_t y8_1 = vget_high_u16(y8); + + // convert uint16 -> uint32 -> fp32 + return simd8float32( + {vcvtq_f32_u32(vmovl_u16(y8_0)), + vcvtq_f32_u32(vmovl_u16(y8_1))}); + } +}; + +#endif + +/******************************************************************* + * 8bit_direct_signed quantizer + *******************************************************************/ + +template +struct Quantizer8bitDirectSigned {}; + +template <> +struct Quantizer8bitDirectSigned<1> : ScalarQuantizer::SQuantizer { + const size_t d; + + Quantizer8bitDirectSigned(size_t d, const std::vector& /* unused */) + : d(d) {} + + void encode_vector(const float* x, uint8_t* code) const final { + for (size_t i = 0; i < d; i++) { + code[i] = (uint8_t)(x[i] + 128); + } + } + + void decode_vector(const uint8_t* code, float* x) const final { + for (size_t i = 0; i < d; i++) { + x[i] = code[i] - 128; + } + } + + FAISS_ALWAYS_INLINE float reconstruct_component(const uint8_t* code, int i) + const { + return code[i] - 128; + } +}; + +#if defined(__AVX512F__) + +template <> +struct Quantizer8bitDirectSigned<16> : Quantizer8bitDirectSigned<1> { + Quantizer8bitDirectSigned(size_t d, const std::vector& trained) + : Quantizer8bitDirectSigned<1>(d, trained) {} + + FAISS_ALWAYS_INLINE simd16float32 + reconstruct_16_components(const uint8_t* code, int i) const { + __m128i x16 = _mm_loadu_si128((__m128i*)(code + i)); // 16 * int8 + __m512i y16 = _mm512_cvtepu8_epi32(x16); // 16 * int32 + __m512i c16 = _mm512_set1_epi32(128); + __m512i z16 = _mm512_sub_epi32(y16, c16); // subtract 128 from all lanes + return simd16float32(_mm512_cvtepi32_ps(z16)); // 16 * float32 + } +}; + +#elif defined(__AVX2__) + +template <> +struct Quantizer8bitDirectSigned<8> : Quantizer8bitDirectSigned<1> { + Quantizer8bitDirectSigned(size_t d, const std::vector& trained) + : Quantizer8bitDirectSigned<1>(d, trained) {} + + FAISS_ALWAYS_INLINE simd8float32 + reconstruct_8_components(const uint8_t* code, int i) const { + __m128i x8 = _mm_loadl_epi64((__m128i*)(code + i)); // 8 * int8 + __m256i y8 = _mm256_cvtepu8_epi32(x8); // 8 * int32 + __m256i c8 = _mm256_set1_epi32(128); + __m256i z8 = _mm256_sub_epi32(y8, c8); // subtract 128 from all lanes + return simd8float32(_mm256_cvtepi32_ps(z8)); // 8 * float32 + } +}; + +#endif + +#ifdef USE_NEON + +template <> +struct Quantizer8bitDirectSigned<8> : Quantizer8bitDirectSigned<1> { + Quantizer8bitDirectSigned(size_t d, const std::vector& trained) + : Quantizer8bitDirectSigned<1>(d, trained) {} + + FAISS_ALWAYS_INLINE simd8float32 + reconstruct_8_components(const uint8_t* code, int i) const { + uint8x8_t x8 = vld1_u8((const uint8_t*)(code + i)); + uint16x8_t y8 = vmovl_u8(x8); // convert uint8 -> uint16 + uint16x4_t y8_0 = vget_low_u16(y8); + uint16x4_t y8_1 = vget_high_u16(y8); + + float32x4_t z8_0 = vcvtq_f32_u32( + vmovl_u16(y8_0)); // convert uint16 -> uint32 -> fp32 + float32x4_t z8_1 = vcvtq_f32_u32(vmovl_u16(y8_1)); + + // subtract 128 to convert into signed numbers + return simd8float32( + {vsubq_f32(z8_0, vmovq_n_f32(128.0)), + vsubq_f32(z8_1, vmovq_n_f32(128.0))}); + } +}; + +#endif + +} // namespace scalar_quantizer + +} // namespace faiss diff --git a/faiss/impl/scalar_quantizer/similarities.h b/faiss/impl/scalar_quantizer/similarities.h new file mode 100644 index 0000000000..99e5b1c089 --- /dev/null +++ b/faiss/impl/scalar_quantizer/similarities.h @@ -0,0 +1,345 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * + * This source code is licensed under the MIT license found in the + * LICENSE file in the root directory of this source tree. + */ + +#pragma once + +#include + +namespace faiss { + +namespace scalar_quantizer { + +template +struct SimilarityL2 {}; + +template <> +struct SimilarityL2<1> { + static constexpr int simdwidth = 1; + static constexpr MetricType metric_type = METRIC_L2; + + const float *y, *yi; + + explicit SimilarityL2(const float* y) : y(y) {} + + /******* scalar accumulator *******/ + + float accu; + + FAISS_ALWAYS_INLINE void begin() { + accu = 0; + yi = y; + } + + FAISS_ALWAYS_INLINE void add_component(float x) { + float tmp = *yi++ - x; + accu += tmp * tmp; + } + + FAISS_ALWAYS_INLINE void add_component_2(float x1, float x2) { + float tmp = x1 - x2; + accu += tmp * tmp; + } + + FAISS_ALWAYS_INLINE float result() { + return accu; + } +}; + +#if defined(__AVX512F__) + +template <> +struct SimilarityL2<16> { + static constexpr int simdwidth = 16; + static constexpr MetricType metric_type = METRIC_L2; + + const float *y, *yi; + + explicit SimilarityL2(const float* y) : y(y) {} + simd16float32 accu16; + + FAISS_ALWAYS_INLINE void begin_16() { + accu16.clear(); + yi = y; + } + + FAISS_ALWAYS_INLINE void add_16_components(simd16float32 x) { + __m512 yiv = _mm512_loadu_ps(yi); + yi += 16; + __m512 tmp = _mm512_sub_ps(yiv, x.f); + accu16 = simd16float32(_mm512_fmadd_ps(tmp, tmp, accu16.f)); + } + + FAISS_ALWAYS_INLINE void add_16_components_2( + simd16float32 x, + simd16float32 y_2) { + __m512 tmp = _mm512_sub_ps(y_2.f, x.f); + accu16 = simd16float32(_mm512_fmadd_ps(tmp, tmp, accu16.f)); + } + + FAISS_ALWAYS_INLINE float result_16() { + // performs better than dividing into _mm256 and adding + return _mm512_reduce_add_ps(accu16.f); + } +}; + +#elif defined(__AVX2__) + +template <> +struct SimilarityL2<8> { + static constexpr int simdwidth = 8; + static constexpr MetricType metric_type = METRIC_L2; + + const float *y, *yi; + + explicit SimilarityL2(const float* y) : y(y) {} + simd8float32 accu8; + + FAISS_ALWAYS_INLINE void begin_8() { + accu8.clear(); + yi = y; + } + + FAISS_ALWAYS_INLINE void add_8_components(simd8float32 x) { + __m256 yiv = _mm256_loadu_ps(yi); + yi += 8; + __m256 tmp = _mm256_sub_ps(yiv, x.f); + accu8 = simd8float32(_mm256_fmadd_ps(tmp, tmp, accu8.f)); + } + + FAISS_ALWAYS_INLINE void add_8_components_2( + simd8float32 x, + simd8float32 y_2) { + __m256 tmp = _mm256_sub_ps(y_2.f, x.f); + accu8 = simd8float32(_mm256_fmadd_ps(tmp, tmp, accu8.f)); + } + + FAISS_ALWAYS_INLINE float result_8() { + const __m128 sum = _mm_add_ps( + _mm256_castps256_ps128(accu8.f), + _mm256_extractf128_ps(accu8.f, 1)); + const __m128 v0 = _mm_shuffle_ps(sum, sum, _MM_SHUFFLE(0, 0, 3, 2)); + const __m128 v1 = _mm_add_ps(sum, v0); + __m128 v2 = _mm_shuffle_ps(v1, v1, _MM_SHUFFLE(0, 0, 0, 1)); + const __m128 v3 = _mm_add_ps(v1, v2); + return _mm_cvtss_f32(v3); + } +}; + +#endif + +#ifdef USE_NEON +template <> +struct SimilarityL2<8> { + static constexpr int simdwidth = 8; + static constexpr MetricType metric_type = METRIC_L2; + + const float *y, *yi; + explicit SimilarityL2(const float* y) : y(y) {} + simd8float32 accu8; + + FAISS_ALWAYS_INLINE void begin_8() { + accu8 = {vdupq_n_f32(0.0f), vdupq_n_f32(0.0f)}; + yi = y; + } + + FAISS_ALWAYS_INLINE void add_8_components(simd8float32 x) { + float32x4x2_t yiv = vld1q_f32_x2(yi); + yi += 8; + + float32x4_t sub0 = vsubq_f32(yiv.val[0], x.val[0]); + float32x4_t sub1 = vsubq_f32(yiv.val[1], x.val[1]); + + float32x4_t accu8_0 = vfmaq_f32(accu8.val[0], sub0, sub0); + float32x4_t accu8_1 = vfmaq_f32(accu8.val[1], sub1, sub1); + + accu8 = simd8float32({accu8_0, accu8_1}); + } + + FAISS_ALWAYS_INLINE void add_8_components_2( + simd8float32 x, + simd8float32 y) { + float32x4_t sub0 = vsubq_f32(y.val[0], x.val[0]); + float32x4_t sub1 = vsubq_f32(y.val[1], x.val[1]); + + float32x4_t accu8_0 = vfmaq_f32(accu8.val[0], sub0, sub0); + float32x4_t accu8_1 = vfmaq_f32(accu8.val[1], sub1, sub1); + + accu8 = simd8float32({accu8_0, accu8_1}); + } + + FAISS_ALWAYS_INLINE float result_8() { + float32x4_t sum_0 = vpaddq_f32(accu8.data.val[0], accu8.data.val[0]); + float32x4_t sum_1 = vpaddq_f32(accu8.data.val[1], accu8.data.val[1]); + + float32x4_t sum2_0 = vpaddq_f32(sum_0, sum_0); + float32x4_t sum2_1 = vpaddq_f32(sum_1, sum_1); + return vgetq_lane_f32(sum2_0, 0) + vgetq_lane_f32(sum2_1, 0); + } +}; +#endif + +template +struct SimilarityIP {}; + +template <> +struct SimilarityIP<1> { + static constexpr int simdwidth = 1; + static constexpr MetricType metric_type = METRIC_INNER_PRODUCT; + const float *y, *yi; + + float accu; + + explicit SimilarityIP(const float* y) : y(y) {} + + FAISS_ALWAYS_INLINE void begin() { + accu = 0; + yi = y; + } + + FAISS_ALWAYS_INLINE void add_component(float x) { + accu += *yi++ * x; + } + + FAISS_ALWAYS_INLINE void add_component_2(float x1, float x2) { + accu += x1 * x2; + } + + FAISS_ALWAYS_INLINE float result() { + return accu; + } +}; + +#if defined(__AVX512F__) + +template <> +struct SimilarityIP<16> { + static constexpr int simdwidth = 16; + static constexpr MetricType metric_type = METRIC_INNER_PRODUCT; + + const float *y, *yi; + + float accu; + + explicit SimilarityIP(const float* y) : y(y) {} + + simd16float32 accu16; + + FAISS_ALWAYS_INLINE void begin_16() { + accu16.clear(); + yi = y; + } + + FAISS_ALWAYS_INLINE void add_16_components(__m512 x) { + __m512 yiv = _mm512_loadu_ps(yi); + yi += 16; + accu16.f = _mm512_fmadd_ps(yiv, x, accu16.f); + } + + FAISS_ALWAYS_INLINE void add_16_components_2(__m512 x1, __m512 x2) { + accu16.f = _mm512_fmadd_ps(x1, x2, accu16.f); + } + + FAISS_ALWAYS_INLINE float result_16() { + // performs better than dividing into _mm256 and adding + return _mm512_reduce_add_ps(accu16.f); + } +}; + +#elif defined(__AVX2__) + +template <> +struct SimilarityIP<8> { + static constexpr int simdwidth = 8; + static constexpr MetricType metric_type = METRIC_INNER_PRODUCT; + + const float *y, *yi; + + float accu; + + explicit SimilarityIP(const float* y) : y(y) {} + + simd8float32 accu8; + + FAISS_ALWAYS_INLINE void begin_8() { + accu8.clear(); + yi = y; + } + + FAISS_ALWAYS_INLINE void add_8_components(simd8float32 x) { + __m256 yiv = _mm256_loadu_ps(yi); + yi += 8; + accu8.f = _mm256_fmadd_ps(yiv, x.f, accu8.f); + } + + FAISS_ALWAYS_INLINE void add_8_components_2( + simd8float32 x1, + simd8float32 x2) { + accu8.f = _mm256_fmadd_ps(x1.f, x2.f, accu8.f); + } + + FAISS_ALWAYS_INLINE float result_8() { + const __m128 sum = _mm_add_ps( + _mm256_castps256_ps128(accu8.f), + _mm256_extractf128_ps(accu8.f, 1)); + const __m128 v0 = _mm_shuffle_ps(sum, sum, _MM_SHUFFLE(0, 0, 3, 2)); + const __m128 v1 = _mm_add_ps(sum, v0); + __m128 v2 = _mm_shuffle_ps(v1, v1, _MM_SHUFFLE(0, 0, 0, 1)); + const __m128 v3 = _mm_add_ps(v1, v2); + return _mm_cvtss_f32(v3); + } +}; +#endif + +#ifdef USE_NEON + +template <> +struct SimilarityIP<8> { + static constexpr int simdwidth = 8; + static constexpr MetricType metric_type = METRIC_INNER_PRODUCT; + + const float *y, *yi; + + explicit SimilarityIP(const float* y) : y(y) {} + float32x4x2_t accu8; + + FAISS_ALWAYS_INLINE void begin_8() { + accu8 = {vdupq_n_f32(0.0f), vdupq_n_f32(0.0f)}; + yi = y; + } + + FAISS_ALWAYS_INLINE void add_8_components(float32x4x2_t x) { + float32x4x2_t yiv = vld1q_f32_x2(yi); + yi += 8; + + float32x4_t accu8_0 = vfmaq_f32(accu8.val[0], yiv.val[0], x.val[0]); + float32x4_t accu8_1 = vfmaq_f32(accu8.val[1], yiv.val[1], x.val[1]); + accu8 = {accu8_0, accu8_1}; + } + + FAISS_ALWAYS_INLINE void add_8_components_2( + float32x4x2_t x1, + float32x4x2_t x2) { + float32x4_t accu8_0 = vfmaq_f32(accu8.val[0], x1.val[0], x2.val[0]); + float32x4_t accu8_1 = vfmaq_f32(accu8.val[1], x1.val[1], x2.val[1]); + accu8 = {accu8_0, accu8_1}; + } + + FAISS_ALWAYS_INLINE float result_8() { + float32x4x2_t sum = { + vpaddq_f32(accu8.val[0], accu8.val[0]), + vpaddq_f32(accu8.val[1], accu8.val[1])}; + + float32x4x2_t sum2 = { + vpaddq_f32(sum.val[0], sum.val[0]), + vpaddq_f32(sum.val[1], sum.val[1])}; + return vgetq_lane_f32(sum2.val[0], 0) + vgetq_lane_f32(sum2.val[1], 0); + } +}; +#endif + +} // namespace scalar_quantizer +} // namespace faiss diff --git a/faiss/impl/scalar_quantizer/training.cpp b/faiss/impl/scalar_quantizer/training.cpp new file mode 100644 index 0000000000..23c51384fd --- /dev/null +++ b/faiss/impl/scalar_quantizer/training.cpp @@ -0,0 +1,188 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * + * This source code is licensed under the MIT license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include + +#include +#include + +namespace faiss { + +namespace scalar_quantizer { +/******************************************************************* + * Quantizer range training + */ + +static float sqr(float x) { + return x * x; +} + +void train_Uniform( + RangeStat rs, + float rs_arg, + idx_t n, + int k, + const float* x, + std::vector& trained) { + trained.resize(2); + float& vmin = trained[0]; + float& vmax = trained[1]; + + if (rs == ScalarQuantizer::RS_minmax) { + vmin = HUGE_VAL; + vmax = -HUGE_VAL; + for (size_t i = 0; i < n; i++) { + if (x[i] < vmin) + vmin = x[i]; + if (x[i] > vmax) + vmax = x[i]; + } + float vexp = (vmax - vmin) * rs_arg; + vmin -= vexp; + vmax += vexp; + } else if (rs == ScalarQuantizer::RS_meanstd) { + double sum = 0, sum2 = 0; + for (size_t i = 0; i < n; i++) { + sum += x[i]; + sum2 += x[i] * x[i]; + } + float mean = sum / n; + float var = sum2 / n - mean * mean; + float std = var <= 0 ? 1.0 : sqrt(var); + + vmin = mean - std * rs_arg; + vmax = mean + std * rs_arg; + } else if (rs == ScalarQuantizer::RS_quantiles) { + std::vector x_copy(n); + memcpy(x_copy.data(), x, n * sizeof(*x)); + // TODO just do a quickselect + std::sort(x_copy.begin(), x_copy.end()); + int o = int(rs_arg * n); + if (o < 0) + o = 0; + if (o > n - o) + o = n / 2; + vmin = x_copy[o]; + vmax = x_copy[n - 1 - o]; + + } else if (rs == ScalarQuantizer::RS_optim) { + float a, b; + float sx = 0; + { + vmin = HUGE_VAL, vmax = -HUGE_VAL; + for (size_t i = 0; i < n; i++) { + if (x[i] < vmin) + vmin = x[i]; + if (x[i] > vmax) + vmax = x[i]; + sx += x[i]; + } + b = vmin; + a = (vmax - vmin) / (k - 1); + } + int verbose = false; + int niter = 2000; + float last_err = -1; + int iter_last_err = 0; + for (int it = 0; it < niter; it++) { + float sn = 0, sn2 = 0, sxn = 0, err1 = 0; + + for (idx_t i = 0; i < n; i++) { + float xi = x[i]; + float ni = floor((xi - b) / a + 0.5); + if (ni < 0) + ni = 0; + if (ni >= k) + ni = k - 1; + err1 += sqr(xi - (ni * a + b)); + sn += ni; + sn2 += ni * ni; + sxn += ni * xi; + } + + if (err1 == last_err) { + iter_last_err++; + if (iter_last_err == 16) + break; + } else { + last_err = err1; + iter_last_err = 0; + } + + float det = sqr(sn) - sn2 * n; + + b = (sn * sxn - sn2 * sx) / det; + a = (sn * sx - n * sxn) / det; + if (verbose) { + printf("it %d, err1=%g \r", it, err1); + fflush(stdout); + } + } + if (verbose) + printf("\n"); + + vmin = b; + vmax = b + a * (k - 1); + + } else { + FAISS_THROW_MSG("Invalid qtype"); + } + vmax -= vmin; +} + +void train_NonUniform( + RangeStat rs, + float rs_arg, + idx_t n, + int d, + int k, + const float* x, + std::vector& trained) { + trained.resize(2 * d); + float* vmin = trained.data(); + float* vmax = trained.data() + d; + if (rs == ScalarQuantizer::RS_minmax) { + memcpy(vmin, x, sizeof(*x) * d); + memcpy(vmax, x, sizeof(*x) * d); + for (size_t i = 1; i < n; i++) { + const float* xi = x + i * d; + for (size_t j = 0; j < d; j++) { + if (xi[j] < vmin[j]) + vmin[j] = xi[j]; + if (xi[j] > vmax[j]) + vmax[j] = xi[j]; + } + } + float* vdiff = vmax; + for (size_t j = 0; j < d; j++) { + float vexp = (vmax[j] - vmin[j]) * rs_arg; + vmin[j] -= vexp; + vmax[j] += vexp; + vdiff[j] = vmax[j] - vmin[j]; + } + } else { + // transpose + std::vector xt(n * d); + for (size_t i = 1; i < n; i++) { + const float* xi = x + i * d; + for (size_t j = 0; j < d; j++) { + xt[j * n + i] = xi[j]; + } + } + std::vector trained_d(2); +#pragma omp parallel for + for (int j = 0; j < d; j++) { + train_Uniform(rs, rs_arg, n, k, xt.data() + j * n, trained_d); + vmin[j] = trained_d[0]; + vmax[j] = trained_d[1]; + } + } +} + +} // namespace scalar_quantizer + +} // namespace faiss diff --git a/faiss/impl/scalar_quantizer/training.h b/faiss/impl/scalar_quantizer/training.h new file mode 100644 index 0000000000..9eeb39b926 --- /dev/null +++ b/faiss/impl/scalar_quantizer/training.h @@ -0,0 +1,40 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * + * This source code is licensed under the MIT license found in the + * LICENSE file in the root directory of this source tree. + */ + +/******************************************************************* + * Quantizer range training for the scalar quantizer. This is independent of the + * searching code and needs not to be very optimized (scalar quantizer training + * is very efficient). + */ + +#include + +namespace faiss { + +namespace scalar_quantizer { + +using RangeStat = ScalarQuantizer::RangeStat; + +void train_Uniform( + RangeStat rs, + float rs_arg, + idx_t n, + int k, + const float* x, + std::vector& trained); + +void train_NonUniform( + RangeStat rs, + float rs_arg, + idx_t n, + int d, + int k, + const float* x, + std::vector& trained); +} // namespace scalar_quantizer + +} // namespace faiss diff --git a/faiss/utils/simd_impl/simdlib_avx512.h b/faiss/utils/simd_impl/simdlib_avx512.h index 63b23f9b19..b1195c7e3c 100644 --- a/faiss/utils/simd_impl/simdlib_avx512.h +++ b/faiss/utils/simd_impl/simdlib_avx512.h @@ -14,7 +14,7 @@ #include -#include +#include namespace faiss { diff --git a/faiss/utils/simdlib.h b/faiss/utils/simdlib.h index 98c38f7a0d..2b8bef4716 100644 --- a/faiss/utils/simdlib.h +++ b/faiss/utils/simdlib.h @@ -16,8 +16,8 @@ #if defined(__AVX512F__) -#include -#include +#include +#include #elif defined(__AVX2__) From e4aab93bd4d6e860d10f03205c2ac825371f6f71 Mon Sep 17 00:00:00 2001 From: Subhadeep Karan Date: Thu, 28 Aug 2025 01:14:45 -0700 Subject: [PATCH 5/5] Split ScalarQuantizer code into independent parts (#4557) Summary: Pull Request resolved: https://github.com/facebookresearch/faiss/pull/4557 Pull Request resolved: https://github.com/facebookresearch/faiss/pull/4296 Splits the ScalarQuantizer code into parts so that the AVX2 and AVX512 can be compiled independently. Reviewed By: mnorris11 Differential Revision: D73037185 --- faiss/impl/ScalarQuantizer.cpp | 594 ++---------------- .../code_distance/code_distance-avx512.cpp | 2 +- faiss/impl/scalar_quantizer/codecs.h | 222 +------ .../scalar_quantizer/distance_computers.h | 400 +++++------- faiss/impl/scalar_quantizer/impl-avx2.cpp | 437 +++++++++++++ faiss/impl/scalar_quantizer/impl-avx512.cpp | 409 ++++++++++++ faiss/impl/scalar_quantizer/impl-neon.cpp | 377 +++++++++++ faiss/impl/scalar_quantizer/quantizers.h | 441 +++---------- faiss/impl/scalar_quantizer/scanners.h | 356 +++++++++++ faiss/impl/scalar_quantizer/similarities.h | 345 ---------- faiss/utils/simd_impl/simdlib_avx512.h | 43 ++ 11 files changed, 1898 insertions(+), 1728 deletions(-) create mode 100644 faiss/impl/scalar_quantizer/impl-avx2.cpp create mode 100644 faiss/impl/scalar_quantizer/impl-avx512.cpp create mode 100644 faiss/impl/scalar_quantizer/impl-neon.cpp create mode 100644 faiss/impl/scalar_quantizer/scanners.h delete mode 100644 faiss/impl/scalar_quantizer/similarities.h diff --git a/faiss/impl/ScalarQuantizer.cpp b/faiss/impl/ScalarQuantizer.cpp index ada60c116d..17d08f2e8c 100644 --- a/faiss/impl/ScalarQuantizer.cpp +++ b/faiss/impl/ScalarQuantizer.cpp @@ -12,20 +12,14 @@ #include #include -#include - -#include - -#ifdef __SSE__ -#include -#endif +#include #include #include #include #include -#include -#include +#include +#include #include /******************************************************************* @@ -43,26 +37,6 @@ * that hides the template mess. ********************************************************************/ -#if defined(__AVX512F__) && defined(__F16C__) -#define USE_AVX512_F16C -#elif defined(__AVX2__) -#ifdef __F16C__ -#define USE_F16C -#else -#warning \ - "Cannot enable AVX optimizations in scalar quantizer if -mf16c is not set as well" -#endif -#endif - -#if defined(__aarch64__) -#if defined(__GNUC__) && __GNUC__ < 8 -#warning \ - "Cannot enable NEON optimizations in scalar quantizer if the compiler is GCC<8" -#else -#define USE_NEON -#endif -#endif - /******************************************************************* * Codec: converts between values in [0, 1] and an index in a code * array. The "i" parameter is the vector component index (not byte @@ -82,161 +56,19 @@ * Similarity: gets vector components and computes a similarity wrt. a * query vector stored in the object. The data fields just encapsulate * an accumulator. - */ - -#include - -/******************************************************************* * DistanceComputer: combines a similarity and a quantizer to do * code-to-vector or code-to-code comparisons *******************************************************************/ #include -namespace faiss { - -namespace scalar_quantizer { - -typedef ScalarQuantizer::QuantizerType QuantizerType; -typedef ScalarQuantizer::RangeStat RangeStat; -using SQDistanceComputer = ScalarQuantizer::SQDistanceComputer; - -template -ScalarQuantizer::SQuantizer* select_quantizer_1( - QuantizerType qtype, - size_t d, - const std::vector& trained) { - switch (qtype) { - case ScalarQuantizer::QT_8bit: - return new QuantizerTemplate< - Codec8bit, - QuantizerTemplateScaling::NON_UNIFORM, - SIMDWIDTH>(d, trained); - case ScalarQuantizer::QT_6bit: - return new QuantizerTemplate< - Codec6bit, - QuantizerTemplateScaling::NON_UNIFORM, - SIMDWIDTH>(d, trained); - case ScalarQuantizer::QT_4bit: - return new QuantizerTemplate< - Codec4bit, - QuantizerTemplateScaling::NON_UNIFORM, - SIMDWIDTH>(d, trained); - case ScalarQuantizer::QT_8bit_uniform: - return new QuantizerTemplate< - Codec8bit, - QuantizerTemplateScaling::UNIFORM, - SIMDWIDTH>(d, trained); - case ScalarQuantizer::QT_4bit_uniform: - return new QuantizerTemplate< - Codec4bit, - QuantizerTemplateScaling::UNIFORM, - SIMDWIDTH>(d, trained); - case ScalarQuantizer::QT_fp16: - return new QuantizerFP16(d, trained); - case ScalarQuantizer::QT_bf16: - return new QuantizerBF16(d, trained); - case ScalarQuantizer::QT_8bit_direct: - return new Quantizer8bitDirect(d, trained); - case ScalarQuantizer::QT_8bit_direct_signed: - return new Quantizer8bitDirectSigned(d, trained); - } - FAISS_THROW_MSG("unknown qtype"); -} - /******************************************************************* - * select_distance_computer: runtime selection of template - * specialization + * InvertedListScanner: scans series of codes and keeps the best ones *******************************************************************/ -template -SQDistanceComputer* select_distance_computer( - QuantizerType qtype, - size_t d, - const std::vector& trained) { - constexpr int SIMDWIDTH = Sim::simdwidth; - switch (qtype) { - case ScalarQuantizer::QT_8bit_uniform: - return new DCTemplate< - QuantizerTemplate< - Codec8bit, - QuantizerTemplateScaling::UNIFORM, - SIMDWIDTH>, - Sim, - SIMDWIDTH>(d, trained); - - case ScalarQuantizer::QT_4bit_uniform: - return new DCTemplate< - QuantizerTemplate< - Codec4bit, - QuantizerTemplateScaling::UNIFORM, - SIMDWIDTH>, - Sim, - SIMDWIDTH>(d, trained); - - case ScalarQuantizer::QT_8bit: - return new DCTemplate< - QuantizerTemplate< - Codec8bit, - QuantizerTemplateScaling::NON_UNIFORM, - SIMDWIDTH>, - Sim, - SIMDWIDTH>(d, trained); - - case ScalarQuantizer::QT_6bit: - return new DCTemplate< - QuantizerTemplate< - Codec6bit, - QuantizerTemplateScaling::NON_UNIFORM, - SIMDWIDTH>, - Sim, - SIMDWIDTH>(d, trained); - - case ScalarQuantizer::QT_4bit: - return new DCTemplate< - QuantizerTemplate< - Codec4bit, - QuantizerTemplateScaling::NON_UNIFORM, - SIMDWIDTH>, - Sim, - SIMDWIDTH>(d, trained); - - case ScalarQuantizer::QT_fp16: - return new DCTemplate, Sim, SIMDWIDTH>( - d, trained); - - case ScalarQuantizer::QT_bf16: - return new DCTemplate, Sim, SIMDWIDTH>( - d, trained); - - case ScalarQuantizer::QT_8bit_direct: -#if defined(__AVX512F__) - if (d % 32 == 0) { - return new DistanceComputerByte(d, trained); - } else -#elif defined(__AVX2__) - if (d % 16 == 0) { - return new DistanceComputerByte(d, trained); - } else -#endif - { - return new DCTemplate< - Quantizer8bitDirect, - Sim, - SIMDWIDTH>(d, trained); - } - case ScalarQuantizer::QT_8bit_direct_signed: - return new DCTemplate< - Quantizer8bitDirectSigned, - Sim, - SIMDWIDTH>(d, trained); - } - FAISS_THROW_MSG("unknown qtype"); - return nullptr; -} - -} // namespace scalar_quantizer +#include +namespace faiss { using namespace scalar_quantizer; /******************************************************************* @@ -320,18 +152,19 @@ void ScalarQuantizer::train(size_t n, const float* x) { } ScalarQuantizer::SQuantizer* ScalarQuantizer::select_quantizer() const { -#if defined(USE_AVX512_F16C) - if (d % 16 == 0) { - return select_quantizer_1<16>(qtype, d, trained); - } else -#elif defined(USE_F16C) || defined(USE_NEON) - if (d % 8 == 0) { - return select_quantizer_1<8>(qtype, d, trained); - } else + // here we can't just dispatch because the SIMD code works only on certain + // vector sizes +#ifdef COMPILE_SIMD_AVX512 + if (d % 16 == 0 && SIMDConfig::level == SIMDLevel::AVX512) { + return select_quantizer_1(qtype, d, trained); + } #endif - { - return select_quantizer_1<1>(qtype, d, trained); +#ifdef COMPILE_SIMD_AVX2 + if (d % 8 == 0 && SIMDConfig::level == SIMDLevel::AVX2) { + return select_quantizer_1(qtype, d, trained); } +#endif + return select_quantizer_1(qtype, d, trained); } void ScalarQuantizer::compute_codes(const float* x, uint8_t* codes, size_t n) @@ -356,33 +189,20 @@ void ScalarQuantizer::decode(const uint8_t* codes, float* x, size_t n) const { SQDistanceComputer* ScalarQuantizer::get_distance_computer( MetricType metric) const { - FAISS_THROW_IF_NOT(metric == METRIC_L2 || metric == METRIC_INNER_PRODUCT); -#if defined(USE_AVX512_F16C) - if (d % 16 == 0) { - if (metric == METRIC_L2) { - return select_distance_computer>( - qtype, d, trained); - } else { - return select_distance_computer>( - qtype, d, trained); - } - } else -#elif defined(USE_F16C) || defined(USE_NEON) - if (d % 8 == 0) { - if (metric == METRIC_L2) { - return select_distance_computer>(qtype, d, trained); - } else { - return select_distance_computer>(qtype, d, trained); - } - } else +#ifdef COMPILE_SIMD_AVX512 + if (d % 16 == 0 && SIMDConfig::level == SIMDLevel::AVX512) { + return select_distance_computer_1( + metric, qtype, d, trained); + } #endif - { - if (metric == METRIC_L2) { - return select_distance_computer>(qtype, d, trained); - } else { - return select_distance_computer>(qtype, d, trained); - } +#ifdef COMPILE_SIMD_AVX2 + if (d % 8 == 0 && SIMDConfig::level == SIMDLevel::AVX2) { + return select_distance_computer_1( + metric, qtype, d, trained); } +#endif + return select_distance_computer_1( + metric, qtype, d, trained); } /******************************************************************* @@ -392,366 +212,26 @@ SQDistanceComputer* ScalarQuantizer::get_distance_computer( * IndexScalarQuantizer as well. ********************************************************************/ -namespace { - -template -struct IVFSQScannerIP : InvertedListScanner { - DCClass dc; - bool by_residual; - - float accu0; /// added to all distances - - IVFSQScannerIP( - int d, - const std::vector& trained, - size_t code_size, - bool store_pairs, - const IDSelector* sel, - bool by_residual) - : dc(d, trained), by_residual(by_residual), accu0(0) { - this->store_pairs = store_pairs; - this->sel = sel; - this->code_size = code_size; - this->keep_max = true; - } - - void set_query(const float* query) override { - dc.set_query(query); - } - - void set_list(idx_t list_no, float coarse_dis) override { - this->list_no = list_no; - accu0 = by_residual ? coarse_dis : 0; - } - - float distance_to_code(const uint8_t* code) const final { - return accu0 + dc.query_to_code(code); - } - - size_t scan_codes( - size_t list_size, - const uint8_t* codes, - const idx_t* ids, - float* simi, - idx_t* idxi, - size_t k) const override { - size_t nup = 0; - - for (size_t j = 0; j < list_size; j++, codes += code_size) { - if (use_sel && !sel->is_member(use_sel == 1 ? ids[j] : j)) { - continue; - } - - float accu = accu0 + dc.query_to_code(codes); - - if (accu > simi[0]) { - int64_t id = store_pairs ? (list_no << 32 | j) : ids[j]; - minheap_replace_top(k, simi, idxi, accu, id); - nup++; - } - } - return nup; - } - - void scan_codes_range( - size_t list_size, - const uint8_t* codes, - const idx_t* ids, - float radius, - RangeQueryResult& res) const override { - for (size_t j = 0; j < list_size; j++, codes += code_size) { - if (use_sel && !sel->is_member(use_sel == 1 ? ids[j] : j)) { - continue; - } - - float accu = accu0 + dc.query_to_code(codes); - if (accu > radius) { - int64_t id = store_pairs ? (list_no << 32 | j) : ids[j]; - res.add(accu, id); - } - } - } -}; - -/* use_sel = 0: don't check selector - * = 1: check on ids[j] - * = 2: check in j directly (normally ids is nullptr and store_pairs) - */ -template -struct IVFSQScannerL2 : InvertedListScanner { - DCClass dc; - - bool by_residual; - const Index* quantizer; - const float* x; /// current query - - std::vector tmp; - - IVFSQScannerL2( - int d, - const std::vector& trained, - size_t code_size, - const Index* quantizer, - bool store_pairs, - const IDSelector* sel, - bool by_residual) - : dc(d, trained), - by_residual(by_residual), - quantizer(quantizer), - x(nullptr), - tmp(d) { - this->store_pairs = store_pairs; - this->sel = sel; - this->code_size = code_size; - } - - void set_query(const float* query) override { - x = query; - if (!quantizer) { - dc.set_query(query); - } - } - - void set_list(idx_t list_no, float /*coarse_dis*/) override { - this->list_no = list_no; - if (by_residual) { - // shift of x_in wrt centroid - quantizer->compute_residual(x, tmp.data(), list_no); - dc.set_query(tmp.data()); - } else { - dc.set_query(x); - } - } - - float distance_to_code(const uint8_t* code) const final { - return dc.query_to_code(code); - } - - size_t scan_codes( - size_t list_size, - const uint8_t* codes, - const idx_t* ids, - float* simi, - idx_t* idxi, - size_t k) const override { - size_t nup = 0; - for (size_t j = 0; j < list_size; j++, codes += code_size) { - if (use_sel && !sel->is_member(use_sel == 1 ? ids[j] : j)) { - continue; - } - - float dis = dc.query_to_code(codes); - - if (dis < simi[0]) { - int64_t id = store_pairs ? (list_no << 32 | j) : ids[j]; - maxheap_replace_top(k, simi, idxi, dis, id); - nup++; - } - } - return nup; - } - - void scan_codes_range( - size_t list_size, - const uint8_t* codes, - const idx_t* ids, - float radius, - RangeQueryResult& res) const override { - for (size_t j = 0; j < list_size; j++, codes += code_size) { - if (use_sel && !sel->is_member(use_sel == 1 ? ids[j] : j)) { - continue; - } - - float dis = dc.query_to_code(codes); - if (dis < radius) { - int64_t id = store_pairs ? (list_no << 32 | j) : ids[j]; - res.add(dis, id); - } - } - } -}; - -template -InvertedListScanner* sel3_InvertedListScanner( - const ScalarQuantizer* sq, - const Index* quantizer, - bool store_pairs, - const IDSelector* sel, - bool r) { - if (DCClass::Sim::metric_type == METRIC_L2) { - return new IVFSQScannerL2( - sq->d, - sq->trained, - sq->code_size, - quantizer, - store_pairs, - sel, - r); - } else if (DCClass::Sim::metric_type == METRIC_INNER_PRODUCT) { - return new IVFSQScannerIP( - sq->d, sq->trained, sq->code_size, store_pairs, sel, r); - } else { - FAISS_THROW_MSG("unsupported metric type"); - } -} - -template -InvertedListScanner* sel2_InvertedListScanner( - const ScalarQuantizer* sq, - const Index* quantizer, - bool store_pairs, - const IDSelector* sel, - bool r) { - if (sel) { - if (store_pairs) { - return sel3_InvertedListScanner( - sq, quantizer, store_pairs, sel, r); - } else { - return sel3_InvertedListScanner( - sq, quantizer, store_pairs, sel, r); - } - } else { - return sel3_InvertedListScanner( - sq, quantizer, store_pairs, sel, r); - } -} - -template -InvertedListScanner* sel12_InvertedListScanner( - const ScalarQuantizer* sq, - const Index* quantizer, - bool store_pairs, - const IDSelector* sel, - bool r) { - constexpr int SIMDWIDTH = Similarity::simdwidth; - using QuantizerClass = QuantizerTemplate; - using DCClass = DCTemplate; - return sel2_InvertedListScanner( - sq, quantizer, store_pairs, sel, r); -} - -template -InvertedListScanner* sel1_InvertedListScanner( - const ScalarQuantizer* sq, - const Index* quantizer, - bool store_pairs, - const IDSelector* sel, - bool r) { - constexpr int SIMDWIDTH = Similarity::simdwidth; - switch (sq->qtype) { - case ScalarQuantizer::QT_8bit_uniform: - return sel12_InvertedListScanner< - Similarity, - Codec8bit, - QuantizerTemplateScaling::UNIFORM>( - sq, quantizer, store_pairs, sel, r); - case ScalarQuantizer::QT_4bit_uniform: - return sel12_InvertedListScanner< - Similarity, - Codec4bit, - QuantizerTemplateScaling::UNIFORM>( - sq, quantizer, store_pairs, sel, r); - case ScalarQuantizer::QT_8bit: - return sel12_InvertedListScanner< - Similarity, - Codec8bit, - QuantizerTemplateScaling::NON_UNIFORM>( - sq, quantizer, store_pairs, sel, r); - case ScalarQuantizer::QT_4bit: - return sel12_InvertedListScanner< - Similarity, - Codec4bit, - QuantizerTemplateScaling::NON_UNIFORM>( - sq, quantizer, store_pairs, sel, r); - case ScalarQuantizer::QT_6bit: - return sel12_InvertedListScanner< - Similarity, - Codec6bit, - QuantizerTemplateScaling::NON_UNIFORM>( - sq, quantizer, store_pairs, sel, r); - case ScalarQuantizer::QT_fp16: - return sel2_InvertedListScanner, - Similarity, - SIMDWIDTH>>(sq, quantizer, store_pairs, sel, r); - case ScalarQuantizer::QT_bf16: - return sel2_InvertedListScanner, - Similarity, - SIMDWIDTH>>(sq, quantizer, store_pairs, sel, r); - case ScalarQuantizer::QT_8bit_direct: -#if defined(__AVX512F__) - if (sq->d % 32 == 0) { - return sel2_InvertedListScanner< - DistanceComputerByte>( - sq, quantizer, store_pairs, sel, r); - } else -#elif defined(__AVX2__) - if (sq->d % 16 == 0) { - return sel2_InvertedListScanner< - DistanceComputerByte>( - sq, quantizer, store_pairs, sel, r); - } else -#endif - { - return sel2_InvertedListScanner, - Similarity, - SIMDWIDTH>>(sq, quantizer, store_pairs, sel, r); - } - case ScalarQuantizer::QT_8bit_direct_signed: - return sel2_InvertedListScanner, - Similarity, - SIMDWIDTH>>(sq, quantizer, store_pairs, sel, r); - } - - FAISS_THROW_MSG("unknown qtype"); - return nullptr; -} - -template -InvertedListScanner* sel0_InvertedListScanner( - MetricType mt, - const ScalarQuantizer* sq, - const Index* quantizer, - bool store_pairs, - const IDSelector* sel, - bool by_residual) { - if (mt == METRIC_L2) { - return sel1_InvertedListScanner>( - sq, quantizer, store_pairs, sel, by_residual); - } else if (mt == METRIC_INNER_PRODUCT) { - return sel1_InvertedListScanner>( - sq, quantizer, store_pairs, sel, by_residual); - } else { - FAISS_THROW_MSG("unsupported metric type"); - } -} - -} // anonymous namespace - InvertedListScanner* ScalarQuantizer::select_InvertedListScanner( MetricType mt, const Index* quantizer, bool store_pairs, const IDSelector* sel, bool by_residual) const { -#if defined(USE_AVX512_F16C) - if (d % 16 == 0) { - return sel0_InvertedListScanner<16>( +#ifdef COMPILE_SIMD_AVX512 + if (d % 16 == 0 && SIMDConfig::level == SIMDLevel::AVX512) { + return sel0_InvertedListScanner( mt, this, quantizer, store_pairs, sel, by_residual); } else -#elif defined(USE_F16C) || defined(USE_NEON) - if (d % 8 == 0) { - return sel0_InvertedListScanner<8>( +#endif +#ifdef COMPILE_SIMD_AVX2 + if (d % 8 == 0 && SIMDConfig::level == SIMDLevel::AVX2) { + return sel0_InvertedListScanner( mt, this, quantizer, store_pairs, sel, by_residual); } else #endif - { - return sel0_InvertedListScanner<1>( + return sel0_InvertedListScanner( mt, this, quantizer, store_pairs, sel, by_residual); - } } } // namespace faiss diff --git a/faiss/impl/code_distance/code_distance-avx512.cpp b/faiss/impl/code_distance/code_distance-avx512.cpp index aa16b1c4b8..bff3a72968 100644 --- a/faiss/impl/code_distance/code_distance-avx512.cpp +++ b/faiss/impl/code_distance/code_distance-avx512.cpp @@ -192,7 +192,7 @@ struct PQCodeDistance { }; // explicit template instanciations -// template struct PQCodeDistance; +// template struct PQCodeDistance; // these two will automatically use the generic implementation template struct PQCodeDistance; diff --git a/faiss/impl/scalar_quantizer/codecs.h b/faiss/impl/scalar_quantizer/codecs.h index 31c75bc632..b5c20d464b 100644 --- a/faiss/impl/scalar_quantizer/codecs.h +++ b/faiss/impl/scalar_quantizer/codecs.h @@ -8,6 +8,7 @@ #pragma once #include +#include namespace faiss { @@ -19,7 +20,17 @@ namespace scalar_quantizer { * index). */ -struct Codec8bit { +template +struct Codec8bit {}; + +template +struct Codec4bit {}; + +template +struct Codec6bit {}; + +template <> +struct Codec8bit { static FAISS_ALWAYS_INLINE void encode_component( float x, uint8_t* code, @@ -32,45 +43,9 @@ struct Codec8bit { int i) { return (code[i] + 0.5f) / 255.0f; } - -#if defined(__AVX512F__) - static FAISS_ALWAYS_INLINE simd16float32 - decode_16_components(const uint8_t* code, int i) { - const __m128i c16 = _mm_loadu_si128((__m128i*)(code + i)); - const __m512i i32 = _mm512_cvtepu8_epi32(c16); - const __m512 f16 = _mm512_cvtepi32_ps(i32); - const __m512 half_one_255 = _mm512_set1_ps(0.5f / 255.f); - const __m512 one_255 = _mm512_set1_ps(1.f / 255.f); - return simd16float32(_mm512_fmadd_ps(f16, one_255, half_one_255)); - } -#elif defined(__AVX2__) - static FAISS_ALWAYS_INLINE simd8float32 - decode_8_components(const uint8_t* code, int i) { - const uint64_t c8 = *(uint64_t*)(code + i); - - const __m128i i8 = _mm_set1_epi64x(c8); - const __m256i i32 = _mm256_cvtepu8_epi32(i8); - const __m256 f8 = _mm256_cvtepi32_ps(i32); - const __m256 half_one_255 = _mm256_set1_ps(0.5f / 255.f); - const __m256 one_255 = _mm256_set1_ps(1.f / 255.f); - return simd8float32(_mm256_fmadd_ps(f8, one_255, half_one_255)); - } -#endif - -#ifdef USE_NEON - static FAISS_ALWAYS_INLINE decode_8_components(const uint8_t* code, int i) { - float32_t result[8] = {}; - for (size_t j = 0; j < 8; j++) { - result[j] = decode_component(code, i + j); - } - float32x4_t res1 = vld1q_f32(result); - float32x4_t res2 = vld1q_f32(result + 4); - return simd8float32(float32x4x2_t{res1, res2}); - } -#endif }; - -struct Codec4bit { +template <> +struct Codec4bit { static FAISS_ALWAYS_INLINE void encode_component( float x, uint8_t* code, @@ -83,64 +58,10 @@ struct Codec4bit { int i) { return (((code[i / 2] >> ((i & 1) << 2)) & 0xf) + 0.5f) / 15.0f; } - -#if defined(__AVX512F__) - static FAISS_ALWAYS_INLINE simd16float32 - decode_16_components(const uint8_t* code, int i) { - uint64_t c8 = *(uint64_t*)(code + (i >> 1)); - uint64_t mask = 0x0f0f0f0f0f0f0f0f; - uint64_t c8ev = c8 & mask; - uint64_t c8od = (c8 >> 4) & mask; - - __m128i c16 = - _mm_unpacklo_epi8(_mm_set1_epi64x(c8ev), _mm_set1_epi64x(c8od)); - __m256i c8lo = _mm256_cvtepu8_epi32(c16); - __m256i c8hi = _mm256_cvtepu8_epi32(_mm_srli_si128(c16, 8)); - __m512i i16 = _mm512_castsi256_si512(c8lo); - i16 = _mm512_inserti32x8(i16, c8hi, 1); - __m512 f16 = _mm512_cvtepi32_ps(i16); - const __m512 half_one_255 = _mm512_set1_ps(0.5f / 15.f); - const __m512 one_255 = _mm512_set1_ps(1.f / 15.f); - return simd16float32(_mm512_fmadd_ps(f16, one_255, half_one_255)); - } -#elif defined(__AVX2__) - static FAISS_ALWAYS_INLINE simd8float32 - decode_8_components(const uint8_t* code, int i) { - uint32_t c4 = *(uint32_t*)(code + (i >> 1)); - uint32_t mask = 0x0f0f0f0f; - uint32_t c4ev = c4 & mask; - uint32_t c4od = (c4 >> 4) & mask; - - // the 8 lower bytes of c8 contain the values - __m128i c8 = - _mm_unpacklo_epi8(_mm_set1_epi32(c4ev), _mm_set1_epi32(c4od)); - __m128i c4lo = _mm_cvtepu8_epi32(c8); - __m128i c4hi = _mm_cvtepu8_epi32(_mm_srli_si128(c8, 4)); - __m256i i8 = _mm256_castsi128_si256(c4lo); - i8 = _mm256_insertf128_si256(i8, c4hi, 1); - __m256 f8 = _mm256_cvtepi32_ps(i8); - __m256 half = _mm256_set1_ps(0.5f); - f8 = _mm256_add_ps(f8, half); - __m256 one_255 = _mm256_set1_ps(1.f / 15.f); - return simd8float32(_mm256_mul_ps(f8, one_255)); - } -#endif - -#ifdef USE_NEON - static FAISS_ALWAYS_INLINE simd8float32 - decode_8_components(const uint8_t* code, int i) { - float32_t result[8] = {}; - for (size_t j = 0; j < 8; j++) { - result[j] = decode_component(code, i + j); - } - float32x4_t res1 = vld1q_f32(result); - float32x4_t res2 = vld1q_f32(result + 4); - return simd8float32({res1, res2}); - } -#endif }; -struct Codec6bit { +template <> +struct Codec6bit { static FAISS_ALWAYS_INLINE void encode_component( float x, uint8_t* code, @@ -188,117 +109,6 @@ struct Codec6bit { } return (bits + 0.5f) / 63.0f; } - -#if defined(__AVX512F__) - - static FAISS_ALWAYS_INLINE simd16float32 - decode_16_components(const uint8_t* code, int i) { - // pure AVX512 implementation (not necessarily the fastest). - // see: - // https://github.com/zilliztech/knowhere/blob/main/thirdparty/faiss/faiss/impl/ScalarQuantizerCodec_avx512.h - - // clang-format off - - // 16 components, 16x6 bit=12 bytes - const __m128i bit_6v = - _mm_maskz_loadu_epi8(0b0000111111111111, code + (i >> 2) * 3); - const __m256i bit_6v_256 = _mm256_broadcast_i32x4(bit_6v); - - // 00 01 02 03 04 05 06 07 08 09 0A 0B 0C 0D 0E 0F - // 00 01 02 03 - const __m256i shuffle_mask = _mm256_setr_epi16( - 0xFF00, 0x0100, 0x0201, 0xFF02, - 0xFF03, 0x0403, 0x0504, 0xFF05, - 0xFF06, 0x0706, 0x0807, 0xFF08, - 0xFF09, 0x0A09, 0x0B0A, 0xFF0B); - const __m256i shuffled = _mm256_shuffle_epi8(bit_6v_256, shuffle_mask); - - // 0: xxxxxxxx xx543210 - // 1: xxxx5432 10xxxxxx - // 2: xxxxxx54 3210xxxx - // 3: xxxxxxxx 543210xx - const __m256i shift_right_v = _mm256_setr_epi16( - 0x0U, 0x6U, 0x4U, 0x2U, - 0x0U, 0x6U, 0x4U, 0x2U, - 0x0U, 0x6U, 0x4U, 0x2U, - 0x0U, 0x6U, 0x4U, 0x2U); - __m256i shuffled_shifted = _mm256_srlv_epi16(shuffled, shift_right_v); - - // remove unneeded bits - shuffled_shifted = - _mm256_and_si256(shuffled_shifted, _mm256_set1_epi16(0x003F)); - - // scale - const __m512 f8 = - _mm512_cvtepi32_ps(_mm512_cvtepi16_epi32(shuffled_shifted)); - const __m512 half_one_255 = _mm512_set1_ps(0.5f / 63.f); - const __m512 one_255 = _mm512_set1_ps(1.f / 63.f); - return simd16float32(_mm512_fmadd_ps(f8, one_255, half_one_255)); - - // clang-format on - } - -#elif defined(__AVX2__) - - /* Load 6 bytes that represent 8 6-bit values, return them as a - * 8*32 bit vector register */ - static FAISS_ALWAYS_INLINE __m256i load6(const uint16_t* code16) { - const __m128i perm = _mm_set_epi8( - -1, 5, 5, 4, 4, 3, -1, 3, -1, 2, 2, 1, 1, 0, -1, 0); - const __m256i shifts = _mm256_set_epi32(2, 4, 6, 0, 2, 4, 6, 0); - - // load 6 bytes - __m128i c1 = - _mm_set_epi16(0, 0, 0, 0, 0, code16[2], code16[1], code16[0]); - - // put in 8 * 32 bits - __m128i c2 = _mm_shuffle_epi8(c1, perm); - __m256i c3 = _mm256_cvtepi16_epi32(c2); - - // shift and mask out useless bits - __m256i c4 = _mm256_srlv_epi32(c3, shifts); - __m256i c5 = _mm256_and_si256(_mm256_set1_epi32(63), c4); - return c5; - } - - static FAISS_ALWAYS_INLINE simd8float32 - decode_8_components(const uint8_t* code, int i) { - // // Faster code for Intel CPUs or AMD Zen3+, just keeping it here - // // for the reference, maybe, it becomes used oned day. - // const uint16_t* data16 = (const uint16_t*)(code + (i >> 2) * 3); - // const uint32_t* data32 = (const uint32_t*)data16; - // const uint64_t val = *data32 + ((uint64_t)data16[2] << 32); - // const uint64_t vext = _pdep_u64(val, 0x3F3F3F3F3F3F3F3FULL); - // const __m128i i8 = _mm_set1_epi64x(vext); - // const __m256i i32 = _mm256_cvtepi8_epi32(i8); - // const __m256 f8 = _mm256_cvtepi32_ps(i32); - // const __m256 half_one_255 = _mm256_set1_ps(0.5f / 63.f); - // const __m256 one_255 = _mm256_set1_ps(1.f / 63.f); - // return _mm256_fmadd_ps(f8, one_255, half_one_255); - - __m256i i8 = load6((const uint16_t*)(code + (i >> 2) * 3)); - __m256 f8 = _mm256_cvtepi32_ps(i8); - // this could also be done with bit manipulations but it is - // not obviously faster - const __m256 half_one_255 = _mm256_set1_ps(0.5f / 63.f); - const __m256 one_255 = _mm256_set1_ps(1.f / 63.f); - return simd8float32(_mm256_fmadd_ps(f8, one_255, half_one_255)); - } - -#endif - -#ifdef USE_NEON - static FAISS_ALWAYS_INLINE simd8float32 - decode_8_components(const uint8_t* code, int i) { - float32_t result[8] = {}; - for (size_t j = 0; j < 8; j++) { - result[j] = decode_component(code, i + j); - } - float32x4_t res1 = vld1q_f32(result); - float32x4_t res2 = vld1q_f32(result + 4); - return simd8float32(float32x4x2_t({res1, res2})); - } -#endif }; } // namespace scalar_quantizer diff --git a/faiss/impl/scalar_quantizer/distance_computers.h b/faiss/impl/scalar_quantizer/distance_computers.h index 96de493204..7698c265da 100644 --- a/faiss/impl/scalar_quantizer/distance_computers.h +++ b/faiss/impl/scalar_quantizer/distance_computers.h @@ -8,113 +8,98 @@ #pragma once #include +#include +#include +#include namespace faiss { namespace scalar_quantizer { -using SQDistanceComputer = ScalarQuantizer::SQDistanceComputer; +/******************************************************************* + * Similarities: accumulates the element-wise similarities + *******************************************************************/ -template -struct DCTemplate : SQDistanceComputer {}; +template +struct SimilarityL2 {}; -template -struct DCTemplate : SQDistanceComputer { - using Sim = Similarity; +template +struct SimilarityIP {}; - Quantizer quant; +template <> +struct SimilarityL2 { + static constexpr SIMDLevel SIMD_LEVEL = SIMDLevel::NONE; + static constexpr int simdwidth = 1; + static constexpr MetricType metric_type = METRIC_L2; - DCTemplate(size_t d, const std::vector& trained) - : quant(d, trained) {} + const float *y, *yi; - float compute_distance(const float* x, const uint8_t* code) const { - Similarity sim(x); - sim.begin(); - for (size_t i = 0; i < quant.d; i++) { - float xi = quant.reconstruct_component(code, i); - sim.add_component(xi); - } - return sim.result(); - } + explicit SimilarityL2(const float* y) : y(y) {} - float compute_code_distance(const uint8_t* code1, const uint8_t* code2) - const { - Similarity sim(nullptr); - sim.begin(); - for (size_t i = 0; i < quant.d; i++) { - float x1 = quant.reconstruct_component(code1, i); - float x2 = quant.reconstruct_component(code2, i); - sim.add_component_2(x1, x2); - } - return sim.result(); + /******* scalar accumulator *******/ + + float accu; + + FAISS_ALWAYS_INLINE void begin() { + accu = 0; + yi = y; } - void set_query(const float* x) final { - q = x; + FAISS_ALWAYS_INLINE void add_component(float x) { + float tmp = *yi++ - x; + accu += tmp * tmp; } - float symmetric_dis(idx_t i, idx_t j) override { - return compute_code_distance( - codes + i * code_size, codes + j * code_size); + FAISS_ALWAYS_INLINE void add_component_2(float x1, float x2) { + float tmp = x1 - x2; + accu += tmp * tmp; } - float query_to_code(const uint8_t* code) const final { - return compute_distance(q, code); + FAISS_ALWAYS_INLINE float result() { + return accu; } }; -#if defined(USE_AVX512_F16C) +template <> +struct SimilarityIP { + static constexpr int simdwidth = 1; + static constexpr SIMDLevel SIMD_LEVEL = SIMDLevel::NONE; + static constexpr MetricType metric_type = METRIC_INNER_PRODUCT; + const float *y, *yi; -template -struct DCTemplate - : SQDistanceComputer { // Update to handle 16 lanes - using Sim = Similarity; - - Quantizer quant; + float accu; - DCTemplate(size_t d, const std::vector& trained) - : quant(d, trained) {} + explicit SimilarityIP(const float* y) : y(y) {} - float compute_distance(const float* x, const uint8_t* code) const { - Similarity sim(x); - sim.begin_16(); - for (size_t i = 0; i < quant.d; i += 16) { - simd16float32 xi = quant.reconstruct_16_components(code, i); - sim.add_16_components(xi); - } - return sim.result_16(); + FAISS_ALWAYS_INLINE void begin() { + accu = 0; + yi = y; } - float compute_code_distance(const uint8_t* code1, const uint8_t* code2) - const { - Similarity sim(nullptr); - sim.begin_16(); - for (size_t i = 0; i < quant.d; i += 16) { - simd16float32 x1 = quant.reconstruct_16_components(code1, i); - simd16float32 x2 = quant.reconstruct_16_components(code2, i); - sim.add_16_components_2(x1, x2); - } - return sim.result_16(); + FAISS_ALWAYS_INLINE void add_component(float x) { + accu += *yi++ * x; } - void set_query(const float* x) final { - q = x; - } - - float symmetric_dis(idx_t i, idx_t j) override { - return compute_code_distance( - codes + i * code_size, codes + j * code_size); + FAISS_ALWAYS_INLINE void add_component_2(float x1, float x2) { + accu += x1 * x2; } - float query_to_code(const uint8_t* code) const final { - return compute_distance(q, code); + FAISS_ALWAYS_INLINE float result() { + return accu; } }; -#elif defined(USE_F16C) || defined(USE_NEON) +/******************************************************************* + * Distance computers: compute distances between a query and a code + *******************************************************************/ + +using SQDistanceComputer = ScalarQuantizer::SQDistanceComputer; + +template +struct DCTemplate : SQDistanceComputer {}; template -struct DCTemplate : SQDistanceComputer { +struct DCTemplate : SQDistanceComputer { using Sim = Similarity; Quantizer quant; @@ -124,24 +109,24 @@ struct DCTemplate : SQDistanceComputer { float compute_distance(const float* x, const uint8_t* code) const { Similarity sim(x); - sim.begin_8(); - for (size_t i = 0; i < quant.d; i += 8) { - simd8float32 xi = quant.reconstruct_8_components(code, i); - sim.add_8_components(xi); + sim.begin(); + for (size_t i = 0; i < quant.d; i++) { + float xi = quant.reconstruct_component(code, i); + sim.add_component(xi); } - return sim.result_8(); + return sim.result(); } float compute_code_distance(const uint8_t* code1, const uint8_t* code2) const { Similarity sim(nullptr); - sim.begin_8(); - for (size_t i = 0; i < quant.d; i += 8) { - simd8float32 x1 = quant.reconstruct_8_components(code1, i); - simd8float32 x2 = quant.reconstruct_8_components(code2, i); - sim.add_8_components_2(x1, x2); + sim.begin(); + for (size_t i = 0; i < quant.d; i++) { + float x1 = quant.reconstruct_component(code1, i); + float x2 = quant.reconstruct_component(code2, i); + sim.add_component_2(x1, x2); } - return sim.result_8(); + return sim.result(); } void set_query(const float* x) final { @@ -158,17 +143,15 @@ struct DCTemplate : SQDistanceComputer { } }; -#endif - /******************************************************************* * DistanceComputerByte: computes distances in the integer domain *******************************************************************/ -template +template struct DistanceComputerByte : SQDistanceComputer {}; template -struct DistanceComputerByte : SQDistanceComputer { +struct DistanceComputerByte : SQDistanceComputer { using Sim = Similarity; int d; @@ -211,171 +194,84 @@ struct DistanceComputerByte : SQDistanceComputer { } }; -#if defined(__AVX512F__) - -template -struct DistanceComputerByte : SQDistanceComputer { - using Sim = Similarity; - - int d; - std::vector tmp; - - DistanceComputerByte(int d, const std::vector&) : d(d), tmp(d) {} - - int compute_code_distance(const uint8_t* code1, const uint8_t* code2) - const { - __m512i accu = _mm512_setzero_si512(); - for (int i = 0; i < d; i += 32) { // Process 32 bytes at a time - __m512i c1 = _mm512_cvtepu8_epi16( - _mm256_loadu_si256((__m256i*)(code1 + i))); - __m512i c2 = _mm512_cvtepu8_epi16( - _mm256_loadu_si256((__m256i*)(code2 + i))); - __m512i prod32; - if (Sim::metric_type == METRIC_INNER_PRODUCT) { - prod32 = _mm512_madd_epi16(c1, c2); - } else { - __m512i diff = _mm512_sub_epi16(c1, c2); - prod32 = _mm512_madd_epi16(diff, diff); - } - accu = _mm512_add_epi32(accu, prod32); - } - // Horizontally add elements of accu - return _mm512_reduce_add_epi32(accu); - } - - void set_query(const float* x) final { - for (int i = 0; i < d; i++) { - tmp[i] = int(x[i]); - } - } - - int compute_distance(const float* x, const uint8_t* code) { - set_query(x); - return compute_code_distance(tmp.data(), code); - } - - float symmetric_dis(idx_t i, idx_t j) override { - return compute_code_distance( - codes + i * code_size, codes + j * code_size); - } - - float query_to_code(const uint8_t* code) const final { - return compute_code_distance(tmp.data(), code); - } -}; - -#elif defined(__AVX2__) - -template -struct DistanceComputerByte : SQDistanceComputer { - using Sim = Similarity; - - int d; - std::vector tmp; - - DistanceComputerByte(int d, const std::vector&) : d(d), tmp(d) {} - - int compute_code_distance(const uint8_t* code1, const uint8_t* code2) - const { - // __m256i accu = _mm256_setzero_ps (); - __m256i accu = _mm256_setzero_si256(); - for (int i = 0; i < d; i += 16) { - // load 16 bytes, convert to 16 uint16_t - __m256i c1 = _mm256_cvtepu8_epi16( - _mm_loadu_si128((__m128i*)(code1 + i))); - __m256i c2 = _mm256_cvtepu8_epi16( - _mm_loadu_si128((__m128i*)(code2 + i))); - __m256i prod32; - if (Sim::metric_type == METRIC_INNER_PRODUCT) { - prod32 = _mm256_madd_epi16(c1, c2); - } else { - __m256i diff = _mm256_sub_epi16(c1, c2); - prod32 = _mm256_madd_epi16(diff, diff); - } - accu = _mm256_add_epi32(accu, prod32); - } - __m128i sum = _mm256_extractf128_si256(accu, 0); - sum = _mm_add_epi32(sum, _mm256_extractf128_si256(accu, 1)); - sum = _mm_hadd_epi32(sum, sum); - sum = _mm_hadd_epi32(sum, sum); - return _mm_cvtsi128_si32(sum); - } - - void set_query(const float* x) final { - /* - for (int i = 0; i < d; i += 8) { - __m256 xi = _mm256_loadu_ps (x + i); - __m256i ci = _mm256_cvtps_epi32(xi); - */ - for (int i = 0; i < d; i++) { - tmp[i] = int(x[i]); - } - } - - int compute_distance(const float* x, const uint8_t* code) { - set_query(x); - return compute_code_distance(tmp.data(), code); - } - - float symmetric_dis(idx_t i, idx_t j) override { - return compute_code_distance( - codes + i * code_size, codes + j * code_size); - } - - float query_to_code(const uint8_t* code) const final { - return compute_code_distance(tmp.data(), code); - } -}; - -#endif - -#ifdef USE_NEON - -template -struct DistanceComputerByte : SQDistanceComputer { - using Sim = Similarity; - - int d; - std::vector tmp; - - DistanceComputerByte(int d, const std::vector&) : d(d), tmp(d) {} - - int compute_code_distance(const uint8_t* code1, const uint8_t* code2) - const { - int accu = 0; - for (int i = 0; i < d; i++) { - if (Sim::metric_type == METRIC_INNER_PRODUCT) { - accu += int(code1[i]) * code2[i]; - } else { - int diff = int(code1[i]) - code2[i]; - accu += diff * diff; - } - } - return accu; - } - - void set_query(const float* x) final { - for (int i = 0; i < d; i++) { - tmp[i] = int(x[i]); - } - } - - int compute_distance(const float* x, const uint8_t* code) { - set_query(x); - return compute_code_distance(tmp.data(), code); - } - - float symmetric_dis(idx_t i, idx_t j) override { - return compute_code_distance( - codes + i * code_size, codes + j * code_size); - } - - float query_to_code(const uint8_t* code) const final { - return compute_code_distance(tmp.data(), code); - } -}; +/******************************************************************* + * select_distance_computer: runtime selection of template + * specialization + *******************************************************************/ -#endif +template +SQDistanceComputer* select_distance_computer( + QuantizerType qtype, + size_t d, + const std::vector& trained) { + constexpr SIMDLevel SL = Sim::SIMD_LEVEL; + constexpr QScaling NU = QScaling::NON_UNIFORM; + constexpr QScaling U = QScaling::UNIFORM; + switch (qtype) { + case ScalarQuantizer::QT_8bit_uniform: + return new DCTemplate, U, SL>, Sim, SL>( + d, trained); + + case ScalarQuantizer::QT_4bit_uniform: + return new DCTemplate, U, SL>, Sim, SL>( + d, trained); + + case ScalarQuantizer::QT_8bit: + return new DCTemplate, NU, SL>, Sim, SL>( + d, trained); + + case ScalarQuantizer::QT_6bit: + return new DCTemplate, NU, SL>, Sim, SL>( + d, trained); + + case ScalarQuantizer::QT_4bit: + return new DCTemplate, NU, SL>, Sim, SL>( + d, trained); + + case ScalarQuantizer::QT_fp16: + return new DCTemplate, Sim, SL>(d, trained); + + case ScalarQuantizer::QT_bf16: + return new DCTemplate, Sim, SL>(d, trained); + + case ScalarQuantizer::QT_8bit_direct: + return new DCTemplate, Sim, SL>(d, trained); + case ScalarQuantizer::QT_8bit_direct_signed: + return new DCTemplate, Sim, SL>( + d, trained); + } + FAISS_THROW_MSG("unknown qtype"); + return nullptr; +} + +template +SQDistanceComputer* select_distance_computer_1( + MetricType metric_type, + QuantizerType qtype, + size_t d, + const std::vector& trained) { + if (metric_type == METRIC_L2) { + return select_distance_computer>(qtype, d, trained); + } else if (metric_type == METRIC_INNER_PRODUCT) { + return select_distance_computer>(qtype, d, trained); + } else { + FAISS_THROW_MSG("unsuppored metric type"); + } +} + +// prevent implicit instantiation of the template +extern template SQDistanceComputer* select_distance_computer_1( + MetricType metric_type, + QuantizerType qtype, + size_t d, + const std::vector& trained); + +extern template SQDistanceComputer* select_distance_computer_1< + SIMDLevel::AVX512>( + MetricType metric_type, + QuantizerType qtype, + size_t d, + const std::vector& trained); } // namespace scalar_quantizer } // namespace faiss diff --git a/faiss/impl/scalar_quantizer/impl-avx2.cpp b/faiss/impl/scalar_quantizer/impl-avx2.cpp new file mode 100644 index 0000000000..42f2e7a6e3 --- /dev/null +++ b/faiss/impl/scalar_quantizer/impl-avx2.cpp @@ -0,0 +1,437 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * + * This source code is licensed under the MIT license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include + +#ifdef COMPILE_SIMD_AVX2 + +#include +#include +#include +#include +#include + +#include + +#include +#include +#include + +#ifdef __F16C__ +#define USE_F16C +#else +#warning \ + "Cannot enable AVX optimizations in scalar quantizer if -mf16c is not set as well" +#endif + +namespace faiss { + +namespace scalar_quantizer { + +/****************************************** Specialization of codecs */ + +template <> +struct Codec8bit : Codec8bit { + static FAISS_ALWAYS_INLINE simd8float32 + decode_8_components(const uint8_t* code, int i) { + const uint64_t c8 = *(uint64_t*)(code + i); + + const __m128i i8 = _mm_set1_epi64x(c8); + const __m256i i32 = _mm256_cvtepu8_epi32(i8); + const __m256 f8 = _mm256_cvtepi32_ps(i32); + const __m256 half_one_255 = _mm256_set1_ps(0.5f / 255.f); + const __m256 one_255 = _mm256_set1_ps(1.f / 255.f); + return simd8float32(_mm256_fmadd_ps(f8, one_255, half_one_255)); + } +}; + +template <> +struct Codec4bit : Codec4bit { + static FAISS_ALWAYS_INLINE simd8float32 + decode_8_components(const uint8_t* code, int i) { + uint32_t c4 = *(uint32_t*)(code + (i >> 1)); + uint32_t mask = 0x0f0f0f0f; + uint32_t c4ev = c4 & mask; + uint32_t c4od = (c4 >> 4) & mask; + + // the 8 lower bytes of c8 contain the values + __m128i c8 = + _mm_unpacklo_epi8(_mm_set1_epi32(c4ev), _mm_set1_epi32(c4od)); + __m128i c4lo = _mm_cvtepu8_epi32(c8); + __m128i c4hi = _mm_cvtepu8_epi32(_mm_srli_si128(c8, 4)); + __m256i i8 = _mm256_castsi128_si256(c4lo); + i8 = _mm256_insertf128_si256(i8, c4hi, 1); + __m256 f8 = _mm256_cvtepi32_ps(i8); + __m256 half = _mm256_set1_ps(0.5f); + f8 = _mm256_add_ps(f8, half); + __m256 one_255 = _mm256_set1_ps(1.f / 15.f); + return simd8float32(_mm256_mul_ps(f8, one_255)); + } +}; + +template <> +struct Codec6bit : Codec6bit { + /* Load 6 bytes that represent 8 6-bit values, return them as a + * 8*32 bit vector register */ + static FAISS_ALWAYS_INLINE __m256i load6(const uint16_t* code16) { + const __m128i perm = _mm_set_epi8( + -1, 5, 5, 4, 4, 3, -1, 3, -1, 2, 2, 1, 1, 0, -1, 0); + const __m256i shifts = _mm256_set_epi32(2, 4, 6, 0, 2, 4, 6, 0); + + // load 6 bytes + __m128i c1 = + _mm_set_epi16(0, 0, 0, 0, 0, code16[2], code16[1], code16[0]); + + // put in 8 * 32 bits + __m128i c2 = _mm_shuffle_epi8(c1, perm); + __m256i c3 = _mm256_cvtepi16_epi32(c2); + + // shift and mask out useless bits + __m256i c4 = _mm256_srlv_epi32(c3, shifts); + __m256i c5 = _mm256_and_si256(_mm256_set1_epi32(63), c4); + return c5; + } + + static FAISS_ALWAYS_INLINE simd8float32 + decode_8_components(const uint8_t* code, int i) { + // // Faster code for Intel CPUs or AMD Zen3+, just keeping it here + // // for the reference, maybe, it becomes used oned day. + // const uint16_t* data16 = (const uint16_t*)(code + (i >> 2) * 3); + // const uint32_t* data32 = (const uint32_t*)data16; + // const uint64_t val = *data32 + ((uint64_t)data16[2] << 32); + // const uint64_t vext = _pdep_u64(val, 0x3F3F3F3F3F3F3F3FULL); + // const __m128i i8 = _mm_set1_epi64x(vext); + // const __m256i i32 = _mm256_cvtepi8_epi32(i8); + // const __m256 f8 = _mm256_cvtepi32_ps(i32); + // const __m256 half_one_255 = _mm256_set1_ps(0.5f / 63.f); + // const __m256 one_255 = _mm256_set1_ps(1.f / 63.f); + // return _mm256_fmadd_ps(f8, one_255, half_one_255); + + __m256i i8 = load6((const uint16_t*)(code + (i >> 2) * 3)); + __m256 f8 = _mm256_cvtepi32_ps(i8); + // this could also be done with bit manipulations but it is + // not obviously faster + const __m256 half_one_255 = _mm256_set1_ps(0.5f / 63.f); + const __m256 one_255 = _mm256_set1_ps(1.f / 63.f); + return simd8float32(_mm256_fmadd_ps(f8, one_255, half_one_255)); + } +}; + +/****************************************** Specialization of quantizers */ + +template +struct QuantizerT + : QuantizerT { + QuantizerT(size_t d, const std::vector& trained) + : QuantizerT( + d, + trained) {} + + FAISS_ALWAYS_INLINE simd8float32 + reconstruct_8_components(const uint8_t* code, int i) const { + __m256 xi = Codec::decode_8_components(code, i).f; + return simd8float32(_mm256_fmadd_ps( + xi, _mm256_set1_ps(this->vdiff), _mm256_set1_ps(this->vmin))); + } +}; + +template +struct QuantizerT + : QuantizerT { + QuantizerT(size_t d, const std::vector& trained) + : QuantizerT( + d, + trained) {} + + FAISS_ALWAYS_INLINE simd8float32 + reconstruct_8_components(const uint8_t* code, int i) const { + __m256 xi = Codec::decode_8_components(code, i).f; + return simd8float32(_mm256_fmadd_ps( + xi, + _mm256_loadu_ps(this->vdiff + i), + _mm256_loadu_ps(this->vmin + i))); + } +}; + +template <> +struct QuantizerFP16 : QuantizerFP16 { + QuantizerFP16(size_t d, const std::vector& trained) + : QuantizerFP16(d, trained) {} + + FAISS_ALWAYS_INLINE simd8float32 + reconstruct_8_components(const uint8_t* code, int i) const { + __m128i codei = _mm_loadu_si128((const __m128i*)(code + 2 * i)); + return simd8float32(_mm256_cvtph_ps(codei)); + } +}; + +template <> +struct QuantizerBF16 : QuantizerBF16 { + QuantizerBF16(size_t d, const std::vector& trained) + : QuantizerBF16(d, trained) {} + + FAISS_ALWAYS_INLINE simd8float32 + reconstruct_8_components(const uint8_t* code, int i) const { + __m128i code_128i = _mm_loadu_si128((const __m128i*)(code + 2 * i)); + __m256i code_256i = _mm256_cvtepu16_epi32(code_128i); + code_256i = _mm256_slli_epi32(code_256i, 16); + return simd8float32(_mm256_castsi256_ps(code_256i)); + } +}; + +template <> +struct Quantizer8bitDirect + : Quantizer8bitDirect { + Quantizer8bitDirect(size_t d, const std::vector& trained) + : Quantizer8bitDirect(d, trained) {} + + FAISS_ALWAYS_INLINE simd8float32 + reconstruct_8_components(const uint8_t* code, int i) const { + __m128i x8 = _mm_loadl_epi64((__m128i*)(code + i)); // 8 * int8 + __m256i y8 = _mm256_cvtepu8_epi32(x8); // 8 * int32 + return simd8float32(_mm256_cvtepi32_ps(y8)); // 8 * float32 + } +}; + +template <> +struct Quantizer8bitDirectSigned + : Quantizer8bitDirectSigned { + Quantizer8bitDirectSigned(size_t d, const std::vector& trained) + : Quantizer8bitDirectSigned(d, trained) {} + + FAISS_ALWAYS_INLINE simd8float32 + reconstruct_8_components(const uint8_t* code, int i) const { + __m128i x8 = _mm_loadl_epi64((__m128i*)(code + i)); // 8 * int8 + __m256i y8 = _mm256_cvtepu8_epi32(x8); // 8 * int32 + __m256i c8 = _mm256_set1_epi32(128); + __m256i z8 = _mm256_sub_epi32(y8, c8); // subtract 128 from all lanes + return simd8float32(_mm256_cvtepi32_ps(z8)); // 8 * float32 + } +}; + +/****************************************** Specialization of similarities */ + +template <> +struct SimilarityL2 { + static constexpr SIMDLevel SIMD_LEVEL = SIMDLevel::AVX2; + static constexpr int simdwidth = 8; + static constexpr MetricType metric_type = METRIC_L2; + + const float *y, *yi; + + explicit SimilarityL2(const float* y) : y(y) {} + simd8float32 accu8; + + FAISS_ALWAYS_INLINE void begin_8() { + accu8.clear(); + yi = y; + } + + FAISS_ALWAYS_INLINE void add_8_components(simd8float32 x) { + __m256 yiv = _mm256_loadu_ps(yi); + yi += 8; + __m256 tmp = _mm256_sub_ps(yiv, x.f); + accu8 = simd8float32(_mm256_fmadd_ps(tmp, tmp, accu8.f)); + } + + FAISS_ALWAYS_INLINE void add_8_components_2( + simd8float32 x, + simd8float32 y_2) { + __m256 tmp = _mm256_sub_ps(y_2.f, x.f); + accu8 = simd8float32(_mm256_fmadd_ps(tmp, tmp, accu8.f)); + } + + FAISS_ALWAYS_INLINE float result_8() { + const __m128 sum = _mm_add_ps( + _mm256_castps256_ps128(accu8.f), + _mm256_extractf128_ps(accu8.f, 1)); + const __m128 v0 = _mm_shuffle_ps(sum, sum, _MM_SHUFFLE(0, 0, 3, 2)); + const __m128 v1 = _mm_add_ps(sum, v0); + __m128 v2 = _mm_shuffle_ps(v1, v1, _MM_SHUFFLE(0, 0, 0, 1)); + const __m128 v3 = _mm_add_ps(v1, v2); + return _mm_cvtss_f32(v3); + } +}; + +template <> +struct SimilarityIP { + static constexpr SIMDLevel SIMD_LEVEL = SIMDLevel::AVX2; + static constexpr int simdwidth = 8; + static constexpr MetricType metric_type = METRIC_INNER_PRODUCT; + + const float *y, *yi; + + float accu; + + explicit SimilarityIP(const float* y) : y(y) {} + + simd8float32 accu8; + + FAISS_ALWAYS_INLINE void begin_8() { + accu8.clear(); + yi = y; + } + + FAISS_ALWAYS_INLINE void add_8_components(simd8float32 x) { + __m256 yiv = _mm256_loadu_ps(yi); + yi += 8; + accu8.f = _mm256_fmadd_ps(yiv, x.f, accu8.f); + } + + FAISS_ALWAYS_INLINE void add_8_components_2( + simd8float32 x1, + simd8float32 x2) { + accu8.f = _mm256_fmadd_ps(x1.f, x2.f, accu8.f); + } + + FAISS_ALWAYS_INLINE float result_8() { + const __m128 sum = _mm_add_ps( + _mm256_castps256_ps128(accu8.f), + _mm256_extractf128_ps(accu8.f, 1)); + const __m128 v0 = _mm_shuffle_ps(sum, sum, _MM_SHUFFLE(0, 0, 3, 2)); + const __m128 v1 = _mm_add_ps(sum, v0); + __m128 v2 = _mm_shuffle_ps(v1, v1, _MM_SHUFFLE(0, 0, 0, 1)); + const __m128 v3 = _mm_add_ps(v1, v2); + return _mm_cvtss_f32(v3); + } +}; + +/****************************************** Specialization of distance computers + */ + +template +struct DCTemplate : SQDistanceComputer { + using Sim = Similarity; + + Quantizer quant; + + DCTemplate(size_t d, const std::vector& trained) + : quant(d, trained) {} + + float compute_distance(const float* x, const uint8_t* code) const { + Similarity sim(x); + sim.begin_8(); + for (size_t i = 0; i < quant.d; i += 8) { + simd8float32 xi = quant.reconstruct_8_components(code, i); + sim.add_8_components(xi); + } + return sim.result_8(); + } + + float compute_code_distance(const uint8_t* code1, const uint8_t* code2) + const { + Similarity sim(nullptr); + sim.begin_8(); + for (size_t i = 0; i < quant.d; i += 8) { + simd8float32 x1 = quant.reconstruct_8_components(code1, i); + simd8float32 x2 = quant.reconstruct_8_components(code2, i); + sim.add_8_components_2(x1, x2); + } + return sim.result_8(); + } + + void set_query(const float* x) final { + q = x; + } + + float symmetric_dis(idx_t i, idx_t j) override { + return compute_code_distance( + codes + i * code_size, codes + j * code_size); + } + + float query_to_code(const uint8_t* code) const final { + return compute_distance(q, code); + } +}; + +template +struct DistanceComputerByte : SQDistanceComputer { + using Sim = Similarity; + + int d; + std::vector tmp; + + DistanceComputerByte(int d, const std::vector&) : d(d), tmp(d) {} + + int compute_code_distance(const uint8_t* code1, const uint8_t* code2) + const { + // __m256i accu = _mm256_setzero_ps (); + __m256i accu = _mm256_setzero_si256(); + for (int i = 0; i < d; i += 16) { + // load 16 bytes, convert to 16 uint16_t + __m256i c1 = _mm256_cvtepu8_epi16( + _mm_loadu_si128((__m128i*)(code1 + i))); + __m256i c2 = _mm256_cvtepu8_epi16( + _mm_loadu_si128((__m128i*)(code2 + i))); + __m256i prod32; + if (Sim::metric_type == METRIC_INNER_PRODUCT) { + prod32 = _mm256_madd_epi16(c1, c2); + } else { + __m256i diff = _mm256_sub_epi16(c1, c2); + prod32 = _mm256_madd_epi16(diff, diff); + } + accu = _mm256_add_epi32(accu, prod32); + } + __m128i sum = _mm256_extractf128_si256(accu, 0); + sum = _mm_add_epi32(sum, _mm256_extractf128_si256(accu, 1)); + sum = _mm_hadd_epi32(sum, sum); + sum = _mm_hadd_epi32(sum, sum); + return _mm_cvtsi128_si32(sum); + } + + void set_query(const float* x) final { + /* + for (int i = 0; i < d; i += 8) { + __m256 xi = _mm256_loadu_ps (x + i); + __m256i ci = _mm256_cvtps_epi32(xi); + */ + for (int i = 0; i < d; i++) { + tmp[i] = int(x[i]); + } + } + + int compute_distance(const float* x, const uint8_t* code) { + set_query(x); + return compute_code_distance(tmp.data(), code); + } + + float symmetric_dis(idx_t i, idx_t j) override { + return compute_code_distance( + codes + i * code_size, codes + j * code_size); + } + + float query_to_code(const uint8_t* code) const final { + return compute_code_distance(tmp.data(), code); + } +}; + +// explicit instantiation + +template ScalarQuantizer::SQuantizer* select_quantizer_1( + QuantizerType qtype, + size_t d, + const std::vector& trained); + +template SQDistanceComputer* select_distance_computer_1( + MetricType metric_type, + QuantizerType qtype, + size_t d, + const std::vector& trained); + +template InvertedListScanner* sel0_InvertedListScanner( + MetricType mt, + const ScalarQuantizer* sq, + const Index* quantizer, + bool store_pairs, + const IDSelector* sel, + bool by_residual); + +} // namespace scalar_quantizer + +} // namespace faiss + +#endif // COMPILE_SIMD_AVX2 diff --git a/faiss/impl/scalar_quantizer/impl-avx512.cpp b/faiss/impl/scalar_quantizer/impl-avx512.cpp new file mode 100644 index 0000000000..b0fe1e9eaa --- /dev/null +++ b/faiss/impl/scalar_quantizer/impl-avx512.cpp @@ -0,0 +1,409 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * + * This source code is licensed under the MIT license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include + +#ifdef COMPILE_SIMD_AVX512 + +#include +#include +#include +#include +#include + +#include + +#if defined(__AVX512F__) && defined(__F16C__) +#define USE_AVX512_F16C +#else +#warning "Wrong compiler flags for AVX512_F16C" +#endif + +namespace faiss { + +namespace scalar_quantizer { + +/******************************** Codec specializations */ + +template <> +struct Codec8bit : Codec8bit { + static FAISS_ALWAYS_INLINE simd16float32 + decode_16_components(const uint8_t* code, int i) { + const __m128i c16 = _mm_loadu_si128((__m128i*)(code + i)); + const __m512i i32 = _mm512_cvtepu8_epi32(c16); + const __m512 f16 = _mm512_cvtepi32_ps(i32); + const __m512 half_one_255 = _mm512_set1_ps(0.5f / 255.f); + const __m512 one_255 = _mm512_set1_ps(1.f / 255.f); + return simd16float32(_mm512_fmadd_ps(f16, one_255, half_one_255)); + } +}; + +template <> +struct Codec4bit : Codec4bit { + static FAISS_ALWAYS_INLINE simd16float32 + decode_16_components(const uint8_t* code, int i) { + uint64_t c8 = *(uint64_t*)(code + (i >> 1)); + uint64_t mask = 0x0f0f0f0f0f0f0f0f; + uint64_t c8ev = c8 & mask; + uint64_t c8od = (c8 >> 4) & mask; + + __m128i c16 = + _mm_unpacklo_epi8(_mm_set1_epi64x(c8ev), _mm_set1_epi64x(c8od)); + __m256i c8lo = _mm256_cvtepu8_epi32(c16); + __m256i c8hi = _mm256_cvtepu8_epi32(_mm_srli_si128(c16, 8)); + __m512i i16 = _mm512_castsi256_si512(c8lo); + i16 = _mm512_inserti32x8(i16, c8hi, 1); + __m512 f16 = _mm512_cvtepi32_ps(i16); + const __m512 half_one_255 = _mm512_set1_ps(0.5f / 15.f); + const __m512 one_255 = _mm512_set1_ps(1.f / 15.f); + return simd16float32(_mm512_fmadd_ps(f16, one_255, half_one_255)); + } +}; + +template <> +struct Codec6bit : Codec6bit { + static FAISS_ALWAYS_INLINE simd16float32 + decode_16_components(const uint8_t* code, int i) { + // pure AVX512 implementation (not necessarily the fastest). + // see: + // https://github.com/zilliztech/knowhere/blob/main/thirdparty/faiss/faiss/impl/ScalarQuantizerCodec_avx512.h + + // clang-format off + + // 16 components, 16x6 bit=12 bytes + const __m128i bit_6v = + _mm_maskz_loadu_epi8(0b0000111111111111, code + (i >> 2) * 3); + const __m256i bit_6v_256 = _mm256_broadcast_i32x4(bit_6v); + + // 00 01 02 03 04 05 06 07 08 09 0A 0B 0C 0D 0E 0F + // 00 01 02 03 + const __m256i shuffle_mask = _mm256_setr_epi16( + 0xFF00, 0x0100, 0x0201, 0xFF02, + 0xFF03, 0x0403, 0x0504, 0xFF05, + 0xFF06, 0x0706, 0x0807, 0xFF08, + 0xFF09, 0x0A09, 0x0B0A, 0xFF0B); + const __m256i shuffled = _mm256_shuffle_epi8(bit_6v_256, shuffle_mask); + + // 0: xxxxxxxx xx543210 + // 1: xxxx5432 10xxxxxx + // 2: xxxxxx54 3210xxxx + // 3: xxxxxxxx 543210xx + const __m256i shift_right_v = _mm256_setr_epi16( + 0x0U, 0x6U, 0x4U, 0x2U, + 0x0U, 0x6U, 0x4U, 0x2U, + 0x0U, 0x6U, 0x4U, 0x2U, + 0x0U, 0x6U, 0x4U, 0x2U); + __m256i shuffled_shifted = _mm256_srlv_epi16(shuffled, shift_right_v); + + // remove unneeded bits + shuffled_shifted = + _mm256_and_si256(shuffled_shifted, _mm256_set1_epi16(0x003F)); + + // scale + const __m512 f8 = + _mm512_cvtepi32_ps(_mm512_cvtepi16_epi32(shuffled_shifted)); + const __m512 half_one_255 = _mm512_set1_ps(0.5f / 63.f); + const __m512 one_255 = _mm512_set1_ps(1.f / 63.f); + return simd16float32(_mm512_fmadd_ps(f8, one_255, half_one_255)); + + // clang-format on + } +}; + +/******************************** Quantizer specializations */ + +template +struct QuantizerT + : QuantizerT { + QuantizerT(size_t d, const std::vector& trained) + : QuantizerT( + d, + trained) {} + + FAISS_ALWAYS_INLINE simd16float32 + reconstruct_16_components(const uint8_t* code, int i) const { + __m512 xi = Codec::decode_16_components(code, i).f; + return simd16float32(_mm512_fmadd_ps( + xi, _mm512_set1_ps(this->vdiff), _mm512_set1_ps(this->vmin))); + } +}; + +template +struct QuantizerT + : QuantizerT { + QuantizerT(size_t d, const std::vector& trained) + : QuantizerT( + d, + trained) {} + + FAISS_ALWAYS_INLINE simd16float32 + reconstruct_16_components(const uint8_t* code, int i) const { + __m512 xi = Codec::decode_16_components(code, i).f; + return simd16float32(_mm512_fmadd_ps( + xi, + _mm512_loadu_ps(this->vdiff + i), + _mm512_loadu_ps(this->vmin + i))); + } +}; + +template <> +struct QuantizerFP16 : QuantizerFP16 { + QuantizerFP16(size_t d, const std::vector& trained) + : QuantizerFP16(d, trained) {} + + FAISS_ALWAYS_INLINE simd16float32 + reconstruct_16_components(const uint8_t* code, int i) const { + __m256i codei = _mm256_loadu_si256((const __m256i*)(code + 2 * i)); + return simd16float32(_mm512_cvtph_ps(codei)); + } +}; + +template <> +struct QuantizerBF16 : QuantizerBF16 { + QuantizerBF16(size_t d, const std::vector& trained) + : QuantizerBF16(d, trained) {} + FAISS_ALWAYS_INLINE simd16float32 + reconstruct_16_components(const uint8_t* code, int i) const { + __m256i code_256i = _mm256_loadu_si256((const __m256i*)(code + 2 * i)); + __m512i code_512i = _mm512_cvtepu16_epi32(code_256i); + code_512i = _mm512_slli_epi32(code_512i, 16); + return simd16float32(_mm512_castsi512_ps(code_512i)); + } +}; + +template <> +struct Quantizer8bitDirect + : Quantizer8bitDirect { + Quantizer8bitDirect(size_t d, const std::vector& trained) + : Quantizer8bitDirect(d, trained) {} + + FAISS_ALWAYS_INLINE simd16float32 + reconstruct_16_components(const uint8_t* code, int i) const { + __m128i x16 = _mm_loadu_si128((__m128i*)(code + i)); // 16 * int8 + __m512i y16 = _mm512_cvtepu8_epi32(x16); // 16 * int32 + return simd16float32(_mm512_cvtepi32_ps(y16)); // 16 * float32 + } +}; + +template <> +struct Quantizer8bitDirectSigned + : Quantizer8bitDirectSigned { + Quantizer8bitDirectSigned(size_t d, const std::vector& trained) + : Quantizer8bitDirectSigned(d, trained) {} + + FAISS_ALWAYS_INLINE simd16float32 + reconstruct_16_components(const uint8_t* code, int i) const { + __m128i x16 = _mm_loadu_si128((__m128i*)(code + i)); // 16 * int8 + __m512i y16 = _mm512_cvtepu8_epi32(x16); // 16 * int32 + __m512i c16 = _mm512_set1_epi32(128); + __m512i z16 = _mm512_sub_epi32(y16, c16); // subtract 128 from all lanes + return simd16float32(_mm512_cvtepi32_ps(z16)); // 16 * float32 + } +}; + +/****************************************** Specialization of similarities */ + +template <> +struct SimilarityL2 { + static constexpr SIMDLevel SIMD_LEVEL = SIMDLevel::AVX512; + static constexpr int simdwidth = 16; + static constexpr MetricType metric_type = METRIC_L2; + + const float *y, *yi; + + explicit SimilarityL2(const float* y) : y(y) {} + simd16float32 accu16; + + FAISS_ALWAYS_INLINE void begin_16() { + accu16.clear(); + yi = y; + } + + FAISS_ALWAYS_INLINE void add_16_components(simd16float32 x) { + __m512 yiv = _mm512_loadu_ps(yi); + yi += 16; + __m512 tmp = _mm512_sub_ps(yiv, x.f); + accu16 = simd16float32(_mm512_fmadd_ps(tmp, tmp, accu16.f)); + } + + FAISS_ALWAYS_INLINE void add_16_components_2( + simd16float32 x, + simd16float32 y_2) { + __m512 tmp = _mm512_sub_ps(y_2.f, x.f); + accu16 = simd16float32(_mm512_fmadd_ps(tmp, tmp, accu16.f)); + } + + FAISS_ALWAYS_INLINE float result_16() { + // performs better than dividing into _mm256 and adding + return _mm512_reduce_add_ps(accu16.f); + } +}; + +template <> +struct SimilarityIP { + static constexpr SIMDLevel SIMD_LEVEL = SIMDLevel::AVX512; + static constexpr int simdwidth = 16; + static constexpr MetricType metric_type = METRIC_INNER_PRODUCT; + + const float *y, *yi; + + float accu; + + explicit SimilarityIP(const float* y) : y(y) {} + + simd16float32 accu16; + + FAISS_ALWAYS_INLINE void begin_16() { + accu16.clear(); + yi = y; + } + + FAISS_ALWAYS_INLINE void add_16_components(simd16float32 x) { + __m512 yiv = _mm512_loadu_ps(yi); + yi += 16; + accu16.f = _mm512_fmadd_ps(yiv, x.f, accu16.f); + } + + FAISS_ALWAYS_INLINE void add_16_components_2( + simd16float32 x1, + simd16float32 x2) { + accu16.f = _mm512_fmadd_ps(x1.f, x2.f, accu16.f); + } + + FAISS_ALWAYS_INLINE float result_16() { + // performs better than dividing into _mm256 and adding + return _mm512_reduce_add_ps(accu16.f); + } +}; + +/****************************************** Specialization of distance computers + */ + +template +struct DCTemplate + : SQDistanceComputer { // Update to handle 16 lanes + using Sim = Similarity; + + Quantizer quant; + + DCTemplate(size_t d, const std::vector& trained) + : quant(d, trained) {} + + float compute_distance(const float* x, const uint8_t* code) const { + Similarity sim(x); + sim.begin_16(); + for (size_t i = 0; i < quant.d; i += 16) { + simd16float32 xi = quant.reconstruct_16_components(code, i); + sim.add_16_components(xi); + } + return sim.result_16(); + } + + float compute_code_distance(const uint8_t* code1, const uint8_t* code2) + const { + Similarity sim(nullptr); + sim.begin_16(); + for (size_t i = 0; i < quant.d; i += 16) { + simd16float32 x1 = quant.reconstruct_16_components(code1, i); + simd16float32 x2 = quant.reconstruct_16_components(code2, i); + sim.add_16_components_2(x1, x2); + } + return sim.result_16(); + } + + void set_query(const float* x) final { + q = x; + } + + float symmetric_dis(idx_t i, idx_t j) override { + return compute_code_distance( + codes + i * code_size, codes + j * code_size); + } + + float query_to_code(const uint8_t* code) const final { + return compute_distance(q, code); + } +}; + +template +struct DistanceComputerByte + : SQDistanceComputer { + using Sim = Similarity; + + int d; + std::vector tmp; + + DistanceComputerByte(int d, const std::vector&) : d(d), tmp(d) {} + + int compute_code_distance(const uint8_t* code1, const uint8_t* code2) + const { + __m512i accu = _mm512_setzero_si512(); + for (int i = 0; i < d; i += 32) { // Process 32 bytes at a time + __m512i c1 = _mm512_cvtepu8_epi16( + _mm256_loadu_si256((__m256i*)(code1 + i))); + __m512i c2 = _mm512_cvtepu8_epi16( + _mm256_loadu_si256((__m256i*)(code2 + i))); + __m512i prod32; + if (Sim::metric_type == METRIC_INNER_PRODUCT) { + prod32 = _mm512_madd_epi16(c1, c2); + } else { + __m512i diff = _mm512_sub_epi16(c1, c2); + prod32 = _mm512_madd_epi16(diff, diff); + } + accu = _mm512_add_epi32(accu, prod32); + } + // Horizontally add elements of accu + return _mm512_reduce_add_epi32(accu); + } + + void set_query(const float* x) final { + for (int i = 0; i < d; i++) { + tmp[i] = int(x[i]); + } + } + + int compute_distance(const float* x, const uint8_t* code) { + set_query(x); + return compute_code_distance(tmp.data(), code); + } + + float symmetric_dis(idx_t i, idx_t j) override { + return compute_code_distance( + codes + i * code_size, codes + j * code_size); + } + + float query_to_code(const uint8_t* code) const final { + return compute_code_distance(tmp.data(), code); + } +}; + +// explicit instantiation + +template ScalarQuantizer::SQuantizer* select_quantizer_1( + QuantizerType qtype, + size_t d, + const std::vector& trained); + +template SQDistanceComputer* select_distance_computer_1( + MetricType metric_type, + QuantizerType qtype, + size_t d, + const std::vector& trained); + +template InvertedListScanner* sel0_InvertedListScanner( + MetricType mt, + const ScalarQuantizer* sq, + const Index* quantizer, + bool store_pairs, + const IDSelector* sel, + bool by_residual); + +} // namespace scalar_quantizer + +} // namespace faiss + +#endif // COMPILE_SIMD_AVX512 diff --git a/faiss/impl/scalar_quantizer/impl-neon.cpp b/faiss/impl/scalar_quantizer/impl-neon.cpp new file mode 100644 index 0000000000..5ec3d81847 --- /dev/null +++ b/faiss/impl/scalar_quantizer/impl-neon.cpp @@ -0,0 +1,377 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * + * This source code is licensed under the MIT license found in the + * LICENSE file in the root directory of this source tree. + */ + +#ifdef COMPILE_SIMD_NEON + +#include +#include +#include + +#if defined(__aarch64__) +#if defined(__GNUC__) && __GNUC__ < 8 +#warning \ + "Cannot enable NEON optimizations in scalar quantizer if the compiler is GCC<8" +#else +#define USE_NEON +#endif +#endif + +namespace faiss { + +namespace scalar_quantizer { +/******************************** Codec specializations */ + +template <> +struct Codec8bit { + static FAISS_ALWAYS_INLINE decode_8_components(const uint8_t* code, int i) { + float32_t result[8] = {}; + for (size_t j = 0; j < 8; j++) { + result[j] = decode_component(code, i + j); + } + float32x4_t res1 = vld1q_f32(result); + float32x4_t res2 = vld1q_f32(result + 4); + return simd8float32(float32x4x2_t{res1, res2}); + } +}; + +template <> +struct Codec4bit { + static FAISS_ALWAYS_INLINE simd8float32 + decode_8_components(const uint8_t* code, int i) { + float32_t result[8] = {}; + for (size_t j = 0; j < 8; j++) { + result[j] = decode_component(code, i + j); + } + float32x4_t res1 = vld1q_f32(result); + float32x4_t res2 = vld1q_f32(result + 4); + return simd8float32({res1, res2}); + } +}; + +template <> +struct Codec6bit { + static FAISS_ALWAYS_INLINE simd8float32 + decode_8_components(const uint8_t* code, int i) { + float32_t result[8] = {}; + for (size_t j = 0; j < 8; j++) { + result[j] = decode_component(code, i + j); + } + float32x4_t res1 = vld1q_f32(result); + float32x4_t res2 = vld1q_f32(result + 4); + return simd8float32(float32x4x2_t({res1, res2})); + } +}; +/******************************** Quantizatoin specializations */ + +template +struct QuantizerT + : QuantizerT { + QuantizerT(size_t d, const std::vector& trained) + : QuantizerT(d, trained) {} + + FAISS_ALWAYS_INLINE simd8float32 + reconstruct_8_components(const uint8_t* code, int i) const { + float32x4x2_t xi = Codec::decode_8_components(code, i); + return simd8float32(float32x4x2_t( + {vfmaq_f32( + vdupq_n_f32(this->vmin), + xi.val[0], + vdupq_n_f32(this->vdiff)), + vfmaq_f32( + vdupq_n_f32(this->vmin), + xi.val[1], + vdupq_n_f32(this->vdiff))})); + } +}; + +template +struct QuantizerT + : QuantizerT { + QuantizerT(size_t d, const std::vector& trained) + : QuantizerT(d, trained) {} + + FAISS_ALWAYS_INLINE simd8float32 + reconstruct_8_components(const uint8_t* code, int i) const { + float32x4x2_t xi = Codec::decode_8_components(code, i).data; + + float32x4x2_t vmin_8 = vld1q_f32_x2(this->vmin + i); + float32x4x2_t vdiff_8 = vld1q_f32_x2(this->vdiff + i); + + return simd8float32( + {vfmaq_f32(vmin_8.val[0], xi.val[0], vdiff_8.val[0]), + vfmaq_f32(vmin_8.val[1], xi.val[1], vdiff_8.val[1])}); + } +}; + +template <> +struct QuantizerFP16 : QuantizerFP16 { + QuantizerFP16(size_t d, const std::vector& trained) + : QuantizerFP16<1>(d, trained) {} + + FAISS_ALWAYS_INLINE simd8float32 + reconstruct_8_components(const uint8_t* code, int i) const { + uint16x4x2_t codei = vld1_u16_x2((const uint16_t*)(code + 2 * i)); + return simd8float32( + {vcvt_f32_f16(vreinterpret_f16_u16(codei.val[0])), + vcvt_f32_f16(vreinterpret_f16_u16(codei.val[1]))}); + } +}; + +template <> +struct QuantizerBF16 : QuantizerBF16 { + QuantizerBF16(size_t d, const std::vector& trained) + : QuantizerBF16<1>(d, trained) {} + + FAISS_ALWAYS_INLINE simd8float32 + reconstruct_8_components(const uint8_t* code, int i) const { + uint16x4x2_t codei = vld1_u16_x2((const uint16_t*)(code + 2 * i)); + return simd8float32( + {vreinterpretq_f32_u32( + vshlq_n_u32(vmovl_u16(codei.val[0]), 16)), + vreinterpretq_f32_u32( + vshlq_n_u32(vmovl_u16(codei.val[1]), 16))}); + } +}; + +template <> +struct Quantizer8bitDirect + : Quantizer8bitDirect { + Quantizer8bitDirect(size_t d, const std::vector& trained) + : Quantizer8bitDirect<1>(d, trained) {} + + FAISS_ALWAYS_INLINE simd8float32 + reconstruct_8_components(const uint8_t* code, int i) const { + uint8x8_t x8 = vld1_u8((const uint8_t*)(code + i)); + uint16x8_t y8 = vmovl_u8(x8); + uint16x4_t y8_0 = vget_low_u16(y8); + uint16x4_t y8_1 = vget_high_u16(y8); + + // convert uint16 -> uint32 -> fp32 + return simd8float32( + {vcvtq_f32_u32(vmovl_u16(y8_0)), + vcvtq_f32_u32(vmovl_u16(y8_1))}); + } +}; + +template <> +struct Quantizer8bitDirectSigned + : Quantizer8bitDirectSigned { + Quantizer8bitDirectSigned(size_t d, const std::vector& trained) + : Quantizer8bitDirectSigned<1>(d, trained) {} + + FAISS_ALWAYS_INLINE simd8float32 + reconstruct_8_components(const uint8_t* code, int i) const { + uint8x8_t x8 = vld1_u8((const uint8_t*)(code + i)); + uint16x8_t y8 = vmovl_u8(x8); // convert uint8 -> uint16 + uint16x4_t y8_0 = vget_low_u16(y8); + uint16x4_t y8_1 = vget_high_u16(y8); + + float32x4_t z8_0 = vcvtq_f32_u32( + vmovl_u16(y8_0)); // convert uint16 -> uint32 -> fp32 + float32x4_t z8_1 = vcvtq_f32_u32(vmovl_u16(y8_1)); + + // subtract 128 to convert into signed numbers + return simd8float32( + {vsubq_f32(z8_0, vmovq_n_f32(128.0)), + vsubq_f32(z8_1, vmovq_n_f32(128.0))}); + } +}; + +/****************************************** Specialization of similarities */ + +template <> +struct SimilarityL2 { + static constexpr int simdwidth = 8; + static constexpr MetricType metric_type = METRIC_L2; + + const float *y, *yi; + explicit SimilarityL2(const float* y) : y(y) {} + simd8float32 accu8; + + FAISS_ALWAYS_INLINE void begin_8() { + accu8 = {vdupq_n_f32(0.0f), vdupq_n_f32(0.0f)}; + yi = y; + } + + FAISS_ALWAYS_INLINE void add_8_components(simd8float32 x) { + float32x4x2_t yiv = vld1q_f32_x2(yi); + yi += 8; + + float32x4_t sub0 = vsubq_f32(yiv.val[0], x.val[0]); + float32x4_t sub1 = vsubq_f32(yiv.val[1], x.val[1]); + + float32x4_t accu8_0 = vfmaq_f32(accu8.val[0], sub0, sub0); + float32x4_t accu8_1 = vfmaq_f32(accu8.val[1], sub1, sub1); + + accu8 = simd8float32({accu8_0, accu8_1}); + } + + FAISS_ALWAYS_INLINE void add_8_components_2( + simd8float32 x, + simd8float32 y) { + float32x4_t sub0 = vsubq_f32(y.val[0], x.val[0]); + float32x4_t sub1 = vsubq_f32(y.val[1], x.val[1]); + + float32x4_t accu8_0 = vfmaq_f32(accu8.val[0], sub0, sub0); + float32x4_t accu8_1 = vfmaq_f32(accu8.val[1], sub1, sub1); + + accu8 = simd8float32({accu8_0, accu8_1}); + } + + FAISS_ALWAYS_INLINE float result_8() { + float32x4_t sum_0 = vpaddq_f32(accu8.data.val[0], accu8.data.val[0]); + float32x4_t sum_1 = vpaddq_f32(accu8.data.val[1], accu8.data.val[1]); + + float32x4_t sum2_0 = vpaddq_f32(sum_0, sum_0); + float32x4_t sum2_1 = vpaddq_f32(sum_1, sum_1); + return vgetq_lane_f32(sum2_0, 0) + vgetq_lane_f32(sum2_1, 0); + } +}; + +template <> +struct SimilarityIP { + static constexpr int simdwidth = 8; + static constexpr MetricType metric_type = METRIC_INNER_PRODUCT; + + const float *y, *yi; + + explicit SimilarityIP(const float* y) : y(y) {} + float32x4x2_t accu8; + + FAISS_ALWAYS_INLINE void begin_8() { + accu8 = {vdupq_n_f32(0.0f), vdupq_n_f32(0.0f)}; + yi = y; + } + + FAISS_ALWAYS_INLINE void add_8_components(float32x4x2_t x) { + float32x4x2_t yiv = vld1q_f32_x2(yi); + yi += 8; + + float32x4_t accu8_0 = vfmaq_f32(accu8.val[0], yiv.val[0], x.val[0]); + float32x4_t accu8_1 = vfmaq_f32(accu8.val[1], yiv.val[1], x.val[1]); + accu8 = {accu8_0, accu8_1}; + } + + FAISS_ALWAYS_INLINE void add_8_components_2( + float32x4x2_t x1, + float32x4x2_t x2) { + float32x4_t accu8_0 = vfmaq_f32(accu8.val[0], x1.val[0], x2.val[0]); + float32x4_t accu8_1 = vfmaq_f32(accu8.val[1], x1.val[1], x2.val[1]); + accu8 = {accu8_0, accu8_1}; + } + + FAISS_ALWAYS_INLINE float result_8() { + float32x4x2_t sum = { + vpaddq_f32(accu8.val[0], accu8.val[0]), + vpaddq_f32(accu8.val[1], accu8.val[1])}; + + float32x4x2_t sum2 = { + vpaddq_f32(sum.val[0], sum.val[0]), + vpaddq_f32(sum.val[1], sum.val[1])}; + return vgetq_lane_f32(sum2.val[0], 0) + vgetq_lane_f32(sum2.val[1], 0); + } +}; + +/****************************************** Specialization of distance computers + */ + +// this is the same code as the AVX2 version... Possible to mutualize? +template +struct DCTemplate : SQDistanceComputer { + using Sim = Similarity; + + Quantizer quant; + + DCTemplate(size_t d, const std::vector& trained) + : quant(d, trained) {} + + float compute_distance(const float* x, const uint8_t* code) const { + Similarity sim(x); + sim.begin_8(); + for (size_t i = 0; i < quant.d; i += 8) { + simd8float32 xi = quant.reconstruct_8_components(code, i); + sim.add_8_components(xi); + } + return sim.result_8(); + } + + float compute_code_distance(const uint8_t* code1, const uint8_t* code2) + const { + Similarity sim(nullptr); + sim.begin_8(); + for (size_t i = 0; i < quant.d; i += 8) { + simd8float32 x1 = quant.reconstruct_8_components(code1, i); + simd8float32 x2 = quant.reconstruct_8_components(code2, i); + sim.add_8_components_2(x1, x2); + } + return sim.result_8(); + } + + void set_query(const float* x) final { + q = x; + } + + float symmetric_dis(idx_t i, idx_t j) override { + return compute_code_distance( + codes + i * code_size, codes + j * code_size); + } + + float query_to_code(const uint8_t* code) const final { + return compute_distance(q, code); + } +}; + +template +struct DistanceComputerByte + : SQDistanceComputer { + using Sim = Similarity; + + int d; + std::vector tmp; + + DistanceComputerByte(int d, const std::vector&) : d(d), tmp(d) {} + + int compute_code_distance(const uint8_t* code1, const uint8_t* code2) + const { + int accu = 0; + for (int i = 0; i < d; i++) { + if (Sim::metric_type == METRIC_INNER_PRODUCT) { + accu += int(code1[i]) * code2[i]; + } else { + int diff = int(code1[i]) - code2[i]; + accu += diff * diff; + } + } + return accu; + } + + void set_query(const float* x) final { + for (int i = 0; i < d; i++) { + tmp[i] = int(x[i]); + } + } + + int compute_distance(const float* x, const uint8_t* code) { + set_query(x); + return compute_code_distance(tmp.data(), code); + } + + float symmetric_dis(idx_t i, idx_t j) override { + return compute_code_distance( + codes + i * code_size, codes + j * code_size); + } + + float query_to_code(const uint8_t* code) const final { + return compute_code_distance(tmp.data(), code); + } +}; + +} // namespace scalar_quantizer + +} // namespace faiss + +#endif // COMPILE_SIMD_NEON diff --git a/faiss/impl/scalar_quantizer/quantizers.h b/faiss/impl/scalar_quantizer/quantizers.h index a4abf058c6..ec0903a3e7 100644 --- a/faiss/impl/scalar_quantizer/quantizers.h +++ b/faiss/impl/scalar_quantizer/quantizers.h @@ -8,28 +8,37 @@ #pragma once #include +#include +#include + +#include + +#include +#include namespace faiss { namespace scalar_quantizer { +using QuantizerType = ScalarQuantizer::QuantizerType; + /******************************************************************* * Quantizer: normalizes scalar vector components, then passes them * through a codec *******************************************************************/ -enum class QuantizerTemplateScaling { UNIFORM = 0, NON_UNIFORM = 1 }; +enum class QScaling { UNIFORM = 0, NON_UNIFORM = 1 }; -template -struct QuantizerTemplate {}; +template +struct QuantizerT {}; template -struct QuantizerTemplate +struct QuantizerT : ScalarQuantizer::SQuantizer { const size_t d; const float vmin, vdiff; - QuantizerTemplate(size_t d, const std::vector& trained) + QuantizerT(size_t d, const std::vector& trained) : d(d), vmin(trained[0]), vdiff(trained[1]) {} void encode_vector(const float* x, uint8_t* code) const final { @@ -62,78 +71,13 @@ struct QuantizerTemplate } }; -#if defined(__AVX512F__) - -template -struct QuantizerTemplate - : QuantizerTemplate { - QuantizerTemplate(size_t d, const std::vector& trained) - : QuantizerTemplate( - d, - trained) {} - - FAISS_ALWAYS_INLINE simd16float32 - reconstruct_16_components(const uint8_t* code, int i) const { - __m512 xi = Codec::decode_16_components(code, i); - return simd16float32(_mm512_fmadd_ps( - xi, _mm512_set1_ps(this->vdiff), _mm512_set1_ps(this->vmin))); - } -}; - -#elif defined(__AVX2__) - template -struct QuantizerTemplate - : QuantizerTemplate { - QuantizerTemplate(size_t d, const std::vector& trained) - : QuantizerTemplate( - d, - trained) {} - - FAISS_ALWAYS_INLINE simd8float32 - reconstruct_8_components(const uint8_t* code, int i) const { - __m256 xi = Codec::decode_8_components(code, i).f; - return simd8float32(_mm256_fmadd_ps( - xi, _mm256_set1_ps(this->vdiff), _mm256_set1_ps(this->vmin))); - } -}; - -#endif - -#ifdef USE_NEON - -template -struct QuantizerTemplate - : QuantizerTemplate { - QuantizerTemplate(size_t d, const std::vector& trained) - : QuantizerTemplate( - d, - trained) {} - - FAISS_ALWAYS_INLINE simd8float32 - reconstruct_8_components(const uint8_t* code, int i) const { - float32x4x2_t xi = Codec::decode_8_components(code, i); - return simd8float32(float32x4x2_t( - {vfmaq_f32( - vdupq_n_f32(this->vmin), - xi.val[0], - vdupq_n_f32(this->vdiff)), - vfmaq_f32( - vdupq_n_f32(this->vmin), - xi.val[1], - vdupq_n_f32(this->vdiff))})); - } -}; - -#endif - -template -struct QuantizerTemplate +struct QuantizerT : ScalarQuantizer::SQuantizer { const size_t d; const float *vmin, *vdiff; - QuantizerTemplate(size_t d, const std::vector& trained) + QuantizerT(size_t d, const std::vector& trained) : d(d), vmin(trained.data()), vdiff(trained.data() + d) {} void encode_vector(const float* x, uint8_t* code) const final { @@ -166,85 +110,19 @@ struct QuantizerTemplate } }; -#if defined(__AVX512F__) - -template -struct QuantizerTemplate - : QuantizerTemplate { - QuantizerTemplate(size_t d, const std::vector& trained) - : QuantizerTemplate< - Codec, - QuantizerTemplateScaling::NON_UNIFORM, - 1>(d, trained) {} - - FAISS_ALWAYS_INLINE simd16float32 - reconstruct_16_components(const uint8_t* code, int i) const { - __m512 xi = Codec::decode_16_components(code, i).f; - return simd16float32(_mm512_fmadd_ps( - xi, - _mm512_loadu_ps(this->vdiff + i), - _mm512_loadu_ps(this->vmin + i))); - } -}; - -#elif defined(__AVX2__) - -template -struct QuantizerTemplate - : QuantizerTemplate { - QuantizerTemplate(size_t d, const std::vector& trained) - : QuantizerTemplate< - Codec, - QuantizerTemplateScaling::NON_UNIFORM, - 1>(d, trained) {} - - FAISS_ALWAYS_INLINE simd8float32 - reconstruct_8_components(const uint8_t* code, int i) const { - __m256 xi = Codec::decode_8_components(code, i).f; - return simd8float32(_mm256_fmadd_ps( - xi, - _mm256_loadu_ps(this->vdiff + i), - _mm256_loadu_ps(this->vmin + i))); - } -}; - -#endif - -#ifdef USE_NEON - -template -struct QuantizerTemplate - : QuantizerTemplate { - QuantizerTemplate(size_t d, const std::vector& trained) - : QuantizerTemplate< - Codec, - QuantizerTemplateScaling::NON_UNIFORM, - 1>(d, trained) {} - - FAISS_ALWAYS_INLINE simd8float32 - reconstruct_8_components(const uint8_t* code, int i) const { - float32x4x2_t xi = Codec::decode_8_components(code, i).data; - - float32x4x2_t vmin_8 = vld1q_f32_x2(this->vmin + i); - float32x4x2_t vdiff_8 = vld1q_f32_x2(this->vdiff + i); - - return simd8float32( - {vfmaq_f32(vmin_8.val[0], xi.val[0], vdiff_8.val[0]), - vfmaq_f32(vmin_8.val[1], xi.val[1], vdiff_8.val[1])}); - } -}; - -#endif +/******************************************************************* + * Quantizers that are not based on codecs + *******************************************************************/ /******************************************************************* * FP16 quantizer *******************************************************************/ -template +template struct QuantizerFP16 {}; template <> -struct QuantizerFP16<1> : ScalarQuantizer::SQuantizer { +struct QuantizerFP16 : ScalarQuantizer::SQuantizer { const size_t d; QuantizerFP16(size_t d, const std::vector& /* unused */) : d(d) {} @@ -267,64 +145,15 @@ struct QuantizerFP16<1> : ScalarQuantizer::SQuantizer { } }; -#if defined(USE_AVX512_F16C) - -template <> -struct QuantizerFP16<16> : QuantizerFP16<1> { - QuantizerFP16(size_t d, const std::vector& trained) - : QuantizerFP16<1>(d, trained) {} - - FAISS_ALWAYS_INLINE simd16float32 - reconstruct_16_components(const uint8_t* code, int i) const { - __m256i codei = _mm256_loadu_si256((const __m256i*)(code + 2 * i)); - return simd16float32(_mm512_cvtph_ps(codei)); - } -}; - -#endif - -#if defined(USE_F16C) - -template <> -struct QuantizerFP16<8> : QuantizerFP16<1> { - QuantizerFP16(size_t d, const std::vector& trained) - : QuantizerFP16<1>(d, trained) {} - - FAISS_ALWAYS_INLINE simd8float32 - reconstruct_8_components(const uint8_t* code, int i) const { - __m128i codei = _mm_loadu_si128((const __m128i*)(code + 2 * i)); - return simd8float32(_mm256_cvtph_ps(codei)); - } -}; - -#endif - -#ifdef USE_NEON - -template <> -struct QuantizerFP16<8> : QuantizerFP16<1> { - QuantizerFP16(size_t d, const std::vector& trained) - : QuantizerFP16<1>(d, trained) {} - - FAISS_ALWAYS_INLINE simd8float32 - reconstruct_8_components(const uint8_t* code, int i) const { - uint16x4x2_t codei = vld1_u16_x2((const uint16_t*)(code + 2 * i)); - return simd8float32( - {vcvt_f32_f16(vreinterpret_f16_u16(codei.val[0])), - vcvt_f32_f16(vreinterpret_f16_u16(codei.val[1]))}); - } -}; -#endif - /******************************************************************* * BF16 quantizer *******************************************************************/ -template +template struct QuantizerBF16 {}; template <> -struct QuantizerBF16<1> : ScalarQuantizer::SQuantizer { +struct QuantizerBF16 : ScalarQuantizer::SQuantizer { const size_t d; QuantizerBF16(size_t d, const std::vector& /* unused */) : d(d) {} @@ -347,67 +176,15 @@ struct QuantizerBF16<1> : ScalarQuantizer::SQuantizer { } }; -#if defined(__AVX512F__) - -template <> -struct QuantizerBF16<16> : QuantizerBF16<1> { - QuantizerBF16(size_t d, const std::vector& trained) - : QuantizerBF16<1>(d, trained) {} - FAISS_ALWAYS_INLINE simd16float32 - reconstruct_16_components(const uint8_t* code, int i) const { - __m256i code_256i = _mm256_loadu_si256((const __m256i*)(code + 2 * i)); - __m512i code_512i = _mm512_cvtepu16_epi32(code_256i); - code_512i = _mm512_slli_epi32(code_512i, 16); - return simd16float32(_mm512_castsi512_ps(code_512i)); - } -}; - -#elif defined(__AVX2__) - -template <> -struct QuantizerBF16<8> : QuantizerBF16<1> { - QuantizerBF16(size_t d, const std::vector& trained) - : QuantizerBF16<1>(d, trained) {} - - FAISS_ALWAYS_INLINE simd8float32 - reconstruct_8_components(const uint8_t* code, int i) const { - __m128i code_128i = _mm_loadu_si128((const __m128i*)(code + 2 * i)); - __m256i code_256i = _mm256_cvtepu16_epi32(code_128i); - code_256i = _mm256_slli_epi32(code_256i, 16); - return simd8float32(_mm256_castsi256_ps(code_256i)); - } -}; - -#endif - -#ifdef USE_NEON - -template <> -struct QuantizerBF16<8> : QuantizerBF16<1> { - QuantizerBF16(size_t d, const std::vector& trained) - : QuantizerBF16<1>(d, trained) {} - - FAISS_ALWAYS_INLINE simd8float32 - reconstruct_8_components(const uint8_t* code, int i) const { - uint16x4x2_t codei = vld1_u16_x2((const uint16_t*)(code + 2 * i)); - return simd8float32( - {vreinterpretq_f32_u32( - vshlq_n_u32(vmovl_u16(codei.val[0]), 16)), - vreinterpretq_f32_u32( - vshlq_n_u32(vmovl_u16(codei.val[1]), 16))}); - } -}; -#endif - /******************************************************************* * 8bit_direct quantizer *******************************************************************/ -template +template struct Quantizer8bitDirect {}; template <> -struct Quantizer8bitDirect<1> : ScalarQuantizer::SQuantizer { +struct Quantizer8bitDirect : ScalarQuantizer::SQuantizer { const size_t d; Quantizer8bitDirect(size_t d, const std::vector& /* unused */) @@ -431,70 +208,16 @@ struct Quantizer8bitDirect<1> : ScalarQuantizer::SQuantizer { } }; -#if defined(__AVX512F__) - -template <> -struct Quantizer8bitDirect<16> : Quantizer8bitDirect<1> { - Quantizer8bitDirect(size_t d, const std::vector& trained) - : Quantizer8bitDirect<1>(d, trained) {} - - FAISS_ALWAYS_INLINE simd16float32 - reconstruct_16_components(const uint8_t* code, int i) const { - __m128i x16 = _mm_loadu_si128((__m128i*)(code + i)); // 16 * int8 - __m512i y16 = _mm512_cvtepu8_epi32(x16); // 16 * int32 - return simd16float32(_mm512_cvtepi32_ps(y16)); // 16 * float32 - } -}; - -#elif defined(__AVX2__) - -template <> -struct Quantizer8bitDirect<8> : Quantizer8bitDirect<1> { - Quantizer8bitDirect(size_t d, const std::vector& trained) - : Quantizer8bitDirect<1>(d, trained) {} - - FAISS_ALWAYS_INLINE simd8float32 - reconstruct_8_components(const uint8_t* code, int i) const { - __m128i x8 = _mm_loadl_epi64((__m128i*)(code + i)); // 8 * int8 - __m256i y8 = _mm256_cvtepu8_epi32(x8); // 8 * int32 - return simd8float32(_mm256_cvtepi32_ps(y8)); // 8 * float32 - } -}; - -#endif - -#ifdef USE_NEON - -template <> -struct Quantizer8bitDirect<8> : Quantizer8bitDirect<1> { - Quantizer8bitDirect(size_t d, const std::vector& trained) - : Quantizer8bitDirect<1>(d, trained) {} - - FAISS_ALWAYS_INLINE simd8float32 - reconstruct_8_components(const uint8_t* code, int i) const { - uint8x8_t x8 = vld1_u8((const uint8_t*)(code + i)); - uint16x8_t y8 = vmovl_u8(x8); - uint16x4_t y8_0 = vget_low_u16(y8); - uint16x4_t y8_1 = vget_high_u16(y8); - - // convert uint16 -> uint32 -> fp32 - return simd8float32( - {vcvtq_f32_u32(vmovl_u16(y8_0)), - vcvtq_f32_u32(vmovl_u16(y8_1))}); - } -}; - -#endif - /******************************************************************* * 8bit_direct_signed quantizer *******************************************************************/ -template +template struct Quantizer8bitDirectSigned {}; template <> -struct Quantizer8bitDirectSigned<1> : ScalarQuantizer::SQuantizer { +struct Quantizer8bitDirectSigned + : ScalarQuantizer::SQuantizer { const size_t d; Quantizer8bitDirectSigned(size_t d, const std::vector& /* unused */) @@ -518,68 +241,52 @@ struct Quantizer8bitDirectSigned<1> : ScalarQuantizer::SQuantizer { } }; -#if defined(__AVX512F__) - -template <> -struct Quantizer8bitDirectSigned<16> : Quantizer8bitDirectSigned<1> { - Quantizer8bitDirectSigned(size_t d, const std::vector& trained) - : Quantizer8bitDirectSigned<1>(d, trained) {} - - FAISS_ALWAYS_INLINE simd16float32 - reconstruct_16_components(const uint8_t* code, int i) const { - __m128i x16 = _mm_loadu_si128((__m128i*)(code + i)); // 16 * int8 - __m512i y16 = _mm512_cvtepu8_epi32(x16); // 16 * int32 - __m512i c16 = _mm512_set1_epi32(128); - __m512i z16 = _mm512_sub_epi32(y16, c16); // subtract 128 from all lanes - return simd16float32(_mm512_cvtepi32_ps(z16)); // 16 * float32 - } -}; - -#elif defined(__AVX2__) - -template <> -struct Quantizer8bitDirectSigned<8> : Quantizer8bitDirectSigned<1> { - Quantizer8bitDirectSigned(size_t d, const std::vector& trained) - : Quantizer8bitDirectSigned<1>(d, trained) {} - - FAISS_ALWAYS_INLINE simd8float32 - reconstruct_8_components(const uint8_t* code, int i) const { - __m128i x8 = _mm_loadl_epi64((__m128i*)(code + i)); // 8 * int8 - __m256i y8 = _mm256_cvtepu8_epi32(x8); // 8 * int32 - __m256i c8 = _mm256_set1_epi32(128); - __m256i z8 = _mm256_sub_epi32(y8, c8); // subtract 128 from all lanes - return simd8float32(_mm256_cvtepi32_ps(z8)); // 8 * float32 - } -}; - -#endif - -#ifdef USE_NEON - -template <> -struct Quantizer8bitDirectSigned<8> : Quantizer8bitDirectSigned<1> { - Quantizer8bitDirectSigned(size_t d, const std::vector& trained) - : Quantizer8bitDirectSigned<1>(d, trained) {} - - FAISS_ALWAYS_INLINE simd8float32 - reconstruct_8_components(const uint8_t* code, int i) const { - uint8x8_t x8 = vld1_u8((const uint8_t*)(code + i)); - uint16x8_t y8 = vmovl_u8(x8); // convert uint8 -> uint16 - uint16x4_t y8_0 = vget_low_u16(y8); - uint16x4_t y8_1 = vget_high_u16(y8); - - float32x4_t z8_0 = vcvtq_f32_u32( - vmovl_u16(y8_0)); // convert uint16 -> uint32 -> fp32 - float32x4_t z8_1 = vcvtq_f32_u32(vmovl_u16(y8_1)); - - // subtract 128 to convert into signed numbers - return simd8float32( - {vsubq_f32(z8_0, vmovq_n_f32(128.0)), - vsubq_f32(z8_1, vmovq_n_f32(128.0))}); - } -}; - -#endif +template +ScalarQuantizer::SQuantizer* select_quantizer_1( + QuantizerType qtype, + size_t d, + const std::vector& trained) { + // constexpr SIMDLevel SL = INSTANCIATE_SIMD_LEVEL; + constexpr QScaling NU = QScaling::NON_UNIFORM; + constexpr QScaling U = QScaling::UNIFORM; + switch (qtype) { + case ScalarQuantizer::QT_8bit: + return new QuantizerT, NU, SL>(d, trained); + + case ScalarQuantizer::QT_6bit: + return new QuantizerT, NU, SL>(d, trained); + case ScalarQuantizer::QT_4bit: + return new QuantizerT, NU, SL>(d, trained); + case ScalarQuantizer::QT_8bit_uniform: + return new QuantizerT, U, SL>(d, trained); + case ScalarQuantizer::QT_4bit_uniform: + return new QuantizerT, U, SL>(d, trained); + case ScalarQuantizer::QT_fp16: + return new QuantizerFP16(d, trained); + case ScalarQuantizer::QT_bf16: + return new QuantizerBF16(d, trained); + case ScalarQuantizer::QT_8bit_direct: + return new Quantizer8bitDirect(d, trained); + case ScalarQuantizer::QT_8bit_direct_signed: + return new Quantizer8bitDirectSigned(d, trained); + default: + FAISS_THROW_MSG("unknown qtype"); + return nullptr; + } +} + +// prevent implicit instanciation +extern template ScalarQuantizer::SQuantizer* select_quantizer_1< + SIMDLevel::AVX2>( + QuantizerType qtype, + size_t d, + const std::vector& trained); + +extern template ScalarQuantizer::SQuantizer* select_quantizer_1< + SIMDLevel::AVX512>( + QuantizerType qtype, + size_t d, + const std::vector& trained); } // namespace scalar_quantizer diff --git a/faiss/impl/scalar_quantizer/scanners.h b/faiss/impl/scalar_quantizer/scanners.h new file mode 100644 index 0000000000..df06358a9a --- /dev/null +++ b/faiss/impl/scalar_quantizer/scanners.h @@ -0,0 +1,356 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * + * This source code is licensed under the MIT license found in the + * LICENSE file in the root directory of this source tree. + */ + +#pragma once + +#include +#include +#include +#include +#include +#include + +namespace faiss { + +namespace scalar_quantizer { + +/******************************************************************* + * IndexScalarQuantizer/IndexIVFScalarQuantizer scanner object + * + * It is an InvertedListScanner, but is designed to work with + * IndexScalarQuantizer as well. + ********************************************************************/ + +template +struct IVFSQScannerIP : InvertedListScanner { + DCClass dc; + bool by_residual; + + float accu0; /// added to all distances + + IVFSQScannerIP( + int d, + const std::vector& trained, + size_t code_size, + bool store_pairs, + const IDSelector* sel, + bool by_residual) + : dc(d, trained), by_residual(by_residual), accu0(0) { + this->store_pairs = store_pairs; + this->sel = sel; + this->code_size = code_size; + this->keep_max = true; + } + + void set_query(const float* query) override { + dc.set_query(query); + } + + void set_list(idx_t list_no, float coarse_dis) override { + this->list_no = list_no; + accu0 = by_residual ? coarse_dis : 0; + } + + float distance_to_code(const uint8_t* code) const final { + return accu0 + dc.query_to_code(code); + } + + size_t scan_codes( + size_t list_size, + const uint8_t* codes, + const idx_t* ids, + float* simi, + idx_t* idxi, + size_t k) const override { + size_t nup = 0; + + for (size_t j = 0; j < list_size; j++, codes += code_size) { + if (use_sel && !sel->is_member(use_sel == 1 ? ids[j] : j)) { + continue; + } + + float accu = accu0 + dc.query_to_code(codes); + + if (accu > simi[0]) { + int64_t id = store_pairs ? (list_no << 32 | j) : ids[j]; + minheap_replace_top(k, simi, idxi, accu, id); + nup++; + } + } + return nup; + } + + void scan_codes_range( + size_t list_size, + const uint8_t* codes, + const idx_t* ids, + float radius, + RangeQueryResult& res) const override { + for (size_t j = 0; j < list_size; j++, codes += code_size) { + if (use_sel && !sel->is_member(use_sel == 1 ? ids[j] : j)) { + continue; + } + + float accu = accu0 + dc.query_to_code(codes); + if (accu > radius) { + int64_t id = store_pairs ? (list_no << 32 | j) : ids[j]; + res.add(accu, id); + } + } + } +}; + +/* use_sel = 0: don't check selector + * = 1: check on ids[j] + * = 2: check in j directly (normally ids is nullptr and store_pairs) + */ +template +struct IVFSQScannerL2 : InvertedListScanner { + DCClass dc; + + bool by_residual; + const Index* quantizer; + const float* x; /// current query + + std::vector tmp; + + IVFSQScannerL2( + int d, + const std::vector& trained, + size_t code_size, + const Index* quantizer, + bool store_pairs, + const IDSelector* sel, + bool by_residual) + : dc(d, trained), + by_residual(by_residual), + quantizer(quantizer), + x(nullptr), + tmp(d) { + this->store_pairs = store_pairs; + this->sel = sel; + this->code_size = code_size; + } + + void set_query(const float* query) override { + x = query; + if (!quantizer) { + dc.set_query(query); + } + } + + void set_list(idx_t list_no, float /*coarse_dis*/) override { + this->list_no = list_no; + if (by_residual) { + // shift of x_in wrt centroid + quantizer->compute_residual(x, tmp.data(), list_no); + dc.set_query(tmp.data()); + } else { + dc.set_query(x); + } + } + + float distance_to_code(const uint8_t* code) const final { + return dc.query_to_code(code); + } + + size_t scan_codes( + size_t list_size, + const uint8_t* codes, + const idx_t* ids, + float* simi, + idx_t* idxi, + size_t k) const override { + size_t nup = 0; + for (size_t j = 0; j < list_size; j++, codes += code_size) { + if (use_sel && !sel->is_member(use_sel == 1 ? ids[j] : j)) { + continue; + } + + float dis = dc.query_to_code(codes); + + if (dis < simi[0]) { + int64_t id = store_pairs ? (list_no << 32 | j) : ids[j]; + maxheap_replace_top(k, simi, idxi, dis, id); + nup++; + } + } + return nup; + } + + void scan_codes_range( + size_t list_size, + const uint8_t* codes, + const idx_t* ids, + float radius, + RangeQueryResult& res) const override { + for (size_t j = 0; j < list_size; j++, codes += code_size) { + if (use_sel && !sel->is_member(use_sel == 1 ? ids[j] : j)) { + continue; + } + + float dis = dc.query_to_code(codes); + if (dis < radius) { + int64_t id = store_pairs ? (list_no << 32 | j) : ids[j]; + res.add(dis, id); + } + } + } +}; + +/* Select the right implementation by dispatching to templatized versions */ + +template +InvertedListScanner* sel3_InvertedListScanner( + const ScalarQuantizer* sq, + const Index* quantizer, + bool store_pairs, + const IDSelector* sel, + bool r) { + if (DCClass::Sim::metric_type == METRIC_L2) { + return new IVFSQScannerL2( + sq->d, + sq->trained, + sq->code_size, + quantizer, + store_pairs, + sel, + r); + } else if (DCClass::Sim::metric_type == METRIC_INNER_PRODUCT) { + return new IVFSQScannerIP( + sq->d, sq->trained, sq->code_size, store_pairs, sel, r); + } else { + FAISS_THROW_MSG("unsupported metric type"); + } +} + +template +InvertedListScanner* sel2_InvertedListScanner( + const ScalarQuantizer* sq, + const Index* quantizer, + bool store_pairs, + const IDSelector* sel, + bool r) { + if (sel) { + if (store_pairs) { + return sel3_InvertedListScanner( + sq, quantizer, store_pairs, sel, r); + } else { + return sel3_InvertedListScanner( + sq, quantizer, store_pairs, sel, r); + } + } else { + return sel3_InvertedListScanner( + sq, quantizer, store_pairs, sel, r); + } +} + +template +InvertedListScanner* sel12_InvertedListScanner( + const ScalarQuantizer* sq, + const Index* quantizer, + bool store_pairs, + const IDSelector* sel, + bool r) { + constexpr SIMDLevel SL = Similarity::SIMD_LEVEL; + using QuantizerClass = QuantizerT; + using DCClass = DCTemplate; + return sel2_InvertedListScanner( + sq, quantizer, store_pairs, sel, r); +} + +template +InvertedListScanner* sel1_InvertedListScanner( + const ScalarQuantizer* sq, + const Index* quantizer, + bool store_pairs, + const IDSelector* sel, + bool r) { + constexpr SIMDLevel SL = Sim::SIMD_LEVEL; + constexpr QScaling NU = QScaling::NON_UNIFORM; + constexpr QScaling U = QScaling::UNIFORM; + + switch (sq->qtype) { + case ScalarQuantizer::QT_8bit_uniform: + return sel12_InvertedListScanner, U>( + sq, quantizer, store_pairs, sel, r); + case ScalarQuantizer::QT_4bit_uniform: + return sel12_InvertedListScanner, U>( + sq, quantizer, store_pairs, sel, r); + case ScalarQuantizer::QT_8bit: + return sel12_InvertedListScanner, NU>( + sq, quantizer, store_pairs, sel, r); + case ScalarQuantizer::QT_4bit: + return sel12_InvertedListScanner, NU>( + sq, quantizer, store_pairs, sel, r); + case ScalarQuantizer::QT_6bit: + return sel12_InvertedListScanner, NU>( + sq, quantizer, store_pairs, sel, r); + case ScalarQuantizer::QT_fp16: + return sel2_InvertedListScanner< + DCTemplate, Sim, SL>>( + sq, quantizer, store_pairs, sel, r); + + case ScalarQuantizer::QT_bf16: + return sel2_InvertedListScanner< + DCTemplate, Sim, SL>>( + sq, quantizer, store_pairs, sel, r); + case ScalarQuantizer::QT_8bit_direct: + return sel2_InvertedListScanner< + DCTemplate, Sim, SL>>( + sq, quantizer, store_pairs, sel, r); + + case ScalarQuantizer::QT_8bit_direct_signed: + return sel2_InvertedListScanner< + DCTemplate, Sim, SL>>( + sq, quantizer, store_pairs, sel, r); + default: + FAISS_THROW_MSG("unknown qtype"); + return nullptr; + } +} + +template +InvertedListScanner* sel0_InvertedListScanner( + MetricType mt, + const ScalarQuantizer* sq, + const Index* quantizer, + bool store_pairs, + const IDSelector* sel, + bool by_residual) { + if (mt == METRIC_L2) { + return sel1_InvertedListScanner>( + sq, quantizer, store_pairs, sel, by_residual); + } else if (mt == METRIC_INNER_PRODUCT) { + return sel1_InvertedListScanner>( + sq, quantizer, store_pairs, sel, by_residual); + } else { + FAISS_THROW_MSG("unsupported metric type"); + } +} + +// prevent implicit instantiation of the template when there are +// SIMD optimized versions... +extern template InvertedListScanner* sel0_InvertedListScanner( + MetricType mt, + const ScalarQuantizer* sq, + const Index* quantizer, + bool store_pairs, + const IDSelector* sel, + bool by_residual); + +extern template InvertedListScanner* sel0_InvertedListScanner< + SIMDLevel::AVX512>( + MetricType mt, + const ScalarQuantizer* sq, + const Index* quantizer, + bool store_pairs, + const IDSelector* sel, + bool by_residual); + +} // namespace scalar_quantizer +} // namespace faiss diff --git a/faiss/impl/scalar_quantizer/similarities.h b/faiss/impl/scalar_quantizer/similarities.h deleted file mode 100644 index 99e5b1c089..0000000000 --- a/faiss/impl/scalar_quantizer/similarities.h +++ /dev/null @@ -1,345 +0,0 @@ -/* - * Copyright (c) Meta Platforms, Inc. and affiliates. - * - * This source code is licensed under the MIT license found in the - * LICENSE file in the root directory of this source tree. - */ - -#pragma once - -#include - -namespace faiss { - -namespace scalar_quantizer { - -template -struct SimilarityL2 {}; - -template <> -struct SimilarityL2<1> { - static constexpr int simdwidth = 1; - static constexpr MetricType metric_type = METRIC_L2; - - const float *y, *yi; - - explicit SimilarityL2(const float* y) : y(y) {} - - /******* scalar accumulator *******/ - - float accu; - - FAISS_ALWAYS_INLINE void begin() { - accu = 0; - yi = y; - } - - FAISS_ALWAYS_INLINE void add_component(float x) { - float tmp = *yi++ - x; - accu += tmp * tmp; - } - - FAISS_ALWAYS_INLINE void add_component_2(float x1, float x2) { - float tmp = x1 - x2; - accu += tmp * tmp; - } - - FAISS_ALWAYS_INLINE float result() { - return accu; - } -}; - -#if defined(__AVX512F__) - -template <> -struct SimilarityL2<16> { - static constexpr int simdwidth = 16; - static constexpr MetricType metric_type = METRIC_L2; - - const float *y, *yi; - - explicit SimilarityL2(const float* y) : y(y) {} - simd16float32 accu16; - - FAISS_ALWAYS_INLINE void begin_16() { - accu16.clear(); - yi = y; - } - - FAISS_ALWAYS_INLINE void add_16_components(simd16float32 x) { - __m512 yiv = _mm512_loadu_ps(yi); - yi += 16; - __m512 tmp = _mm512_sub_ps(yiv, x.f); - accu16 = simd16float32(_mm512_fmadd_ps(tmp, tmp, accu16.f)); - } - - FAISS_ALWAYS_INLINE void add_16_components_2( - simd16float32 x, - simd16float32 y_2) { - __m512 tmp = _mm512_sub_ps(y_2.f, x.f); - accu16 = simd16float32(_mm512_fmadd_ps(tmp, tmp, accu16.f)); - } - - FAISS_ALWAYS_INLINE float result_16() { - // performs better than dividing into _mm256 and adding - return _mm512_reduce_add_ps(accu16.f); - } -}; - -#elif defined(__AVX2__) - -template <> -struct SimilarityL2<8> { - static constexpr int simdwidth = 8; - static constexpr MetricType metric_type = METRIC_L2; - - const float *y, *yi; - - explicit SimilarityL2(const float* y) : y(y) {} - simd8float32 accu8; - - FAISS_ALWAYS_INLINE void begin_8() { - accu8.clear(); - yi = y; - } - - FAISS_ALWAYS_INLINE void add_8_components(simd8float32 x) { - __m256 yiv = _mm256_loadu_ps(yi); - yi += 8; - __m256 tmp = _mm256_sub_ps(yiv, x.f); - accu8 = simd8float32(_mm256_fmadd_ps(tmp, tmp, accu8.f)); - } - - FAISS_ALWAYS_INLINE void add_8_components_2( - simd8float32 x, - simd8float32 y_2) { - __m256 tmp = _mm256_sub_ps(y_2.f, x.f); - accu8 = simd8float32(_mm256_fmadd_ps(tmp, tmp, accu8.f)); - } - - FAISS_ALWAYS_INLINE float result_8() { - const __m128 sum = _mm_add_ps( - _mm256_castps256_ps128(accu8.f), - _mm256_extractf128_ps(accu8.f, 1)); - const __m128 v0 = _mm_shuffle_ps(sum, sum, _MM_SHUFFLE(0, 0, 3, 2)); - const __m128 v1 = _mm_add_ps(sum, v0); - __m128 v2 = _mm_shuffle_ps(v1, v1, _MM_SHUFFLE(0, 0, 0, 1)); - const __m128 v3 = _mm_add_ps(v1, v2); - return _mm_cvtss_f32(v3); - } -}; - -#endif - -#ifdef USE_NEON -template <> -struct SimilarityL2<8> { - static constexpr int simdwidth = 8; - static constexpr MetricType metric_type = METRIC_L2; - - const float *y, *yi; - explicit SimilarityL2(const float* y) : y(y) {} - simd8float32 accu8; - - FAISS_ALWAYS_INLINE void begin_8() { - accu8 = {vdupq_n_f32(0.0f), vdupq_n_f32(0.0f)}; - yi = y; - } - - FAISS_ALWAYS_INLINE void add_8_components(simd8float32 x) { - float32x4x2_t yiv = vld1q_f32_x2(yi); - yi += 8; - - float32x4_t sub0 = vsubq_f32(yiv.val[0], x.val[0]); - float32x4_t sub1 = vsubq_f32(yiv.val[1], x.val[1]); - - float32x4_t accu8_0 = vfmaq_f32(accu8.val[0], sub0, sub0); - float32x4_t accu8_1 = vfmaq_f32(accu8.val[1], sub1, sub1); - - accu8 = simd8float32({accu8_0, accu8_1}); - } - - FAISS_ALWAYS_INLINE void add_8_components_2( - simd8float32 x, - simd8float32 y) { - float32x4_t sub0 = vsubq_f32(y.val[0], x.val[0]); - float32x4_t sub1 = vsubq_f32(y.val[1], x.val[1]); - - float32x4_t accu8_0 = vfmaq_f32(accu8.val[0], sub0, sub0); - float32x4_t accu8_1 = vfmaq_f32(accu8.val[1], sub1, sub1); - - accu8 = simd8float32({accu8_0, accu8_1}); - } - - FAISS_ALWAYS_INLINE float result_8() { - float32x4_t sum_0 = vpaddq_f32(accu8.data.val[0], accu8.data.val[0]); - float32x4_t sum_1 = vpaddq_f32(accu8.data.val[1], accu8.data.val[1]); - - float32x4_t sum2_0 = vpaddq_f32(sum_0, sum_0); - float32x4_t sum2_1 = vpaddq_f32(sum_1, sum_1); - return vgetq_lane_f32(sum2_0, 0) + vgetq_lane_f32(sum2_1, 0); - } -}; -#endif - -template -struct SimilarityIP {}; - -template <> -struct SimilarityIP<1> { - static constexpr int simdwidth = 1; - static constexpr MetricType metric_type = METRIC_INNER_PRODUCT; - const float *y, *yi; - - float accu; - - explicit SimilarityIP(const float* y) : y(y) {} - - FAISS_ALWAYS_INLINE void begin() { - accu = 0; - yi = y; - } - - FAISS_ALWAYS_INLINE void add_component(float x) { - accu += *yi++ * x; - } - - FAISS_ALWAYS_INLINE void add_component_2(float x1, float x2) { - accu += x1 * x2; - } - - FAISS_ALWAYS_INLINE float result() { - return accu; - } -}; - -#if defined(__AVX512F__) - -template <> -struct SimilarityIP<16> { - static constexpr int simdwidth = 16; - static constexpr MetricType metric_type = METRIC_INNER_PRODUCT; - - const float *y, *yi; - - float accu; - - explicit SimilarityIP(const float* y) : y(y) {} - - simd16float32 accu16; - - FAISS_ALWAYS_INLINE void begin_16() { - accu16.clear(); - yi = y; - } - - FAISS_ALWAYS_INLINE void add_16_components(__m512 x) { - __m512 yiv = _mm512_loadu_ps(yi); - yi += 16; - accu16.f = _mm512_fmadd_ps(yiv, x, accu16.f); - } - - FAISS_ALWAYS_INLINE void add_16_components_2(__m512 x1, __m512 x2) { - accu16.f = _mm512_fmadd_ps(x1, x2, accu16.f); - } - - FAISS_ALWAYS_INLINE float result_16() { - // performs better than dividing into _mm256 and adding - return _mm512_reduce_add_ps(accu16.f); - } -}; - -#elif defined(__AVX2__) - -template <> -struct SimilarityIP<8> { - static constexpr int simdwidth = 8; - static constexpr MetricType metric_type = METRIC_INNER_PRODUCT; - - const float *y, *yi; - - float accu; - - explicit SimilarityIP(const float* y) : y(y) {} - - simd8float32 accu8; - - FAISS_ALWAYS_INLINE void begin_8() { - accu8.clear(); - yi = y; - } - - FAISS_ALWAYS_INLINE void add_8_components(simd8float32 x) { - __m256 yiv = _mm256_loadu_ps(yi); - yi += 8; - accu8.f = _mm256_fmadd_ps(yiv, x.f, accu8.f); - } - - FAISS_ALWAYS_INLINE void add_8_components_2( - simd8float32 x1, - simd8float32 x2) { - accu8.f = _mm256_fmadd_ps(x1.f, x2.f, accu8.f); - } - - FAISS_ALWAYS_INLINE float result_8() { - const __m128 sum = _mm_add_ps( - _mm256_castps256_ps128(accu8.f), - _mm256_extractf128_ps(accu8.f, 1)); - const __m128 v0 = _mm_shuffle_ps(sum, sum, _MM_SHUFFLE(0, 0, 3, 2)); - const __m128 v1 = _mm_add_ps(sum, v0); - __m128 v2 = _mm_shuffle_ps(v1, v1, _MM_SHUFFLE(0, 0, 0, 1)); - const __m128 v3 = _mm_add_ps(v1, v2); - return _mm_cvtss_f32(v3); - } -}; -#endif - -#ifdef USE_NEON - -template <> -struct SimilarityIP<8> { - static constexpr int simdwidth = 8; - static constexpr MetricType metric_type = METRIC_INNER_PRODUCT; - - const float *y, *yi; - - explicit SimilarityIP(const float* y) : y(y) {} - float32x4x2_t accu8; - - FAISS_ALWAYS_INLINE void begin_8() { - accu8 = {vdupq_n_f32(0.0f), vdupq_n_f32(0.0f)}; - yi = y; - } - - FAISS_ALWAYS_INLINE void add_8_components(float32x4x2_t x) { - float32x4x2_t yiv = vld1q_f32_x2(yi); - yi += 8; - - float32x4_t accu8_0 = vfmaq_f32(accu8.val[0], yiv.val[0], x.val[0]); - float32x4_t accu8_1 = vfmaq_f32(accu8.val[1], yiv.val[1], x.val[1]); - accu8 = {accu8_0, accu8_1}; - } - - FAISS_ALWAYS_INLINE void add_8_components_2( - float32x4x2_t x1, - float32x4x2_t x2) { - float32x4_t accu8_0 = vfmaq_f32(accu8.val[0], x1.val[0], x2.val[0]); - float32x4_t accu8_1 = vfmaq_f32(accu8.val[1], x1.val[1], x2.val[1]); - accu8 = {accu8_0, accu8_1}; - } - - FAISS_ALWAYS_INLINE float result_8() { - float32x4x2_t sum = { - vpaddq_f32(accu8.val[0], accu8.val[0]), - vpaddq_f32(accu8.val[1], accu8.val[1])}; - - float32x4x2_t sum2 = { - vpaddq_f32(sum.val[0], sum.val[0]), - vpaddq_f32(sum.val[1], sum.val[1])}; - return vgetq_lane_f32(sum2.val[0], 0) + vgetq_lane_f32(sum2.val[1], 0); - } -}; -#endif - -} // namespace scalar_quantizer -} // namespace faiss diff --git a/faiss/utils/simd_impl/simdlib_avx512.h b/faiss/utils/simd_impl/simdlib_avx512.h index b1195c7e3c..80889dc508 100644 --- a/faiss/utils/simd_impl/simdlib_avx512.h +++ b/faiss/utils/simd_impl/simdlib_avx512.h @@ -293,4 +293,47 @@ struct simd64uint8 : simd512bit { } }; +struct simd16float32 : simd512bit { + simd16float32() {} + + explicit simd16float32(simd512bit x) : simd512bit(x) {} + + explicit simd16float32(__m512 x) : simd512bit(x) {} + + explicit simd16float32(float x) : simd512bit(_mm512_set1_ps(x)) {} + + explicit simd16float32(const float* x) + : simd16float32(_mm512_loadu_ps(x)) {} + + simd16float32 operator*(simd16float32 other) const { + return simd16float32(_mm512_mul_ps(f, other.f)); + } + + simd16float32 operator+(simd16float32 other) const { + return simd16float32(_mm512_add_ps(f, other.f)); + } + + simd16float32 operator-(simd16float32 other) const { + return simd16float32(_mm512_sub_ps(f, other.f)); + } + + simd16float32& operator+=(const simd16float32& other) { + f = _mm512_add_ps(f, other.f); + return *this; + } + + std::string tostring() const { + float tab[16]; + storeu((void*)tab); + char res[1000]; + char* ptr = res; + for (int i = 0; i < 16; i++) { + ptr += sprintf(ptr, "%g,", tab[i]); + } + // strip last , + ptr[-1] = 0; + return std::string(res); + } +}; + } // namespace faiss