Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
219 changes: 217 additions & 2 deletions faiss/utils/distances.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -361,7 +361,212 @@ void exhaustive_L2sqr_blas(
exhaustive_L2sqr_blas_default_impl(x, y, d, nx, ny, res);
}

#ifdef __AVX2__
#if defined(__AVX512F__)
void exhaustive_L2sqr_blas_cmax_avx512(
const float* x,
const float* y,
size_t d,
size_t nx,
size_t ny,
Top1BlockResultHandler<CMax<float, int64_t>>& res,
const float* y_norms) {
// BLAS does not like empty matrices
if (nx == 0 || ny == 0) {
return;
}

/* block sizes */
const size_t bs_x = distance_compute_blas_query_bs;
const size_t bs_y = distance_compute_blas_database_bs;
std::unique_ptr<float[]> ip_block(new float[bs_x * bs_y]);
std::unique_ptr<float[]> x_norms(new float[nx]);
std::unique_ptr<float[]> del2;

fvec_norms_L2sqr(x_norms.get(), x, d, nx);

if (!y_norms) {
float* y_norms2 = new float[ny];
del2.reset(y_norms2);
fvec_norms_L2sqr(y_norms2, y, d, ny);
y_norms = y_norms2;
}

for (size_t i0 = 0; i0 < nx; i0 += bs_x) {
size_t i1 = i0 + bs_x;
if (i1 > nx) {
i1 = nx;
}

res.begin_multiple(i0, i1);

for (size_t j0 = 0; j0 < ny; j0 += bs_y) {
size_t j1 = j0 + bs_y;
if (j1 > ny) {
j1 = ny;
}
/* compute the actual dot products */
{
float one = 1, zero = 0;
FINTEGER nyi = j1 - j0, nxi = i1 - i0, di = d;
sgemm_("Transpose",
"Not transpose",
&nyi,
&nxi,
&di,
&one,
y + j0 * d,
&di,
x + i0 * d,
&di,
&zero,
ip_block.get(),
&nyi);
}
#pragma omp parallel for
for (int64_t i = i0; i < i1; i++) {
float* ip_line = ip_block.get() + (i - i0) * (j1 - j0);

_mm_prefetch((const char*)ip_line, _MM_HINT_NTA);
_mm_prefetch((const char*)(ip_line + 16), _MM_HINT_NTA);

// constant
const __m512 mul_minus2 = _mm512_set1_ps(-2);

// Track 16 min distances + 16 min indices.
// All the distances tracked do not take x_norms[i]
// into account in order to get rid of extra
// _mm512_add_ps(x_norms[i], ...) instructions
// in distance computations.
__m512 min_distances =
_mm512_set1_ps(res.dis_tab[i] - x_norms[i]);

// these indices are local and are relative to j0.
// so, value 0 means j0.
__m512i min_indices = _mm512_set1_epi32(0);

__m512i current_indices = _mm512_setr_epi32(
0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15);
const __m512i indices_delta = _mm512_set1_epi32(16);

// current j index
size_t idx_j = 0;
size_t count = j1 - j0;

// process 32 elements per loop
for (; idx_j < (count / 32) * 32; idx_j += 32, ip_line += 32) {
_mm_prefetch((const char*)(ip_line + 32), _MM_HINT_NTA);
_mm_prefetch((const char*)(ip_line + 48), _MM_HINT_NTA);

// load values for norms
const __m512 y_norm_0 =
_mm512_loadu_ps(y_norms + idx_j + j0 + 0);
const __m512 y_norm_1 =
_mm512_loadu_ps(y_norms + idx_j + j0 + 16);

// load values for dot products
const __m512 ip_0 = _mm512_loadu_ps(ip_line + 0);
const __m512 ip_1 = _mm512_loadu_ps(ip_line + 16);

// compute dis = y_norm[j] - 2 * dot(x_norm[i], y_norm[j]).
// x_norm[i] was dropped off because it is a constant for a
// given i. We'll deal with it later.
__m512 distances_0 =
_mm512_fmadd_ps(ip_0, mul_minus2, y_norm_0);
__m512 distances_1 =
_mm512_fmadd_ps(ip_1, mul_minus2, y_norm_1);

// compare the new distances to the min distances
// for each of the first group of 16 AVX512 components.
const __mmask16 comparison_0 = _mm512_cmp_ps_mask(
min_distances, distances_0, _CMP_LE_OS);

// update min distances and indices with closest vectors if
// needed.
min_distances = _mm512_mask_blend_ps(
comparison_0, distances_0, min_distances);
min_indices = _mm512_mask_blend_epi32(
comparison_0, current_indices, min_indices);
current_indices =
_mm512_add_epi32(current_indices, indices_delta);

// compare the new distances to the min distances
// for each of the second group of 16 AVX512 components.
const __mmask16 comparison_1 = _mm512_cmp_ps_mask(
min_distances, distances_1, _CMP_LE_OS);

// update min distances and indices with closest vectors if
// needed.
min_distances = _mm512_mask_blend_ps(
comparison_1, distances_1, min_distances);
min_indices = _mm512_mask_blend_epi32(
comparison_1, current_indices, min_indices);
current_indices =
_mm512_add_epi32(current_indices, indices_delta);
}

// dump values and find the minimum distance / minimum index
float min_distances_scalar[16];
uint32_t min_indices_scalar[16];
_mm512_storeu_ps(min_distances_scalar, min_distances);
_mm512_storeu_si512(
(__m512i*)(min_indices_scalar), min_indices);

float current_min_distance = res.dis_tab[i];
uint32_t current_min_index = res.ids_tab[i];

// This unusual comparison is needed to maintain the behavior
// of the original implementation: if two indices are
// represented with equal distance values, then
// the index with the min value is returned.
for (size_t jv = 0; jv < 16; jv++) {
// add missing x_norms[i]
float distance_candidate =
min_distances_scalar[jv] + x_norms[i];

// negative values can occur for identical vectors
// due to roundoff errors.
if (distance_candidate < 0) {
distance_candidate = 0;
}

int64_t index_candidate = min_indices_scalar[jv] + j0;

if (current_min_distance > distance_candidate) {
current_min_distance = distance_candidate;
current_min_index = index_candidate;
} else if (
current_min_distance == distance_candidate &&
current_min_index > index_candidate) {
current_min_index = index_candidate;
}
}

// process leftovers
for (; idx_j < count; idx_j++, ip_line++) {
float ip = *ip_line;
float dis = x_norms[i] + y_norms[idx_j + j0] - 2 * ip;
// negative values can occur for identical vectors
// due to roundoff errors.
if (dis < 0) {
dis = 0;
}

if (current_min_distance > dis) {
current_min_distance = dis;
current_min_index = idx_j + j0;
}
}

res.add_result(i, current_min_distance, current_min_index);
}
}
// Does nothing for SingleBestResultHandler, but
// keeping the call for the consistency.
res.end_multiple();
InterruptCallback::check();
}
}
#elif defined(__AVX2__)
void exhaustive_L2sqr_blas_cmax_avx2(
const float* x,
const float* y,
Expand Down Expand Up @@ -761,7 +966,17 @@ void exhaustive_L2sqr_blas<Top1BlockResultHandler<CMax<float, int64_t>>>(
size_t ny,
Top1BlockResultHandler<CMax<float, int64_t>>& res,
const float* y_norms) {
#if defined(__AVX2__)
#if defined(__AVX512F__)
// use a faster fused kernel if available
if (exhaustive_L2sqr_fused_cmax(x, y, d, nx, ny, res, y_norms)) {
// the kernel is available and it is complete, we're done.
return;
}

// run the specialized AVX512 implementation
exhaustive_L2sqr_blas_cmax_avx512(x, y, d, nx, ny, res, y_norms);

#elif defined(__AVX2__)
// use a faster fused kernel if available
if (exhaustive_L2sqr_fused_cmax(x, y, d, nx, ny, res, y_norms)) {
// the kernel is available and it is complete, we're done.
Expand Down
Loading