Skip to content

Commit

Permalink
Merge branch 'branch-24.04' into remove_unnecessary_sync
Browse files Browse the repository at this point in the history
  • Loading branch information
linhu-nv authored Mar 6, 2024
2 parents 30978be + daafc76 commit ad81cf5
Show file tree
Hide file tree
Showing 13 changed files with 77 additions and 40 deletions.
14 changes: 12 additions & 2 deletions cpp/include/wholememory/wholememory.h
Original file line number Diff line number Diff line change
Expand Up @@ -80,13 +80,23 @@ enum wholememory_distributed_backend_t {
WHOLEMEMORY_DB_NCCL,
WHOLEMEMORY_DB_NVSHMEM,
};

enum LogLevel {
LEVEL_FATAL = 0, /*!< Fatal */
LEVEL_ERROR, /*!< Error */
LEVEL_WARN, /*!< Warn */
LEVEL_INFO, /*!< Info */
LEVEL_DEBUG, /*!< Debug*/
LEVEL_TRACE /*!< Trace */
};

/**
* Initialize WholeMemory library
* @param flags : reserved should be 0
* @param wm_log_level : wholememory log level, the default level is "info"
* @param log_level : wholememory log level, the default level is "info"
* @return : wholememory_error_code_t
*/
wholememory_error_code_t wholememory_init(unsigned int flags, unsigned int wm_log_level = 3);
wholememory_error_code_t wholememory_init(unsigned int flags, LogLevel log_level = LEVEL_INFO);

