diff --git a/faiss/IndexLSH.cpp b/faiss/IndexLSH.cpp index a2d29f8173..d84ee948a3 100644 --- a/faiss/IndexLSH.cpp +++ b/faiss/IndexLSH.cpp @@ -86,12 +86,14 @@ void IndexLSH::train(idx_t n, const float* x) { for (idx_t i = 0; i < nbits; i++) { float* xi = transposed_x.get() + i * n; - // std::nth_element - std::sort(xi, xi + n); - if (n % 2 == 1) - thresholds[i] = xi[n / 2]; - else - thresholds[i] = (xi[n / 2 - 1] + xi[n / 2]) / 2; + // Use nth_element (O(n)) instead of sort (O(n log n)) + std::nth_element(xi, xi + n / 2, xi + n); + float median = xi[n / 2]; + if (n % 2 == 0) { + std::nth_element(xi, xi + n / 2 - 1, xi + n); + median = (median + xi[n / 2 - 1]) / 2; + } + thresholds[i] = median; } } is_trained = true;