diff --git a/cpp/oneapi/dal/backend/communicator.cpp b/cpp/oneapi/dal/backend/communicator.cpp index 5682ade0a1b..d0f7924864f 100644 --- a/cpp/oneapi/dal/backend/communicator.cpp +++ b/cpp/oneapi/dal/backend/communicator.cpp @@ -86,6 +86,14 @@ class fake_spmd_communicator_host_impl : public spmd::communicator_iface_base { return rank_count; } + bool get_mpi_offload_support() override { + return false; + } + + bool use_sendrecv_replace_alternative() override { + return false; + } + void barrier() override {} request_t* bcast(byte_t* send_buf, @@ -122,7 +130,8 @@ class fake_spmd_communicator_host_impl : public spmd::communicator_iface_base { std::int64_t count, const data_type& dtype, std::int64_t destination_rank, - std::int64_t source_rank) override { + std::int64_t source_rank, + byte_t* recv_buf = nullptr) override { return nullptr; } }; @@ -148,6 +157,14 @@ class fake_spmd_communicator_device_impl : public spmd::communicator_iface { return rank_count; } + bool get_mpi_offload_support() override { + return false; + } + + bool use_sendrecv_replace_alternative() override { + return false; + } + void barrier() override {} request_t* bcast(byte_t* send_buf, @@ -225,7 +242,8 @@ class fake_spmd_communicator_device_impl : public spmd::communicator_iface { std::int64_t count, const data_type& dtype, std::int64_t destination_rank, - std::int64_t source_rank) override { + std::int64_t source_rank, + byte_t* recv_buf = nullptr) override { return nullptr; } request_t* sendrecv_replace(sycl::queue& q, diff --git a/cpp/oneapi/dal/detail/ccl/communicator.hpp b/cpp/oneapi/dal/detail/ccl/communicator.hpp index bb291628305..d78d7fc8ddf 100644 --- a/cpp/oneapi/dal/detail/ccl/communicator.hpp +++ b/cpp/oneapi/dal/detail/ccl/communicator.hpp @@ -325,6 +325,18 @@ class ccl_communicator_impl : public ccl_interface_selector::t return default_root_; } + bool get_mpi_offload_support() override { + auto ccl_backend = ccl::get_library_version().cl_backend_name; + if (ccl_backend == "DPCPP") { + return true; + } + return false; + } + + bool use_sendrecv_replace_alternative() override { + return false; + } + void barrier() override { ccl::barrier(host_comm_->get_ref()).wait(); } @@ -396,7 +408,8 @@ class ccl_communicator_impl : public ccl_interface_selector::t std::int64_t count, const data_type& dtype, std::int64_t destination_rank, - std::int64_t source_rank) override { + std::int64_t source_rank, + byte_t* recv_buf = nullptr) override { ONEDAL_ASSERT(destination_rank >= 0); ONEDAL_ASSERT(source_rank >= 0); ONEDAL_ASSERT(destination_rank < rank_count_); diff --git a/cpp/oneapi/dal/detail/communicator.cpp b/cpp/oneapi/dal/detail/communicator.cpp index 388686810c1..8af2ddf1638 100644 --- a/cpp/oneapi/dal/detail/communicator.cpp +++ b/cpp/oneapi/dal/detail/communicator.cpp @@ -16,6 +16,7 @@ #include "oneapi/dal/detail/common.hpp" #include "oneapi/dal/detail/communicator.hpp" +#include "oneapi/dal/detail/profiler.hpp" #include "oneapi/dal/array.hpp" namespace spmd = oneapi::dal::preview::spmd; @@ -52,15 +53,24 @@ spmd::request_iface* spmd_communicator_via_host_impl::bcast(sycl::queue& q, const std::int64_t dtype_size = get_data_type_size(dtype); const std::int64_t size = check_mul_overflow(dtype_size, count); - const auto send_buff_host = array::empty(size); - if (get_rank() == root) { - memcpy_usm2host(q, send_buff_host.get_mutable_data(), send_buf, size); + const bool gpu_offloading = get_mpi_offload_support(); + + if (gpu_offloading) { + ONEDAL_PROFILER_TASK(comm.bcast_gpu, q); + wait_request(bcast(send_buf, count, dtype, root)); } + else { + ONEDAL_PROFILER_TASK(comm.bcast_gpu, q); + const auto send_buff_host = array::empty(size); + if (get_rank() == root) { + memcpy_usm2host(q, send_buff_host.get_mutable_data(), send_buf, size); + } - wait_request(bcast(send_buff_host.get_mutable_data(), count, dtype, root)); + wait_request(bcast(send_buff_host.get_mutable_data(), count, dtype, root)); - if (get_rank() != root) { - memcpy_host2usm(q, send_buf, send_buff_host.get_mutable_data(), size); + if (get_rank() != root) { + memcpy_host2usm(q, send_buf, send_buff_host.get_mutable_data(), size); + } } return nullptr; @@ -102,35 +112,49 @@ spmd::request_iface* spmd_communicator_via_host_impl::allgatherv( const std::int64_t dtype_size = get_data_type_size(dtype); const std::int64_t send_size = check_mul_overflow(dtype_size, send_count); const std::int64_t total_recv_size = check_mul_overflow(dtype_size, total_recv_count); - // Workaround for zero send_size - const auto send_buff_host = array::empty(send_size > 0 ? send_size : 1); - if (send_size > 0) { - memcpy_usm2host(q, send_buff_host.get_mutable_data(), send_buf, send_size); + + const bool gpu_offloading = get_mpi_offload_support(); + + if (gpu_offloading) { + ONEDAL_PROFILER_TASK(comm.allgatherv_gpu, q); + wait_request( + allgatherv(send_buf, send_count, recv_buf, recv_counts_host, displs_host, dtype)); } + else { + ONEDAL_PROFILER_TASK(comm.allgatherv_cpu, q); + // Workaround for zero send_size + const auto send_buff_host = array::empty(send_size > 0 ? send_size : 1); + if (send_size > 0) { + memcpy_usm2host(q, send_buff_host.get_mutable_data(), send_buf, send_size); + } - array recv_buf_host; - byte_t* recv_buf_host_ptr = nullptr; - ONEDAL_ASSERT(total_recv_size > 0); - recv_buf_host.reset(total_recv_size); - recv_buf_host_ptr = recv_buf_host.get_mutable_data(); - wait_request(allgatherv(send_buff_host.get_data(), - send_count, - recv_buf_host_ptr, - recv_counts_host, - displs_host, - dtype)); - - const std::int64_t* displs_host_root_ptr = displs_host_root.get_data(); - ONEDAL_ASSERT(displs_host_root_ptr); - ONEDAL_ASSERT(displs_host); - ONEDAL_ASSERT(recv_counts_host); + array recv_buf_host; + byte_t* recv_buf_host_ptr = nullptr; + ONEDAL_ASSERT(total_recv_size > 0); + recv_buf_host.reset(total_recv_size); + recv_buf_host_ptr = recv_buf_host.get_mutable_data(); + wait_request(allgatherv(send_buff_host.get_data(), + send_count, + recv_buf_host_ptr, + recv_counts_host, + displs_host, + dtype)); + + const std::int64_t* displs_host_root_ptr = displs_host_root.get_data(); + ONEDAL_ASSERT(displs_host_root_ptr); + ONEDAL_ASSERT(displs_host); + ONEDAL_ASSERT(recv_counts_host); - for (std::int64_t i = 0; i < rank_count; i++) { - const std::int64_t src_offset = check_mul_overflow(dtype_size, displs_host_root_ptr[i]); - const std::int64_t dst_offset = check_mul_overflow(dtype_size, displs_host[i]); - const std::int64_t copy_size = check_mul_overflow(dtype_size, recv_counts_host[i]); - if (copy_size > 0) { - memcpy_host2usm(q, recv_buf + dst_offset, recv_buf_host_ptr + src_offset, copy_size); + for (std::int64_t i = 0; i < rank_count; i++) { + const std::int64_t src_offset = check_mul_overflow(dtype_size, displs_host_root_ptr[i]); + const std::int64_t dst_offset = check_mul_overflow(dtype_size, displs_host[i]); + const std::int64_t copy_size = check_mul_overflow(dtype_size, recv_counts_host[i]); + if (copy_size > 0) { + memcpy_host2usm(q, + recv_buf + dst_offset, + recv_buf_host_ptr + src_offset, + copy_size); + } } } return nullptr; @@ -161,13 +185,25 @@ spmd::request_iface* spmd_communicator_via_host_impl::allreduce( const std::int64_t dtype_size = get_data_type_size(dtype); const std::int64_t byte_count = check_mul_overflow(dtype_size, count); - const auto send_buff_host = array::empty(byte_count); - const auto recv_buf_host = array::empty(byte_count); + const bool gpu_offloading = get_mpi_offload_support(); - memcpy_usm2host(q, send_buff_host.get_mutable_data(), send_buf, byte_count); - wait_request( - allreduce(send_buff_host.get_data(), recv_buf_host.get_mutable_data(), count, dtype, op)); - memcpy_host2usm(q, recv_buf, recv_buf_host.get_data(), byte_count); + if (gpu_offloading) { + ONEDAL_PROFILER_TASK(comm.allreduce_gpu, q); + wait_request(allreduce(send_buf, recv_buf, count, dtype, op)); + } + else { + ONEDAL_PROFILER_TASK(comm.allreduce_cpu, q); + const auto send_buff_host = array::empty(byte_count); + const auto recv_buf_host = array::empty(byte_count); + + memcpy_usm2host(q, send_buff_host.get_mutable_data(), send_buf, byte_count); + wait_request(allreduce(send_buff_host.get_data(), + recv_buf_host.get_mutable_data(), + count, + dtype, + op)); + memcpy_host2usm(q, recv_buf, recv_buf_host.get_data(), byte_count); + } return nullptr; } @@ -198,16 +234,39 @@ spmd::request_iface* spmd_communicator_via_host_impl::sendrecv_replace( const std::int64_t dtype_size = get_data_type_size(dtype); const std::int64_t size = check_mul_overflow(dtype_size, count); - const auto buff_host = array::empty(size); - memcpy_usm2host(q, buff_host.get_mutable_data(), buf, size); - - wait_request(sendrecv_replace(buff_host.get_mutable_data(), - count, - dtype, - destination_rank, - source_rank)); - - memcpy_host2usm(q, buf, buff_host.get_mutable_data(), size); + const bool gpu_offloading = get_mpi_offload_support(); + + if (gpu_offloading) { + ONEDAL_PROFILER_TASK(comm.srr_gpu, q); + const bool mpich_sendrecv = use_sendrecv_replace_alternative(); + if (mpich_sendrecv) { + static byte_t* recv_buf = nullptr; + static bool initialized = false; + if (!initialized) { + recv_buf = sycl::malloc_device(size, q); + initialized = true; + } + wait_request( + sendrecv_replace(buf, count, dtype, destination_rank, source_rank, recv_buf)); + q.memcpy(buf, recv_buf, size).wait(); + } + else { + wait_request(sendrecv_replace(buf, count, dtype, destination_rank, source_rank)); + } + } + else { + ONEDAL_PROFILER_TASK(comm.srr_cpu, q); + const auto buff_host = array::empty(size); + memcpy_usm2host(q, buff_host.get_mutable_data(), buf, size); + + wait_request(sendrecv_replace(buff_host.get_mutable_data(), + count, + dtype, + destination_rank, + source_rank)); + + memcpy_host2usm(q, buf, buff_host.get_mutable_data(), size); + } return nullptr; } diff --git a/cpp/oneapi/dal/detail/mpi/communicator.hpp b/cpp/oneapi/dal/detail/mpi/communicator.hpp index cb406b289a6..28f6f7589ba 100644 --- a/cpp/oneapi/dal/detail/mpi/communicator.hpp +++ b/cpp/oneapi/dal/detail/mpi/communicator.hpp @@ -21,6 +21,9 @@ // TODO: In the future this can be solved via __has_include C++17 feature #include +#include +#include +#include #include #include "oneapi/dal/detail/communicator.hpp" @@ -146,6 +149,66 @@ class mpi_communicator_impl : public via_host_interface_selector> major; + + return (major >= 2021); + } + + // Return status of MPI ze support using pointer to function + typedef int (*MPIX_Query_ze_support_ptr)(); + MPIX_Query_ze_support_ptr query_ze_support_ptr = (MPIX_Query_ze_support_ptr)sym; + + bool result = query_ze_support_ptr(); + dlclose(handle); + + return result; + } + + bool use_sendrecv_replace_alternative() override { + char version[MPI_MAX_LIBRARY_VERSION_STRING]; + int len = 0; + MPI_Get_library_version(version, &len); + std::string version_str(version); + if (version_str.compare(0, 5, "MPICH") == 0) { + return true; + } + else { + return false; + } + } + void barrier() override { mpi_call(MPI_Barrier(mpi_comm_)); } @@ -244,23 +307,13 @@ class mpi_communicator_impl : public via_host_interface_selector::empty(size); - // TODO Replace with MPI_Iallreduce - mpi_call(MPI_Allreduce(send_buf, - recv_buf_backup.get_mutable_data(), + mpi_call(MPI_Allreduce(MPI_IN_PLACE, + recv_buf, integral_cast(count), make_mpi_data_type(dtype), make_mpi_reduce_op(op), mpi_comm_)); - - memcpy(default_host_policy{}, recv_buf, recv_buf_backup.get_data(), size); - - // We have to copy memory after reduction, this cannot be performed - // asynchronously in the current implementation, so we return `nullptr` - // indicating that operation was performed synchronously return nullptr; } } @@ -269,7 +322,8 @@ class mpi_communicator_impl : public via_host_interface_selector= 0); ONEDAL_ASSERT(source_rank >= 0); @@ -282,15 +336,34 @@ class mpi_communicator_impl : public via_host_interface_selector(count), - make_mpi_data_type(dtype), - integral_cast(destination_rank), - zero_tag, - integral_cast(source_rank), - zero_tag, - mpi_comm_, - &status)); + + if (recv_buf) { + // MPICH-specific workaround for GPU performance + mpi_call(MPI_Sendrecv(buf, + integral_cast(count), + make_mpi_data_type(dtype), + integral_cast(destination_rank), + zero_tag, + recv_buf, + integral_cast(count), + make_mpi_data_type(dtype), + integral_cast(source_rank), + zero_tag, + mpi_comm_, + &status)); + } + else { + // Standard call to sendrecv_replace of designated mpi backend + mpi_call(MPI_Sendrecv_replace(buf, + integral_cast(count), + make_mpi_data_type(dtype), + integral_cast(destination_rank), + zero_tag, + integral_cast(source_rank), + zero_tag, + mpi_comm_, + &status)); + } return nullptr; } diff --git a/cpp/oneapi/dal/spmd/communicator.hpp b/cpp/oneapi/dal/spmd/communicator.hpp index 0e7aa5892d1..b5eebf27ef7 100644 --- a/cpp/oneapi/dal/spmd/communicator.hpp +++ b/cpp/oneapi/dal/spmd/communicator.hpp @@ -66,6 +66,8 @@ class communicator_iface_base { virtual std::int64_t get_rank() = 0; virtual std::int64_t get_rank_count() = 0; virtual std::int64_t get_default_root_rank() = 0; + virtual bool get_mpi_offload_support() = 0; + virtual bool use_sendrecv_replace_alternative() = 0; virtual void barrier() = 0; @@ -88,7 +90,8 @@ class communicator_iface_base { std::int64_t count, const data_type& dtype, std::int64_t destination_rank, - std::int64_t source_rank) = 0; + std::int64_t source_rank, + byte_t* recv_buf = nullptr) = 0; }; template diff --git a/cpp/oneapi/dal/test/engine/thread_communicator.cpp b/cpp/oneapi/dal/test/engine/thread_communicator.cpp index 0b981740338..2b090efd3dd 100644 --- a/cpp/oneapi/dal/test/engine/thread_communicator.cpp +++ b/cpp/oneapi/dal/test/engine/thread_communicator.cpp @@ -500,8 +500,8 @@ auto thread_communicator_impl::sendrecv_replace(byte_t* buf, std::int64_t count, const data_type& dtype, std::int64_t destination_rank, - std::int64_t source_rank) - -> request_t* { + std::int64_t source_rank, + byte_t* recv_buf) -> request_t* { collective_operation_guard guard{ ctx_ }; sendrecv_replace_(buf, count, dtype, destination_rank, source_rank); return nullptr; diff --git a/cpp/oneapi/dal/test/engine/thread_communicator.hpp b/cpp/oneapi/dal/test/engine/thread_communicator.hpp index f7b5b119391..4cd0a2053c4 100644 --- a/cpp/oneapi/dal/test/engine/thread_communicator.hpp +++ b/cpp/oneapi/dal/test/engine/thread_communicator.hpp @@ -382,6 +382,14 @@ class thread_communicator_impl return ctx_.get_thread_count(); } + bool get_mpi_offload_support() override { + return false; + } + + bool use_sendrecv_replace_alternative() override { + return false; + } + void barrier() override; request_t* bcast(byte_t* send_buf, @@ -406,7 +414,8 @@ class thread_communicator_impl std::int64_t count, const data_type& dtype, std::int64_t destination_rank, - std::int64_t source_rank) override; + std::int64_t source_rank, + byte_t* recv_buf = nullptr) override; private: thread_communicator_context ctx_;