Skip to content

Commit

Permalink
implement hierarchy gather function
Browse files Browse the repository at this point in the history
  • Loading branch information
zhuofan1123 committed Aug 19, 2024
1 parent 01fcd1a commit d7ace2d
Show file tree
Hide file tree
Showing 7 changed files with 1,009 additions and 83 deletions.
386 changes: 386 additions & 0 deletions cpp/src/wholememory_ops/functions/bucket_ids_for_hierarchy_func.cu

Large diffs are not rendered by default.

49 changes: 49 additions & 0 deletions cpp/src/wholememory_ops/functions/bucket_ids_for_hierarchy_func.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
/*
* Copyright (c) 2024, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#pragma once

#include <wholememory/tensor_description.h>
#include <wholememory/wholememory.h>
#include <wholememory_ops/thrust_allocator.hpp>

#include "wholememory_ops/temp_memory_handle.hpp"

namespace wholememory_ops {

wholememory_error_code_t bucket_and_reorder_ids_for_hierarchy_func(
void* indices,
wholememory_array_description_t indice_desc,
void* dev_bucket_indices,
void* dev_indice_map,
int64_t* host_bucket_id_count,
size_t embedding_entry_count_per_rank,
wholememory_comm_t wm_global_comm,
wholememory_comm_t wm_local_comm,
wm_thrust_allocator* p_thrust_allocator,
wholememory_env_func_t* p_env_fns,
cudaStream_t stream);

wholememory_error_code_t bucket_local_ids_func(void* indices,
wholememory_array_description_t indice_desc,
int64_t* host_bucket_id_count,
size_t embedding_entry_count_per_rank,
wholememory_comm_t wm_local_comm,
wholememory_comm_t wm_cross_comm,
wm_thrust_allocator* p_thrust_allocator,
wholememory_env_func_t* p_env_fns,
cudaStream_t stream);

} // namespace wholememory_ops
Original file line number Diff line number Diff line change
@@ -0,0 +1,145 @@
/*
* Copyright (c) 2024, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "sort_unique_ids_for_hierarchy_func.h"
#include "sort_unique_indices_func.h"

#include <cassert>
#include <cstdint>

#include <cub/cub.cuh>
#include <cub/device/device_radix_sort.cuh>
#include <thrust/sequence.h>
#include <thrust/unique.h>

#include <wholememory/wholememory.h>

#include "cuda_macros.hpp"
#include "error.hpp"
#include "logger.hpp"
#include "wholememory/communicator.hpp"
#include "wholememory/integer_utils.hpp"
#include "wholememory_ops/register.hpp"
#include "wholememory_ops/temp_memory_handle.hpp"
#include <wholememory_ops/thrust_allocator.hpp>

namespace wholememory_ops {

template <typename IndexT>
__global__ void SortUniqueIndiceMapKernel(IndexT* indice_map,
size_t indice_count,
const IndexT* sort_raw_indices,
const int* unique_count_ptr,
const IndexT* unique_offset_ptr,
size_t num_unique)
{
for (int idx = threadIdx.x + blockIdx.x * blockDim.x; idx < indice_count;
idx += blockDim.x * gridDim.x) {
if (idx >= num_unique) break;
IndexT offset = unique_offset_ptr[idx];
int count = unique_count_ptr[idx];
for (IndexT i = offset; i < offset + count; i++) {
indice_map[sort_raw_indices[i]] = idx;
}
}
}

template <typename IndexT>
void SortUniqueIndicesMapTempFunc(void* indice_map,
wholememory_array_description_t indice_desc,
const void* sort_raw_indices,
const int* unique_count_ptr,
size_t num_unique,
wm_thrust_allocator* p_thrust_allocator,
wholememory_env_func_t* p_env_fns,
cudaStream_t stream)
{
static constexpr int BLOCK_SIZE = 128;
int block_count = wholememory::div_rounding_up_unsafe(num_unique, BLOCK_SIZE);

temp_memory_handle dev_unique_offset_handle(p_env_fns);
IndexT* unique_offset_ptr =
static_cast<IndexT*>(dev_unique_offset_handle.device_malloc(num_unique, indice_desc.dtype));
IndexT* indice_map_ptr = static_cast<IndexT*>(indice_map);
const IndexT* sort_raw_indices_ptr = static_cast<const IndexT*>(sort_raw_indices);

void* cub_temp_storage = NULL;
size_t temp_storage_bytes = 0;
cub::DeviceScan::ExclusiveSum(
cub_temp_storage, temp_storage_bytes, unique_count_ptr, unique_offset_ptr, num_unique, stream);
cub_temp_storage = p_thrust_allocator->allocate(temp_storage_bytes);
cub::DeviceScan::ExclusiveSum(
cub_temp_storage, temp_storage_bytes, unique_count_ptr, unique_offset_ptr, num_unique, stream);
SortUniqueIndiceMapKernel<<<block_count, BLOCK_SIZE, 0, stream>>>(indice_map_ptr,
indice_desc.size,
sort_raw_indices_ptr,
unique_count_ptr,
unique_offset_ptr,
num_unique);
p_thrust_allocator->deallocate(reinterpret_cast<char*>(cub_temp_storage), temp_storage_bytes);
}

REGISTER_DISPATCH_ONE_TYPE(SortUniqueIndicesMapTempFunc, SortUniqueIndicesMapTempFunc, SINT3264)

wholememory_error_code_t sort_unique_ids_for_hierarchy_func(
void* indices,
wholememory_array_description_t indice_desc,
temp_memory_handle* output_indices_handle,
wholememory_array_description_t* output_indices_desc,
temp_memory_handle* dev_indice_map_handle,
wm_thrust_allocator* p_thrust_allocator,
wholememory_env_func_t* p_env_fns,
cudaStream_t stream)
{
if (indice_desc.size == 0) {
*output_indices_desc = wholememory_create_array_desc(0, 0, indice_desc.dtype);
return WHOLEMEMORY_SUCCESS;
}
int num_runs = 0;
temp_memory_handle unique_count_handle(p_env_fns);
temp_memory_handle dev_sort_raw_indices_handle(p_env_fns);
void* dev_sort_raw_indices_ptr =
dev_sort_raw_indices_handle.device_malloc(indice_desc.size, indice_desc.dtype);
sort_unique_indices_func(indices,
indice_desc,
dev_sort_raw_indices_ptr,
&num_runs,
output_indices_handle,
&unique_count_handle,
p_thrust_allocator,
p_env_fns,
stream);
*output_indices_desc = wholememory_create_array_desc(num_runs, 0, indice_desc.dtype);
void* dev_indice_map_ptr =
dev_indice_map_handle->device_malloc(indice_desc.size, indice_desc.dtype);
WM_CUDA_CHECK(cudaGetLastError());
try {
DISPATCH_ONE_TYPE(indice_desc.dtype,
SortUniqueIndicesMapTempFunc,
dev_indice_map_ptr,
indice_desc,
dev_sort_raw_indices_ptr,
static_cast<int*>(unique_count_handle.pointer()),
num_runs,
p_thrust_allocator,
p_env_fns,
stream);
} catch (...) {
WHOLEMEMORY_FAIL_NOTHROW("map indices failed");
}
return WHOLEMEMORY_SUCCESS;
}

} // namespace wholememory_ops
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
/*
* Copyright (c) 2024, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#pragma once

#include "wholememory_ops/temp_memory_handle.hpp"
#include <wholememory/tensor_description.h>
#include <wholememory/wholememory.h>
#include <wholememory_ops/thrust_allocator.hpp>

namespace wholememory_ops {

wholememory_error_code_t sort_unique_ids_for_hierarchy_func(
void* indices,
wholememory_array_description_t indice_desc,
temp_memory_handle* output_indices_handle,
wholememory_array_description_t* output_indices_desc,
temp_memory_handle* dev_indice_map_handle, // indice_desc
wm_thrust_allocator* p_thrust_allocator,
wholememory_env_func_t* p_env_fns,
cudaStream_t stream);

} // namespace wholememory_ops
118 changes: 118 additions & 0 deletions cpp/src/wholememory_ops/functions/sort_unique_indices_func.cu
Original file line number Diff line number Diff line change
@@ -0,0 +1,118 @@
/*
* Copyright (c) 2024, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#include "sort_indices_func.h"
#include "sort_unique_indices_func.h"

#include <cub/cub.cuh>
#include <cub/device/device_radix_sort.cuh>
#include <thrust/sequence.h>

#include "cuda_macros.hpp"
#include "error.hpp"
#include "logger.hpp"
#include "wholememory_ops/register.hpp"

namespace wholememory_ops {

template <typename IndexT>
void SortUniqueIndicesTempFunc(const void* indices,
wholememory_array_description_t indice_desc,
void* sort_raw_indices,
int* num_runs,
temp_memory_handle* unique_indices_handle,
temp_memory_handle* unique_count_handle,
wm_thrust_allocator* p_thrust_allocator,
wholememory_env_func_t* p_env_fns,
cudaStream_t stream)
{
if (indice_desc.size == 0) return;
wm_thrust_allocator& allocator = *p_thrust_allocator;
WHOLEMEMORY_CHECK_NOTHROW(indice_desc.storage_offset == 0);
temp_memory_handle sorted_indices_handle(p_env_fns);
sorted_indices_handle.device_malloc(indice_desc.size, indice_desc.dtype);
IndexT* sorted_indices = static_cast<IndexT*>(sorted_indices_handle.pointer());

sort_indices_func(
indices, indice_desc, sorted_indices, sort_raw_indices, p_thrust_allocator, p_env_fns, stream);

unique_indices_handle->device_malloc(indice_desc.size, indice_desc.dtype);
unique_count_handle->device_malloc(indice_desc.size, WHOLEMEMORY_DT_INT);
IndexT* unique_indices = static_cast<IndexT*>(unique_indices_handle->pointer());
int* unique_counts = static_cast<int*>(unique_count_handle->pointer());
temp_memory_handle number_runs_handle(p_env_fns);
number_runs_handle.device_malloc(1, WHOLEMEMORY_DT_INT);
int* number_runs = static_cast<int*>(number_runs_handle.pointer());
void* cub_temp_storage = nullptr;
size_t temp_storage_bytes = 0;
cub::DeviceRunLengthEncode::Encode(cub_temp_storage,
temp_storage_bytes,
sorted_indices,
unique_indices,
unique_counts,
number_runs,
indice_desc.size,
stream);
cub_temp_storage = allocator.allocate(temp_storage_bytes);
cub::DeviceRunLengthEncode::Encode(cub_temp_storage,
temp_storage_bytes,
sorted_indices,
unique_indices,
unique_counts,
number_runs,
indice_desc.size,
stream);
WM_CUDA_CHECK_NO_THROW(
cudaMemcpyAsync(num_runs, number_runs, sizeof(int), cudaMemcpyDeviceToHost, stream));
}

REGISTER_DISPATCH_ONE_TYPE(SortUniqueIndicesTempFunc, SortUniqueIndicesTempFunc, SINT3264)

wholememory_error_code_t sort_unique_indices_func(const void* indices,
wholememory_array_description_t indice_desc,
void* sort_raw_indices,
int* num_runs,
temp_memory_handle* unique_indices_handle,
temp_memory_handle* unique_count_handle,
wm_thrust_allocator* p_thrust_allocator,
wholememory_env_func_t* p_env_fns,
cudaStream_t stream)
{
try {
DISPATCH_ONE_TYPE(indice_desc.dtype,
SortUniqueIndicesTempFunc,
indices,
indice_desc,
sort_raw_indices,
num_runs,
unique_indices_handle,
unique_count_handle,
p_thrust_allocator,
p_env_fns,
stream);
} catch (wholememory::cuda_error& wce) {
WHOLEMEMORY_ERROR("sort_unique_indices_func CUDA LOGIC Error %s\n", wce.what());
return WHOLEMEMORY_CUDA_ERROR;
} catch (wholememory::logic_error& wle) {
WHOLEMEMORY_ERROR("sort_unique_indices_func LOGIC Error %s\n", wle.what());
return WHOLEMEMORY_LOGIC_ERROR;
} catch (...) {
return WHOLEMEMORY_UNKNOW_ERROR;
}
return WHOLEMEMORY_SUCCESS;
}

} // namespace wholememory_ops
37 changes: 37 additions & 0 deletions cpp/src/wholememory_ops/functions/sort_unique_indices_func.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
/*
* Copyright (c) 2024, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#pragma once

#include <wholememory/tensor_description.h>
#include <wholememory/wholememory.h>

#include <wholememory_ops/temp_memory_handle.hpp>
#include <wholememory_ops/thrust_allocator.hpp>

namespace wholememory_ops {

wholememory_error_code_t sort_unique_indices_func(const void* indices,
wholememory_array_description_t indice_desc,
void* sort_raw_indices,
int* num_runs,
temp_memory_handle* unique_indices_handle,
temp_memory_handle* unique_count_handle,
wm_thrust_allocator* p_thrust_allocator,
wholememory_env_func_t* p_env_fns,
cudaStream_t stream);

} // namespace wholememory_ops
Loading

0 comments on commit d7ace2d

Please sign in to comment.