Skip to content

Commit

Permalink
Implement CacheSetUpdater using warp sort
Browse files Browse the repository at this point in the history
  • Loading branch information
chuangz0 committed Aug 9, 2023
1 parent 603f261 commit a28c713
Showing 1 changed file with 95 additions and 107 deletions.
202 changes: 95 additions & 107 deletions cpp/src/wholememory_ops/functions/embedding_cache_func.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@

#include <stdint.h>

#include "wholegraph_ops/block_topk_with_raft.cuh"
#include <raft/matrix/detail/select_k-inl.cuh>

namespace wholememory_ops {

Expand Down Expand Up @@ -151,14 +151,30 @@ class CacheLineInfo {
uint32_t lfu_count_;
};


template <typename NodeIDT>
class CacheSetUpdater {
public:
static constexpr int kTopKRegisterCount = 4;
static constexpr int kCacheSetSize = CacheLineInfo::kCacheSetSize;
static constexpr int kScaledCounterBits = 14;
using BlockTopK = wholegraph_ops::BlockTopkRaftWarpSort<int64_t, kCacheSetSize, kTopKRegisterCount, kCacheSetSize, false, int>;
struct TempStorage : BlockTopK::TempStorage {};

private:

using warp_bq_t =
raft::matrix::detail::select::warpsort::warp_sort_immediate<kCacheSetSize, false, int64_t, int>;

static constexpr int WARP_SIZE = 32;
static constexpr int BLOCK_SIZE = kCacheSetSize;
static_assert(kCacheSetSize == WARP_SIZE,"only support CacheSetSize==32,and BLOCK_SIZE==32\n");

public:
struct TempStorage {
int64_t store_keys[kCacheSetSize];
int store_values[kCacheSetSize];
};

;
/**
* From all invalid CacheSet, recompute lids to cache, and update cache_line_info.
* NOTE: data are not loaded, need to load after this function
Expand All @@ -175,26 +191,14 @@ class CacheSetUpdater {
{
if (id_count <= 0) return;
assert(cache_line_info.IsInValid());
#pragma unroll
for (int i = 0; i < kTopKRegisterCount; i++) {
candidate_lfu_count_[i] = -1;
candidate_local_id_[i] = -1;
}
int base_idx = 0;
int valid_count = 0;
FillCandidate<0, false>(
nullptr, nullptr, memory_lfu_counter, base_idx, valid_count, 0, id_count, temp_storage, -1);
while (base_idx < id_count) {
FillCandidate<1, false>(
nullptr, nullptr, memory_lfu_counter, base_idx, valid_count, 0, id_count, temp_storage, -1);
FillCandidate<2, false>(
nullptr, nullptr, memory_lfu_counter, base_idx, valid_count, 0, id_count, temp_storage, -1);
FillCandidate<3, false>(
nullptr, nullptr, memory_lfu_counter, base_idx, valid_count, 0, id_count, temp_storage, -1);
}

// int base_idx = 0;
// int valid_count = 0;

FillCandidate<false>(nullptr, nullptr, memory_lfu_counter, 0, id_count, temp_storage, -1);
cache_line_info.ClearCacheLine();
cache_line_info.SetLocalID(candidate_local_id_[0]);
cache_line_info.SetScaleLfuCountSync(candidate_local_id_[0] >= 0 ? candidate_lfu_count_[0] : 0);
cache_line_info.SetLocalID(candidate_local_id_);
cache_line_info.SetScaleLfuCountSync(candidate_local_id_ >= 0 ? candidate_lfu_count_ : 0);
}
/**
* Update cache set according to gids and inc_count
Expand Down Expand Up @@ -225,84 +229,58 @@ class CacheSetUpdater {
int id_count)
{
if (id_count <= 0) return;
#pragma unroll
for (int i = 0; i < kTopKRegisterCount; i++) {
candidate_lfu_count_[i] = -1;
candidate_local_id_[i] = -1;
}
int base_idx = 0;
int valid_count = 0;

candidate_lfu_count_ = -1;
candidate_local_id_ = -1;
int cached_local_id = cache_line_info.LocalID();
int has_local_id_count = 0;
has_local_id_count += FillCandidate<0>(gids,
inc_count,
memory_lfu_counter,
base_idx,
valid_count,
set_start_id,
id_count,
temp_storage,
cached_local_id);
while (base_idx < id_count) {
has_local_id_count += FillCandidate<1>(gids,
inc_count,
memory_lfu_counter,
base_idx,
valid_count,
set_start_id,
id_count,
temp_storage,
cached_local_id);
has_local_id_count += FillCandidate<2>(gids,
inc_count,
memory_lfu_counter,
base_idx,
valid_count,
set_start_id,
id_count,
temp_storage,
cached_local_id);
has_local_id_count += FillCandidate<3>(gids,
inc_count,
memory_lfu_counter,
base_idx,
valid_count,
set_start_id,
id_count,
temp_storage,
cached_local_id);
}
// printf("[TopK init dump] threadIdx.x=%d, lfu_count=%ld, lid=%d\n", threadIdx.x,
// candidate_lfu_count_[0], candidate_local_id_[0]);
candidate_lfu_count_[1] = -1;
candidate_local_id_[1] = -1;
int has_local_id_count = FillCandidate(
gids, inc_count, memory_lfu_counter, set_start_id, id_count, temp_storage, cached_local_id);

// printf("[TopK init dump] threadIdx.x=%d, lfu_count=%ld, lid=%d, has_local_id_count = %d \n",
// threadIdx.x,
// candidate_lfu_count_,
// candidate_local_id_,
// has_local_id_count);
int64_t candidate_lfu_count0 = -1;
int candidate_local_id0 = -1;
unsigned int match_flag;
// match_flag = WarpMatchLocalIDPairSync(candidate_local_id_[0], cached_local_id);
int64_t estimated_lfu_count = cache_line_info.LfuCountSync();
// Valid AND NOT exist in update list

if (cached_local_id != -1 && has_local_id_count == 0) {
// cached key not updated, use estimated lfu_count from cache
candidate_lfu_count_[1] = estimated_lfu_count;
candidate_local_id_[1] = cached_local_id;
candidate_lfu_count0 = estimated_lfu_count;
candidate_local_id0 = cached_local_id;
}

warp_bq_t warp_queue(kCacheSetSize);
warp_queue.add(candidate_lfu_count_, candidate_local_id_);
warp_queue.add(candidate_lfu_count0, candidate_local_id0);
warp_queue.done();
warp_queue.store(temp_storage.store_keys, temp_storage.store_values);
__syncthreads();
if (threadIdx.x < kCacheSetSize) {
candidate_lfu_count_ = temp_storage.store_keys[threadIdx.x];
candidate_local_id_ = temp_storage.store_values[threadIdx.x];
}
BlockTopK(temp_storage)
.TopKToStriped(candidate_lfu_count_, candidate_local_id_, kCacheSetSize, kCacheSetSize * 2);

// printf("[TopK merge dump] threadIdx.x=%d, lfu_count=%ld, lid=%d\n", threadIdx.x,
// candidate_lfu_count_[0], candidate_local_id_[0]);
match_flag = WarpMatchLocalIDPairSync(candidate_local_id_[0], cached_local_id);
match_flag = WarpMatchLocalIDPairSync(candidate_local_id_, cached_local_id);
int from_lane = -1;
bool has_match = (cached_local_id >= 0 && match_flag != 0);
if (has_match) from_lane = __ffs(match_flag) - 1;
unsigned int can_update_mask = __ballot_sync(0xFFFFFFFF, !has_match);
unsigned int lower_thread_mask = (1U << threadIdx.x) - 1;
int updatable_cache_line_rank = !has_match ? __popc(can_update_mask & lower_thread_mask) : -1;
unsigned int new_match_flag = WarpMatchLocalIDPairSync(cached_local_id, candidate_local_id_[0]);
unsigned int new_match_flag = WarpMatchLocalIDPairSync(cached_local_id, candidate_local_id_);
// printf("tid=%d, cached_local_id=%d, candidate_local_id_=%d, new_match_flag=%x\n",
// threadIdx.x,
// cached_local_id,
// candidate_local_id_[0],
// candidate_local_id_,
// new_match_flag);
bool new_need_slot = (candidate_local_id_[0] >= 0 && new_match_flag == 0);
bool new_need_slot = (candidate_local_id_ >= 0 && new_match_flag == 0);
unsigned int need_new_slot_mask = __ballot_sync(0xFFFFFFFF, new_need_slot);
int insert_data_rank = new_need_slot ? __popc(need_new_slot_mask & lower_thread_mask) : -1;
// printf("tid=%d, updatable_cache_line_rank=%d, insert_data_rank=%d\n", threadIdx.x,
Expand All @@ -313,8 +291,8 @@ class CacheSetUpdater {
from_lane = __ffs(rank_match_flag) - 1;
}
int src_lane_idx = from_lane >= 0 ? from_lane : 0;
int64_t new_lfu_count = __shfl_sync(0xFFFFFFFF, candidate_lfu_count_[0], src_lane_idx, 32);
int new_local_id = __shfl_sync(0xFFFFFFFF, candidate_local_id_[0], src_lane_idx, 32);
int64_t new_lfu_count = __shfl_sync(0xFFFFFFFF, candidate_lfu_count_, src_lane_idx, 32);
int new_local_id = __shfl_sync(0xFFFFFFFF, candidate_local_id_, src_lane_idx, 32);
if (from_lane == -1) {
new_local_id = -1;
new_lfu_count = 0;
Expand All @@ -323,7 +301,7 @@ class CacheSetUpdater {
// new_lfu_count);
if (NeedOutputLoadIDs && need_load_to_cache_ids != nullptr) {
int new_cached_lid = -1;
if (new_need_slot) { new_cached_lid = candidate_local_id_[0]; }
if (new_need_slot) { new_cached_lid = candidate_local_id_; }
unsigned int load_cache_mask = __ballot_sync(0xFFFFFFFF, new_cached_lid >= 0);
int output_idx = __popc(load_cache_mask & ((1 << threadIdx.x) - 1));
int total_load_count = __popc(load_cache_mask);
Expand Down Expand Up @@ -359,41 +337,51 @@ class CacheSetUpdater {
}

private:
int64_t candidate_lfu_count_[kTopKRegisterCount];
int candidate_local_id_[kTopKRegisterCount];
template <int StrideIdx, bool IncCounter = true>
int64_t candidate_lfu_count_;
int candidate_local_id_;
template <bool IncCounter = true>
__device__ __forceinline__ int FillCandidate(const NodeIDT* gids,
const int* inc_freq_count,
int64_t* cache_set_coverage_counter,
int& base_idx,
int& valid_count,
int64_t cache_set_start_id,
int id_count,
TempStorage& temp_storage,
int cached_local_id)
{
int const idx = base_idx + threadIdx.x;
valid_count += min(kCacheSetSize, max(0, id_count - base_idx));
int local_id = -1;
if (idx < id_count) {
local_id = gids != nullptr ? gids[idx] - cache_set_start_id : idx;
candidate_lfu_count_[StrideIdx] = cache_set_coverage_counter[local_id];
if (IncCounter) {
int id_inc_count = inc_freq_count != nullptr ? inc_freq_count[idx] : 1;
candidate_lfu_count_[StrideIdx] += id_inc_count;
cache_set_coverage_counter[local_id] = candidate_lfu_count_[StrideIdx];

warp_bq_t warp_queue(kCacheSetSize);
const int per_thread_lim = id_count + raft::laneId();

int has_local_id_count = 0;
for (int idx = threadIdx.x; idx < per_thread_lim; idx += BLOCK_SIZE) {
int local_id = -1;
int64_t candidate_lfu_count = -1;
int candidate_local_id = -1;
if (idx < id_count) {
local_id = gids != nullptr ? gids[idx] - cache_set_start_id : idx;
candidate_lfu_count = cache_set_coverage_counter[local_id];
if (IncCounter) {
int id_inc_count = inc_freq_count != nullptr ? inc_freq_count[idx] : 1;
candidate_lfu_count += id_inc_count;
cache_set_coverage_counter[local_id] = candidate_lfu_count;
}
candidate_local_id = local_id;
}
candidate_local_id_[StrideIdx] = local_id;
unsigned int local_id_match_mask = WarpMatchLocalIDPairSync(local_id, cached_local_id);
has_local_id_count += ((cached_local_id != -1) ? __popc(local_id_match_mask) : 0);
warp_queue.add(candidate_lfu_count, candidate_local_id);
}
unsigned int local_id_match_mask = WarpMatchLocalIDPairSync(local_id, cached_local_id);
int has_local_id_count = (cached_local_id != -1) ? __popc(local_id_match_mask) : 0;
if (StrideIdx == kTopKRegisterCount - 1) {
BlockTopK(temp_storage)
.TopKToStriped(
candidate_lfu_count_, candidate_local_id_, min(kCacheSetSize, valid_count), valid_count);
valid_count = min(valid_count, kCacheSetSize);

warp_queue.done();
warp_queue.store(temp_storage.store_keys, temp_storage.store_values);
__syncthreads();
if (threadIdx.x < kCacheSetSize) {
candidate_lfu_count_ = temp_storage.store_keys[threadIdx.x];
candidate_local_id_ = temp_storage.store_values[threadIdx.x];
}
base_idx += kCacheSetSize;
__syncthreads();


return has_local_id_count;
}
};
Expand Down

0 comments on commit a28c713

Please sign in to comment.