/**
* Finalize WholeMemory library
Expand Down
8 changes: 4 additions & 4 deletions cpp/src/logger.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -21,14 +21,14 @@

namespace wholememory {

int& get_log_level()
LogLevel& get_log_level()
{
static int log_level = LEVEL_INFO;
static LogLevel log_level = LEVEL_INFO;
return log_level;
}

void set_log_level(int lev) { get_log_level() = lev; }
void set_log_level(LogLevel lev) { get_log_level() = lev; }

bool will_log_for(int lev) { return lev <= get_log_level(); }
bool will_log_for(LogLevel lev) { return lev <= get_log_level(); }

} // namespace wholememory
24 changes: 9 additions & 15 deletions cpp/src/logger.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -24,21 +24,15 @@
#include <raft/core/error.hpp>

#include "error.hpp"
#include <wholememory/wholememory.h>

namespace wholememory {

static constexpr int LEVEL_FATAL = 0;
static constexpr int LEVEL_ERROR = 10;
static constexpr int LEVEL_WARN = 100;
static constexpr int LEVEL_INFO = 1000;
static constexpr int LEVEL_DEBUG = 10000;
static constexpr int LEVEL_TRACE = 100000;
LogLevel& get_log_level();

int& get_log_level();
void set_log_level(LogLevel lev);

void set_log_level(int lev);

bool will_log_for(int lev);
bool will_log_for(LogLevel lev);

/**
* @defgroup CStringFormat Expand a C-style format string
Expand Down Expand Up @@ -86,10 +80,10 @@ inline std::string format(const char* fmt, ...)
throw wholememory::logic_error(fatal_msg); \
} while (0)

#define WHOLEMEMORY_ERROR(fmt, ...) WHOLEMEMORY_LOG(wholememory::LEVEL_ERROR, fmt, ##__VA_ARGS__)
#define WHOLEMEMORY_WARN(fmt, ...) WHOLEMEMORY_LOG(wholememory::LEVEL_WARN, fmt, ##__VA_ARGS__)
#define WHOLEMEMORY_INFO(fmt, ...) WHOLEMEMORY_LOG(wholememory::LEVEL_INFO, fmt, ##__VA_ARGS__)
#define WHOLEMEMORY_DEBUG(fmt, ...) WHOLEMEMORY_LOG(wholememory::LEVEL_DEBUG, fmt, ##__VA_ARGS__)
#define WHOLEMEMORY_TRACE(fmt, ...) WHOLEMEMORY_LOG(wholememory::LEVEL_TRACE, fmt, ##__VA_ARGS__)
#define WHOLEMEMORY_ERROR(fmt, ...) WHOLEMEMORY_LOG(LEVEL_ERROR, fmt, ##__VA_ARGS__)
#define WHOLEMEMORY_WARN(fmt, ...) WHOLEMEMORY_LOG(LEVEL_WARN, fmt, ##__VA_ARGS__)
#define WHOLEMEMORY_INFO(fmt, ...) WHOLEMEMORY_LOG(LEVEL_INFO, fmt, ##__VA_ARGS__)
#define WHOLEMEMORY_DEBUG(fmt, ...) WHOLEMEMORY_LOG(LEVEL_DEBUG, fmt, ##__VA_ARGS__)
#define WHOLEMEMORY_TRACE(fmt, ...) WHOLEMEMORY_LOG(LEVEL_TRACE, fmt, ##__VA_ARGS__)

} // namespace wholememory
5 changes: 2 additions & 3 deletions cpp/src/wholememory/initialize.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@

#include <cuda.h>
#include <cuda_runtime_api.h>
#include <math.h>
#include <nccl.h>

#include "communicator.hpp"
Expand All @@ -33,7 +32,7 @@ static bool is_wm_init = false;
static const std::string RAFT_NAME = "wholememory";
static cudaDeviceProp* device_props = nullptr;

wholememory_error_code_t init(unsigned int flags, unsigned int wm_log_level) noexcept
wholememory_error_code_t init(unsigned int flags, LogLevel log_level) noexcept
{
try {
std::unique_lock<std::mutex> lock(mu);
Expand All @@ -51,7 +50,7 @@ wholememory_error_code_t init(unsigned int flags, unsigned int wm_log_level) noe
WM_CUDA_CHECK(cudaGetDeviceProperties(device_props + i, i));
}
is_wm_init = true;
wholememory::set_log_level(std::pow(10, wm_log_level));
wholememory::set_log_level(log_level);
return WHOLEMEMORY_SUCCESS;
} catch (raft::logic_error& logic_error) {
WHOLEMEMORY_ERROR("init failed, logic_error=%s", logic_error.what());
Expand Down
2 changes: 1 addition & 1 deletion cpp/src/wholememory/initialize.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@

namespace wholememory {

wholememory_error_code_t init(unsigned int flags, unsigned int wm_log_level) noexcept;
wholememory_error_code_t init(unsigned int flags, LogLevel log_level) noexcept;

wholememory_error_code_t finalize() noexcept;

Expand Down
4 changes: 2 additions & 2 deletions cpp/src/wholememory/wholememory.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -25,9 +25,9 @@
extern "C" {
#endif

wholememory_error_code_t wholememory_init(unsigned int flags, unsigned int wm_log_level)
wholememory_error_code_t wholememory_init(unsigned int flags, LogLevel log_level)
{
return wholememory::init(flags, wm_log_level);
return wholememory::init(flags, log_level);
}

wholememory_error_code_t wholememory_finalize() { return wholememory::finalize(); }
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,7 @@ wholememory_error_code_t nvshmem_scatter_integer_int64_func(
embedding_entry_count_per_rank,
p_env_fns,
stream,
scatter_sms;
scatter_sms);
} catch (const wholememory::cuda_error& wle) {
WHOLEMEMORY_ERROR("scatter CUDA LOGIC Error %s\n", wle.what());
return WHOLEMEMORY_LOGIC_ERROR;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,16 @@ cdef extern from "wholememory/wholememory.h":
WHOLEMEMORY_DB_NONE "WHOLEMEMORY_DB_NONE"
WHOLEMEMORY_DB_NCCL "WHOLEMEMORY_DB_NCCL"
WHOLEMEMORY_DB_NVSHMEM "WHOLEMEMORY_DB_NVSHMEM"
cdef wholememory_error_code_t wholememory_init(unsigned int flags, unsigned int wm_log_level)

ctypedef enum LogLevel:
LEVEL_FATAL "LEVEL_FATAL"
LEVEL_ERROR "LEVEL_ERROR"
LEVEL_WARN "LEVEL_WARN"
LEVEL_INFO "LEVEL_INFO"
LEVEL_DEBUG "LEVEL_DEBUG"
LEVEL_TRACE "LEVEL_TRACE"

cdef wholememory_error_code_t wholememory_init(unsigned int flags, LogLevel log_level)

cdef wholememory_error_code_t wholememory_finalize()

Expand Down Expand Up @@ -204,6 +213,14 @@ cpdef enum WholeMemoryDistributedBackend:
DbNCCL = WHOLEMEMORY_DB_NCCL
DbNVSHMEM = WHOLEMEMORY_DB_NVSHMEM

cpdef enum WholeMemoryLogLevel:
LevFatal = LEVEL_FATAL
LevError = LEVEL_ERROR
LevWarn = LEVEL_WARN
LevInfo = LEVEL_INFO
LevDebug = LEVEL_DEBUG
LevTrace = LEVEL_TRACE

cdef check_wholememory_error_code(wholememory_error_code_t err):
cdef WholeMemoryErrorCode err_code = int(err)
if err_code == Success:
Expand Down Expand Up @@ -986,8 +1003,8 @@ cdef class PyWholeMemoryUniqueID:
def __dlpack_device__(self):
return (kDLCPU, 0)

def init(unsigned int flags, unsigned int wm_log_level = 3):
check_wholememory_error_code(wholememory_init(flags, wm_log_level))
def init(unsigned int flags, LogLevel log_level = LEVEL_INFO):
check_wholememory_error_code(wholememory_init(flags, log_level))

def finalize():
check_wholememory_error_code(wholememory_finalize())
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -103,9 +103,9 @@ def host_sample_all_neighbors(
output_id = output_sample_offset_tensor[i]
for j in range(end - start):
output_dest_tensor[output_id + j] = host_csr_col_ptr[start + j]
output_center_localid_tensor[output_id + j] = node_id
output_center_localid_tensor[output_id + j] = i
output_edge_gid_tensor[output_id + j] = start + j
return output_dest_tensor, output_center_localid_tensor, output_edge_gid_tensor
return output_sample_offset_tensor, output_dest_tensor, output_center_localid_tensor, output_edge_gid_tensor


def copy_host_1D_tensor_to_wholememory(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -359,7 +359,7 @@ def routine_func(world_rank: int, world_size: int, **kwargs):

@pytest.mark.parametrize("graph_node_count", [103])
@pytest.mark.parametrize("graph_edge_count", [1043])
@pytest.mark.parametrize("max_sample_count", [11])
@pytest.mark.parametrize("max_sample_count", [11, -1])
@pytest.mark.parametrize("center_node_count", [13])
@pytest.mark.parametrize("center_node_dtype", [torch.int32, torch.int64])
@pytest.mark.parametrize("col_id_dtype", [0, 1])
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.ß
# limitations under the License.

from argparse import ArgumentParser

Expand Down
9 changes: 4 additions & 5 deletions python/pylibwholegraph/pylibwholegraph/torch/initialize.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,15 +16,15 @@
import torch.utils.dlpack
import pylibwholegraph.binding.wholememory_binding as wmb
from .comm import set_world_info, get_global_communicator, get_local_node_communicator, reset_communicators
from .utils import str_to_wmb_wholememory_log_level


def init(world_rank: int, world_size: int, local_rank: int, local_size: int, wm_log_level="info"):
log_level_dic = {"error": 1, "warn": 2, "info": 3, "debug": 4, "trace": 5}
wmb.init(0, log_level_dic[wm_log_level])
wmb.init(0, str_to_wmb_wholememory_log_level(wm_log_level))
set_world_info(world_rank, world_size, local_rank, local_size)


def init_torch_env(world_rank: int, world_size: int, local_rank: int, local_size: int, wm_log_level):
def init_torch_env(world_rank: int, world_size: int, local_rank: int, local_size: int, wm_log_level="info"):
r"""Init WholeGraph environment for PyTorch.
:param world_rank: world rank of current process
:param world_size: world size of all processes
Expand All @@ -45,8 +45,7 @@ def init_torch_env(world_rank: int, world_size: int, local_rank: int, local_size
print("[WARNING] MASTER_PORT not set, resetting to 12335")
os.environ["MASTER_PORT"] = "12335"

log_level_dic = {"error": 1, "warn": 2, "info": 3, "debug": 4, "trace": 5}
wmb.init(0, log_level_dic[wm_log_level])
wmb.init(0, str_to_wmb_wholememory_log_level(wm_log_level))
torch.set_num_threads(1)
torch.cuda.set_device(local_rank)
torch.distributed.init_process_group(backend="nccl", init_method="env://")
Expand Down
18 changes: 18 additions & 0 deletions python/pylibwholegraph/pylibwholegraph/torch/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,6 +99,24 @@ def str_to_wmb_wholememory_memory_type(str_wmb_type: str):
)


def str_to_wmb_wholememory_log_level(str_log_level: str):
if str_log_level == "error":
return wmb.WholeMemoryLogLevel.LevError
elif str_log_level == "warn":
return wmb.WholeMemoryLogLevel.LevWarn
elif str_log_level == "info":
return wmb.WholeMemoryLogLevel.LevInfo
elif str_log_level == "debug":
return wmb.WholeMemoryLogLevel.LevDebug
elif str_log_level == "trace":
return wmb.WholeMemoryLogLevel.LevTrace
else:
raise ValueError(
"WholeMemory log level %s not supported, shold be (error, warn, info, debug, trace)"
% (str_log_level,)
)


def str_to_wmb_wholememory_location(str_wmb_location: str):
if str_wmb_location == "cuda":
return wmb.WholeMemoryMemoryLocation.MlDevice
Expand Down

0 comments on commit ad81cf5

Please sign in to comment.