Skip to content

Commit

Permalink
[Embedding] Rename api calls.
Browse files Browse the repository at this point in the history
Signed-off-by: lixy9474 <[email protected]>
  • Loading branch information
lixy9474 committed Sep 21, 2023
1 parent 8c1ed0b commit a97ddf5
Show file tree
Hide file tree
Showing 35 changed files with 511 additions and 1,895 deletions.
6 changes: 1 addition & 5 deletions tensorflow/core/framework/embedding/config.proto
Original file line number Diff line number Diff line change
Expand Up @@ -50,11 +50,7 @@ enum EmbeddingVariableType {
enum ValuePtrStatus {
OK = 0;
IS_DELETED = 1;
}

enum ValuePosition {
IN_DRAM = 0;
NOT_IN_DRAM = 1;
NOT_IN_DRAM = 2;
}

enum IsSetInitialized {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -106,6 +106,9 @@ class CounterFilterDescriptorImpl: public FeatureDescriptorImpl<V> {
void* Admit(void* val) override {
if (!IsAdmit(val)) {
return feat_desc_impl_->Allocate();
} else {
LOG(FATAL)<<"Only unadmited feature could be admited.";
return nullptr;
}
}

Expand Down
15 changes: 6 additions & 9 deletions tensorflow/core/framework/embedding/dense_hash_map_kv.h
Original file line number Diff line number Diff line change
Expand Up @@ -23,9 +23,6 @@ limitations under the License.
#include "tensorflow/core/framework/embedding/kv_interface.h"

namespace tensorflow {
template <class V>
class ValuePtr;

namespace embedding {

template <class K, class V>
Expand All @@ -45,7 +42,7 @@ class DenseHashMap : public KVInterface<K, V> {
delete []hash_map_;
}

Status Lookup(K key, ValuePtr<V>** value_ptr) override {
Status Lookup(K key, void** value_ptr) override {
int64 l_id = std::abs(key)%partition_num_;
spin_rd_lock l(hash_map_[l_id].mu);
auto iter = hash_map_[l_id].hash_map.find(key);
Expand All @@ -70,7 +67,7 @@ class DenseHashMap : public KVInterface<K, V> {
}
}

Status Insert(K key, const ValuePtr<V>* value_ptr) override {
Status Insert(K key, const void* value_ptr) override {
int64 l_id = std::abs(key)%partition_num_;
spin_wr_lock l(hash_map_[l_id].mu);
auto iter = hash_map_[l_id].hash_map.find(key);
Expand All @@ -80,8 +77,8 @@ class DenseHashMap : public KVInterface<K, V> {
"already exists Key: ", key, " in DenseHashMap.");
} else {
auto iter = hash_map_[l_id].hash_map.insert(
std::move(std::pair<K, ValuePtr<V>*>(key,
const_cast<ValuePtr<V>*>(value_ptr))));
std::move(std::pair<K, void*>(key,
const_cast<void*>(value_ptr))));
return Status::OK();
}
}
Expand Down Expand Up @@ -109,7 +106,7 @@ class DenseHashMap : public KVInterface<K, V> {
}

Status GetSnapshot(std::vector<K>* key_list,
std::vector<ValuePtr<V>* >* value_ptr_list) override {
std::vector<void*>* value_ptr_list) override {
dense_hash_map hash_map_dump[partition_num_];
for (int i = 0; i< partition_num_; i++) {
spin_rd_lock l(hash_map_[i].mu);
Expand All @@ -132,7 +129,7 @@ class DenseHashMap : public KVInterface<K, V> {
const int partition_num_ = 1000;
struct dense_hash_map {
mutable easy_spinrwlock_t mu = EASY_SPINRWLOCK_INITIALIZER;
google::dense_hash_map<K, ValuePtr<V>* > hash_map;
google::dense_hash_map<K, void* > hash_map;
};
dense_hash_map* hash_map_;
};
Expand Down
12 changes: 4 additions & 8 deletions tensorflow/core/framework/embedding/embedding_memory_pool.h
Original file line number Diff line number Diff line change
Expand Up @@ -18,9 +18,6 @@ limitations under the License.
#include <deque>

namespace tensorflow {
template <class V>
class ValuePtr;

namespace embedding {
template<typename V>
class EmbeddingMemoryPool {
Expand Down Expand Up @@ -50,7 +47,7 @@ class EmbeddingMemoryPool {
return ptr;
}

void Deallocate(std::vector<ValuePtr<V>*> value_ptrs) {
void Deallocate(std::vector<void*> value_ptrs) {
int64 prev_size = value_ptrs_queue_.size();
for (auto it : value_ptrs) {
value_ptrs_queue_.emplace_back(it);
Expand All @@ -59,9 +56,8 @@ class EmbeddingMemoryPool {
int64 n = value_ptrs_queue_.size() - embs_per_block_;
n = std::min(prev_size, n);
for (int64 i = 0; i < n; i++) {
ValuePtr<V>* val = value_ptrs_queue_.front();
free_ptr_queue_.emplace_back(val->GetValue(0, 0));
delete val;
void* val = value_ptrs_queue_.front();
free_ptr_queue_.emplace_back((V*)val);
value_ptrs_queue_.pop_front();
}
}
Expand All @@ -88,7 +84,7 @@ class EmbeddingMemoryPool {
int64 embs_per_block_;
Allocator* alloc_;
std::deque<V*> free_ptr_queue_;
std::deque<ValuePtr<V>*> value_ptrs_queue_;
std::deque<void*> value_ptrs_queue_;
std::vector<V*> block_list_;
};
} //embedding
Expand Down
144 changes: 0 additions & 144 deletions tensorflow/core/framework/embedding/embedding_var.cu.cc
Original file line number Diff line number Diff line change
Expand Up @@ -42,71 +42,6 @@ void SyncWithEventMgr(se::Stream* stream,
while(!is_kernel_finish) {}
}

template <class K, class V>
void EmbeddingVar<K, V>::SetDefaultValueOfNewFeatures(
const K* keys, int64 size, const std::list<int64>& init_cursor,
V** memcpy_address, se::Stream* compute_stream, EventMgr* event_mgr,
const Eigen::GpuDevice& gpu_device) {
if (init_cursor.size() > 0) {
int64 total = init_cursor.size();
V** value_address = nullptr;
value_address = TypedAllocator::Allocate<V*>(cpu_allocator(), total * 2,
AllocationAttributes());
V** default_value_address = value_address + total;
V** dev_value_address = nullptr;
dev_value_address =
TypedAllocator::Allocate<V*>(alloc_, total * 2, AllocationAttributes());
V** dev_default_value_address = dev_value_address + total;
int64 i = 0;
auto it = init_cursor.cbegin();
for (; it != init_cursor.cend(); ++it, ++i) {
ValuePtr<V>* value_ptr =
reinterpret_cast<ValuePtr<V>*>(memcpy_address[*it]);
value_address[i] =
*((V**)((char*)(value_ptr->GetPtr()) + sizeof(FixedLengthHeader))) +
storage_->GetOffset(emb_config_.emb_index);
default_value_address[i] =
default_value_ +
(keys[i] % emb_config_.default_value_dim) % value_len_;
}
DeviceMemoryBase gpu_dst_ptr(dev_value_address, total * 2 * sizeof(V*));
compute_stream->ThenMemcpy(&gpu_dst_ptr, value_address,
total * 2 * sizeof(V*));
int block_dim = 128;
TF_CHECK_OK(GpuLaunchKernel(
embedding::CopyEmbedding<V>,
(total * value_len_ + block_dim - 1) / block_dim,
block_dim, 0, gpu_device.stream(), dev_default_value_address,
dev_value_address, value_len_, total));
SyncWithEventMgr(compute_stream, event_mgr);
// Set init meta of ValuePtrs
for (auto it = init_cursor.cbegin(); it != init_cursor.cend(); ++it) {
ValuePtr<V>* value_ptr =
reinterpret_cast<ValuePtr<V>*>(memcpy_address[*it]);
value_ptr->SetInitialized(emb_config_.emb_index);
memcpy_address[*it] = value_ptr->GetValue(
emb_config_.emb_index,
storage_->GetOffset(emb_config_.emb_index));
}
TypedAllocator::Deallocate(alloc_, dev_value_address, total * 2);
TypedAllocator::Deallocate(cpu_allocator(), value_address, total * 2);
}
}

#define REGISTER_KERNELS(ktype, vtype) \
template void EmbeddingVar<ktype, vtype>::SetDefaultValueOfNewFeatures( \
const ktype*, int64, const std::list<int64>&, vtype**, \
se::Stream*, EventMgr*, const Eigen::GpuDevice& gpu_device);
#define REGISTER_KERNELS_ALL(type) \
REGISTER_KERNELS(int32, type); \
REGISTER_KERNELS(int64, type)
#define REGISTER_KERNELS_CPU(type) REGISTER_KERNELS_ALL(type)
TF_CALL_FLOAT_TYPES(REGISTER_KERNELS_CPU)
#undef REGISTER_KERNELS_CPU

#undef REGISTER_KERNELS_ALL
#undef REGISTER_KERNELS

template <class K, class V>
void EmbeddingVar<K, V>::CopyEmbeddingsToBuffer(
V* val_base, int64 size, V** memcpy_address,
Expand Down Expand Up @@ -136,85 +71,6 @@ void EmbeddingVar<K, V>::CopyEmbeddingsToBuffer(
TF_CALL_FLOAT_TYPES(REGISTER_KERNELS_CPU)
#undef REGISTER_KERNELS_CPU

#undef REGISTER_KERNELS_ALL
#undef REGISTER_KERNELS

template <class K, class V>
void EmbeddingVar<K, V>::CopyEmbeddingsFromCPUToGPU(
const K* keys, const std::list<int64>& copyback_cursor, V** memcpy_address,
se::Stream* compute_stream, EventMgr* event_mgr,
const Eigen::GpuDevice& gpu_device,
const DeviceBase::CpuWorkerThreads* worker_threads,
int64* output_value_ptrs) {
if (copyback_cursor.size() > 0) {
int64 total = copyback_cursor.size();
size_t value_len = emb_config_.total_num(storage_->GetAllocLen());
V* memcpy_buffer_gpu = nullptr;
ValuePtr<V>** gpu_value_ptrs = new ValuePtr<V>*[total];
memcpy_buffer_gpu = (V*)alloc_->AllocateRaw(Allocator::kAllocatorAlignment,
total * value_len * sizeof(V));
storage_->CopyEmbeddingsFromCPUToGPU(
total, keys, copyback_cursor, memcpy_address, value_len, gpu_value_ptrs,
memcpy_buffer_gpu, compute_stream, event_mgr, worker_threads);

V** value_address = (V**)cpu_allocator()->AllocateRaw(
Allocator::kAllocatorAlignment, sizeof(V*) * total);
V** dev_value_address = (V**)alloc_->AllocateRaw(Allocator::kAllocatorAlignment,
sizeof(V*) * total);
std::vector<K> copyback_keys(total);
int64 i = 0;
auto it = copyback_cursor.cbegin();
for (; it != copyback_cursor.cend(); ++it, ++i) {
bool init;
// Get the curosr
int64 cursor = *it & 0x0fffffffffffffff;
gpu_value_ptrs[i]->SetInitialized(emb_config_.emb_index);
memcpy_address[cursor] = LookupOrCreateEmb(gpu_value_ptrs[i], init);
value_address[i] = memcpy_address[cursor];
copyback_keys[i] = keys[cursor];
}
DeviceMemoryBase gpu_dst_ptr(dev_value_address, total * sizeof(V*));
compute_stream->ThenMemcpy(&gpu_dst_ptr, value_address, total * sizeof(V*));

int block_dim = 128;
TF_CHECK_OK(GpuLaunchKernel(
embedding::BatchUnpack<V>, (total + block_dim - 1) / block_dim * value_len,
block_dim, 0, gpu_device.stream(), dev_value_address, memcpy_buffer_gpu,
value_len, total));

auto do_insert = [this, copyback_keys, gpu_value_ptrs, value_len](
int64 start, int64 limit) {
for (int64 i = start; i < limit; i++)
storage_->Insert(copyback_keys[i], gpu_value_ptrs[i]);
};
Shard(worker_threads->num_threads, worker_threads->workers,
copyback_keys.size(), 100000, do_insert);
if (output_value_ptrs != nullptr) {
auto it = copyback_cursor.cbegin();
for (int64 i = 0; it != copyback_cursor.cend(); ++it, ++i) {
int64 cursor = *it & 0x0fffffffffffffff;
output_value_ptrs[cursor] = (int64)gpu_value_ptrs[i];
}
}
SyncWithEventMgr(compute_stream, event_mgr);

alloc_->DeallocateRaw(dev_value_address);
alloc_->DeallocateRaw(memcpy_buffer_gpu);
cpu_allocator()->DeallocateRaw(value_address);
delete[] gpu_value_ptrs;
}
}
#define REGISTER_KERNELS(ktype, vtype) \
template void EmbeddingVar<ktype, vtype>::CopyEmbeddingsFromCPUToGPU( \
const ktype*, const std::list<int64>&, vtype**, se::Stream*, EventMgr*, \
const Eigen::GpuDevice&, const DeviceBase::CpuWorkerThreads*, int64*);
#define REGISTER_KERNELS_ALL(type) \
REGISTER_KERNELS(int32, type); \
REGISTER_KERNELS(int64, type)
#define REGISTER_KERNELS_CPU(type) REGISTER_KERNELS_ALL(type)
TF_CALL_FLOAT_TYPES(REGISTER_KERNELS_CPU)
#undef REGISTER_KERNELS_CPU

#undef REGISTER_KERNELS_ALL
#undef REGISTER_KERNELS
} // namespace tensorflow
Expand Down
13 changes: 7 additions & 6 deletions tensorflow/core/framework/embedding/embedding_var_ckpt_data.cc
Original file line number Diff line number Diff line change
Expand Up @@ -21,24 +21,25 @@ namespace tensorflow {
namespace embedding {
template<class K, class V>
void EmbeddingVarCkptData<K, V>::Emplace(
K key, ValuePtr<V>* value_ptr,
K key, void* value_ptr,
const EmbeddingConfig& emb_config,
V* default_value, int64 value_offset,
V* default_value,
FeatureDescriptor<V>* feat_desc,
bool is_save_freq,
bool is_save_version,
bool save_unfiltered_features) {
if((int64)value_ptr == ValuePtrStatus::IS_DELETED)
return;

bool is_in_dram = ((int64)value_ptr >> 49 == 0);
bool is_in_dram = ((int64)value_ptr >> kDramFlagOffset == 0);
bool is_admit = feat_desc->IsAdmit(value_ptr);

if (is_admit) {
key_vec_.emplace_back(key);

if (!is_in_dram) {
value_ptr_vec_.emplace_back((V*)ValuePtrStatus::NOT_IN_DRAM);
value_ptr = (void*)((int64)value_ptr & 0x1ffffffffffff);
value_ptr = (void*)((int64)value_ptr & ((1L << kDramFlagOffset) - 1));
} else if (feat_desc->GetEmbedding(value_ptr, 0) == nullptr) {
value_ptr_vec_.emplace_back(default_value);
} else {
Expand Down Expand Up @@ -71,8 +72,8 @@ void EmbeddingVarCkptData<K, V>::Emplace(
}
#define REGISTER_KERNELS(ktype, vtype) \
template void EmbeddingVarCkptData<ktype, vtype>::Emplace( \
ktype, ValuePtr<vtype>*, const EmbeddingConfig&, \
vtype*, int64, bool, bool, bool);
ktype, void*, const EmbeddingConfig&, \
vtype*, FeatureDescriptor<vtype>*, bool, bool, bool);
#define REGISTER_KERNELS_ALL_INDEX(type) \
REGISTER_KERNELS(int32, type) \
REGISTER_KERNELS(int64, type)
Expand Down
4 changes: 4 additions & 0 deletions tensorflow/core/framework/embedding/embedding_var_ckpt_data.h
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,10 @@ limitations under the License.
#include "tensorflow/core/framework/embedding/embedding_var_dump_iterator.h"
namespace tensorflow {
class BundleWriter;
namespace {
const int kSavedPartitionNum = 1000;
const int kDramFlagOffset = 49;
}

namespace embedding {
template<class K, class V>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@ class EV2dVectorDataDumpIterator: public DumpIterator<T> {
value_len_(value_len),
col_idx_(0) {
if (!valueptr_list.empty()) {
if ((int64)*curr_iter_ == ValuePosition::NOT_IN_DRAM) {
if ((int64)*curr_iter_ == ValuePtrStatus::NOT_IN_DRAM) {
curr_ptr_ = val_iter_->Next();
} else {
curr_ptr_ = *curr_iter_;
Expand All @@ -75,7 +75,7 @@ class EV2dVectorDataDumpIterator: public DumpIterator<T> {
curr_iter_++;
col_idx_ = 0;
if (curr_iter_ != end_iter_) {
if ((int64)*curr_iter_ == ValuePosition::NOT_IN_DRAM) {
if ((int64)*curr_iter_ == ValuePtrStatus::NOT_IN_DRAM) {
curr_ptr_ = val_iter_->Next();
} else {
curr_ptr_ = *curr_iter_;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,6 @@ limitations under the License.

#if GOOGLE_CUDA
#define EIGEN_USE_GPU
#include "tensorflow/core/kernels/gpu_device_array.h"
#include "tensorflow/core/platform/stream_executor.h"
#endif // GOOGLE_CUDA

Expand Down
Loading

0 comments on commit a97ddf5

Please sign in to comment.