Skip to content

Commit

Permalink
vectorized
Browse files Browse the repository at this point in the history
  • Loading branch information
DiamonDinoia committed Sep 12, 2024
1 parent 7aa45ec commit 8da21c2
Showing 1 changed file with 35 additions and 30 deletions.
65 changes: 35 additions & 30 deletions src/spreadinterp.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1889,13 +1889,24 @@ void bin_sort_singlethread_vector(
static constexpr auto simd_size = simd_type::size;
static constexpr auto alignment = arch_t::alignment();

constexpr auto to_array = [](const auto &vec) constexpr noexcept {
static constexpr auto to_array = [](const auto &vec) constexpr noexcept {
using T = decltype(std::decay_t<decltype(vec)>());
alignas(alignment) std::array<typename T::value_type, T::size> array{};
vec.store_aligned(array.data());
return array;
};

static constexpr auto has_duplicates = [](const auto &vec) constexpr noexcept {
using T = decltype(std::decay_t<decltype(vec)>());
for (auto i = 0; i < simd_size; i++) {
const auto rotated = xsimd::rotl(vec, (sizeof(typename T::value_type) * 8) * i);
if ((rotated == vec).mask() != 0) {
return true;
}
}
return false;
};

const auto isky = (N2 > 1), iskz = (N3 > 1); // ky,kz avail? (cannot access if not)
// here the +1 is needed to allow round-off error causing i1=N1/bin_size_x,
// for kx near +pi, ie foldrescale gives N1 (exact arith would be 0 to N1-1).
Expand All @@ -1916,8 +1927,6 @@ void bin_sort_singlethread_vector(

// count how many pts in each bin
alignas(alignment) std::vector<xsimd::as_integer_t<FLT>> counts(nbins + simd_size, 0);
alignas(alignment) std::vector<xsimd::as_integer_t<FLT>> ref_counts(nbins + simd_size,
0);
const auto simd_M = M & (-simd_size); // round down to simd_size multiple
UBIGINT i{};
for (i = 0; i < simd_M; i += simd_size) {
Expand All @@ -1931,14 +1940,17 @@ void bin_sort_singlethread_vector(
iskz ? xsimd::to_int(fold_rescale(simd_type::load_unaligned(kz + i), N3) *
inv_bin_size_z_vec)
: zero;
const auto bin = i1 + nbins1 * (i2 + nbins2 * i3);
const auto bin_array = to_array(bin);
for (int j = 0; j < simd_size; j++) {
++ref_counts[bin_array[j]];
const auto bin = i1 + nbins1 * (i2 + nbins2 * i3);
if (has_duplicates(bin)) {
const auto bin_array = to_array(bin);
for (int j = 0; j < simd_size; j++) {
++counts[bin_array[j]];
}
} else {
const auto bins = int_simd_type::gather(counts.data(), bin);
const auto incr_bins = bins + 1;
incr_bins.scatter(counts.data(), bin);
}
const auto bins = int_simd_type::gather(counts.data(), bin);
const auto incr_bins = bins + 1;
incr_bins.scatter(counts.data(), bin);
}

for (; i < M; i++) {
Expand All @@ -1948,16 +1960,6 @@ void bin_sort_singlethread_vector(
const auto i3 = iskz ? BIGINT(fold_rescale(kz[i], N3) * inv_bin_size_z) : 0;
const auto bin = i1 + nbins1 * (i2 + nbins2 * i3);
++counts[bin];
++ref_counts[bin];
}

for (i = 0; i < nbins; i++) {
if (counts[i] != ref_counts[i]) {
std::cerr << "Error: bin count mismatch at bin " << i
<< " counts[i] = " << counts[i] << " ref_counts[i] = " << ref_counts[i]
<< std::endl;
std::abort();
}
}

// compute the offsets directly in the counts array (no offset array)
Expand All @@ -1979,16 +1981,19 @@ void bin_sort_singlethread_vector(
iskz ? xsimd::to_int(fold_rescale(simd_type::load_unaligned(kz + i), N3) *
inv_bin_size_z_vec)
: zero;
const auto bin = i1 + nbins1 * (i2 + nbins2 * i3);
// const auto bins = decltype(bin)::gather(counts.data(), bin);
// const auto ret_elems = decltype(bins)::gather(ret, bins) + (increment+i);
// ret_elems.scatter(ret, bins);
// const auto inc_bins = bins+1;
// inc_bins.scatter(counts.data(), bin);
const auto bin_array = to_array(to_int(bin));
for (int j = 0; j < simd_size; j++) {
ret[counts[bin_array[j]]] = j + i;
counts[bin_array[j]]++;
const auto bin = i1 + nbins1 * (i2 + nbins2 * i3);
const auto bins = decltype(bin)::gather(counts.data(), bin);
if (has_duplicates(bin) || has_duplicates(bins)) {
const auto bin_array = to_array(to_int(bin));
for (int j = 0; j < simd_size; j++) {
ret[counts[bin_array[j]]] = j + i;
counts[bin_array[j]]++;
}
} else {
const auto incr_bins = bins + 1;
incr_bins.scatter(counts.data(), bin);
const auto result = increment + i;
result.scatter(ret, bins);
}
}
for (; i < M; i++) {
Expand Down

0 comments on commit 8da21c2

Please sign in to comment.