diff --git a/cpp/src/wholememory_ops/functions/bucket_ids_for_hierarchy_func.cu b/cpp/src/wholememory_ops/functions/bucket_ids_for_hierarchy_func.cu new file mode 100644 index 000000000..6af509a77 --- /dev/null +++ b/cpp/src/wholememory_ops/functions/bucket_ids_for_hierarchy_func.cu @@ -0,0 +1,386 @@ +/* + * Copyright (c) 2024, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include +#include + +#include +#include +#include +#include + +#include + +#include "cuda_macros.hpp" +#include "error.hpp" +#include "logger.hpp" +#include "wholememory/integer_utils.hpp" +#include "wholememory_ops/register.hpp" +#include "wholememory_ops/temp_memory_handle.hpp" +#include + +namespace wholememory_ops { + +template +__global__ void bucket_ids_for_hierarchy_kernel(const IndexT* indices, + size_t indice_count, + int64_t* dev_rank_id_count_ptr, + size_t embedding_entry_count_per_rank, + int local_size, + int bucket_size) +{ + extern __shared__ int rank_count_shared[]; + for (int idx = threadIdx.x; idx < bucket_size; idx += blockDim.x) { + rank_count_shared[idx] = 0; + } + __syncthreads(); + for (int idx = threadIdx.x + blockIdx.x * blockDim.x; idx < indice_count; + idx += blockDim.x * gridDim.x) { + IndexT node_idx = indices[idx]; + if (node_idx < 0) continue; + int rank = node_idx / embedding_entry_count_per_rank; + int bucket = 0; + if (CROSS_OR_LOCAL == 0) // bucket cross ranks + bucket = rank % local_size; + else // bucket local ranks + bucket = rank / local_size; +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 700 + atomicAdd_block(&rank_count_shared[bucket], 1); +#else + atomicAdd(&rank_count_shared[bucket], 1); +#endif + } + __syncthreads(); + for (int idx = threadIdx.x; idx < bucket_size; idx += blockDim.x) { + atomicAdd(reinterpret_cast(dev_rank_id_count_ptr) + idx, + static_cast(rank_count_shared[idx])); + } +} + +template +void bucket_ids_for_hierarchy_temp_func(const void* indices, + wholememory_array_description_t indice_desc, + int64_t* dev_rank_id_count_ptr, + size_t embedding_entry_count_per_rank, + int local_size, + int cross_size, + int bucket_cross_or_local, + int sm_count, + cudaStream_t stream) +{ + static constexpr int BLOCK_SIZE = 128; + int block_count = wholememory::div_rounding_up_unsafe(indice_desc.size, BLOCK_SIZE); + block_count = std::min(block_count, sm_count * 4); + const IndexT* indices_ptr = static_cast(indices); + indices_ptr += indice_desc.storage_offset; + + if (bucket_cross_or_local == 0) { + int bucket_size = local_size; + cudaMemsetAsync(dev_rank_id_count_ptr, 0, sizeof(int64_t) * bucket_size, stream); + bucket_ids_for_hierarchy_kernel + <<>>( + indices_ptr, + indice_desc.size, + dev_rank_id_count_ptr, + embedding_entry_count_per_rank, + local_size, + bucket_size); + } else { + int bucket_size = cross_size; + cudaMemsetAsync(dev_rank_id_count_ptr, 0, sizeof(int64_t) * bucket_size, stream); + bucket_ids_for_hierarchy_kernel + <<>>( + indices_ptr, + indice_desc.size, + dev_rank_id_count_ptr, + embedding_entry_count_per_rank, + local_size, + bucket_size); + } +} + +REGISTER_DISPATCH_ONE_TYPE(BucketIdForHierarchy, bucket_ids_for_hierarchy_temp_func, SINT3264) + +template +__global__ void reorder_ids_for_cross_ranks_kernel(const IndexT* indices, + size_t indice_count, + IndexT* dev_bucket_indices, + IndexT* dev_indice_map, + const int64_t* dev_rank_id_offset_ptr, + size_t embedding_entry_count_per_rank, + int local_size, + int64_t* dev_bucket_atomic_add_ptr) +{ + int nbucket = local_size; + constexpr size_t shared_mem_size = 24576; + __shared__ char shared_mem[shared_mem_size]; + int* block_bucket_count_shared = reinterpret_cast(shared_mem); + int* block_bucket_atomic_add_shared = reinterpret_cast(shared_mem) + nbucket; + IndexT* block_bucket_offset_shared = + reinterpret_cast(shared_mem + 2 * sizeof(int) * nbucket); + IndexT* global_bucket_offset_shared = block_bucket_offset_shared + nbucket; + size_t buffer_size = + (shared_mem_size - nbucket * 2 * (sizeof(IndexT) + sizeof(int))) / sizeof(IndexT) / 2; + buffer_size = (buffer_size / blockDim.x) * blockDim.x; + assert(buffer_size > 0); + + IndexT* buffer_load = global_bucket_offset_shared + nbucket; + IndexT* buffer_store = buffer_load + buffer_size; + + int warp_idx = threadIdx.x / warpSize; + int lane_idx = threadIdx.x % warpSize; + int nwarp = blockDim.x / warpSize; + for (IndexT load_offset = buffer_size * blockIdx.x; load_offset < indice_count; + load_offset += gridDim.x * buffer_size) { + for (int i = threadIdx.x; i < nbucket; i += blockDim.x) { + block_bucket_count_shared[i] = 0; + block_bucket_atomic_add_shared[i] = 0; + } + __syncthreads(); + for (IndexT i = threadIdx.x; i < buffer_size; i += blockDim.x) { + IndexT load_idx = i + load_offset; + if (load_idx >= indice_count) break; + IndexT indice = indices[load_idx]; + + buffer_load[i] = indice; + int bucket_idx = (indice / embedding_entry_count_per_rank) % local_size; +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 700 + atomicAdd_block(&block_bucket_count_shared[bucket_idx], 1); +#else + atomicAdd(&block_bucket_count_shared[bucket_idx], 1); +#endif + } + __syncthreads(); + if (threadIdx.x == blockDim.x - 1) { + IndexT bucket_offset_tmp = 0; + for (int bi = 0; bi < nbucket; bi++) { + block_bucket_offset_shared[bi] = bucket_offset_tmp; + bucket_offset_tmp += block_bucket_count_shared[bi]; + } + } + if (threadIdx.x < nbucket) { + int bucket_idx = threadIdx.x; + global_bucket_offset_shared[bucket_idx] = + atomicAdd(reinterpret_cast(dev_bucket_atomic_add_ptr) + bucket_idx, + block_bucket_count_shared[bucket_idx]); + } + __syncthreads(); + for (IndexT i = threadIdx.x; i < buffer_size; i += blockDim.x) { + IndexT indice = buffer_load[i]; + IndexT load_idx = i + load_offset; + if (load_idx >= indice_count) break; + int bucket_idx = (indice / embedding_entry_count_per_rank) % local_size; +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 700 + int block_bucket_inc = atomicAdd_block(&block_bucket_atomic_add_shared[bucket_idx], 1); +#else + int block_bucket_inc = atomicAdd(&block_bucket_atomic_add_shared[bucket_idx], 1); +#endif + buffer_store[block_bucket_offset_shared[bucket_idx] + block_bucket_inc] = indice; + dev_indice_map[load_idx] = dev_rank_id_offset_ptr[bucket_idx] + + global_bucket_offset_shared[bucket_idx] + block_bucket_inc; + } + __syncthreads(); + for (int bucket_idx = warp_idx; bucket_idx < nbucket; bucket_idx += nwarp) { + int bucket_length = block_bucket_count_shared[bucket_idx]; + IndexT global_bucket_offset = + dev_rank_id_offset_ptr[bucket_idx] + global_bucket_offset_shared[bucket_idx]; + for (int idx = lane_idx; idx < bucket_length; idx += warpSize) { + dev_bucket_indices[global_bucket_offset + idx] = + buffer_store[block_bucket_offset_shared[bucket_idx] + idx]; + } + } + __syncthreads(); + } +} + +template +void reorder_ids_for_cross_ranks_temp_func(const void* indices, + wholememory_array_description_t indice_desc, + void* dev_bucket_indices, + void* dev_indice_map, + const int64_t* dev_rank_id_count_ptr, + size_t embedding_entry_count_per_rank, + int local_size, + wm_thrust_allocator* p_thrust_allocator, + wholememory_env_func_t* p_env_fns, + int sm_count, + cudaStream_t stream) +{ + WHOLEMEMORY_CHECK(indice_desc.storage_offset == 0); + WHOLEMEMORY_CHECK(indice_desc.dtype == WHOLEMEMORY_DT_INT || + indice_desc.dtype == WHOLEMEMORY_DT_INT64); + + temp_memory_handle dev_rank_id_offset_handle(p_env_fns); + int64_t* dev_rank_id_offset_ptr = static_cast( + dev_rank_id_offset_handle.device_malloc(local_size, WHOLEMEMORY_DT_INT64)); + void* cub_temp_storage = NULL; + size_t temp_storage_bytes = 0; + cub::DeviceScan::ExclusiveSum(cub_temp_storage, + temp_storage_bytes, + dev_rank_id_count_ptr, + dev_rank_id_offset_ptr, + local_size, + stream); + cub_temp_storage = p_thrust_allocator->allocate(temp_storage_bytes); + cub::DeviceScan::ExclusiveSum(cub_temp_storage, + temp_storage_bytes, + dev_rank_id_count_ptr, + dev_rank_id_offset_ptr, + local_size, + stream); + p_thrust_allocator->deallocate(reinterpret_cast(cub_temp_storage), temp_storage_bytes); + + temp_memory_handle dev_bucket_atomic_add_handle(p_env_fns); + int64_t* dev_bucket_atomic_add_ptr = static_cast( + dev_bucket_atomic_add_handle.device_malloc(local_size, WHOLEMEMORY_DT_INT64)); + cudaMemsetAsync((void*)dev_bucket_atomic_add_ptr, 0, sizeof(int64_t) * local_size, stream); + static constexpr int BLOCK_SIZE = 128; + int block_count = wholememory::div_rounding_up_unsafe(indice_desc.size, BLOCK_SIZE); + block_count = std::min(block_count, sm_count * 4); + + reorder_ids_for_cross_ranks_kernel<<>>( + static_cast(indices), + indice_desc.size, + static_cast(dev_bucket_indices), + static_cast(dev_indice_map), + dev_rank_id_offset_ptr, + embedding_entry_count_per_rank, + local_size, + dev_bucket_atomic_add_ptr); +} + +REGISTER_DISPATCH_ONE_TYPE(ReorderIdForCrossRanks, reorder_ids_for_cross_ranks_temp_func, SINT3264) + +wholememory_error_code_t bucket_and_reorder_ids_for_hierarchy_func( + void* indices, + wholememory_array_description_t indice_desc, + void* dev_bucket_indices, + void* dev_indice_map, + int64_t* host_bucket_id_count, + size_t embedding_entry_count_per_rank, + wholememory_comm_t wm_global_comm, + wholememory_comm_t wm_local_comm, + wm_thrust_allocator* p_thrust_allocator, + wholememory_env_func_t* p_env_fns, + cudaStream_t stream) +{ + if (indice_desc.size == 0) { return WHOLEMEMORY_SUCCESS; } + int world_size, local_size; + WHOLEMEMORY_RETURN_ON_FAIL(wholememory_communicator_get_size(&world_size, wm_global_comm)); + WHOLEMEMORY_RETURN_ON_FAIL(wholememory_communicator_get_size(&local_size, wm_local_comm)); + WHOLEMEMORY_CHECK_NOTHROW(world_size % local_size == 0); + + constexpr int K_DEFAULT_SM_COUNT = 108; + auto prop = get_device_prop(-1); + int sm_count = (prop != nullptr) ? prop->multiProcessorCount : K_DEFAULT_SM_COUNT; + temp_memory_handle dev_rank_id_count_handle(p_env_fns); + int64_t* dev_rank_id_count_ptr = + static_cast(dev_rank_id_count_handle.device_malloc(local_size, WHOLEMEMORY_DT_INT64)); + cudaMemsetAsync((void*)dev_rank_id_count_ptr, 0, sizeof(int64_t) * local_size, stream); + try { + DISPATCH_ONE_TYPE(indice_desc.dtype, + BucketIdForHierarchy, + indices, + indice_desc, + dev_rank_id_count_ptr, + embedding_entry_count_per_rank, + local_size, + 0, // ignore + 0, + sm_count, + stream); + } catch (wholememory::cuda_error& wce) { + WHOLEMEMORY_ERROR("bucket_ids_for_cross_ranks CUDA LOGIC Error %s\n", wce.what()); + return WHOLEMEMORY_CUDA_ERROR; + } + WM_CUDA_CHECK_NO_THROW(cudaMemcpyAsync(host_bucket_id_count, + dev_rank_id_count_ptr, + local_size * sizeof(int64_t), + cudaMemcpyDeviceToHost, + stream)); + try { + DISPATCH_ONE_TYPE(indice_desc.dtype, + ReorderIdForCrossRanks, + indices, + indice_desc, + dev_bucket_indices, + dev_indice_map, + dev_rank_id_count_ptr, + embedding_entry_count_per_rank, + local_size, + p_thrust_allocator, + p_env_fns, + sm_count, + stream); + } catch (wholememory::cuda_error& wce) { + WHOLEMEMORY_ERROR("reorder_ids_for_cross_ranks CUDA LOGIC Error %s\n", wce.what()); + return WHOLEMEMORY_CUDA_ERROR; + } catch (wholememory::logic_error& wle) { + WHOLEMEMORY_ERROR("reorder_ids_for_cross_ranks LOGIC Error %s\n", wle.what()); + return WHOLEMEMORY_LOGIC_ERROR; + } catch (...) { + return WHOLEMEMORY_UNKNOW_ERROR; + } + return WHOLEMEMORY_SUCCESS; +} + +wholememory_error_code_t bucket_local_ids_func(void* indices, + wholememory_array_description_t indice_desc, + int64_t* host_bucket_id_count, + size_t embedding_entry_count_per_rank, + wholememory_comm_t wm_local_comm, + wholememory_comm_t wm_cross_comm, + wm_thrust_allocator* p_thrust_allocator, + wholememory_env_func_t* p_env_fns, + cudaStream_t stream) +{ + if (indice_desc.size == 0) { return WHOLEMEMORY_SUCCESS; } + int cross_size, local_size; + WHOLEMEMORY_RETURN_ON_FAIL(wholememory_communicator_get_size(&cross_size, wm_cross_comm)); + WHOLEMEMORY_RETURN_ON_FAIL(wholememory_communicator_get_size(&local_size, wm_local_comm)); + + constexpr int K_DEFAULT_SM_COUNT = 108; + auto prop = get_device_prop(-1); + int sm_count = (prop != nullptr) ? prop->multiProcessorCount : K_DEFAULT_SM_COUNT; + temp_memory_handle dev_rank_id_count_handle(p_env_fns); + int64_t* dev_rank_id_count_ptr = + static_cast(dev_rank_id_count_handle.device_malloc(cross_size, WHOLEMEMORY_DT_INT64)); + cudaMemsetAsync((void*)dev_rank_id_count_ptr, 0, sizeof(int64_t) * cross_size, stream); + try { + DISPATCH_ONE_TYPE(indice_desc.dtype, + BucketIdForHierarchy, + indices, + indice_desc, + dev_rank_id_count_ptr, + embedding_entry_count_per_rank, + local_size, + cross_size, + 1, + sm_count, + stream); + } catch (wholememory::cuda_error& wce) { + WHOLEMEMORY_ERROR("bucket_ids_for_cross_ranks CUDA LOGIC Error %s\n", wce.what()); + return WHOLEMEMORY_CUDA_ERROR; + } + WM_CUDA_CHECK_NO_THROW(cudaMemcpyAsync(host_bucket_id_count, + dev_rank_id_count_ptr, + cross_size * sizeof(int64_t), + cudaMemcpyDeviceToHost, + stream)); + WM_CUDA_CHECK(cudaGetLastError()); + return WHOLEMEMORY_SUCCESS; +} + +} // namespace wholememory_ops diff --git a/cpp/src/wholememory_ops/functions/bucket_ids_for_hierarchy_func.h b/cpp/src/wholememory_ops/functions/bucket_ids_for_hierarchy_func.h new file mode 100644 index 000000000..a86a9945e --- /dev/null +++ b/cpp/src/wholememory_ops/functions/bucket_ids_for_hierarchy_func.h @@ -0,0 +1,49 @@ +/* + * Copyright (c) 2024, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#pragma once + +#include +#include +#include + +#include "wholememory_ops/temp_memory_handle.hpp" + +namespace wholememory_ops { + +wholememory_error_code_t bucket_and_reorder_ids_for_hierarchy_func( + void* indices, + wholememory_array_description_t indice_desc, + void* dev_bucket_indices, + void* dev_indice_map, + int64_t* host_bucket_id_count, + size_t embedding_entry_count_per_rank, + wholememory_comm_t wm_global_comm, + wholememory_comm_t wm_local_comm, + wm_thrust_allocator* p_thrust_allocator, + wholememory_env_func_t* p_env_fns, + cudaStream_t stream); + +wholememory_error_code_t bucket_local_ids_func(void* indices, + wholememory_array_description_t indice_desc, + int64_t* host_bucket_id_count, + size_t embedding_entry_count_per_rank, + wholememory_comm_t wm_local_comm, + wholememory_comm_t wm_cross_comm, + wm_thrust_allocator* p_thrust_allocator, + wholememory_env_func_t* p_env_fns, + cudaStream_t stream); + +} // namespace wholememory_ops diff --git a/cpp/src/wholememory_ops/functions/sort_unique_ids_for_hierarchy_func.cu b/cpp/src/wholememory_ops/functions/sort_unique_ids_for_hierarchy_func.cu new file mode 100644 index 000000000..caa9667c4 --- /dev/null +++ b/cpp/src/wholememory_ops/functions/sort_unique_ids_for_hierarchy_func.cu @@ -0,0 +1,145 @@ +/* + * Copyright (c) 2024, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "sort_unique_ids_for_hierarchy_func.h" +#include "sort_unique_indices_func.h" + +#include +#include + +#include +#include +#include +#include + +#include + +#include "cuda_macros.hpp" +#include "error.hpp" +#include "logger.hpp" +#include "wholememory/communicator.hpp" +#include "wholememory/integer_utils.hpp" +#include "wholememory_ops/register.hpp" +#include "wholememory_ops/temp_memory_handle.hpp" +#include + +namespace wholememory_ops { + +template +__global__ void SortUniqueIndiceMapKernel(IndexT* indice_map, + size_t indice_count, + const IndexT* sort_raw_indices, + const int* unique_count_ptr, + const IndexT* unique_offset_ptr, + size_t num_unique) +{ + for (int idx = threadIdx.x + blockIdx.x * blockDim.x; idx < indice_count; + idx += blockDim.x * gridDim.x) { + if (idx >= num_unique) break; + IndexT offset = unique_offset_ptr[idx]; + int count = unique_count_ptr[idx]; + for (IndexT i = offset; i < offset + count; i++) { + indice_map[sort_raw_indices[i]] = idx; + } + } +} + +template +void SortUniqueIndicesMapTempFunc(void* indice_map, + wholememory_array_description_t indice_desc, + const void* sort_raw_indices, + const int* unique_count_ptr, + size_t num_unique, + wm_thrust_allocator* p_thrust_allocator, + wholememory_env_func_t* p_env_fns, + cudaStream_t stream) +{ + static constexpr int BLOCK_SIZE = 128; + int block_count = wholememory::div_rounding_up_unsafe(num_unique, BLOCK_SIZE); + + temp_memory_handle dev_unique_offset_handle(p_env_fns); + IndexT* unique_offset_ptr = + static_cast(dev_unique_offset_handle.device_malloc(num_unique, indice_desc.dtype)); + IndexT* indice_map_ptr = static_cast(indice_map); + const IndexT* sort_raw_indices_ptr = static_cast(sort_raw_indices); + + void* cub_temp_storage = NULL; + size_t temp_storage_bytes = 0; + cub::DeviceScan::ExclusiveSum( + cub_temp_storage, temp_storage_bytes, unique_count_ptr, unique_offset_ptr, num_unique, stream); + cub_temp_storage = p_thrust_allocator->allocate(temp_storage_bytes); + cub::DeviceScan::ExclusiveSum( + cub_temp_storage, temp_storage_bytes, unique_count_ptr, unique_offset_ptr, num_unique, stream); + SortUniqueIndiceMapKernel<<>>(indice_map_ptr, + indice_desc.size, + sort_raw_indices_ptr, + unique_count_ptr, + unique_offset_ptr, + num_unique); + p_thrust_allocator->deallocate(reinterpret_cast(cub_temp_storage), temp_storage_bytes); +} + +REGISTER_DISPATCH_ONE_TYPE(SortUniqueIndicesMapTempFunc, SortUniqueIndicesMapTempFunc, SINT3264) + +wholememory_error_code_t sort_unique_ids_for_hierarchy_func( + void* indices, + wholememory_array_description_t indice_desc, + temp_memory_handle* output_indices_handle, + wholememory_array_description_t* output_indices_desc, + temp_memory_handle* dev_indice_map_handle, + wm_thrust_allocator* p_thrust_allocator, + wholememory_env_func_t* p_env_fns, + cudaStream_t stream) +{ + if (indice_desc.size == 0) { + *output_indices_desc = wholememory_create_array_desc(0, 0, indice_desc.dtype); + return WHOLEMEMORY_SUCCESS; + } + int num_runs = 0; + temp_memory_handle unique_count_handle(p_env_fns); + temp_memory_handle dev_sort_raw_indices_handle(p_env_fns); + void* dev_sort_raw_indices_ptr = + dev_sort_raw_indices_handle.device_malloc(indice_desc.size, indice_desc.dtype); + sort_unique_indices_func(indices, + indice_desc, + dev_sort_raw_indices_ptr, + &num_runs, + output_indices_handle, + &unique_count_handle, + p_thrust_allocator, + p_env_fns, + stream); + *output_indices_desc = wholememory_create_array_desc(num_runs, 0, indice_desc.dtype); + void* dev_indice_map_ptr = + dev_indice_map_handle->device_malloc(indice_desc.size, indice_desc.dtype); + WM_CUDA_CHECK(cudaGetLastError()); + try { + DISPATCH_ONE_TYPE(indice_desc.dtype, + SortUniqueIndicesMapTempFunc, + dev_indice_map_ptr, + indice_desc, + dev_sort_raw_indices_ptr, + static_cast(unique_count_handle.pointer()), + num_runs, + p_thrust_allocator, + p_env_fns, + stream); + } catch (...) { + WHOLEMEMORY_FAIL_NOTHROW("map indices failed"); + } + return WHOLEMEMORY_SUCCESS; +} + +} // namespace wholememory_ops diff --git a/cpp/src/wholememory_ops/functions/sort_unique_ids_for_hierarchy_func.h b/cpp/src/wholememory_ops/functions/sort_unique_ids_for_hierarchy_func.h new file mode 100644 index 000000000..8491e58f7 --- /dev/null +++ b/cpp/src/wholememory_ops/functions/sort_unique_ids_for_hierarchy_func.h @@ -0,0 +1,35 @@ +/* + * Copyright (c) 2024, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#pragma once + +#include "wholememory_ops/temp_memory_handle.hpp" +#include +#include +#include + +namespace wholememory_ops { + +wholememory_error_code_t sort_unique_ids_for_hierarchy_func( + void* indices, + wholememory_array_description_t indice_desc, + temp_memory_handle* output_indices_handle, + wholememory_array_description_t* output_indices_desc, + temp_memory_handle* dev_indice_map_handle, // indice_desc + wm_thrust_allocator* p_thrust_allocator, + wholememory_env_func_t* p_env_fns, + cudaStream_t stream); + +} // namespace wholememory_ops diff --git a/cpp/src/wholememory_ops/functions/sort_unique_indices_func.cu b/cpp/src/wholememory_ops/functions/sort_unique_indices_func.cu new file mode 100644 index 000000000..a3d3fc647 --- /dev/null +++ b/cpp/src/wholememory_ops/functions/sort_unique_indices_func.cu @@ -0,0 +1,118 @@ +/* + * Copyright (c) 2024, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "sort_indices_func.h" +#include "sort_unique_indices_func.h" + +#include +#include +#include + +#include "cuda_macros.hpp" +#include "error.hpp" +#include "logger.hpp" +#include "wholememory_ops/register.hpp" + +namespace wholememory_ops { + +template +void SortUniqueIndicesTempFunc(const void* indices, + wholememory_array_description_t indice_desc, + void* sort_raw_indices, + int* num_runs, + temp_memory_handle* unique_indices_handle, + temp_memory_handle* unique_count_handle, + wm_thrust_allocator* p_thrust_allocator, + wholememory_env_func_t* p_env_fns, + cudaStream_t stream) +{ + if (indice_desc.size == 0) return; + wm_thrust_allocator& allocator = *p_thrust_allocator; + WHOLEMEMORY_CHECK_NOTHROW(indice_desc.storage_offset == 0); + temp_memory_handle sorted_indices_handle(p_env_fns); + sorted_indices_handle.device_malloc(indice_desc.size, indice_desc.dtype); + IndexT* sorted_indices = static_cast(sorted_indices_handle.pointer()); + + sort_indices_func( + indices, indice_desc, sorted_indices, sort_raw_indices, p_thrust_allocator, p_env_fns, stream); + + unique_indices_handle->device_malloc(indice_desc.size, indice_desc.dtype); + unique_count_handle->device_malloc(indice_desc.size, WHOLEMEMORY_DT_INT); + IndexT* unique_indices = static_cast(unique_indices_handle->pointer()); + int* unique_counts = static_cast(unique_count_handle->pointer()); + temp_memory_handle number_runs_handle(p_env_fns); + number_runs_handle.device_malloc(1, WHOLEMEMORY_DT_INT); + int* number_runs = static_cast(number_runs_handle.pointer()); + void* cub_temp_storage = nullptr; + size_t temp_storage_bytes = 0; + cub::DeviceRunLengthEncode::Encode(cub_temp_storage, + temp_storage_bytes, + sorted_indices, + unique_indices, + unique_counts, + number_runs, + indice_desc.size, + stream); + cub_temp_storage = allocator.allocate(temp_storage_bytes); + cub::DeviceRunLengthEncode::Encode(cub_temp_storage, + temp_storage_bytes, + sorted_indices, + unique_indices, + unique_counts, + number_runs, + indice_desc.size, + stream); + WM_CUDA_CHECK_NO_THROW( + cudaMemcpyAsync(num_runs, number_runs, sizeof(int), cudaMemcpyDeviceToHost, stream)); +} + +REGISTER_DISPATCH_ONE_TYPE(SortUniqueIndicesTempFunc, SortUniqueIndicesTempFunc, SINT3264) + +wholememory_error_code_t sort_unique_indices_func(const void* indices, + wholememory_array_description_t indice_desc, + void* sort_raw_indices, + int* num_runs, + temp_memory_handle* unique_indices_handle, + temp_memory_handle* unique_count_handle, + wm_thrust_allocator* p_thrust_allocator, + wholememory_env_func_t* p_env_fns, + cudaStream_t stream) +{ + try { + DISPATCH_ONE_TYPE(indice_desc.dtype, + SortUniqueIndicesTempFunc, + indices, + indice_desc, + sort_raw_indices, + num_runs, + unique_indices_handle, + unique_count_handle, + p_thrust_allocator, + p_env_fns, + stream); + } catch (wholememory::cuda_error& wce) { + WHOLEMEMORY_ERROR("sort_unique_indices_func CUDA LOGIC Error %s\n", wce.what()); + return WHOLEMEMORY_CUDA_ERROR; + } catch (wholememory::logic_error& wle) { + WHOLEMEMORY_ERROR("sort_unique_indices_func LOGIC Error %s\n", wle.what()); + return WHOLEMEMORY_LOGIC_ERROR; + } catch (...) { + return WHOLEMEMORY_UNKNOW_ERROR; + } + return WHOLEMEMORY_SUCCESS; +} + +} // namespace wholememory_ops diff --git a/cpp/src/wholememory_ops/functions/sort_unique_indices_func.h b/cpp/src/wholememory_ops/functions/sort_unique_indices_func.h new file mode 100644 index 000000000..2ff697c90 --- /dev/null +++ b/cpp/src/wholememory_ops/functions/sort_unique_indices_func.h @@ -0,0 +1,37 @@ +/* + * Copyright (c) 2024, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include +#include + +#include +#include + +namespace wholememory_ops { + +wholememory_error_code_t sort_unique_indices_func(const void* indices, + wholememory_array_description_t indice_desc, + void* sort_raw_indices, + int* num_runs, + temp_memory_handle* unique_indices_handle, + temp_memory_handle* unique_count_handle, + wm_thrust_allocator* p_thrust_allocator, + wholememory_env_func_t* p_env_fns, + cudaStream_t stream); + +} // namespace wholememory_ops diff --git a/cpp/src/wholememory_ops/gather_op_impl_hierarchy.cu b/cpp/src/wholememory_ops/gather_op_impl_hierarchy.cu index 2e49bcd02..808ebe768 100644 --- a/cpp/src/wholememory_ops/gather_op_impl_hierarchy.cu +++ b/cpp/src/wholememory_ops/gather_op_impl_hierarchy.cu @@ -21,16 +21,110 @@ #include "logger.hpp" #include "wholememory/communicator.hpp" #include "wholememory/memory_handle.hpp" -#include "wholememory_ops/functions/bucket_ids_func.h" +#include "wholememory_ops/functions/bucket_ids_for_hierarchy_func.h" #include "wholememory_ops/functions/exchange_embeddings_nccl_func.h" -#include "wholememory_ops/functions/exchange_ids_nccl_func.h" #include "wholememory_ops/functions/gather_scatter_func.h" +#include "wholememory_ops/functions/sort_unique_ids_for_hierarchy_func.h" #include "wholememory_ops/gather_op_impl.h" #include "wholememory_ops/temp_memory_handle.hpp" #include "wholememory_ops/thrust_allocator.hpp" namespace wholememory_ops { +static wholememory_error_code_t wholememory_cross_gather( + wholememory_handle_t wholememory_handle, + wholememory_matrix_description_t wholememory_desc, + void* indices, + wholememory_array_description_t indice_desc, + void* output, + wholememory_matrix_description_t output_desc, + size_t embedding_entry_count_per_rank, + wholememory_comm_t wm_local_comm, + wholememory_comm_t wm_cross_comm, + wm_thrust_allocator* p_thrust_allocator, + wholememory_env_func_t* p_env_fns, + cudaStream_t stream, + int gather_sms) +{ + int cross_size; + WHOLEMEMORY_RETURN_ON_FAIL(wholememory_communicator_get_size(&cross_size, wm_cross_comm)); + // bucket ids + std::vector host_bucket_id_count(cross_size, 0); + std::vector host_bucket_id_offset(cross_size); + std::vector host_recv_id_count(cross_size, 0); + std::vector host_recv_id_offset(cross_size); + bucket_local_ids_func(indices, + indice_desc, + host_bucket_id_count.data(), + embedding_entry_count_per_rank, + wm_local_comm, + wm_cross_comm, + p_thrust_allocator, + p_env_fns, + stream); + WM_CUDA_CHECK(cudaStreamSynchronize(stream)); + // exchange node count + wm_cross_comm->host_alltoall( + host_bucket_id_count.data(), host_recv_id_count.data(), 1, WHOLEMEMORY_DT_INT64); + host_bucket_id_offset[0] = 0; + for (int i = 1; i < cross_size; i++) + host_bucket_id_offset[i] = host_bucket_id_offset[i - 1] + host_bucket_id_count[i - 1]; + wm_cross_comm->sync_stream(); + // exchange indices + int64_t total_recv_count = 0; + for (int i = 0; i < cross_size; i++) { + host_recv_id_offset[i] = total_recv_count; + total_recv_count += host_recv_id_count[i]; + } + temp_memory_handle dev_recv_bucket_indices_handle(p_env_fns); + void* dev_recv_bucket_indices_ptr = + dev_recv_bucket_indices_handle.device_malloc(total_recv_count, indice_desc.dtype); + wm_cross_comm->alltoallv(indices, + dev_recv_bucket_indices_ptr, + reinterpret_cast(host_bucket_id_count.data()), + reinterpret_cast(host_bucket_id_offset.data()), + reinterpret_cast(host_recv_id_count.data()), + reinterpret_cast(host_recv_id_offset.data()), + indice_desc.dtype, + stream); + wm_cross_comm->sync_stream(stream); + // local gather + temp_memory_handle dev_local_gather_buffer_handle(p_env_fns); + void* dev_local_gather_buffer_ptr = dev_local_gather_buffer_handle.device_malloc( + wholememory_desc.sizes[1] * total_recv_count, output_desc.dtype); + int64_t local_gather_buffer_size[2] = {total_recv_count, wholememory_desc.sizes[1]}; + wholememory_matrix_description_t local_gather_buffer_desc = wholememory_create_matrix_desc( + local_gather_buffer_size, wholememory_desc.sizes[1], 0, output_desc.dtype); + void* local_fake_ptr = nullptr; + size_t local_mem_offset, local_mem_size; + WHOLEMEMORY_RETURN_ON_FAIL(wholememory_get_local_memory( + &local_fake_ptr, &local_mem_size, &local_mem_offset, wholememory_handle)); + local_fake_ptr = static_cast(local_fake_ptr) - local_mem_offset; + wholememory_gref_t local_fake_gref = + wholememory_create_continuous_global_reference(local_fake_ptr); + auto local_gather_indice_desc = + wholememory_create_array_desc(total_recv_count, 0, indice_desc.dtype); + WHOLEMEMORY_RETURN_ON_FAIL(gather_func(local_fake_gref, + wholememory_desc, + dev_recv_bucket_indices_ptr, + local_gather_indice_desc, + dev_local_gather_buffer_ptr, + local_gather_buffer_desc, + stream, + gather_sms)); + // exchange embeddings + size_t output_embedding_size = + wholememory_desc.sizes[1] * wholememory_dtype_get_element_size(output_desc.dtype); + WHOLEMEMORY_RETURN_ON_FAIL(exchange_embeddings_nccl_func(dev_local_gather_buffer_ptr, + host_recv_id_count.data(), + host_bucket_id_count.data(), + output, + output_embedding_size, + wm_cross_comm, + stream)); + return WHOLEMEMORY_SUCCESS; +} + wholememory_error_code_t wholememory_gather_hierarchy( wholememory_handle_t wholememory_handle, wholememory_matrix_description_t wholememory_desc, @@ -53,109 +147,171 @@ wholememory_error_code_t wholememory_gather_hierarchy( size_t embedding_size_per_rank; WHOLEMEMORY_RETURN_ON_FAIL( wholememory_get_partition_plan(&embedding_size_per_rank, wholememory_handle)); - size_t element_size = wholememory_dtype_get_element_size(wholememory_desc.dtype); size_t embedding_entry_size = element_size * wholememory_desc.stride; - WHOLEMEMORY_EXPECTS_NOTHROW( embedding_size_per_rank % embedding_entry_size == 0, "embedding_size_per_rank=%ld is not multiple of embedding_entry_size=%ldx%ld", embedding_size_per_rank, element_size, wholememory_desc.stride); - size_t embedding_entry_count_per_rank = embedding_size_per_rank / embedding_entry_size; - wholememory_comm_t wm_comm; - WHOLEMEMORY_RETURN_ON_FAIL(wholememory_get_communicator(&wm_comm, wholememory_handle)); + wholememory_comm_t wm_global_comm; + int world_size, world_rank; + WHOLEMEMORY_RETURN_ON_FAIL(wholememory_get_communicator(&wm_global_comm, wholememory_handle)); + WHOLEMEMORY_RETURN_ON_FAIL(wholememory_communicator_get_size(&world_size, wm_global_comm)); + WHOLEMEMORY_RETURN_ON_FAIL(wholememory_communicator_get_rank(&world_rank, wm_global_comm)); + + wholememory_comm_t wm_local_comm; + int local_size, local_rank; + WHOLEMEMORY_RETURN_ON_FAIL( + wholememory_get_local_communicator(&wm_local_comm, wholememory_handle)); + // WHOLEMEMORY_RETURN_ON_FAIL(wholememory_split_communicator( + // &wm_local_comm, wm_global_comm, world_rank / local_size, world_rank % local_size)); + WHOLEMEMORY_RETURN_ON_FAIL(wholememory_communicator_get_size(&local_size, wm_local_comm)); + WHOLEMEMORY_RETURN_ON_FAIL(wholememory_communicator_get_rank(&local_rank, wm_local_comm)); - int world_size; - WHOLEMEMORY_RETURN_ON_FAIL(wholememory_communicator_get_size(&world_size, wm_comm)); + wholememory_comm_t wm_cross_comm; + int cross_size; + WHOLEMEMORY_RETURN_ON_FAIL( + wholememory_get_cross_communicator(&wm_cross_comm, wholememory_handle)); + // WHOLEMEMORY_RETURN_ON_FAIL(wholememory_split_communicator( + // &wm_cross_comm, wm_global_comm, world_rank % local_size, world_rank / local_size)); + WHOLEMEMORY_RETURN_ON_FAIL(wholememory_communicator_get_size(&cross_size, wm_cross_comm)); + WHOLEMEMORY_CHECK_NOTHROW(world_size == local_size * cross_size); - temp_memory_handle host_rank_id_count(p_env_fns), host_recv_rank_id_count(p_env_fns); - int64_t* host_rank_id_count_ptr = - static_cast(host_rank_id_count.host_malloc(world_size, WHOLEMEMORY_DT_INT64)); - int64_t* host_recv_rank_id_count_ptr = - static_cast(host_recv_rank_id_count.host_malloc(world_size, WHOLEMEMORY_DT_INT64)); + temp_memory_handle dev_bucket_indices_handle(p_env_fns); + void* dev_bucket_indices_ptr = + dev_bucket_indices_handle.device_malloc(indice_desc.size, indice_desc.dtype); + temp_memory_handle dev_bucket_ids_map_handle(p_env_fns); + void* dev_bucket_ids_map_ptr = + dev_bucket_ids_map_handle.device_malloc(indice_desc.size, indice_desc.dtype); - temp_memory_handle dev_recv_indice_buffer(p_env_fns); - temp_memory_handle dev_raw_indice(p_env_fns); - int64_t* dev_raw_indice_ptr = - static_cast(dev_raw_indice.device_malloc(indice_desc.size, WHOLEMEMORY_DT_INT64)); + std::vector host_bucket_id_count(local_size, 0); + std::vector host_bucket_id_offset(local_size); + std::vector host_recv_id_count(local_size, 0); + std::vector host_recv_id_offset(local_size); + // bucket indices + WHOLEMEMORY_RETURN_ON_FAIL( + bucket_and_reorder_ids_for_hierarchy_func(indices, + indice_desc, + dev_bucket_indices_ptr, + dev_bucket_ids_map_ptr, + host_bucket_id_count.data(), + embedding_entry_count_per_rank, + wm_global_comm, + wm_local_comm, + &thrust_allocator, + p_env_fns, + stream)); + WM_CUDA_CHECK(cudaStreamSynchronize(stream)); + // exchange node count + wm_local_comm->host_alltoall( + host_bucket_id_count.data(), host_recv_id_count.data(), 1, WHOLEMEMORY_DT_INT64); + host_bucket_id_offset[0] = 0; + for (int i = 1; i < local_size; i++) + host_bucket_id_offset[i] = host_bucket_id_offset[i - 1] + host_bucket_id_count[i - 1]; + wm_local_comm->sync_stream(); + // exchange indices int64_t total_recv_count = 0; - WHOLEMEMORY_RETURN_ON_FAIL(bucket_and_exchange_ids_func(indices, - indice_desc, - host_recv_rank_id_count_ptr, - host_rank_id_count_ptr, - &dev_recv_indice_buffer, - dev_raw_indice_ptr, - embedding_entry_count_per_rank, - wm_comm, - &thrust_allocator, - p_env_fns, - stream)); - - // Local Gather - for (int i = 0; i < world_size; i++) { - total_recv_count += host_recv_rank_id_count_ptr[i]; + for (int i = 0; i < local_size; i++) { + host_recv_id_offset[i] = total_recv_count; + total_recv_count += host_recv_id_count[i]; } - size_t local_mem_offset, local_mem_size; - temp_memory_handle dev_local_gather_buffer(p_env_fns); - temp_memory_handle dev_embedding_recv_buffer(p_env_fns); - void* dev_local_gather_buffer_ptr = dev_local_gather_buffer.device_malloc( - wholememory_desc.sizes[1] * total_recv_count, output_desc.dtype); - void* dev_embedding_recv_buffer_ptr = dev_embedding_recv_buffer.device_malloc( - wholememory_desc.sizes[1] * indice_desc.size, output_desc.dtype); - void* local_fake_ptr = nullptr; - WHOLEMEMORY_RETURN_ON_FAIL(wholememory_get_local_memory( - &local_fake_ptr, &local_mem_size, &local_mem_offset, wholememory_handle)); - local_fake_ptr = static_cast(local_fake_ptr) - local_mem_offset; - wholememory_gref_t local_fake_gref = - wholememory_create_continuous_global_reference(local_fake_ptr); - int64_t local_buffer_size[2] = {total_recv_count, wholememory_desc.sizes[1]}; - wholememory_matrix_description_t local_gather_buffer_desc = wholememory_create_matrix_desc( - local_buffer_size, wholememory_desc.sizes[1], 0, output_desc.dtype); - auto dev_recv_indice_desc = + temp_memory_handle dev_recv_bucket_indices_handle(p_env_fns); + void* dev_recv_bucket_indices_ptr = + dev_recv_bucket_indices_handle.device_malloc(total_recv_count, indice_desc.dtype); + auto recv_bucket_indices_desc = wholememory_create_array_desc(total_recv_count, 0, indice_desc.dtype); - WHOLEMEMORY_RETURN_ON_FAIL(gather_func(local_fake_gref, - wholememory_desc, - dev_recv_indice_buffer.pointer(), - dev_recv_indice_desc, - dev_local_gather_buffer_ptr, - local_gather_buffer_desc, + wm_local_comm->alltoallv(dev_bucket_indices_ptr, + dev_recv_bucket_indices_ptr, + reinterpret_cast(host_bucket_id_count.data()), + reinterpret_cast(host_bucket_id_offset.data()), + reinterpret_cast(host_recv_id_count.data()), + reinterpret_cast(host_recv_id_offset.data()), + indice_desc.dtype, + stream); + wm_local_comm->sync_stream(stream); + WM_CUDA_CHECK(cudaGetLastError()); + // sort unique recv indices + temp_memory_handle sort_unique_indices_handle(p_env_fns); + wholememory_array_description_t sort_unique_indice_desc; + temp_memory_handle dev_sort_unique_ids_map_handle(p_env_fns); + sort_unique_ids_for_hierarchy_func(dev_recv_bucket_indices_ptr, + recv_bucket_indices_desc, + &sort_unique_indices_handle, + &sort_unique_indice_desc, + &dev_sort_unique_ids_map_handle, + &thrust_allocator, + p_env_fns, + stream); + // cross gather + temp_memory_handle dev_cross_gather_buffer_handle(p_env_fns); + void* dev_cross_gather_buffer_ptr = dev_cross_gather_buffer_handle.device_malloc( + wholememory_desc.sizes[1] * sort_unique_indice_desc.size, output_desc.dtype); + int64_t cross_gather_buffer_size[2] = {sort_unique_indice_desc.size, wholememory_desc.sizes[1]}; + wholememory_matrix_description_t cross_gather_buffer_desc = wholememory_create_matrix_desc( + cross_gather_buffer_size, wholememory_desc.sizes[1], 0, output_desc.dtype); + wholememory_cross_gather(wholememory_handle, + wholememory_desc, + sort_unique_indices_handle.pointer(), + sort_unique_indice_desc, + dev_cross_gather_buffer_ptr, + cross_gather_buffer_desc, + embedding_entry_count_per_rank, + wm_local_comm, + wm_cross_comm, + &thrust_allocator, + p_env_fns, + stream, + gather_sms); + // sort-unique reorder + temp_memory_handle dev_embedding_map_buffer_handle(p_env_fns); + void* dev_embedding_map_buffer_ptr = dev_embedding_map_buffer_handle.device_malloc( + wholememory_desc.sizes[1] * total_recv_count, output_desc.dtype); + int64_t embedding_map_buffer_size[2] = {total_recv_count, wholememory_desc.sizes[1]}; + wholememory_matrix_description_t embedding_map_buffer_desc = wholememory_create_matrix_desc( + embedding_map_buffer_size, wholememory_desc.sizes[1], 0, output_desc.dtype); + wholememory_gref_t cross_gather_fake_gref = + wholememory_create_continuous_global_reference(dev_cross_gather_buffer_ptr); + WHOLEMEMORY_RETURN_ON_FAIL(gather_func(cross_gather_fake_gref, + cross_gather_buffer_desc, + dev_sort_unique_ids_map_handle.pointer(), + recv_bucket_indices_desc, + dev_embedding_map_buffer_ptr, + embedding_map_buffer_desc, stream, gather_sms)); - // AllToAllV for embeddings - size_t embedding_size = + // exchange embeddings + size_t output_embedding_size = wholememory_desc.sizes[1] * wholememory_dtype_get_element_size(output_desc.dtype); - WHOLEMEMORY_RETURN_ON_FAIL(exchange_embeddings_nccl_func(dev_local_gather_buffer_ptr, - host_recv_rank_id_count_ptr, - host_rank_id_count_ptr, - dev_embedding_recv_buffer_ptr, - embedding_size, - wm_comm, + temp_memory_handle dev_recv_embedding_buffer_handle(p_env_fns); + void* dev_recv_embedding_buffer_ptr = dev_recv_embedding_buffer_handle.device_malloc( + wholememory_desc.sizes[1] * indice_desc.size, output_desc.dtype); + WHOLEMEMORY_RETURN_ON_FAIL(exchange_embeddings_nccl_func(dev_embedding_map_buffer_ptr, + host_recv_id_count.data(), + host_bucket_id_count.data(), + dev_recv_embedding_buffer_ptr, + output_embedding_size, + wm_local_comm, stream)); - // Local reorder - int64_t total_need_indice_count = 0; - for (int i = 0; i < world_size; i++) { - total_need_indice_count += host_rank_id_count_ptr[i]; - } - wholememory_gref_t output_gref = wholememory_create_continuous_global_reference(output); - wholememory_matrix_description_t local_recv_buffer_desc = - wholememory_create_matrix_desc(output_desc.sizes, output_desc.sizes[1], 0, output_desc.dtype); - local_recv_buffer_desc.sizes[0] = total_need_indice_count; - auto raw_indice_desc = - wholememory_create_array_desc(total_need_indice_count, 0, WHOLEMEMORY_DT_INT64); - WHOLEMEMORY_RETURN_ON_FAIL(scatter_func(dev_embedding_recv_buffer_ptr, - local_recv_buffer_desc, - dev_raw_indice_ptr, - raw_indice_desc, - output_gref, - output_desc, - stream)); + // bucket reorder + wholememory_gref_t recv_embedding_buffer_fake_gref = + wholememory_create_continuous_global_reference(dev_recv_embedding_buffer_ptr); + int64_t recv_embedding_buffer_size[2] = {indice_desc.size, wholememory_desc.sizes[1]}; + wholememory_matrix_description_t recv_embedding_buffer_desc = wholememory_create_matrix_desc( + recv_embedding_buffer_size, wholememory_desc.sizes[1], 0, output_desc.dtype); + WHOLEMEMORY_RETURN_ON_FAIL(gather_func(recv_embedding_buffer_fake_gref, + recv_embedding_buffer_desc, + dev_bucket_ids_map_ptr, + indice_desc, + output, + output_desc, + stream, + gather_sms)); WM_CUDA_CHECK(cudaGetLastError()); - // WM_CUDA_CHECK(cudaStreamSynchronize(stream)); } catch (wholememory::cuda_error& wce) { WHOLEMEMORY_ERROR("CUDA logic Error %s\n", wce.what()); return WHOLEMEMORY_CUDA_ERROR;