Skip to content

Commit

Permalink
[Embedding] Refactor the data structure of EmbeddingVariable. (#924)
Browse files Browse the repository at this point in the history
Signed-off-by: lixy9474 <[email protected]>
  • Loading branch information
lixy9474 authored Oct 17, 2023
1 parent 29ecde4 commit 06f81cc
Show file tree
Hide file tree
Showing 62 changed files with 3,060 additions and 3,738 deletions.
77 changes: 42 additions & 35 deletions tensorflow/core/framework/embedding/bloom_filter_policy.h
Original file line number Diff line number Diff line change
Expand Up @@ -35,9 +35,10 @@ class BloomFilterPolicy : public FilterPolicy<K, V, EV> {
using FilterPolicy<K, V, EV>::config_;

public:
BloomFilterPolicy(const EmbeddingConfig& config, EV* ev) :
FilterPolicy<K, V, EV>(config, ev) {

BloomFilterPolicy(const EmbeddingConfig& config, EV* ev,
embedding::FeatureDescriptor<V>* feat_desc)
: feat_desc_(feat_desc),
FilterPolicy<K, V, EV>(config, ev) {
switch (config_.counter_type){
case DT_UINT64:
VLOG(2) << "The type of bloom counter is uint64";
Expand All @@ -64,10 +65,10 @@ class BloomFilterPolicy : public FilterPolicy<K, V, EV> {

Status Lookup(K key, V* val, const V* default_value_ptr,
const V* default_value_no_permission) override {
ValuePtr<V>* value_ptr = nullptr;
void* value_ptr = nullptr;
Status s = ev_->LookupKey(key, &value_ptr);
if (s.ok()) {
V* mem_val = ev_->LookupOrCreateEmb(value_ptr, default_value_ptr);
V* mem_val = feat_desc_->GetEmbedding(value_ptr, config_.emb_index);
memcpy(val, mem_val, sizeof(V) * ev_->ValueLen());
} else {
memcpy(val, default_value_no_permission, sizeof(V) * ev_->ValueLen());
Expand All @@ -81,17 +82,17 @@ class BloomFilterPolicy : public FilterPolicy<K, V, EV> {
int64 num_of_keys,
V* default_value_ptr,
V* default_value_no_permission) override {
std::vector<ValuePtr<V>*> value_ptr_list(num_of_keys, nullptr);
std::vector<void*> value_ptr_list(num_of_keys, nullptr);
ev_->BatchLookupKey(ctx, keys, value_ptr_list.data(), num_of_keys);
std::vector<V*> embedding_ptr(num_of_keys, nullptr);
auto do_work = [this, value_ptr_list, &embedding_ptr,
default_value_ptr, default_value_no_permission]
(int64 start, int64 limit) {
for (int i = start; i < limit; i++) {
ValuePtr<V>* value_ptr = value_ptr_list[i];
void* value_ptr = value_ptr_list[i];
if (value_ptr != nullptr) {
embedding_ptr[i] =
ev_->LookupOrCreateEmb(value_ptr, default_value_ptr);
feat_desc_->GetEmbedding(value_ptr, config_.emb_index);
} else {
embedding_ptr[i] = default_value_no_permission;
}
Expand All @@ -109,13 +110,13 @@ class BloomFilterPolicy : public FilterPolicy<K, V, EV> {
}

void BatchLookupOrCreateKey(const EmbeddingVarContext<GPUDevice>& ctx,
const K* keys, ValuePtr<V>** value_ptrs_list,
const K* keys, void** value_ptrs_list,
int64 num_of_keys) {
int num_worker_threads = ctx.worker_threads->num_threads;
std::vector<std::vector<K>> lookup_or_create_ids(num_worker_threads);
std::vector<std::vector<int>>
lookup_or_create_cursor(num_worker_threads);
std::vector<std::vector<ValuePtr<V>*>>
std::vector<std::vector<void*>>
lookup_or_create_ptrs(num_worker_threads);
IntraThreadCopyIdAllocator thread_copy_id_alloc(num_worker_threads);
std::vector<std::list<int64>>
Expand Down Expand Up @@ -147,7 +148,7 @@ class BloomFilterPolicy : public FilterPolicy<K, V, EV> {
1000, do_work);

std::vector<K> total_ids(num_of_keys);
std::vector<ValuePtr<V>*> total_ptrs(num_of_keys);
std::vector<void*> total_ptrs(num_of_keys);
std::vector<int> total_cursors(num_of_keys);
int num_of_admit_id = 0;
for (int i = 0; i < num_worker_threads; i++) {
Expand All @@ -157,7 +158,7 @@ class BloomFilterPolicy : public FilterPolicy<K, V, EV> {
sizeof(K) * lookup_or_create_ids[i].size());
memcpy(total_ptrs.data() + num_of_admit_id,
lookup_or_create_ptrs[i].data(),
sizeof(ValuePtr<V>*) * lookup_or_create_ptrs[i].size());
sizeof(void*) * lookup_or_create_ptrs[i].size());
memcpy(total_cursors.data() + num_of_admit_id,
lookup_or_create_cursor[i].data(),
sizeof(int) * lookup_or_create_cursor[i].size());
Expand All @@ -174,31 +175,40 @@ class BloomFilterPolicy : public FilterPolicy<K, V, EV> {
#endif //GOOGLE_CUDA

void LookupOrCreate(K key, V* val, const V* default_value_ptr,
ValuePtr<V>** value_ptr, int count,
void** value_ptr, int count,
const V* default_value_no_permission) override {
if (GetBloomFreq(key) >= config_.filter_freq) {
TF_CHECK_OK(ev_->LookupOrCreateKey(key, value_ptr));
V* mem_val = ev_->LookupOrCreateEmb(*value_ptr, default_value_ptr);
bool is_filter = true;
TF_CHECK_OK(LookupOrCreateKey(key, value_ptr, &is_filter, count));
V* mem_val = feat_desc_->GetEmbedding(*value_ptr, config_.emb_index);
memcpy(val, mem_val, sizeof(V) * ev_->ValueLen());
} else {
AddFreq(key, count);
memcpy(val, default_value_no_permission, sizeof(V) * ev_->ValueLen());
}
}

Status LookupOrCreateKey(K key, ValuePtr<V>** val,
Status LookupOrCreateKey(K key, void** value_ptr,
bool* is_filter, int64 count) override {
*val = nullptr;
if ((GetFreq(key, *val) + count) >= config_.filter_freq) {
*value_ptr = nullptr;
if ((GetFreq(key, *value_ptr) + count) >= config_.filter_freq) {
Status s = ev_->LookupKey(key, value_ptr);
if (!s.ok()) {
*value_ptr = feat_desc_->Allocate();
feat_desc_->SetDefaultValue(*value_ptr, key);
ev_->storage()->Insert(key, value_ptr);
s = Status::OK();
}
*is_filter = true;
return ev_->LookupOrCreateKey(key, val);
feat_desc_->AddFreq(*value_ptr, count);
} else {
*is_filter = false;
AddFreq(key, count);
}
*is_filter = false;
AddFreq(key, count);
return Status::OK();
}

int64 GetFreq(K key, ValuePtr<V>*) override {
int64 GetFreq(K key, void* val) override {
return GetBloomFreq(key);
}

Expand All @@ -210,7 +220,7 @@ class BloomFilterPolicy : public FilterPolicy<K, V, EV> {
return bloom_counter_;
}

bool is_admit(K key, ValuePtr<V>* value_ptr) override {
bool is_admit(K key, void* value_ptr) override {
if (value_ptr == nullptr) {
return false;
} else {
Expand Down Expand Up @@ -326,8 +336,12 @@ class BloomFilterPolicy : public FilterPolicy<K, V, EV> {
LOG(INFO) << "skip EV key:" << *(key_buff + i);
continue;
}
ValuePtr<V>* value_ptr = nullptr;
void* value_ptr = nullptr;
int64 new_freq = freq_buff[i];
int64 import_version = -1;
if (config_.steps_to_live != 0 || config_.record_version) {
import_version = version_buff[i];
}
if (!is_filter) {
if (freq_buff[i] >= config_.filter_freq) {
SetBloomFreq(key_buff[i], freq_buff[i]);
Expand All @@ -339,17 +353,9 @@ class BloomFilterPolicy : public FilterPolicy<K, V, EV> {
SetBloomFreq(key_buff[i], freq_buff[i]);
}
if (new_freq >= config_.filter_freq){
ev_->CreateKey(key_buff[i], &value_ptr, to_dram);
if (config_.steps_to_live != 0 || config_.record_version) {
value_ptr->SetStep(version_buff[i]);
}
if (!is_filter){
ev_->LookupOrCreateEmb(value_ptr,
value_buff + i * ev_->ValueLen());
} else {
ev_->LookupOrCreateEmb(value_ptr,
ev_->GetDefaultValue(key_buff[i]));
}
ev_->storage()->Import(key_buff[i],
value_buff + i * ev_->ValueLen(),
new_freq, import_version, config_.emb_index);
}
}
return Status::OK();
Expand Down Expand Up @@ -449,6 +455,7 @@ class BloomFilterPolicy : public FilterPolicy<K, V, EV> {
}
private:
void* bloom_counter_;
embedding::FeatureDescriptor<V>* feat_desc_;
std::vector<int64> seeds_;
};
} // tensorflow
Expand Down
6 changes: 1 addition & 5 deletions tensorflow/core/framework/embedding/config.proto
Original file line number Diff line number Diff line change
Expand Up @@ -50,11 +50,7 @@ enum EmbeddingVariableType {
enum ValuePtrStatus {
OK = 0;
IS_DELETED = 1;
}

enum ValuePosition {
IN_DRAM = 0;
NOT_IN_DRAM = 1;
NOT_IN_DRAM = 2;
}

enum IsSetInitialized {
Expand Down
Loading

0 comments on commit 06f81cc

Please sign in to comment.