From ec495836450f7134069431662daf4d06e3c0bae9 Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Thu, 10 Jun 2021 14:18:06 +0800 Subject: [PATCH] Simplify code for getting leftover_index --- k2/csrc/hash.h | 30 ++++++++++-------------------- 1 file changed, 10 insertions(+), 20 deletions(-) diff --git a/k2/csrc/hash.h b/k2/csrc/hash.h index 4a618c83c..330c926f9 100644 --- a/k2/csrc/hash.h +++ b/k2/csrc/hash.h @@ -157,8 +157,7 @@ class Hash { uint64_t key_value = src_data[i]; if (~key_value == 0) return; // equals -1.. nothing there. uint64_t key = key_value & key_mask, - leftover_index = 1 | - ((key >> new_buckets_num_bitsm1) ^ (key & new_num_buckets_mask)); + leftover_index = 1 | ((key >> new_buckets_num_bitsm1) ^ key); size_t cur_bucket = key & new_num_buckets_mask; while (1) { uint64_t assumed = ~((uint64_t)0), @@ -286,8 +285,7 @@ class Hash { uint64_t *old_value = nullptr, uint64_t **key_value_location = nullptr) const { uint32_t cur_bucket = static_cast(key) & num_buckets_mask_, - leftover_index = - 1 | ((key >> buckets_num_bitsm1_) & (key & num_buckets_mask_)); + leftover_index = 1 | ((key >> buckets_num_bitsm1_) ^ key); constexpr int64_t KEY_MASK = (uint64_t(1)<> buckets_num_bitsm1_) ^ (key & num_buckets_mask_)); + leftover_index = 1 | ((key >> buckets_num_bitsm1_) ^ key); while (1) { uint64_t old_elem = data_[cur_bucket]; if (~old_elem == 0) { @@ -439,8 +436,7 @@ class Hash { __forceinline__ __host__ __device__ void Delete(uint64_t key) const { constexpr int64_t KEY_MASK = (uint64_t(1) << NUM_KEY_BITS) - 1; uint32_t cur_bucket = key & num_buckets_mask_, - leftover_index = - 1 | ((key >> buckets_num_bitsm1_) ^ (key & num_buckets_mask_)); + leftover_index = 1 | ((key >> buckets_num_bitsm1_) ^ key); while (1) { uint64_t old_elem = data_[cur_bucket]; if ((old_elem & KEY_MASK) == key) { @@ -509,8 +505,7 @@ class Hash { uint64_t *old_value = nullptr, uint64_t **key_value_location = nullptr) const { uint32_t cur_bucket = static_cast(key) & num_buckets_mask_, - leftover_index = - 1 | ((key >> buckets_num_bitsm1_) ^ (key & num_buckets_mask_)); + leftover_index = 1 | ((key >> buckets_num_bitsm1_) ^ key); const uint32_t num_key_bits = num_key_bits_; const uint64_t key_mask = (uint64_t(1) << num_key_bits) - 1, not_value_mask = (uint64_t(-1) << (64 - num_key_bits)); @@ -576,8 +571,7 @@ class Hash { const int64_t key_mask = (uint64_t(1) << num_key_bits) - 1; uint32_t cur_bucket = key & num_buckets_mask_, - leftover_index = - 1 | ((key >> buckets_num_bitsm1_) ^ (key & num_buckets_mask_)); + leftover_index = 1 | ((key >> buckets_num_bitsm1_) ^ key); while (1) { uint64_t old_elem = data_[cur_bucket]; if (~old_elem == 0) { @@ -653,8 +647,7 @@ class Hash { */ __forceinline__ __host__ __device__ void Delete(uint64_t key) const { uint32_t cur_bucket = key & num_buckets_mask_, - leftover_index = - 1 | ((key >> buckets_num_bitsm1_) ^ (key & num_buckets_mask_)); + leftover_index = 1 | ((key >> buckets_num_bitsm1_) ^ key); const uint64_t key_mask = (uint64_t(1) << num_key_bits_) - 1; while (1) { uint64_t old_elem = data_[cur_bucket]; @@ -744,8 +737,7 @@ class Hash { // the lowest-order `num_implicit_key_bits_` bits of the bucket index will // not change when we fail over to the next location. Without this, our // scheme would not work. - uint32_t leftover_index = - (1 | ((key >> buckets_num_bitsm1_) ^ (key & num_buckets_mask_))) + uint32_t leftover_index = (1 | ((key >> buckets_num_bitsm1_) ^ key)) << num_implicit_key_bits_; uint64_t kept_key = key >> num_implicit_key_bits_; @@ -808,8 +800,7 @@ class Hash { const int64_t kept_key_mask = (uint64_t(1) << num_kept_key_bits_) - 1; uint32_t cur_bucket = key & num_buckets_mask_, - leftover_index = - (1 | ((key >> buckets_num_bitsm1_) ^ (key & num_buckets_mask_))) + leftover_index = (1 | ((key >> buckets_num_bitsm1_) ^ key)) << num_implicit_key_bits_; uint64_t kept_key = key >> num_implicit_key_bits_; @@ -892,8 +883,7 @@ class Hash { */ __forceinline__ __host__ __device__ void Delete(uint64_t key) const { uint32_t cur_bucket = key & num_buckets_mask_, - leftover_index = - (1 | ((key >> buckets_num_bitsm1_) ^ (key & num_buckets_mask_))) + leftover_index = (1 | ((key >> buckets_num_bitsm1_) ^ key)) << num_implicit_key_bits_; uint64_t kept_key = key >> num_implicit_key_bits_; const uint64_t kept_key_mask = (uint64_t(1) << num_kept_key_bits_) - 1;