Skip to content

Commit

Permalink
Remove f16c flag in utils_sse (#814)
Browse files Browse the repository at this point in the history
Signed-off-by: cqy123456 <[email protected]>
  • Loading branch information
cqy123456 authored Sep 6, 2024
1 parent 3f82179 commit ed4cf98
Show file tree
Hide file tree
Showing 5 changed files with 5 additions and 93 deletions.
7 changes: 1 addition & 6 deletions cmake/libs/libfaiss.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -34,12 +34,7 @@ if(__X86_64)
add_library(utils_avx OBJECT ${UTILS_AVX_SRC})
add_library(utils_avx512 OBJECT ${UTILS_AVX512_SRC})

check_cxx_compiler_flag("-mf16c" COMPILER_SUPPORTS_F16C)
if(COMPILER_SUPPORTS_F16C)
target_compile_options(utils_sse PRIVATE -msse4.2 -mpopcnt -mf16c)
else()
target_compile_options(utils_sse PRIVATE -msse4.2 -mpopcnt)
endif()
target_compile_options(utils_sse PRIVATE -msse4.2 -mpopcnt)
target_compile_options(utils_avx PRIVATE -mfma -mf16c -mavx2 -mpopcnt)
target_compile_options(utils_avx512 PRIVATE -mfma -mf16c -mavx512f -mavx512dq
-mavx512bw -mpopcnt -mavx512vl)
Expand Down
62 changes: 0 additions & 62 deletions src/simd/distances_sse.cc
Original file line number Diff line number Diff line change
Expand Up @@ -65,24 +65,6 @@ fvec_norm_L2sqr_sse(const float* x, size_t d) {
return _mm_cvtss_f32(msum1);
}

float
fp16_vec_norm_L2sqr_sse(const knowhere::fp16* x, size_t d) {
__m128 m_res = _mm_setzero_ps();
while (d >= 4) {
__m128 m_x = _mm_cvtph_ps(_mm_loadl_epi64((const __m128i_u*)x));
m_res = _mm_add_ps(m_res, _mm_mul_ps(m_x, m_x));
x += 4;
d -= 4;
}
if (d > 0) {
__m128 m_x = _mm_cvtph_ps(mm_masked_read_short(d, (uint16_t*)x));
m_res = _mm_add_ps(m_res, _mm_mul_ps(m_x, m_x));
}
m_res = _mm_hadd_ps(m_res, m_res);
m_res = _mm_hadd_ps(m_res, m_res);
return _mm_cvtss_f32(m_res);
}

float
bf16_vec_norm_L2sqr_sse(const knowhere::bf16* x, size_t d) {
__m128 m_res = _mm_setzero_ps();
Expand Down Expand Up @@ -315,29 +297,6 @@ fvec_L2sqr_sse(const float* x, const float* y, size_t d) {
return _mm_cvtss_f32(msum1);
}

float
fp16_vec_L2sqr_sse(const knowhere::fp16* x, const knowhere::fp16* y, size_t d) {
__m128 m_res = _mm_setzero_ps();
while (d >= 4) {
__m128 m_x = _mm_cvtph_ps(_mm_loadl_epi64((const __m128i_u*)x));
__m128 m_y = _mm_cvtph_ps(_mm_loadl_epi64((const __m128i_u*)y));
m_x = _mm_sub_ps(m_x, m_y);
m_res = _mm_add_ps(m_res, _mm_mul_ps(m_x, m_x));
x += 4;
y += 4;
d -= 4;
}
if (d > 0) {
__m128 m_x = _mm_cvtph_ps(mm_masked_read_short(d, (uint16_t*)x));
__m128 m_y = _mm_cvtph_ps(mm_masked_read_short(d, (uint16_t*)y));
m_x = _mm_sub_ps(m_x, m_y);
m_res = _mm_add_ps(m_res, _mm_mul_ps(m_x, m_x));
}
m_res = _mm_hadd_ps(m_res, m_res);
m_res = _mm_hadd_ps(m_res, m_res);
return _mm_cvtss_f32(m_res);
}

float
bf16_vec_L2sqr_sse(const knowhere::bf16* x, const knowhere::bf16* y, size_t d) {
__m128 m_res = _mm_setzero_ps();
Expand Down Expand Up @@ -387,27 +346,6 @@ fvec_inner_product_sse(const float* x, const float* y, size_t d) {
return _mm_cvtss_f32(msum1);
}

float
fp16_vec_inner_product_sse(const knowhere::fp16* x, const knowhere::fp16* y, size_t d) {
__m128 m_res = _mm_setzero_ps();
while (d >= 4) {
__m128 m_x = _mm_cvtph_ps(_mm_loadl_epi64((const __m128i*)x));
__m128 m_y = _mm_cvtph_ps(_mm_loadl_epi64((const __m128i*)y));
m_res = _mm_add_ps(m_res, _mm_mul_ps(m_x, m_y));
x += 4;
y += 4;
d -= 4;
}
if (d > 0) {
__m128 m_x = _mm_cvtph_ps(mm_masked_read_short(d, (uint16_t*)x));
__m128 m_y = _mm_cvtph_ps(mm_masked_read_short(d, (uint16_t*)y));
m_res = _mm_add_ps(m_res, _mm_mul_ps(m_x, m_y));
}
m_res = _mm_hadd_ps(m_res, m_res);
m_res = _mm_hadd_ps(m_res, m_res);
return _mm_cvtss_f32(m_res);
}

float
bf16_vec_inner_product_sse(const knowhere::bf16* x, const knowhere::bf16* y, size_t d) {
__m128 m_res = _mm_setzero_ps();
Expand Down
9 changes: 0 additions & 9 deletions src/simd/distances_sse.h
Original file line number Diff line number Diff line change
Expand Up @@ -22,19 +22,13 @@ namespace faiss {
float
fvec_L2sqr_sse(const float* x, const float* y, size_t d);

float
fp16_vec_L2sqr_sse(const knowhere::fp16* x, const knowhere::fp16* y, size_t d);

float
bf16_vec_L2sqr_sse(const knowhere::bf16* x, const knowhere::bf16* y, size_t d);

/// inner product
float
fvec_inner_product_sse(const float* x, const float* y, size_t d);

float
fp16_vec_inner_product_sse(const knowhere::fp16* x, const knowhere::fp16* y, size_t d);

float
bf16_vec_inner_product_sse(const knowhere::bf16* x, const knowhere::bf16* y, size_t d);

Expand All @@ -49,9 +43,6 @@ fvec_Linf_sse(const float* x, const float* y, size_t d);
float
fvec_norm_L2sqr_sse(const float* x, size_t d);

float
fp16_vec_norm_L2sqr_sse(const knowhere::fp16* x, size_t d);

float
bf16_vec_norm_L2sqr_sse(const knowhere::bf16* x, size_t d);

Expand Down
13 changes: 4 additions & 9 deletions src/simd/hook.cc
Original file line number Diff line number Diff line change
Expand Up @@ -240,15 +240,10 @@ fvec_hook(std::string& simd_type) {
bf16_vec_L2sqr = bf16_vec_L2sqr_sse;
bf16_vec_norm_L2sqr = bf16_vec_norm_L2sqr_sse;

if (cpu_support_f16c()) {
fp16_vec_inner_product = fp16_vec_inner_product_sse;
fp16_vec_L2sqr = fp16_vec_L2sqr_sse;
fp16_vec_norm_L2sqr = fp16_vec_norm_L2sqr_sse;
} else {
fp16_vec_inner_product = fp16_vec_inner_product_ref;
fp16_vec_L2sqr = fp16_vec_L2sqr_ref;
fp16_vec_norm_L2sqr = fp16_vec_norm_L2sqr_ref;
}
fp16_vec_inner_product = fp16_vec_inner_product_ref;
fp16_vec_L2sqr = fp16_vec_L2sqr_ref;
fp16_vec_norm_L2sqr = fp16_vec_norm_L2sqr_ref;

simd_type = "SSE4_2";
support_pq_fast_scan = false;
} else {
Expand Down
7 changes: 0 additions & 7 deletions tests/ut/test_simd.cc
Original file line number Diff line number Diff line change
Expand Up @@ -153,13 +153,6 @@ TEST_CASE("Test fp16 distance", "[fp16]") {
REQUIRE_THAT(faiss::fp16_vec_norm_L2sqr_neon(x.get(), dim), Catch::Matchers::WithinRel(ref_norm_l2_dist, 0.001f));
#endif
#if defined(__x86_64__)
if (faiss::cpu_support_sse4_2()) {
REQUIRE_THAT(faiss::fp16_vec_L2sqr_sse(x.get(), y.get(), dim), Catch::Matchers::WithinRel(ref_l2_dist, 0.001f));
REQUIRE_THAT(faiss::fp16_vec_inner_product_sse(x.get(), y.get(), dim),
Catch::Matchers::WithinRel(ref_ip_dist, 0.001f));
REQUIRE_THAT(faiss::fp16_vec_norm_L2sqr_sse(x.get(), dim),
Catch::Matchers::WithinRel(ref_norm_l2_dist, 0.001f));
}
if (faiss::cpu_support_avx2()) {
REQUIRE_THAT(faiss::fp16_vec_L2sqr_avx(x.get(), y.get(), dim), Catch::Matchers::WithinRel(ref_l2_dist, 0.001f));
REQUIRE_THAT(faiss::fp16_vec_inner_product_avx(x.get(), y.get(), dim),
Expand Down

0 comments on commit ed4cf98

Please sign in to comment.