Skip to content

Commit

Permalink
fixed bugs (#180)
Browse files Browse the repository at this point in the history
  • Loading branch information
zhuofan1123 authored Jun 13, 2024
1 parent c92bba3 commit 996f8f7
Show file tree
Hide file tree
Showing 3 changed files with 7 additions and 6 deletions.
4 changes: 2 additions & 2 deletions cpp/src/wholememory/embedding_optimizer.cpp
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* Copyright (c) 2019-2023, NVIDIA CORPORATION.
* Copyright (c) 2019-2024, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -80,7 +80,6 @@ wholememory_tensor_t embedding_optimizer_impl_base::get_optimizer_state(
WHOLEMEMORY_CHECK_NOTHROW(optimizer_state != nullptr);
WHOLEMEMORY_CHECK_NOTHROW(state_names_.size() == optimizer_state->cachable_states.size() +
optimizer_state->uncachable_states.size() + 1);
WHOLEMEMORY_FAIL_NOTHROW("optimizer state name %s not found for %s", state_name, name_);
for (size_t i = 0; i < optimizer_state->cachable_states.size(); i++) {
if (strcmp(state_name, optimizer_state->cachable_states[i].name.c_str()) == 0) {
WHOLEMEMORY_CHECK_NOTHROW(strcmp(state_name, state_names_[i]) == 0);
Expand All @@ -94,6 +93,7 @@ wholememory_tensor_t embedding_optimizer_impl_base::get_optimizer_state(
return optimizer_state->uncachable_states[i].global_raw_sub_tensor;
}
}
WHOLEMEMORY_FAIL_NOTHROW("optimizer state name %s not found for %s", state_name, name_);
return nullptr;
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -828,9 +828,10 @@ cdef class PyWholeMemoryEmbedding:
result = []
cdef const char * const * state_names
state_names = wholememory_embedding_get_optimizer_state_names(self.wm_embedding)
while state_names[i] != NULL:
result.append(<object> PyUnicode_FromString(state_names[i]))
i += 1
if state_names != NULL:
while state_names[i] != NULL:
result.append(<object> PyUnicode_FromString(state_names[i]))
i += 1
return result

def get_optimizer_state(self,
Expand Down
2 changes: 1 addition & 1 deletion python/pylibwholegraph/pylibwholegraph/torch/embedding.py
Original file line number Diff line number Diff line change
Expand Up @@ -267,7 +267,7 @@ def __init__(
super().__init__()
self.wmb_embedding = wmb_embedding
self.embedding_tensor = None
self.optimizer_states = None
self.optimizer_states = dict()

self.wmb_optimizer = wmb_optimizer
self.wmb_cache_policy = wmb_cache_policy
Expand Down

0 comments on commit 996f8f7

Please sign in to comment.