diff --git a/benchs/bench_fw/optimize.py b/benchs/bench_fw/optimize.py index b3d62980c3..ac6c45ab0c 100644 --- a/benchs/bench_fw/optimize.py +++ b/benchs/bench_fw/optimize.py @@ -228,6 +228,7 @@ def optimize_codec( (None, "SQfp16"), (None, "SQbf16"), (None, "SQ8"), + (None, "SQ8_direct_signed"), ] + [ (f"OPQ{M}_{M * dim}", f"PQ{M}x{b}") for M in [8, 12, 16, 32, 48, 64, 96, 128, 192, 256] diff --git a/c_api/IndexScalarQuantizer_c.h b/c_api/IndexScalarQuantizer_c.h index 87fe6d3415..55a2676d22 100644 --- a/c_api/IndexScalarQuantizer_c.h +++ b/c_api/IndexScalarQuantizer_c.h @@ -27,6 +27,8 @@ typedef enum FaissQuantizerType { QT_8bit_direct, ///< fast indexing of uint8s QT_6bit, ///< 6 bits per component QT_bf16, + QT_8bit_direct_signed, ///< fast indexing of signed int8s ranging from [-128 + ///< to 127] } FaissQuantizerType; // forward declaration diff --git a/faiss/IndexScalarQuantizer.cpp b/faiss/IndexScalarQuantizer.cpp index 7ce838db5e..44a628f000 100644 --- a/faiss/IndexScalarQuantizer.cpp +++ b/faiss/IndexScalarQuantizer.cpp @@ -33,7 +33,8 @@ IndexScalarQuantizer::IndexScalarQuantizer( : IndexFlatCodes(0, d, metric), sq(d, qtype) { is_trained = qtype == ScalarQuantizer::QT_fp16 || qtype == ScalarQuantizer::QT_8bit_direct || - qtype == ScalarQuantizer::QT_bf16; + qtype == ScalarQuantizer::QT_bf16 || + qtype == ScalarQuantizer::QT_8bit_direct_signed; code_size = sq.code_size; } diff --git a/faiss/impl/ScalarQuantizer.cpp b/faiss/impl/ScalarQuantizer.cpp index 7ad50189e4..528843f606 100644 --- a/faiss/impl/ScalarQuantizer.cpp +++ b/faiss/impl/ScalarQuantizer.cpp @@ -621,13 +621,90 @@ struct Quantizer8bitDirect<8> : Quantizer8bitDirect<1> { FAISS_ALWAYS_INLINE float32x4x2_t reconstruct_8_components(const uint8_t* code, int i) const { - float32_t result[8] = {}; - for (size_t j = 0; j < 8; j++) { - result[j] = code[i + j]; + uint8x8_t x8 = vld1_u8((const uint8_t*)(code + i)); + uint16x8_t y8 = vmovl_u8(x8); + uint16x4_t y8_0 = vget_low_u16(y8); + uint16x4_t y8_1 = vget_high_u16(y8); + + // convert uint16 -> uint32 -> fp32 + return {vcvtq_f32_u32(vmovl_u16(y8_0)), vcvtq_f32_u32(vmovl_u16(y8_1))}; + } +}; + +#endif + +/******************************************************************* + * 8bit_direct_signed quantizer + *******************************************************************/ + +template +struct Quantizer8bitDirectSigned {}; + +template <> +struct Quantizer8bitDirectSigned<1> : ScalarQuantizer::SQuantizer { + const size_t d; + + Quantizer8bitDirectSigned(size_t d, const std::vector& /* unused */) + : d(d) {} + + void encode_vector(const float* x, uint8_t* code) const final { + for (size_t i = 0; i < d; i++) { + code[i] = (uint8_t)(x[i] + 128); } - float32x4_t res1 = vld1q_f32(result); - float32x4_t res2 = vld1q_f32(result + 4); - return {res1, res2}; + } + + void decode_vector(const uint8_t* code, float* x) const final { + for (size_t i = 0; i < d; i++) { + x[i] = code[i] - 128; + } + } + + FAISS_ALWAYS_INLINE float reconstruct_component(const uint8_t* code, int i) + const { + return code[i] - 128; + } +}; + +#ifdef __AVX2__ + +template <> +struct Quantizer8bitDirectSigned<8> : Quantizer8bitDirectSigned<1> { + Quantizer8bitDirectSigned(size_t d, const std::vector& trained) + : Quantizer8bitDirectSigned<1>(d, trained) {} + + FAISS_ALWAYS_INLINE __m256 + reconstruct_8_components(const uint8_t* code, int i) const { + __m128i x8 = _mm_loadl_epi64((__m128i*)(code + i)); // 8 * int8 + __m256i y8 = _mm256_cvtepu8_epi32(x8); // 8 * int32 + __m256i c8 = _mm256_set1_epi32(128); + __m256i z8 = _mm256_sub_epi32(y8, c8); // subtract 128 from all lanes + return _mm256_cvtepi32_ps(z8); // 8 * float32 + } +}; + +#endif + +#ifdef __aarch64__ + +template <> +struct Quantizer8bitDirectSigned<8> : Quantizer8bitDirectSigned<1> { + Quantizer8bitDirectSigned(size_t d, const std::vector& trained) + : Quantizer8bitDirectSigned<1>(d, trained) {} + + FAISS_ALWAYS_INLINE float32x4x2_t + reconstruct_8_components(const uint8_t* code, int i) const { + uint8x8_t x8 = vld1_u8((const uint8_t*)(code + i)); + uint16x8_t y8 = vmovl_u8(x8); // convert uint8 -> uint16 + uint16x4_t y8_0 = vget_low_u16(y8); + uint16x4_t y8_1 = vget_high_u16(y8); + + float32x4_t z8_0 = vcvtq_f32_u32( + vmovl_u16(y8_0)); // convert uint16 -> uint32 -> fp32 + float32x4_t z8_1 = vcvtq_f32_u32(vmovl_u16(y8_1)); + + // subtract 128 to convert into signed numbers + return {vsubq_f32(z8_0, vmovq_n_f32(128.0)), + vsubq_f32(z8_1, vmovq_n_f32(128.0))}; } }; @@ -660,6 +737,8 @@ ScalarQuantizer::SQuantizer* select_quantizer_1( return new QuantizerBF16(d, trained); case ScalarQuantizer::QT_8bit_direct: return new Quantizer8bitDirect(d, trained); + case ScalarQuantizer::QT_8bit_direct_signed: + return new Quantizer8bitDirectSigned(d, trained); } FAISS_THROW_MSG("unknown qtype"); } @@ -1460,6 +1539,11 @@ SQDistanceComputer* select_distance_computer( Sim, SIMDWIDTH>(d, trained); } + case ScalarQuantizer::QT_8bit_direct_signed: + return new DCTemplate< + Quantizer8bitDirectSigned, + Sim, + SIMDWIDTH>(d, trained); } FAISS_THROW_MSG("unknown qtype"); return nullptr; @@ -1483,6 +1567,7 @@ void ScalarQuantizer::set_derived_sizes() { case QT_8bit: case QT_8bit_uniform: case QT_8bit_direct: + case QT_8bit_direct_signed: code_size = d; bits = 8; break; @@ -1540,6 +1625,7 @@ void ScalarQuantizer::train(size_t n, const float* x) { case QT_fp16: case QT_8bit_direct: case QT_bf16: + case QT_8bit_direct_signed: // no training necessary break; } @@ -1885,6 +1971,11 @@ InvertedListScanner* sel1_InvertedListScanner( Similarity, SIMDWIDTH>>(sq, quantizer, store_pairs, sel, r); } + case ScalarQuantizer::QT_8bit_direct_signed: + return sel2_InvertedListScanner, + Similarity, + SIMDWIDTH>>(sq, quantizer, store_pairs, sel, r); } FAISS_THROW_MSG("unknown qtype"); diff --git a/faiss/impl/ScalarQuantizer.h b/faiss/impl/ScalarQuantizer.h index 49fd42cc31..904e6f6b60 100644 --- a/faiss/impl/ScalarQuantizer.h +++ b/faiss/impl/ScalarQuantizer.h @@ -33,6 +33,8 @@ struct ScalarQuantizer : Quantizer { QT_8bit_direct, ///< fast indexing of uint8s QT_6bit, ///< 6 bits per component QT_bf16, + QT_8bit_direct_signed, ///< fast indexing of signed int8s ranging from + ///< [-128 to 127] }; QuantizerType qtype = QT_8bit; diff --git a/faiss/index_factory.cpp b/faiss/index_factory.cpp index d88fe7b393..564a164e79 100644 --- a/faiss/index_factory.cpp +++ b/faiss/index_factory.cpp @@ -141,8 +141,11 @@ std::map sq_types = { {"SQ6", ScalarQuantizer::QT_6bit}, {"SQfp16", ScalarQuantizer::QT_fp16}, {"SQbf16", ScalarQuantizer::QT_bf16}, + {"SQ8_direct_signed", ScalarQuantizer::QT_8bit_direct_signed}, + {"SQ8_direct", ScalarQuantizer::QT_8bit_direct}, }; -const std::string sq_pattern = "(SQ4|SQ8|SQ6|SQfp16|SQbf16)"; +const std::string sq_pattern = + "(SQ4|SQ8|SQ6|SQfp16|SQbf16|SQ8_direct_signed|SQ8_direct)"; std::map aq_search_type = { {"_Nfloat", AdditiveQuantizer::ST_norm_float}, diff --git a/tests/test_index_accuracy.py b/tests/test_index_accuracy.py index 8d8b4a28f6..2c5cf7b901 100644 --- a/tests/test_index_accuracy.py +++ b/tests/test_index_accuracy.py @@ -312,7 +312,7 @@ def test_parallel_mode(self): class TestSQByte(unittest.TestCase): - def subtest_8bit_direct(self, metric_type, d): + def subtest_8bit_direct(self, metric_type, d, quantizer_type): xt, xb, xq = get_dataset_2(d, 500, 1000, 30) # rescale everything to get integer @@ -324,16 +324,28 @@ def rescale(x): x[x > 255] = 255 return x - xt = rescale(xt) - xb = rescale(xb) - xq = rescale(xq) + def rescale_signed(x): + x = np.floor((x - tmin) * 256 / (tmax - tmin)) + x[x < 0] = 0 + x[x > 255] = 255 + x -= 128 + return x + + if quantizer_type == faiss.ScalarQuantizer.QT_8bit_direct_signed: + xt = rescale_signed(xt) + xb = rescale_signed(xb) + xq = rescale_signed(xq) + else: + xt = rescale(xt) + xb = rescale(xb) + xq = rescale(xq) gt_index = faiss.IndexFlat(d, metric_type) gt_index.add(xb) Dref, Iref = gt_index.search(xq, 10) index = faiss.IndexScalarQuantizer( - d, faiss.ScalarQuantizer.QT_8bit_direct, metric_type + d, quantizer_type, metric_type ) index.add(xb) D, I = index.search(xq, 10) @@ -353,7 +365,7 @@ def rescale(x): Dref, Iref = gt_index.search(xq, 10) index = faiss.IndexIVFScalarQuantizer( - quantizer, d, nlist, faiss.ScalarQuantizer.QT_8bit_direct, + quantizer, d, nlist, quantizer_type, metric_type ) index.nprobe = 4 @@ -366,9 +378,10 @@ def rescale(x): assert np.all(D == Dref) def test_8bit_direct(self): - for d in 13, 16, 24: - for metric_type in faiss.METRIC_L2, faiss.METRIC_INNER_PRODUCT: - self.subtest_8bit_direct(metric_type, d) + for quantizer in faiss.ScalarQuantizer.QT_8bit_direct, faiss.ScalarQuantizer.QT_8bit_direct_signed: + for d in 13, 16, 24: + for metric_type in faiss.METRIC_L2, faiss.METRIC_INNER_PRODUCT: + self.subtest_8bit_direct(metric_type, d, quantizer) class TestNNDescent(unittest.TestCase):