diff --git a/cpp/include/wholememory/wholememory.h b/cpp/include/wholememory/wholememory.h index 885dddd8e..7fc6ad174 100644 --- a/cpp/include/wholememory/wholememory.h +++ b/cpp/include/wholememory/wholememory.h @@ -80,13 +80,23 @@ enum wholememory_distributed_backend_t { WHOLEMEMORY_DB_NCCL, WHOLEMEMORY_DB_NVSHMEM, }; + +enum LogLevel { + LEVEL_FATAL = 0, /*!< Fatal */ + LEVEL_ERROR, /*!< Error */ + LEVEL_WARN, /*!< Warn */ + LEVEL_INFO, /*!< Info */ + LEVEL_DEBUG, /*!< Debug*/ + LEVEL_TRACE /*!< Trace */ +}; + /** * Initialize WholeMemory library * @param flags : reserved should be 0 - * @param wm_log_level : wholememory log level, the default level is "info" + * @param log_level : wholememory log level, the default level is "info" * @return : wholememory_error_code_t */ -wholememory_error_code_t wholememory_init(unsigned int flags, unsigned int wm_log_level = 3); +wholememory_error_code_t wholememory_init(unsigned int flags, LogLevel log_level = LEVEL_INFO); /** * Finalize WholeMemory library diff --git a/cpp/src/logger.cpp b/cpp/src/logger.cpp index 21bd618f2..bc0337cfa 100644 --- a/cpp/src/logger.cpp +++ b/cpp/src/logger.cpp @@ -21,14 +21,14 @@ namespace wholememory { -int& get_log_level() +LogLevel& get_log_level() { - static int log_level = LEVEL_INFO; + static LogLevel log_level = LEVEL_INFO; return log_level; } -void set_log_level(int lev) { get_log_level() = lev; } +void set_log_level(LogLevel lev) { get_log_level() = lev; } -bool will_log_for(int lev) { return lev <= get_log_level(); } +bool will_log_for(LogLevel lev) { return lev <= get_log_level(); } } // namespace wholememory diff --git a/cpp/src/logger.hpp b/cpp/src/logger.hpp index 0d10f0638..5fe9a6689 100644 --- a/cpp/src/logger.hpp +++ b/cpp/src/logger.hpp @@ -24,21 +24,15 @@ #include #include "error.hpp" +#include namespace wholememory { -static constexpr int LEVEL_FATAL = 0; -static constexpr int LEVEL_ERROR = 10; -static constexpr int LEVEL_WARN = 100; -static constexpr int LEVEL_INFO = 1000; -static constexpr int LEVEL_DEBUG = 10000; -static constexpr int LEVEL_TRACE = 100000; +LogLevel& get_log_level(); -int& get_log_level(); +void set_log_level(LogLevel lev); -void set_log_level(int lev); - -bool will_log_for(int lev); +bool will_log_for(LogLevel lev); /** * @defgroup CStringFormat Expand a C-style format string @@ -86,10 +80,10 @@ inline std::string format(const char* fmt, ...) throw wholememory::logic_error(fatal_msg); \ } while (0) -#define WHOLEMEMORY_ERROR(fmt, ...) WHOLEMEMORY_LOG(wholememory::LEVEL_ERROR, fmt, ##__VA_ARGS__) -#define WHOLEMEMORY_WARN(fmt, ...) WHOLEMEMORY_LOG(wholememory::LEVEL_WARN, fmt, ##__VA_ARGS__) -#define WHOLEMEMORY_INFO(fmt, ...) WHOLEMEMORY_LOG(wholememory::LEVEL_INFO, fmt, ##__VA_ARGS__) -#define WHOLEMEMORY_DEBUG(fmt, ...) WHOLEMEMORY_LOG(wholememory::LEVEL_DEBUG, fmt, ##__VA_ARGS__) -#define WHOLEMEMORY_TRACE(fmt, ...) WHOLEMEMORY_LOG(wholememory::LEVEL_TRACE, fmt, ##__VA_ARGS__) +#define WHOLEMEMORY_ERROR(fmt, ...) WHOLEMEMORY_LOG(LEVEL_ERROR, fmt, ##__VA_ARGS__) +#define WHOLEMEMORY_WARN(fmt, ...) WHOLEMEMORY_LOG(LEVEL_WARN, fmt, ##__VA_ARGS__) +#define WHOLEMEMORY_INFO(fmt, ...) WHOLEMEMORY_LOG(LEVEL_INFO, fmt, ##__VA_ARGS__) +#define WHOLEMEMORY_DEBUG(fmt, ...) WHOLEMEMORY_LOG(LEVEL_DEBUG, fmt, ##__VA_ARGS__) +#define WHOLEMEMORY_TRACE(fmt, ...) WHOLEMEMORY_LOG(LEVEL_TRACE, fmt, ##__VA_ARGS__) } // namespace wholememory diff --git a/cpp/src/wholememory/initialize.cpp b/cpp/src/wholememory/initialize.cpp index b7d1e54ac..f614ad38f 100644 --- a/cpp/src/wholememory/initialize.cpp +++ b/cpp/src/wholememory/initialize.cpp @@ -17,7 +17,6 @@ #include #include -#include #include #include "communicator.hpp" @@ -33,7 +32,7 @@ static bool is_wm_init = false; static const std::string RAFT_NAME = "wholememory"; static cudaDeviceProp* device_props = nullptr; -wholememory_error_code_t init(unsigned int flags, unsigned int wm_log_level) noexcept +wholememory_error_code_t init(unsigned int flags, LogLevel log_level) noexcept { try { std::unique_lock lock(mu); @@ -51,7 +50,7 @@ wholememory_error_code_t init(unsigned int flags, unsigned int wm_log_level) noe WM_CUDA_CHECK(cudaGetDeviceProperties(device_props + i, i)); } is_wm_init = true; - wholememory::set_log_level(std::pow(10, wm_log_level)); + wholememory::set_log_level(log_level); return WHOLEMEMORY_SUCCESS; } catch (raft::logic_error& logic_error) { WHOLEMEMORY_ERROR("init failed, logic_error=%s", logic_error.what()); diff --git a/cpp/src/wholememory/initialize.hpp b/cpp/src/wholememory/initialize.hpp index 77870f989..6afb1cbe8 100644 --- a/cpp/src/wholememory/initialize.hpp +++ b/cpp/src/wholememory/initialize.hpp @@ -21,7 +21,7 @@ namespace wholememory { -wholememory_error_code_t init(unsigned int flags, unsigned int wm_log_level) noexcept; +wholememory_error_code_t init(unsigned int flags, LogLevel log_level) noexcept; wholememory_error_code_t finalize() noexcept; diff --git a/cpp/src/wholememory/wholememory.cpp b/cpp/src/wholememory/wholememory.cpp index 2f5f33a36..478833117 100644 --- a/cpp/src/wholememory/wholememory.cpp +++ b/cpp/src/wholememory/wholememory.cpp @@ -25,9 +25,9 @@ extern "C" { #endif -wholememory_error_code_t wholememory_init(unsigned int flags, unsigned int wm_log_level) +wholememory_error_code_t wholememory_init(unsigned int flags, LogLevel log_level) { - return wholememory::init(flags, wm_log_level); + return wholememory::init(flags, log_level); } wholememory_error_code_t wholememory_finalize() { return wholememory::finalize(); } diff --git a/python/pylibwholegraph/pylibwholegraph/binding/wholememory_binding.pyx b/python/pylibwholegraph/pylibwholegraph/binding/wholememory_binding.pyx index 77d86ffdb..0499007bf 100644 --- a/python/pylibwholegraph/pylibwholegraph/binding/wholememory_binding.pyx +++ b/python/pylibwholegraph/pylibwholegraph/binding/wholememory_binding.pyx @@ -71,7 +71,16 @@ cdef extern from "wholememory/wholememory.h": WHOLEMEMORY_DB_NONE "WHOLEMEMORY_DB_NONE" WHOLEMEMORY_DB_NCCL "WHOLEMEMORY_DB_NCCL" WHOLEMEMORY_DB_NVSHMEM "WHOLEMEMORY_DB_NVSHMEM" - cdef wholememory_error_code_t wholememory_init(unsigned int flags, unsigned int wm_log_level) + + ctypedef enum LogLevel: + LEVEL_FATAL "LEVEL_FATAL" + LEVEL_ERROR "LEVEL_ERROR" + LEVEL_WARN "LEVEL_WARN" + LEVEL_INFO "LEVEL_INFO" + LEVEL_DEBUG "LEVEL_DEBUG" + LEVEL_TRACE "LEVEL_TRACE" + + cdef wholememory_error_code_t wholememory_init(unsigned int flags, LogLevel log_level) cdef wholememory_error_code_t wholememory_finalize() @@ -204,6 +213,14 @@ cpdef enum WholeMemoryDistributedBackend: DbNCCL = WHOLEMEMORY_DB_NCCL DbNVSHMEM = WHOLEMEMORY_DB_NVSHMEM +cpdef enum WholeMemoryLogLevel: + LevFatal = LEVEL_FATAL + LevError = LEVEL_ERROR + LevWarn = LEVEL_WARN + LevInfo = LEVEL_INFO + LevDebug = LEVEL_DEBUG + LevTrace = LEVEL_TRACE + cdef check_wholememory_error_code(wholememory_error_code_t err): cdef WholeMemoryErrorCode err_code = int(err) if err_code == Success: @@ -986,8 +1003,8 @@ cdef class PyWholeMemoryUniqueID: def __dlpack_device__(self): return (kDLCPU, 0) -def init(unsigned int flags, unsigned int wm_log_level = 3): - check_wholememory_error_code(wholememory_init(flags, wm_log_level)) +def init(unsigned int flags, LogLevel log_level = LEVEL_INFO): + check_wholememory_error_code(wholememory_init(flags, log_level)) def finalize(): check_wholememory_error_code(wholememory_finalize()) diff --git a/python/pylibwholegraph/pylibwholegraph/torch/common_options.py b/python/pylibwholegraph/pylibwholegraph/torch/common_options.py index 14955305b..3bf480ba1 100644 --- a/python/pylibwholegraph/pylibwholegraph/torch/common_options.py +++ b/python/pylibwholegraph/pylibwholegraph/torch/common_options.py @@ -9,7 +9,7 @@ # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and -# limitations under the License.ß +# limitations under the License. from argparse import ArgumentParser diff --git a/python/pylibwholegraph/pylibwholegraph/torch/initialize.py b/python/pylibwholegraph/pylibwholegraph/torch/initialize.py index 94ee74261..535594f6b 100644 --- a/python/pylibwholegraph/pylibwholegraph/torch/initialize.py +++ b/python/pylibwholegraph/pylibwholegraph/torch/initialize.py @@ -16,15 +16,15 @@ import torch.utils.dlpack import pylibwholegraph.binding.wholememory_binding as wmb from .comm import set_world_info, get_global_communicator, get_local_node_communicator, reset_communicators +from .utils import str_to_wmb_wholememory_log_level def init(world_rank: int, world_size: int, local_rank: int, local_size: int, wm_log_level="info"): - log_level_dic = {"error": 1, "warn": 2, "info": 3, "debug": 4, "trace": 5} - wmb.init(0, log_level_dic[wm_log_level]) + wmb.init(0, str_to_wmb_wholememory_log_level(wm_log_level)) set_world_info(world_rank, world_size, local_rank, local_size) -def init_torch_env(world_rank: int, world_size: int, local_rank: int, local_size: int, wm_log_level): +def init_torch_env(world_rank: int, world_size: int, local_rank: int, local_size: int, wm_log_level="info"): r"""Init WholeGraph environment for PyTorch. :param world_rank: world rank of current process :param world_size: world size of all processes @@ -45,8 +45,7 @@ def init_torch_env(world_rank: int, world_size: int, local_rank: int, local_size print("[WARNING] MASTER_PORT not set, resetting to 12335") os.environ["MASTER_PORT"] = "12335" - log_level_dic = {"error": 1, "warn": 2, "info": 3, "debug": 4, "trace": 5} - wmb.init(0, log_level_dic[wm_log_level]) + wmb.init(0, str_to_wmb_wholememory_log_level(wm_log_level)) torch.set_num_threads(1) torch.cuda.set_device(local_rank) torch.distributed.init_process_group(backend="nccl", init_method="env://") diff --git a/python/pylibwholegraph/pylibwholegraph/torch/utils.py b/python/pylibwholegraph/pylibwholegraph/torch/utils.py index dab74a261..c03c2f061 100644 --- a/python/pylibwholegraph/pylibwholegraph/torch/utils.py +++ b/python/pylibwholegraph/pylibwholegraph/torch/utils.py @@ -99,6 +99,24 @@ def str_to_wmb_wholememory_memory_type(str_wmb_type: str): ) +def str_to_wmb_wholememory_log_level(str_log_level: str): + if str_log_level == "error": + return wmb.WholeMemoryLogLevel.LevError + elif str_log_level == "warn": + return wmb.WholeMemoryLogLevel.LevWarn + elif str_log_level == "info": + return wmb.WholeMemoryLogLevel.LevInfo + elif str_log_level == "debug": + return wmb.WholeMemoryLogLevel.LevDebug + elif str_log_level == "trace": + return wmb.WholeMemoryLogLevel.LevTrace + else: + raise ValueError( + "WholeMemory log level %s not supported, shold be (error, warn, info, debug, trace)" + % (str_log_level,) + ) + + def str_to_wmb_wholememory_location(str_wmb_location: str): if str_wmb_location == "cuda": return wmb.WholeMemoryMemoryLocation.MlDevice