Skip to content

Commit

Permalink
decouple embedding creation from optimizer (#186)
Browse files Browse the repository at this point in the history
This PR refactors the embedding creation interface, decoupling it from the optimizer dependency. Users now can designate the embeddings for optimization during optimizer initialization.
cpp:
```
wholememory_create_embedding(&wm_embedding, ...);
wholememory_create_embedding_optimizer(&optimizer, ...);
wholememory_embedding_set_optimizer(wm_embedding, optimizer);
```
python:
```
wm_embedding = wgth.create_embedding(...)
wm_optimizer = wgth.create_wholememory_optimizer(wm_embedding, "adam", {})
```

Authors:
  - https://github.com/zhuofan1123

Approvers:
  - https://github.com/linhu-nv
  - Brad Rees (https://github.com/BradReesWork)

URL: #186
  • Loading branch information
zhuofan1123 authored Jun 13, 2024
1 parent 996f8f7 commit 8d4cd9b
Show file tree
Hide file tree
Showing 8 changed files with 114 additions and 89 deletions.
11 changes: 9 additions & 2 deletions cpp/include/wholememory/embedding.h
Original file line number Diff line number Diff line change
Expand Up @@ -128,7 +128,6 @@ wholememory_error_code_t wholememory_destroy_embedding_cache_policy(
* @param comm : WholeMemory Communicator
* @param memory_type : Memory Type of the underlying WholeMemory
* @param memory_location : Memory Location of the underlying WholeMemory
* @param optimizer : Optimizer to use for training, if don't train embedding, use nullptr
* @param cache_policy : Cache policy for this embedding, if don't use cache, use nullptr
* @param user_defined_sms : User-defined sms number for raw embedding gather/scatter
* @param round_robin_size : continuous embedding size in each rank under round-robin shard mode
Expand All @@ -140,7 +139,6 @@ wholememory_error_code_t wholememory_create_embedding(
wholememory_comm_t comm,
wholememory_memory_type_t memory_type,
wholememory_memory_location_t memory_location,
wholememory_embedding_optimizer_t optimizer,
wholememory_embedding_cache_policy_t cache_policy,
int user_defined_sms = -1,
int round_robin_size = 0);
Expand All @@ -161,6 +159,15 @@ wholememory_error_code_t wholememory_destroy_embedding(
wholememory_tensor_t wholememory_embedding_get_embedding_tensor(
wholememory_embedding_t wholememory_embedding);

/**
* Set Optimizer for WholeMemory Embedding
* @param wholememory_embedding : WholeMemory Embedding
* @param optimizer : Optimizer to be set
* @return : wholememory_error_code_t
*/
wholememory_error_code_t wholememory_embedding_set_optimizer(
wholememory_embedding_t wholememory_embedding, wholememory_embedding_optimizer_t optimizer);

/**
* Gather from WholeMemory Embedding
* @param wholememory_embedding : WholeMemory Embedding
Expand Down
67 changes: 45 additions & 22 deletions cpp/src/wholememory/embedding.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -49,23 +49,53 @@ static int64_t align_embedding_dim(int64_t embedding_dim, size_t element_size)
return embedding_stride;
}

wholememory_error_code_t embedding_base::set_optimizer(wholememory_embedding_optimizer_t opt)
{
try {
if (optimizer != nullptr) {
WHOLEMEMORY_ERROR("optimizer can only be set once.");
return WHOLEMEMORY_NOT_SUPPORTED;
}
optimizer = opt;
if (optimizer != nullptr) {
if (embedding_dtype_ != WHOLEMEMORY_DT_FLOAT) {
WHOLEMEMORY_ERROR("Only float embedding supports training.");
return WHOLEMEMORY_NOT_IMPLEMENTED;
}
if (cache_policy != nullptr) {
WHOLEMEMORY_CHECK_NOTHROW(cache_policy->access_type == WHOLEMEMORY_AT_READWRITE);
if (cache_policy->cache_comm != raw_embedding_comm_) {
WHOLEMEMORY_ERROR("optimizer not supported for local cached global readonly embedding.");
return WHOLEMEMORY_INVALID_INPUT;
}
}
optimizer_impl_base_ = static_cast<embedding_optimizer_impl_base*>(optimizer);
WHOLEMEMORY_RETURN_ON_FAIL(create_optimizer_states());
WHOLEMEMORY_RETURN_ON_FAIL(init_optimizer_states());
}
} catch (std::bad_alloc& sba) {
WHOLEMEMORY_ERROR("bad_alloc");
return WHOLEMEMORY_OUT_OF_MEMORY;
} catch (...) {
WHOLEMEMORY_ERROR("Unknown error");
return WHOLEMEMORY_UNKNOW_ERROR;
}

return WHOLEMEMORY_SUCCESS;
}

wholememory_error_code_t embedding_base::allocate(
wholememory_matrix_description_t* embedding_description,
wholememory_comm_t comm,
wholememory_memory_type_t memory_type,
wholememory_memory_location_t memory_location,
wholememory_embedding_cache_policy_t policy,
wholememory_embedding_optimizer_t opt) noexcept
wholememory_embedding_cache_policy_t policy) noexcept
{
cache_policy = policy;
optimizer = opt;
raw_embedding_comm_ = comm;
embedding_dtype_ = embedding_description->dtype;
wholememory_tensor_description_t padded_embedding_tensor_description;
try {
if (optimizer != nullptr && embedding_description->dtype != WHOLEMEMORY_DT_FLOAT) {
WHOLEMEMORY_ERROR("Only float embedding supports training.");
return WHOLEMEMORY_NOT_IMPLEMENTED;
}
if (cache_policy != nullptr) {
WHOLEMEMORY_CHECK_NOTHROW(cache_policy->cache_comm != nullptr);
if (cache_policy->cache_comm != comm) {
Expand Down Expand Up @@ -99,14 +129,6 @@ wholememory_error_code_t embedding_base::allocate(
WHOLEMEMORY_RETURN_ON_FAIL(
wholememory_tensor_get_subtensor(allocated_embedding, &starts[0], &ends[0], &user_embedding));
if (cache_ptr_ != nullptr) { WHOLEMEMORY_RETURN_ON_FAIL(cache_ptr_->allocate(user_embedding)); }
if (optimizer != nullptr) {
if (cache_policy != nullptr) {
WHOLEMEMORY_CHECK_NOTHROW(cache_policy->access_type == WHOLEMEMORY_AT_READWRITE);
}
optimizer_impl_base_ = static_cast<embedding_optimizer_impl_base*>(optimizer);
WHOLEMEMORY_RETURN_ON_FAIL(create_optimizer_states());
WHOLEMEMORY_RETURN_ON_FAIL(init_optimizer_states());
}
} catch (std::bad_alloc& sba) {
WHOLEMEMORY_ERROR("bad_alloc");
return WHOLEMEMORY_OUT_OF_MEMORY;
Expand Down Expand Up @@ -341,7 +363,6 @@ wholememory_error_code_t embedding_base::create_optimizer_states() noexcept
raw_embedding_comm_,
memory_type,
memory_location,
nullptr,
cache_policy));

optimizer_state_->global_cachable_raw_user_tensor =
Expand Down Expand Up @@ -881,7 +902,6 @@ wholememory_error_code_t wholememory_create_embedding(
wholememory_comm_t comm,
wholememory_memory_type_t memory_type,
wholememory_memory_location_t memory_location,
wholememory_embedding_optimizer_t optimizer,
wholememory_embedding_cache_policy_t cache_policy,
int user_defined_sms,
int round_robin_size)
Expand Down Expand Up @@ -939,10 +959,6 @@ wholememory_error_code_t wholememory_create_embedding(
"Only ReadOnly access type supported for local cached global readonly embedding.");
return WHOLEMEMORY_INVALID_INPUT;
}
if (optimizer != nullptr) {
WHOLEMEMORY_ERROR("optimizer not supported for local cached global readonly embedding.");
return WHOLEMEMORY_INVALID_INPUT;
}
embedding_impl_ptr = new wholememory::local_cached_global_readonly_embedding();
}
} else {
Expand All @@ -953,11 +969,18 @@ wholememory_error_code_t wholememory_create_embedding(
&embedding_matrix_description, embedding_world_size, round_robin_size);
embedding_impl_ptr->set_gather_sms(user_defined_sms);
WHOLEMEMORY_RETURN_ON_FAIL(embedding_impl_ptr->allocate(
&embedding_matrix_description, comm, memory_type, memory_location, cache_policy, optimizer));
&embedding_matrix_description, comm, memory_type, memory_location, cache_policy));
*wholememory_embedding = static_cast<wholememory_embedding_t>(embedding_impl_ptr);
return WHOLEMEMORY_SUCCESS;
}

wholememory_error_code_t wholememory_embedding_set_optimizer(
wholememory_embedding_t wholememory_embedding, wholememory_embedding_optimizer_t optimizer)
{
auto* embedding_impl_ptr = static_cast<wholememory::embedding_base*>(wholememory_embedding);
WHOLEMEMORY_RETURN_ON_FAIL(embedding_impl_ptr->set_optimizer(optimizer));
return WHOLEMEMORY_SUCCESS;
}
wholememory_error_code_t wholememory_destroy_embedding(
wholememory_embedding_t wholememory_embedding)
{
Expand Down
6 changes: 4 additions & 2 deletions cpp/src/wholememory/embedding.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -45,8 +45,7 @@ class embedding_base : public wholememory_embedding_ {
wholememory_comm_t comm,
wholememory_memory_type_t memory_type,
wholememory_memory_location_t memory_location,
wholememory_embedding_cache_policy_t policy,
wholememory_embedding_optimizer_t opt) noexcept;
wholememory_embedding_cache_policy_t policy) noexcept;
void deallocate() noexcept;
virtual wholememory_error_code_t gather(wholememory_tensor_t indices,
wholememory_tensor_t output,
Expand All @@ -61,6 +60,8 @@ class embedding_base : public wholememory_embedding_ {
wholememory_env_func_t* p_env_fns,
cudaStream_t stream);

wholememory_error_code_t set_optimizer(wholememory_embedding_optimizer_t opt);

[[nodiscard]] const char* const* get_optimizer_state_names() const noexcept
{
if (optimizer_impl_base_ != nullptr) {
Expand Down Expand Up @@ -104,6 +105,7 @@ class embedding_base : public wholememory_embedding_ {

int gather_sms_;
int round_robin_size_;
wholememory_dtype_t embedding_dtype_ = WHOLEMEMORY_DT_UNKNOWN;
wholememory_comm_t raw_embedding_comm_ = nullptr;
wholememory::embedding_cache_base* cache_ptr_ = nullptr;
wholememory::embedding_optimizer_impl_base* optimizer_impl_base_ = nullptr;
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* Copyright (c) 2019-2023, NVIDIA CORPORATION.
* Copyright (c) 2019-2024, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -592,10 +592,9 @@ TEST_P(WholeMemoryEmbeddingBackwardParameterTests, EmbeddingGatherGradientApplyT
wm_comm,
params.memory_type,
params.memory_location,
optimizer,
cache_policy),
WHOLEMEMORY_SUCCESS);

EXPECT_EQ(wholememory_embedding_set_optimizer(wm_embedding, optimizer), WHOLEMEMORY_SUCCESS);
wholememory_tensor_t embedding_tensor =
wholememory_embedding_get_embedding_tensor(wm_embedding);
wholememory_tensor_t local_embed_tensor;
Expand Down
3 changes: 1 addition & 2 deletions cpp/tests/wholememory_ops/wholememory_embedding_tests.cu
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* Copyright (c) 2019-2023, NVIDIA CORPORATION.
* Copyright (c) 2019-2024, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -244,7 +244,6 @@ TEST_P(WholeMemoryEmbeddingParameterTests, EmbeddingGatherTest)
wm_comm,
params.memory_type,
params.memory_location,
nullptr,
cache_policy),
WHOLEMEMORY_SUCCESS);

Expand Down
13 changes: 3 additions & 10 deletions python/pylibwholegraph/examples/node_classfication.py
Original file line number Diff line number Diff line change
Expand Up @@ -201,23 +201,16 @@ def main_func():
args.cache_ratio,
)

wm_optimizer = (
None
if args.train_embedding is False
else wgth.create_wholememory_optimizer("adam", {})
)

wm_optimizer = None
embedding_dtype = torch.float32 if not args.fp16_mbedding else torch.float16

if wm_optimizer is None:
if args.train_embedding is False:
node_feat_wm_embedding = wgth.create_embedding_from_filelist(
feature_comm,
embedding_wholememory_type,
embedding_wholememory_location,
os.path.join(args.root_dir, "node_feat.bin"),
embedding_dtype,
args.feat_dim,
optimizer=wm_optimizer,
cache_policy=cache_policy,
round_robin_size=args.round_robin_size,
)
Expand All @@ -228,11 +221,11 @@ def main_func():
embedding_wholememory_location,
embedding_dtype,
[graph_structure.node_count, args.feat_dim],
optimizer=wm_optimizer,
cache_policy=cache_policy,
random_init=True,
round_robin_size=args.round_robin_size,
)
wm_optimizer = wgth.create_wholememory_optimizer(node_feat_wm_embedding, "adam", {})
wgth.set_framework(args.framework)
model = wgth.HomoGNNModel(graph_structure, node_feat_wm_embedding, args)
model.cuda()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -624,14 +624,17 @@ cdef extern from "wholememory/embedding.h":
wholememory_comm_t comm,
wholememory_memory_type_t memory_type,
wholememory_memory_location_t memory_location,
wholememory_embedding_optimizer_t optimizer,
wholememory_embedding_cache_policy_t cache_policy,
int user_defined_sms,
int round_robin_size)

cdef wholememory_error_code_t wholememory_destroy_embedding(
wholememory_embedding_t wholememory_embedding)

cdef wholememory_error_code_t wholememory_embedding_set_optimizer(
wholememory_embedding_t wholememory_embedding,
wholememory_embedding_optimizer_t optimizer);

cdef wholememory_error_code_t wholememory_embedding_gather(wholememory_embedding_t wholememory_embedding,
wholememory_tensor_t indices,
wholememory_tensor_t output,
Expand Down Expand Up @@ -701,6 +704,10 @@ cdef class WholeMemoryOptimizer:
check_wholememory_error_code(
wholememory_optimizer_set_parameter(self.wm_optimizer, key_bytes, &param_value))

def add_embedding(self,
PyWholeMemoryEmbedding embedding):
wholememory_embedding_set_optimizer(embedding.wm_embedding, self.wm_optimizer)

def destroy_optimizer(self):
if self.wm_optimizer == NULL:
return
Expand Down Expand Up @@ -789,7 +796,6 @@ cdef class PyWholeMemoryEmbedding:
PyWholeMemoryComm comm,
WholeMemoryMemoryType memory_type,
WholeMemoryMemoryLocation memory_location,
WholeMemoryOptimizer optimizer,
WholeMemoryCachePolicy cache_policy,
int user_defined_sms,
int round_robin_size):
Expand All @@ -800,7 +806,6 @@ cdef class PyWholeMemoryEmbedding:
comm.comm_id,
self.memory_type,
self.memory_location,
optimizer.wm_optimizer,
cache_policy.cache_policy,
user_defined_sms,
round_robin_size))
Expand Down Expand Up @@ -848,7 +853,6 @@ def create_embedding(PyWholeMemoryTensorDescription tensor_desc,
PyWholeMemoryComm comm,
WholeMemoryMemoryType memory_type,
WholeMemoryMemoryLocation memory_location,
WholeMemoryOptimizer optimizer,
WholeMemoryCachePolicy cache_policy,
int user_defined_sms,
int round_robin_size):
Expand All @@ -857,7 +861,6 @@ def create_embedding(PyWholeMemoryTensorDescription tensor_desc,
comm,
memory_type,
memory_location,
optimizer,
cache_policy,
user_defined_sms,
round_robin_size)
Expand Down
Loading

0 comments on commit 8d4cd9b

Please sign in to comment.