Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
68 changes: 41 additions & 27 deletions faiss/IndexIVFPQ.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -817,8 +817,9 @@ struct RangeSearchResults {
* The scanning functions call their favorite precompute_*
* function to precompute the tables they need.
*****************************************************/
template <typename IDType, MetricType METRIC_TYPE, class PQDecoder>
template <typename IDType, MetricType METRIC_TYPE, class PQCodeDistance>
struct IVFPQScannerT : QueryTables {
using PQDecoder = typename PQCodeDistance::PQDecoder;
const uint8_t* list_codes;
const IDType* list_ids;
size_t list_size;
Expand Down Expand Up @@ -894,7 +895,7 @@ struct IVFPQScannerT : QueryTables {
float distance_1 = 0;
float distance_2 = 0;
float distance_3 = 0;
distance_four_codes<PQDecoder>(
PQCodeDistance::distance_four_codes(
pq.M,
pq.nbits,
sim_table,
Expand All @@ -917,7 +918,7 @@ struct IVFPQScannerT : QueryTables {

if (counter >= 1) {
float dis = dis0 +
distance_single_code<PQDecoder>(
PQCodeDistance::distance_single_code(
pq.M,
pq.nbits,
sim_table,
Expand All @@ -926,7 +927,7 @@ struct IVFPQScannerT : QueryTables {
}
if (counter >= 2) {
float dis = dis0 +
distance_single_code<PQDecoder>(
PQCodeDistance::distance_single_code(
pq.M,
pq.nbits,
sim_table,
Expand All @@ -935,7 +936,7 @@ struct IVFPQScannerT : QueryTables {
}
if (counter >= 3) {
float dis = dis0 +
distance_single_code<PQDecoder>(
PQCodeDistance::distance_single_code(
pq.M,
pq.nbits,
sim_table,
Expand Down Expand Up @@ -1101,7 +1102,7 @@ struct IVFPQScannerT : QueryTables {
float distance_1 = dis0;
float distance_2 = dis0;
float distance_3 = dis0;
distance_four_codes<PQDecoder>(
PQCodeDistance::distance_four_codes(
pq.M,
pq.nbits,
sim_table,
Expand Down Expand Up @@ -1132,7 +1133,7 @@ struct IVFPQScannerT : QueryTables {
n_hamming_pass++;

float dis = dis0 +
distance_single_code<PQDecoder>(
PQCodeDistance::distance_single_code(
pq.M,
pq.nbits,
sim_table,
Expand All @@ -1152,7 +1153,7 @@ struct IVFPQScannerT : QueryTables {
n_hamming_pass++;

float dis = dis0 +
distance_single_code<PQDecoder>(
PQCodeDistance::distance_single_code(
pq.M,
pq.nbits,
sim_table,
Expand Down Expand Up @@ -1197,8 +1198,8 @@ struct IVFPQScannerT : QueryTables {
*
* use_sel: store or ignore the IDSelector
*/
template <MetricType METRIC_TYPE, class C, class PQDecoder, bool use_sel>
struct IVFPQScanner : IVFPQScannerT<idx_t, METRIC_TYPE, PQDecoder>,
template <MetricType METRIC_TYPE, class C, class PQCodeDistance, bool use_sel>
struct IVFPQScanner : IVFPQScannerT<idx_t, METRIC_TYPE, PQCodeDistance>,
InvertedListScanner {
int precompute_mode;
const IDSelector* sel;
Expand All @@ -1208,7 +1209,7 @@ struct IVFPQScanner : IVFPQScannerT<idx_t, METRIC_TYPE, PQDecoder>,
bool store_pairs,
int precompute_mode,
const IDSelector* sel)
: IVFPQScannerT<idx_t, METRIC_TYPE, PQDecoder>(ivfpq, nullptr),
: IVFPQScannerT<idx_t, METRIC_TYPE, PQCodeDistance>(ivfpq, nullptr),
precompute_mode(precompute_mode),
sel(sel) {
this->store_pairs = store_pairs;
Expand All @@ -1228,7 +1229,7 @@ struct IVFPQScanner : IVFPQScannerT<idx_t, METRIC_TYPE, PQDecoder>,
float distance_to_code(const uint8_t* code) const override {
assert(precompute_mode == 2);
float dis = this->dis0 +
distance_single_code<PQDecoder>(
PQCodeDistance::distance_single_code(
this->pq.M, this->pq.nbits, this->sim_table, code);
return dis;
}
Expand Down Expand Up @@ -1292,7 +1293,9 @@ struct IVFPQScanner : IVFPQScannerT<idx_t, METRIC_TYPE, PQDecoder>,
}
};

template <class PQDecoder, bool use_sel>
/** follow 3 stages of template dispatching */

template <class PQCodeDistance, bool use_sel>
InvertedListScanner* get_InvertedListScanner1(
const IndexIVFPQ& index,
bool store_pairs,
Expand All @@ -1301,32 +1304,47 @@ InvertedListScanner* get_InvertedListScanner1(
return new IVFPQScanner<
METRIC_INNER_PRODUCT,
CMin<float, idx_t>,
PQDecoder,
PQCodeDistance,
use_sel>(index, store_pairs, 2, sel);
} else if (index.metric_type == METRIC_L2) {
return new IVFPQScanner<
METRIC_L2,
CMax<float, idx_t>,
PQDecoder,
PQCodeDistance,
use_sel>(index, store_pairs, 2, sel);
}
return nullptr;
}

template <bool use_sel>
template <bool use_sel, SIMDLevel SL>
InvertedListScanner* get_InvertedListScanner2(
const IndexIVFPQ& index,
bool store_pairs,
const IDSelector* sel) {
if (index.pq.nbits == 8) {
return get_InvertedListScanner1<PQDecoder8, use_sel>(
index, store_pairs, sel);
return get_InvertedListScanner1<
PQCodeDistance<PQDecoder8, SL>,
use_sel>(index, store_pairs, sel);
} else if (index.pq.nbits == 16) {
return get_InvertedListScanner1<PQDecoder16, use_sel>(
index, store_pairs, sel);
return get_InvertedListScanner1<
PQCodeDistance<PQDecoder16, SL>,
use_sel>(index, store_pairs, sel);
} else {
return get_InvertedListScanner1<
PQCodeDistance<PQDecoderGeneric, SL>,
use_sel>(index, store_pairs, sel);
}
}

template <SIMDLevel SL>
InvertedListScanner* get_InvertedListScanner3(
const IndexIVFPQ& index,
bool store_pairs,
const IDSelector* sel) {
if (sel) {
return get_InvertedListScanner2<true, SL>(index, store_pairs, sel);
} else {
return get_InvertedListScanner1<PQDecoderGeneric, use_sel>(
index, store_pairs, sel);
return get_InvertedListScanner2<false, SL>(index, store_pairs, sel);
}
}

Expand All @@ -1336,11 +1354,7 @@ InvertedListScanner* IndexIVFPQ::get_InvertedListScanner(
bool store_pairs,
const IDSelector* sel,
const IVFSearchParameters*) const {
if (sel) {
return get_InvertedListScanner2<true>(*this, store_pairs, sel);
} else {
return get_InvertedListScanner2<false>(*this, store_pairs, sel);
}
DISPATCH_SIMDLevel(get_InvertedListScanner3, *this, store_pairs, sel);
return nullptr;
}

Expand Down
32 changes: 21 additions & 11 deletions faiss/IndexPQ.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,7 @@ void IndexPQ::train(idx_t n, const float* x) {

namespace {

template <class PQDecoder>
template <class PQCodeDistance>
struct PQDistanceComputer : FlatCodesDistanceComputer {
size_t d;
MetricType metric;
Expand All @@ -85,7 +85,7 @@ struct PQDistanceComputer : FlatCodesDistanceComputer {
float distance_to_code(const uint8_t* code) final {
ndis++;

float dis = distance_single_code<PQDecoder>(
float dis = PQCodeDistance::distance_single_code(
pq.M, pq.nbits, precomputed_table.data(), code);
return dis;
}
Expand All @@ -94,8 +94,10 @@ struct PQDistanceComputer : FlatCodesDistanceComputer {
FAISS_THROW_IF_NOT(sdc);
const float* sdci = sdc;
float accu = 0;
PQDecoder codei(codes + i * code_size, pq.nbits);
PQDecoder codej(codes + j * code_size, pq.nbits);
typename PQCodeDistance::PQDecoder codei(
codes + i * code_size, pq.nbits);
typename PQCodeDistance::PQDecoder codej(
codes + j * code_size, pq.nbits);

for (int l = 0; l < pq.M; l++) {
accu += sdci[codei.decode() + (codej.decode() << codei.nbits)];
Expand Down Expand Up @@ -131,16 +133,24 @@ struct PQDistanceComputer : FlatCodesDistanceComputer {
}
};

template <SIMDLevel SL>
FlatCodesDistanceComputer* get_FlatCodesDistanceComputer1(
const IndexPQ& index) {
int nbits = index.pq.nbits;
if (nbits == 8) {
return new PQDistanceComputer<PQCodeDistance<PQDecoder8, SL>>(index);
} else if (nbits == 16) {
return new PQDistanceComputer<PQCodeDistance<PQDecoder16, SL>>(index);
} else {
return new PQDistanceComputer<PQCodeDistance<PQDecoderGeneric, SL>>(
index);
}
}

} // namespace

FlatCodesDistanceComputer* IndexPQ::get_FlatCodesDistanceComputer() const {
if (pq.nbits == 8) {
return new PQDistanceComputer<PQDecoder8>(*this);
} else if (pq.nbits == 16) {
return new PQDistanceComputer<PQDecoder16>(*this);
} else {
return new PQDistanceComputer<PQDecoderGeneric>(*this);
}
DISPATCH_SIMDLevel(get_FlatCodesDistanceComputer1, *this);
}

/*****************************************
Expand Down
Loading
Loading