From a0eb91db73961d53397fbbe83ad6eec4e8388c6a Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Thu, 10 Jun 2021 14:23:53 +0800 Subject: [PATCH] Simplify code.. --- k2/csrc/hash.h | 57 +++++++++++++++++++++++++------------------------- 1 file changed, 28 insertions(+), 29 deletions(-) diff --git a/k2/csrc/hash.h b/k2/csrc/hash.h index 330c926f9..91b039832 100644 --- a/k2/csrc/hash.h +++ b/k2/csrc/hash.h @@ -51,12 +51,11 @@ unsigned long long int __forceinline__ __host__ __device__ AtomicCAS( - The number of buckets is a power of 2 provided by the user to the constructor; currently no resizing is supported. - When accessing hash[key], we use bucket_index == key % num_buckets, - leftover_index = 1 | ((key * 2) / num_buckets). This is the - leftover part of the index times 2, plus 1. + bucket_inc = 1 | (((key * 2) / num_buckets) ^ key). - If the bucket at `bucket_index` is occupied, we look in locations - `(bucket_index + n * leftover_index)%num_buckets` for n = 1, 2, ...; + `(bucket_index + n * bucket_inc)%num_buckets` for n = 1, 2, ...; this choice ensures that if multiple keys hash to the same bucket, - they don't all access the same sequence of locations; and leftover_index + they don't all access the same sequence of locations; and bucket_inc being odd ensures we eventually try all locations (of course for reasonable hash occupancy levels, we shouldn't ever have to try more than two or three). @@ -157,14 +156,14 @@ class Hash { uint64_t key_value = src_data[i]; if (~key_value == 0) return; // equals -1.. nothing there. uint64_t key = key_value & key_mask, - leftover_index = 1 | ((key >> new_buckets_num_bitsm1) ^ key); + bucket_inc = 1 | ((key >> new_buckets_num_bitsm1) ^ key); size_t cur_bucket = key & new_num_buckets_mask; while (1) { uint64_t assumed = ~((uint64_t)0), old_elem = AtomicCAS((unsigned long long*)(data + cur_bucket), assumed, key_value); if (old_elem == assumed) return; - cur_bucket = (cur_bucket + leftover_index) & new_num_buckets_mask; + cur_bucket = (cur_bucket + bucket_inc) & new_num_buckets_mask; // Keep iterating until we find a free spot in the new hash... } }); @@ -285,7 +284,7 @@ class Hash { uint64_t *old_value = nullptr, uint64_t **key_value_location = nullptr) const { uint32_t cur_bucket = static_cast(key) & num_buckets_mask_, - leftover_index = 1 | ((key >> buckets_num_bitsm1_) ^ key); + bucket_inc = 1 | ((key >> buckets_num_bitsm1_) ^ key); constexpr int64_t KEY_MASK = (uint64_t(1)<> buckets_num_bitsm1_) ^ key); + bucket_inc = 1 | ((key >> buckets_num_bitsm1_) ^ key); while (1) { uint64_t old_elem = data_[cur_bucket]; if (~old_elem == 0) { @@ -362,7 +361,7 @@ class Hash { *key_value_location = data_ + cur_bucket; return true; } else { - cur_bucket = (cur_bucket + leftover_index) & num_buckets_mask_; + cur_bucket = (cur_bucket + bucket_inc) & num_buckets_mask_; } } } @@ -436,14 +435,14 @@ class Hash { __forceinline__ __host__ __device__ void Delete(uint64_t key) const { constexpr int64_t KEY_MASK = (uint64_t(1) << NUM_KEY_BITS) - 1; uint32_t cur_bucket = key & num_buckets_mask_, - leftover_index = 1 | ((key >> buckets_num_bitsm1_) ^ key); + bucket_inc = 1 | ((key >> buckets_num_bitsm1_) ^ key); while (1) { uint64_t old_elem = data_[cur_bucket]; if ((old_elem & KEY_MASK) == key) { data_[cur_bucket] = ~((uint64_t)0); return; } else { - cur_bucket = (cur_bucket + leftover_index) & num_buckets_mask_; + cur_bucket = (cur_bucket + bucket_inc) & num_buckets_mask_; } } } @@ -505,7 +504,7 @@ class Hash { uint64_t *old_value = nullptr, uint64_t **key_value_location = nullptr) const { uint32_t cur_bucket = static_cast(key) & num_buckets_mask_, - leftover_index = 1 | ((key >> buckets_num_bitsm1_) ^ key); + bucket_inc = 1 | ((key >> buckets_num_bitsm1_) ^ key); const uint32_t num_key_bits = num_key_bits_; const uint64_t key_mask = (uint64_t(1) << num_key_bits) - 1, not_value_mask = (uint64_t(-1) << (64 - num_key_bits)); @@ -537,10 +536,10 @@ class Hash { } // Rotate bucket index until we find a free location. This will // eventually visit all bucket indexes before it returns to the same - // location, because leftover_index is odd (so only satisfies - // (n * leftover_index) % num_buckets == 0 for n == num_buckets). + // location, because bucket_inc is odd (so only satisfies + // (n * bucket_inc) % num_buckets == 0 for n == num_buckets). // Note: n here is the number of times we went around the loop. - cur_bucket = (cur_bucket + leftover_index) & num_buckets_mask_; + cur_bucket = (cur_bucket + bucket_inc) & num_buckets_mask_; } } @@ -571,7 +570,7 @@ class Hash { const int64_t key_mask = (uint64_t(1) << num_key_bits) - 1; uint32_t cur_bucket = key & num_buckets_mask_, - leftover_index = 1 | ((key >> buckets_num_bitsm1_) ^ key); + bucket_inc = 1 | ((key >> buckets_num_bitsm1_) ^ key); while (1) { uint64_t old_elem = data_[cur_bucket]; if (~old_elem == 0) { @@ -582,7 +581,7 @@ class Hash { *key_value_location = data_ + cur_bucket; return true; } else { - cur_bucket = (cur_bucket + leftover_index) & num_buckets_mask_; + cur_bucket = (cur_bucket + bucket_inc) & num_buckets_mask_; } } } @@ -647,7 +646,7 @@ class Hash { */ __forceinline__ __host__ __device__ void Delete(uint64_t key) const { uint32_t cur_bucket = key & num_buckets_mask_, - leftover_index = 1 | ((key >> buckets_num_bitsm1_) ^ key); + bucket_inc = 1 | ((key >> buckets_num_bitsm1_) ^ key); const uint64_t key_mask = (uint64_t(1) << num_key_bits_) - 1; while (1) { uint64_t old_elem = data_[cur_bucket]; @@ -655,7 +654,7 @@ class Hash { data_[cur_bucket] = ~((uint64_t)0); return; } else { - cur_bucket = (cur_bucket + leftover_index) & num_buckets_mask_; + cur_bucket = (cur_bucket + bucket_inc) & num_buckets_mask_; } } } @@ -733,11 +732,11 @@ class Hash { uint64_t *old_value = nullptr, uint64_t **key_value_location = nullptr) const { uint32_t cur_bucket = static_cast(key) & num_buckets_mask_; - // Shifting `leftover_index` right by num_implicit_key_bits_ ensures that + // Shifting `bucket_inc` right by num_implicit_key_bits_ ensures that // the lowest-order `num_implicit_key_bits_` bits of the bucket index will // not change when we fail over to the next location. Without this, our // scheme would not work. - uint32_t leftover_index = (1 | ((key >> buckets_num_bitsm1_) ^ key)) + uint32_t bucket_inc = (1 | ((key >> buckets_num_bitsm1_) ^ key)) << num_implicit_key_bits_; uint64_t kept_key = key >> num_implicit_key_bits_; @@ -770,7 +769,7 @@ class Hash { } } // Rotate bucket index until we find a free location. - cur_bucket = (cur_bucket + leftover_index) & num_buckets_mask_; + cur_bucket = (cur_bucket + bucket_inc) & num_buckets_mask_; } } @@ -800,7 +799,7 @@ class Hash { const int64_t kept_key_mask = (uint64_t(1) << num_kept_key_bits_) - 1; uint32_t cur_bucket = key & num_buckets_mask_, - leftover_index = (1 | ((key >> buckets_num_bitsm1_) ^ key)) + bucket_inc = (1 | ((key >> buckets_num_bitsm1_) ^ key)) << num_implicit_key_bits_; uint64_t kept_key = key >> num_implicit_key_bits_; @@ -814,7 +813,7 @@ class Hash { *key_value_location = data_ + cur_bucket; return true; } else { - cur_bucket = (cur_bucket + leftover_index) & num_buckets_mask_; + cur_bucket = (cur_bucket + bucket_inc) & num_buckets_mask_; } } } @@ -883,7 +882,7 @@ class Hash { */ __forceinline__ __host__ __device__ void Delete(uint64_t key) const { uint32_t cur_bucket = key & num_buckets_mask_, - leftover_index = (1 | ((key >> buckets_num_bitsm1_) ^ key)) + bucket_inc = (1 | ((key >> buckets_num_bitsm1_) ^ key)) << num_implicit_key_bits_; uint64_t kept_key = key >> num_implicit_key_bits_; const uint64_t kept_key_mask = (uint64_t(1) << num_kept_key_bits_) - 1; @@ -893,7 +892,7 @@ class Hash { data_[cur_bucket] = ~((uint64_t)0); return; } else { - cur_bucket = (cur_bucket + leftover_index) & num_buckets_mask_; + cur_bucket = (cur_bucket + bucket_inc) & num_buckets_mask_; } } }