diff --git a/faiss/utils/distances.cpp b/faiss/utils/distances.cpp index 0eb2ab5e04..90bac9b998 100644 --- a/faiss/utils/distances.cpp +++ b/faiss/utils/distances.cpp @@ -361,7 +361,212 @@ void exhaustive_L2sqr_blas( exhaustive_L2sqr_blas_default_impl(x, y, d, nx, ny, res); } -#ifdef __AVX2__ +#if defined(__AVX512F__) +void exhaustive_L2sqr_blas_cmax_avx512( + const float* x, + const float* y, + size_t d, + size_t nx, + size_t ny, + Top1BlockResultHandler>& res, + const float* y_norms) { + // BLAS does not like empty matrices + if (nx == 0 || ny == 0) { + return; + } + + /* block sizes */ + const size_t bs_x = distance_compute_blas_query_bs; + const size_t bs_y = distance_compute_blas_database_bs; + std::unique_ptr ip_block(new float[bs_x * bs_y]); + std::unique_ptr x_norms(new float[nx]); + std::unique_ptr del2; + + fvec_norms_L2sqr(x_norms.get(), x, d, nx); + + if (!y_norms) { + float* y_norms2 = new float[ny]; + del2.reset(y_norms2); + fvec_norms_L2sqr(y_norms2, y, d, ny); + y_norms = y_norms2; + } + + for (size_t i0 = 0; i0 < nx; i0 += bs_x) { + size_t i1 = i0 + bs_x; + if (i1 > nx) { + i1 = nx; + } + + res.begin_multiple(i0, i1); + + for (size_t j0 = 0; j0 < ny; j0 += bs_y) { + size_t j1 = j0 + bs_y; + if (j1 > ny) { + j1 = ny; + } + /* compute the actual dot products */ + { + float one = 1, zero = 0; + FINTEGER nyi = j1 - j0, nxi = i1 - i0, di = d; + sgemm_("Transpose", + "Not transpose", + &nyi, + &nxi, + &di, + &one, + y + j0 * d, + &di, + x + i0 * d, + &di, + &zero, + ip_block.get(), + &nyi); + } +#pragma omp parallel for + for (int64_t i = i0; i < i1; i++) { + float* ip_line = ip_block.get() + (i - i0) * (j1 - j0); + + _mm_prefetch((const char*)ip_line, _MM_HINT_NTA); + _mm_prefetch((const char*)(ip_line + 16), _MM_HINT_NTA); + + // constant + const __m512 mul_minus2 = _mm512_set1_ps(-2); + + // Track 16 min distances + 16 min indices. + // All the distances tracked do not take x_norms[i] + // into account in order to get rid of extra + // _mm512_add_ps(x_norms[i], ...) instructions + // in distance computations. + __m512 min_distances = + _mm512_set1_ps(res.dis_tab[i] - x_norms[i]); + + // these indices are local and are relative to j0. + // so, value 0 means j0. + __m512i min_indices = _mm512_set1_epi32(0); + + __m512i current_indices = _mm512_setr_epi32( + 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15); + const __m512i indices_delta = _mm512_set1_epi32(16); + + // current j index + size_t idx_j = 0; + size_t count = j1 - j0; + + // process 32 elements per loop + for (; idx_j < (count / 32) * 32; idx_j += 32, ip_line += 32) { + _mm_prefetch((const char*)(ip_line + 32), _MM_HINT_NTA); + _mm_prefetch((const char*)(ip_line + 48), _MM_HINT_NTA); + + // load values for norms + const __m512 y_norm_0 = + _mm512_loadu_ps(y_norms + idx_j + j0 + 0); + const __m512 y_norm_1 = + _mm512_loadu_ps(y_norms + idx_j + j0 + 16); + + // load values for dot products + const __m512 ip_0 = _mm512_loadu_ps(ip_line + 0); + const __m512 ip_1 = _mm512_loadu_ps(ip_line + 16); + + // compute dis = y_norm[j] - 2 * dot(x_norm[i], y_norm[j]). + // x_norm[i] was dropped off because it is a constant for a + // given i. We'll deal with it later. + __m512 distances_0 = + _mm512_fmadd_ps(ip_0, mul_minus2, y_norm_0); + __m512 distances_1 = + _mm512_fmadd_ps(ip_1, mul_minus2, y_norm_1); + + // compare the new distances to the min distances + // for each of the first group of 16 AVX512 components. + const __mmask16 comparison_0 = _mm512_cmp_ps_mask( + min_distances, distances_0, _CMP_LE_OS); + + // update min distances and indices with closest vectors if + // needed. + min_distances = _mm512_mask_blend_ps( + comparison_0, distances_0, min_distances); + min_indices = _mm512_mask_blend_epi32( + comparison_0, current_indices, min_indices); + current_indices = + _mm512_add_epi32(current_indices, indices_delta); + + // compare the new distances to the min distances + // for each of the second group of 16 AVX512 components. + const __mmask16 comparison_1 = _mm512_cmp_ps_mask( + min_distances, distances_1, _CMP_LE_OS); + + // update min distances and indices with closest vectors if + // needed. + min_distances = _mm512_mask_blend_ps( + comparison_1, distances_1, min_distances); + min_indices = _mm512_mask_blend_epi32( + comparison_1, current_indices, min_indices); + current_indices = + _mm512_add_epi32(current_indices, indices_delta); + } + + // dump values and find the minimum distance / minimum index + float min_distances_scalar[16]; + uint32_t min_indices_scalar[16]; + _mm512_storeu_ps(min_distances_scalar, min_distances); + _mm512_storeu_si512( + (__m512i*)(min_indices_scalar), min_indices); + + float current_min_distance = res.dis_tab[i]; + uint32_t current_min_index = res.ids_tab[i]; + + // This unusual comparison is needed to maintain the behavior + // of the original implementation: if two indices are + // represented with equal distance values, then + // the index with the min value is returned. + for (size_t jv = 0; jv < 16; jv++) { + // add missing x_norms[i] + float distance_candidate = + min_distances_scalar[jv] + x_norms[i]; + + // negative values can occur for identical vectors + // due to roundoff errors. + if (distance_candidate < 0) { + distance_candidate = 0; + } + + int64_t index_candidate = min_indices_scalar[jv] + j0; + + if (current_min_distance > distance_candidate) { + current_min_distance = distance_candidate; + current_min_index = index_candidate; + } else if ( + current_min_distance == distance_candidate && + current_min_index > index_candidate) { + current_min_index = index_candidate; + } + } + + // process leftovers + for (; idx_j < count; idx_j++, ip_line++) { + float ip = *ip_line; + float dis = x_norms[i] + y_norms[idx_j + j0] - 2 * ip; + // negative values can occur for identical vectors + // due to roundoff errors. + if (dis < 0) { + dis = 0; + } + + if (current_min_distance > dis) { + current_min_distance = dis; + current_min_index = idx_j + j0; + } + } + + res.add_result(i, current_min_distance, current_min_index); + } + } + // Does nothing for SingleBestResultHandler, but + // keeping the call for the consistency. + res.end_multiple(); + InterruptCallback::check(); + } +} +#elif defined(__AVX2__) void exhaustive_L2sqr_blas_cmax_avx2( const float* x, const float* y, @@ -761,7 +966,17 @@ void exhaustive_L2sqr_blas>>( size_t ny, Top1BlockResultHandler>& res, const float* y_norms) { -#if defined(__AVX2__) +#if defined(__AVX512F__) + // use a faster fused kernel if available + if (exhaustive_L2sqr_fused_cmax(x, y, d, nx, ny, res, y_norms)) { + // the kernel is available and it is complete, we're done. + return; + } + + // run the specialized AVX512 implementation + exhaustive_L2sqr_blas_cmax_avx512(x, y, d, nx, ny, res, y_norms); + +#elif defined(__AVX2__) // use a faster fused kernel if available if (exhaustive_L2sqr_fused_cmax(x, y, d, nx, ny, res, y_norms)) { // the kernel is available and it is complete, we're done.