Skip to content

Commit

Permalink
MPI GPU interface refactoring (#2577)
Browse files Browse the repository at this point in the history
* host transfers to thread_comm, dev upds

* memory from detail not backend

* add detail prefix

* using memcpy with policy instead of queue

* clang

* temp for debug build

* revert last

* forgot to add deps wait

* debugging

* debug rd 2

* debug follow up

* remove throw error just print

* add mpi-based conditional

* check if MPIX_Query_ze_support present

* clang

* debug

* add CCL dummy function

* debug pt 2

* remove debug prints and update bool logic

* removing unneeded threadcomm queues

* minor follow up to last

* additional cleanup

* add some host logic to common in allgatherv

* first attempt at ccl

* debug follow up

* mpi offload true debug revision

* revert last

* ccl samples debug

* comment problematic statement

* revert example debugs

* fixing ccl dispatching issue

* intel parse with debug

* ccl debug

* remove debug

* fixes with debug

* remove debug

* minor revisions based on comments

* temporary - trying MPI suggestion for mpich perf

* minor follow-up fixes

* trying to contain to mpi comm

* trying with optional queue arg

* trying buf instead

* debug build

* forgot semi colons

* remove debugging

* cache additional buffer

* printf debugging

* public CI fix

* revert debug

* mpich condition with prints

* fix public CI

* clang

* add function for workaround

* clang

* another debug...

* alternative workaround

* remove debug

* create function to identify mpi backend

* revert previous

* revised workaround condition

* Temporarily comment mpich sendrecv workaround

* Update cpp/oneapi/dal/detail/mpi/communicator.hpp

* remove temporary workaround
  • Loading branch information
ethanglaser authored Aug 30, 2024
1 parent 758f4cc commit a8df345
Show file tree
Hide file tree
Showing 7 changed files with 252 additions and 77 deletions.
22 changes: 20 additions & 2 deletions cpp/oneapi/dal/backend/communicator.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,14 @@ class fake_spmd_communicator_host_impl : public spmd::communicator_iface_base {
return rank_count;
}

bool get_mpi_offload_support() override {
return false;
}

bool use_sendrecv_replace_alternative() override {
return false;
}

void barrier() override {}

request_t* bcast(byte_t* send_buf,
Expand Down Expand Up @@ -122,7 +130,8 @@ class fake_spmd_communicator_host_impl : public spmd::communicator_iface_base {
std::int64_t count,
const data_type& dtype,
std::int64_t destination_rank,
std::int64_t source_rank) override {
std::int64_t source_rank,
byte_t* recv_buf = nullptr) override {
return nullptr;
}
};
Expand All @@ -148,6 +157,14 @@ class fake_spmd_communicator_device_impl : public spmd::communicator_iface {
return rank_count;
}

bool get_mpi_offload_support() override {
return false;
}

bool use_sendrecv_replace_alternative() override {
return false;
}

void barrier() override {}

request_t* bcast(byte_t* send_buf,
Expand Down Expand Up @@ -225,7 +242,8 @@ class fake_spmd_communicator_device_impl : public spmd::communicator_iface {
std::int64_t count,
const data_type& dtype,
std::int64_t destination_rank,
std::int64_t source_rank) override {
std::int64_t source_rank,
byte_t* recv_buf = nullptr) override {
return nullptr;
}
request_t* sendrecv_replace(sycl::queue& q,
Expand Down
15 changes: 14 additions & 1 deletion cpp/oneapi/dal/detail/ccl/communicator.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -325,6 +325,18 @@ class ccl_communicator_impl : public ccl_interface_selector<MemoryAccessKind>::t
return default_root_;
}

bool get_mpi_offload_support() override {
auto ccl_backend = ccl::get_library_version().cl_backend_name;
if (ccl_backend == "DPCPP") {
return true;
}
return false;
}

bool use_sendrecv_replace_alternative() override {
return false;
}

void barrier() override {
ccl::barrier(host_comm_->get_ref()).wait();
}
Expand Down Expand Up @@ -396,7 +408,8 @@ class ccl_communicator_impl : public ccl_interface_selector<MemoryAccessKind>::t
std::int64_t count,
const data_type& dtype,
std::int64_t destination_rank,
std::int64_t source_rank) override {
std::int64_t source_rank,
byte_t* recv_buf = nullptr) override {
ONEDAL_ASSERT(destination_rank >= 0);
ONEDAL_ASSERT(source_rank >= 0);
ONEDAL_ASSERT(destination_rank < rank_count_);
Expand Down
155 changes: 107 additions & 48 deletions cpp/oneapi/dal/detail/communicator.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@

#include "oneapi/dal/detail/common.hpp"
#include "oneapi/dal/detail/communicator.hpp"
#include "oneapi/dal/detail/profiler.hpp"
#include "oneapi/dal/array.hpp"

namespace spmd = oneapi::dal::preview::spmd;
Expand Down Expand Up @@ -52,15 +53,24 @@ spmd::request_iface* spmd_communicator_via_host_impl::bcast(sycl::queue& q,
const std::int64_t dtype_size = get_data_type_size(dtype);
const std::int64_t size = check_mul_overflow(dtype_size, count);

const auto send_buff_host = array<byte_t>::empty(size);
if (get_rank() == root) {
memcpy_usm2host(q, send_buff_host.get_mutable_data(), send_buf, size);
const bool gpu_offloading = get_mpi_offload_support();

if (gpu_offloading) {
ONEDAL_PROFILER_TASK(comm.bcast_gpu, q);
wait_request(bcast(send_buf, count, dtype, root));
}
else {
ONEDAL_PROFILER_TASK(comm.bcast_gpu, q);
const auto send_buff_host = array<byte_t>::empty(size);
if (get_rank() == root) {
memcpy_usm2host(q, send_buff_host.get_mutable_data(), send_buf, size);
}

wait_request(bcast(send_buff_host.get_mutable_data(), count, dtype, root));
wait_request(bcast(send_buff_host.get_mutable_data(), count, dtype, root));

if (get_rank() != root) {
memcpy_host2usm(q, send_buf, send_buff_host.get_mutable_data(), size);
if (get_rank() != root) {
memcpy_host2usm(q, send_buf, send_buff_host.get_mutable_data(), size);
}
}

return nullptr;
Expand Down Expand Up @@ -102,35 +112,49 @@ spmd::request_iface* spmd_communicator_via_host_impl::allgatherv(
const std::int64_t dtype_size = get_data_type_size(dtype);
const std::int64_t send_size = check_mul_overflow(dtype_size, send_count);
const std::int64_t total_recv_size = check_mul_overflow(dtype_size, total_recv_count);
// Workaround for zero send_size
const auto send_buff_host = array<byte_t>::empty(send_size > 0 ? send_size : 1);
if (send_size > 0) {
memcpy_usm2host(q, send_buff_host.get_mutable_data(), send_buf, send_size);

const bool gpu_offloading = get_mpi_offload_support();

if (gpu_offloading) {
ONEDAL_PROFILER_TASK(comm.allgatherv_gpu, q);
wait_request(
allgatherv(send_buf, send_count, recv_buf, recv_counts_host, displs_host, dtype));
}
else {
ONEDAL_PROFILER_TASK(comm.allgatherv_cpu, q);
// Workaround for zero send_size
const auto send_buff_host = array<byte_t>::empty(send_size > 0 ? send_size : 1);
if (send_size > 0) {
memcpy_usm2host(q, send_buff_host.get_mutable_data(), send_buf, send_size);
}

array<byte_t> recv_buf_host;
byte_t* recv_buf_host_ptr = nullptr;
ONEDAL_ASSERT(total_recv_size > 0);
recv_buf_host.reset(total_recv_size);
recv_buf_host_ptr = recv_buf_host.get_mutable_data();
wait_request(allgatherv(send_buff_host.get_data(),
send_count,
recv_buf_host_ptr,
recv_counts_host,
displs_host,
dtype));

const std::int64_t* displs_host_root_ptr = displs_host_root.get_data();
ONEDAL_ASSERT(displs_host_root_ptr);
ONEDAL_ASSERT(displs_host);
ONEDAL_ASSERT(recv_counts_host);
array<byte_t> recv_buf_host;
byte_t* recv_buf_host_ptr = nullptr;
ONEDAL_ASSERT(total_recv_size > 0);
recv_buf_host.reset(total_recv_size);
recv_buf_host_ptr = recv_buf_host.get_mutable_data();
wait_request(allgatherv(send_buff_host.get_data(),
send_count,
recv_buf_host_ptr,
recv_counts_host,
displs_host,
dtype));

const std::int64_t* displs_host_root_ptr = displs_host_root.get_data();
ONEDAL_ASSERT(displs_host_root_ptr);
ONEDAL_ASSERT(displs_host);
ONEDAL_ASSERT(recv_counts_host);

for (std::int64_t i = 0; i < rank_count; i++) {
const std::int64_t src_offset = check_mul_overflow(dtype_size, displs_host_root_ptr[i]);
const std::int64_t dst_offset = check_mul_overflow(dtype_size, displs_host[i]);
const std::int64_t copy_size = check_mul_overflow(dtype_size, recv_counts_host[i]);
if (copy_size > 0) {
memcpy_host2usm(q, recv_buf + dst_offset, recv_buf_host_ptr + src_offset, copy_size);
for (std::int64_t i = 0; i < rank_count; i++) {
const std::int64_t src_offset = check_mul_overflow(dtype_size, displs_host_root_ptr[i]);
const std::int64_t dst_offset = check_mul_overflow(dtype_size, displs_host[i]);
const std::int64_t copy_size = check_mul_overflow(dtype_size, recv_counts_host[i]);
if (copy_size > 0) {
memcpy_host2usm(q,
recv_buf + dst_offset,
recv_buf_host_ptr + src_offset,
copy_size);
}
}
}
return nullptr;
Expand Down Expand Up @@ -161,13 +185,25 @@ spmd::request_iface* spmd_communicator_via_host_impl::allreduce(
const std::int64_t dtype_size = get_data_type_size(dtype);
const std::int64_t byte_count = check_mul_overflow(dtype_size, count);

const auto send_buff_host = array<byte_t>::empty(byte_count);
const auto recv_buf_host = array<byte_t>::empty(byte_count);
const bool gpu_offloading = get_mpi_offload_support();

memcpy_usm2host(q, send_buff_host.get_mutable_data(), send_buf, byte_count);
wait_request(
allreduce(send_buff_host.get_data(), recv_buf_host.get_mutable_data(), count, dtype, op));
memcpy_host2usm(q, recv_buf, recv_buf_host.get_data(), byte_count);
if (gpu_offloading) {
ONEDAL_PROFILER_TASK(comm.allreduce_gpu, q);
wait_request(allreduce(send_buf, recv_buf, count, dtype, op));
}
else {
ONEDAL_PROFILER_TASK(comm.allreduce_cpu, q);
const auto send_buff_host = array<byte_t>::empty(byte_count);
const auto recv_buf_host = array<byte_t>::empty(byte_count);

memcpy_usm2host(q, send_buff_host.get_mutable_data(), send_buf, byte_count);
wait_request(allreduce(send_buff_host.get_data(),
recv_buf_host.get_mutable_data(),
count,
dtype,
op));
memcpy_host2usm(q, recv_buf, recv_buf_host.get_data(), byte_count);
}

return nullptr;
}
Expand Down Expand Up @@ -198,16 +234,39 @@ spmd::request_iface* spmd_communicator_via_host_impl::sendrecv_replace(
const std::int64_t dtype_size = get_data_type_size(dtype);
const std::int64_t size = check_mul_overflow(dtype_size, count);

const auto buff_host = array<byte_t>::empty(size);
memcpy_usm2host(q, buff_host.get_mutable_data(), buf, size);

wait_request(sendrecv_replace(buff_host.get_mutable_data(),
count,
dtype,
destination_rank,
source_rank));

memcpy_host2usm(q, buf, buff_host.get_mutable_data(), size);
const bool gpu_offloading = get_mpi_offload_support();

if (gpu_offloading) {
ONEDAL_PROFILER_TASK(comm.srr_gpu, q);
const bool mpich_sendrecv = use_sendrecv_replace_alternative();
if (mpich_sendrecv) {
static byte_t* recv_buf = nullptr;
static bool initialized = false;
if (!initialized) {
recv_buf = sycl::malloc_device<byte_t>(size, q);
initialized = true;
}
wait_request(
sendrecv_replace(buf, count, dtype, destination_rank, source_rank, recv_buf));
q.memcpy(buf, recv_buf, size).wait();
}
else {
wait_request(sendrecv_replace(buf, count, dtype, destination_rank, source_rank));
}
}
else {
ONEDAL_PROFILER_TASK(comm.srr_cpu, q);
const auto buff_host = array<byte_t>::empty(size);
memcpy_usm2host(q, buff_host.get_mutable_data(), buf, size);

wait_request(sendrecv_replace(buff_host.get_mutable_data(),
count,
dtype,
destination_rank,
source_rank));

memcpy_host2usm(q, buf, buff_host.get_mutable_data(), size);
}

return nullptr;
}
Expand Down
Loading

0 comments on commit a8df345

Please sign in to comment.