diff --git a/cpp/src/wholememory/embedding_optimizer.cpp b/cpp/src/wholememory/embedding_optimizer.cpp index e96585289..1fa761014 100644 --- a/cpp/src/wholememory/embedding_optimizer.cpp +++ b/cpp/src/wholememory/embedding_optimizer.cpp @@ -1,5 +1,5 @@ /* - * Copyright (c) 2019-2023, NVIDIA CORPORATION. + * Copyright (c) 2019-2024, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -80,7 +80,6 @@ wholememory_tensor_t embedding_optimizer_impl_base::get_optimizer_state( WHOLEMEMORY_CHECK_NOTHROW(optimizer_state != nullptr); WHOLEMEMORY_CHECK_NOTHROW(state_names_.size() == optimizer_state->cachable_states.size() + optimizer_state->uncachable_states.size() + 1); - WHOLEMEMORY_FAIL_NOTHROW("optimizer state name %s not found for %s", state_name, name_); for (size_t i = 0; i < optimizer_state->cachable_states.size(); i++) { if (strcmp(state_name, optimizer_state->cachable_states[i].name.c_str()) == 0) { WHOLEMEMORY_CHECK_NOTHROW(strcmp(state_name, state_names_[i]) == 0); @@ -94,6 +93,7 @@ wholememory_tensor_t embedding_optimizer_impl_base::get_optimizer_state( return optimizer_state->uncachable_states[i].global_raw_sub_tensor; } } + WHOLEMEMORY_FAIL_NOTHROW("optimizer state name %s not found for %s", state_name, name_); return nullptr; } diff --git a/python/pylibwholegraph/pylibwholegraph/binding/wholememory_binding.pyx b/python/pylibwholegraph/pylibwholegraph/binding/wholememory_binding.pyx index feffa9162..1d49065e3 100644 --- a/python/pylibwholegraph/pylibwholegraph/binding/wholememory_binding.pyx +++ b/python/pylibwholegraph/pylibwholegraph/binding/wholememory_binding.pyx @@ -828,9 +828,10 @@ cdef class PyWholeMemoryEmbedding: result = [] cdef const char * const * state_names state_names = wholememory_embedding_get_optimizer_state_names(self.wm_embedding) - while state_names[i] != NULL: - result.append( PyUnicode_FromString(state_names[i])) - i += 1 + if state_names != NULL: + while state_names[i] != NULL: + result.append( PyUnicode_FromString(state_names[i])) + i += 1 return result def get_optimizer_state(self, diff --git a/python/pylibwholegraph/pylibwholegraph/torch/embedding.py b/python/pylibwholegraph/pylibwholegraph/torch/embedding.py index 8abc92be9..67f02df77 100644 --- a/python/pylibwholegraph/pylibwholegraph/torch/embedding.py +++ b/python/pylibwholegraph/pylibwholegraph/torch/embedding.py @@ -267,7 +267,7 @@ def __init__( super().__init__() self.wmb_embedding = wmb_embedding self.embedding_tensor = None - self.optimizer_states = None + self.optimizer_states = dict() self.wmb_optimizer = wmb_optimizer self.wmb_cache_policy = wmb_cache_policy