Skip to content

Commit

Permalink
Add distinct key inner join (rapidsai#14990)
Browse files Browse the repository at this point in the history
Contributes to rapidsai#14948

This PR adds a public `cudf::distinct_hash_join` class that provides a fast code path for joins with distinct keys.

Only distinct inner join is tackled in the current PR.

Authors:
  - Yunsong Wang (https://github.com/PointKernel)

Approvers:
  - Jason Lowe (https://github.com/jlowe)
  - Bradley Dice (https://github.com/bdice)
  - Lawrence Mitchell (https://github.com/wence-)
  - David Wendt (https://github.com/davidwendt)
  - Nghia Truong (https://github.com/ttnghia)

URL: rapidsai#14990
  • Loading branch information
PointKernel authored Feb 23, 2024
1 parent 8e87335 commit 71c9909
Show file tree
Hide file tree
Showing 9 changed files with 999 additions and 3 deletions.
1 change: 1 addition & 0 deletions cpp/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -434,6 +434,7 @@ add_library(
src/jit/util.cpp
src/join/conditional_join.cu
src/join/cross_join.cu
src/join/distinct_hash_join.cu
src/join/hash_join.cu
src/join/join.cu
src/join/join_utils.cu
Expand Down
2 changes: 1 addition & 1 deletion cpp/benchmarks/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -164,7 +164,7 @@ ConfigureNVBench(
# ##################################################################################################
# * join benchmark --------------------------------------------------------------------------------
ConfigureBench(JOIN_BENCH join/left_join.cu join/conditional_join.cu)
ConfigureNVBench(JOIN_NVBENCH join/join.cu join/mixed_join.cu)
ConfigureNVBench(JOIN_NVBENCH join/join.cu join/mixed_join.cu join/distinct_join.cu)

# ##################################################################################################
# * iterator benchmark ----------------------------------------------------------------------------
Expand Down
77 changes: 77 additions & 0 deletions cpp/benchmarks/join/distinct_join.cu
Original file line number Diff line number Diff line change
@@ -0,0 +1,77 @@
/*
* Copyright (c) 2024, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#include "join_common.hpp"

template <typename key_type, typename payload_type, bool Nullable>
void distinct_inner_join(nvbench::state& state,
nvbench::type_list<key_type, payload_type, nvbench::enum_type<Nullable>>)
{
skip_helper(state);

auto join = [](cudf::table_view const& left_input,
cudf::table_view const& right_input,
cudf::null_equality compare_nulls,
rmm::cuda_stream_view stream) {
auto const has_nulls = cudf::has_nested_nulls(left_input) || cudf::has_nested_nulls(right_input)
? cudf::nullable_join::YES
: cudf::nullable_join::NO;
auto hj_obj = cudf::distinct_hash_join<cudf::has_nested::NO>{
left_input, right_input, has_nulls, compare_nulls, stream};
return hj_obj.inner_join(stream);
};

BM_join<key_type, payload_type, Nullable>(state, join);
}

// inner join -----------------------------------------------------------------------
NVBENCH_BENCH_TYPES(distinct_inner_join,
NVBENCH_TYPE_AXES(nvbench::type_list<nvbench::int32_t>,
nvbench::type_list<nvbench::int32_t>,
nvbench::enum_type_list<false>))
.set_name("distinct_inner_join_32bit")
.set_type_axes_names({"Key Type", "Payload Type", "Nullable"})
.add_int64_axis("Build Table Size", {100'000, 10'000'000, 80'000'000, 100'000'000})
.add_int64_axis("Probe Table Size",
{100'000, 400'000, 10'000'000, 40'000'000, 100'000'000, 240'000'000});

NVBENCH_BENCH_TYPES(distinct_inner_join,
NVBENCH_TYPE_AXES(nvbench::type_list<nvbench::int64_t>,
nvbench::type_list<nvbench::int64_t>,
nvbench::enum_type_list<false>))
.set_name("distinct_inner_join_64bit")
.set_type_axes_names({"Key Type", "Payload Type", "Nullable"})
.add_int64_axis("Build Table Size", {40'000'000, 50'000'000})
.add_int64_axis("Probe Table Size", {50'000'000, 120'000'000});

NVBENCH_BENCH_TYPES(distinct_inner_join,
NVBENCH_TYPE_AXES(nvbench::type_list<nvbench::int32_t>,
nvbench::type_list<nvbench::int32_t>,
nvbench::enum_type_list<true>))
.set_name("distinct_inner_join_32bit_nulls")
.set_type_axes_names({"Key Type", "Payload Type", "Nullable"})
.add_int64_axis("Build Table Size", {100'000, 10'000'000, 80'000'000, 100'000'000})
.add_int64_axis("Probe Table Size",
{100'000, 400'000, 10'000'000, 40'000'000, 100'000'000, 240'000'000});

NVBENCH_BENCH_TYPES(distinct_inner_join,
NVBENCH_TYPE_AXES(nvbench::type_list<nvbench::int64_t>,
nvbench::type_list<nvbench::int64_t>,
nvbench::enum_type_list<true>))
.set_name("distinct_inner_join_64bit_nulls")
.set_type_axes_names({"Key Type", "Payload Type", "Nullable"})
.add_int64_axis("Build Table Size", {40'000'000, 50'000'000})
.add_int64_axis("Probe Table Size", {50'000'000, 120'000'000});
3 changes: 3 additions & 0 deletions cpp/include/cudf/detail/cuco_helpers.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,9 @@

namespace cudf::detail {

/// Default load factor for cuco data structures
static double constexpr CUCO_DESIRED_LOAD_FACTOR = 0.5;

/**
* @brief Stream-ordered allocator adaptor used for cuco data structures
*
Expand Down
153 changes: 153 additions & 0 deletions cpp/include/cudf/detail/distinct_hash_join.cuh
Original file line number Diff line number Diff line change
@@ -0,0 +1,153 @@
/*
* Copyright (c) 2024, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#pragma once

#include <cudf/hashing/detail/helper_functions.cuh>
#include <cudf/table/experimental/row_operators.cuh>
#include <cudf/types.hpp>

#include <rmm/cuda_stream_view.hpp>
#include <rmm/device_uvector.hpp>

#include <cuco/static_set.cuh>

#include <cstddef>
#include <memory>
#include <type_traits>
#include <utility>

namespace cudf::detail {

using cudf::experimental::row::lhs_index_type;
using cudf::experimental::row::rhs_index_type;

/**
* @brief An comparator adapter wrapping both self comparator and two table comparator
*/
template <typename Equal>
struct comparator_adapter {
comparator_adapter(Equal const& d_equal) : _d_equal{d_equal} {}

__device__ constexpr auto operator()(
cuco::pair<hash_value_type, lhs_index_type> const&,
cuco::pair<hash_value_type, lhs_index_type> const&) const noexcept
{
// All build table keys are distinct thus `false` no matter what
return false;
}

__device__ constexpr auto operator()(
cuco::pair<hash_value_type, lhs_index_type> const& lhs,
cuco::pair<hash_value_type, rhs_index_type> const& rhs) const noexcept
{
if (lhs.first != rhs.first) { return false; }
return _d_equal(lhs.second, rhs.second);
}

private:
Equal _d_equal;
};

template <typename Hasher>
struct hasher_adapter {
hasher_adapter(Hasher const& d_hasher = {}) : _d_hasher{d_hasher} {}

template <typename T>
__device__ constexpr auto operator()(cuco::pair<hash_value_type, T> const& key) const noexcept
{
return _d_hasher(key.first);
}

private:
Hasher _d_hasher;
};

/**
* @brief Distinct hash join that builds hash table in creation and probes results in subsequent
* `*_join` member functions.
*
* @tparam HasNested Flag indicating whether there are nested columns in build/probe table
*/
template <cudf::has_nested HasNested>
struct distinct_hash_join {
private:
/// Row equality type for nested columns
using nested_row_equal = cudf::experimental::row::equality::strong_index_comparator_adapter<
cudf::experimental::row::equality::device_row_comparator<true, cudf::nullate::DYNAMIC>>;
/// Row equality type for flat columns
using flat_row_equal = cudf::experimental::row::equality::strong_index_comparator_adapter<
cudf::experimental::row::equality::device_row_comparator<false, cudf::nullate::DYNAMIC>>;

/// Device row equal type
using d_equal_type =
std::conditional_t<HasNested == cudf::has_nested::YES, nested_row_equal, flat_row_equal>;
using hasher = hasher_adapter<thrust::identity<hash_value_type>>;
using probing_scheme_type = cuco::linear_probing<1, hasher>;
using cuco_storage_type = cuco::storage<1>;

/// Hash table type
using hash_table_type = cuco::static_set<cuco::pair<hash_value_type, lhs_index_type>,
cuco::extent<size_type>,
cuda::thread_scope_device,
comparator_adapter<d_equal_type>,
probing_scheme_type,
cudf::detail::cuco_allocator,
cuco_storage_type>;

bool _has_nulls; ///< true if nulls are present in either build table or probe table
cudf::null_equality _nulls_equal; ///< whether to consider nulls as equal
cudf::table_view _build; ///< input table to build the hash map
cudf::table_view _probe; ///< input table to probe the hash map
std::shared_ptr<cudf::experimental::row::equality::preprocessed_table>
_preprocessed_build; ///< input table preprocssed for row operators
std::shared_ptr<cudf::experimental::row::equality::preprocessed_table>
_preprocessed_probe; ///< input table preprocssed for row operators
hash_table_type _hash_table; ///< hash table built on `_build`

public:
distinct_hash_join() = delete;
~distinct_hash_join() = default;
distinct_hash_join(distinct_hash_join const&) = delete;
distinct_hash_join(distinct_hash_join&&) = delete;
distinct_hash_join& operator=(distinct_hash_join const&) = delete;
distinct_hash_join& operator=(distinct_hash_join&&) = delete;

/**
* @brief Constructor that internally builds the hash table based on the given `build` table.
*
* @throw cudf::logic_error if the number of columns in `build` table is 0.
*
* @param build The build table, from which the hash table is built
* @param probe The probe table
* @param has_nulls Flag to indicate if any nulls exist in the `build` table or
* any `probe` table that will be used later for join.
* @param compare_nulls Controls whether null join-key values should match or not.
* @param stream CUDA stream used for device memory operations and kernel launches.
*/
distinct_hash_join(cudf::table_view const& build,
cudf::table_view const& probe,
bool has_nulls,
cudf::null_equality compare_nulls,
rmm::cuda_stream_view stream);

/**
* @copydoc cudf::distinct_hash_join::inner_join
*/
std::pair<std::unique_ptr<rmm::device_uvector<size_type>>,
std::unique_ptr<rmm::device_uvector<size_type>>>
inner_join(rmm::cuda_stream_view stream, rmm::mr::device_memory_resource* mr) const;
};
} // namespace cudf::detail
70 changes: 69 additions & 1 deletion cpp/include/cudf/join.hpp
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* Copyright (c) 2019-2023, NVIDIA CORPORATION.
* Copyright (c) 2019-2024, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -33,6 +33,13 @@

namespace cudf {

/**
* @brief Enum to indicate whether the distinct join table has nested columns or not
*
* @ingroup column_join
*/
enum class has_nested : bool { YES, NO };

// forward declaration
namespace hashing::detail {
template <typename T>
Expand All @@ -41,6 +48,9 @@ class MurmurHash3_x86_32;
namespace detail {
template <typename T>
class hash_join;

template <cudf::has_nested HasNested>
class distinct_hash_join;
} // namespace detail

/**
Expand Down Expand Up @@ -438,6 +448,64 @@ class hash_join {
const std::unique_ptr<impl_type const> _impl;
};

/**
* @brief Distinct hash join that builds hash table in creation and probes results in subsequent
* `*_join` member functions
*
* @note Behavior is undefined if the build table contains duplicates.
* @note All NaNs are considered as equal
*
* @tparam HasNested Flag indicating whether there are nested columns in build/probe table
*/
// TODO: `HasNested` to be removed via dispatching
template <cudf::has_nested HasNested>
class distinct_hash_join {
public:
distinct_hash_join() = delete;
~distinct_hash_join();
distinct_hash_join(distinct_hash_join const&) = delete;
distinct_hash_join(distinct_hash_join&&) = delete;
distinct_hash_join& operator=(distinct_hash_join const&) = delete;
distinct_hash_join& operator=(distinct_hash_join&&) = delete;

/**
* @brief Constructs a distinct hash join object for subsequent probe calls
*
* @param build The build table that contains distinct elements
* @param probe The probe table, from which the keys are probed
* @param has_nulls Flag to indicate if there exists any nulls in the `build` table or
* any `probe` table that will be used later for join
* @param compare_nulls Controls whether null join-key values should match or not
* @param stream CUDA stream used for device memory operations and kernel launches
*/
distinct_hash_join(cudf::table_view const& build,
cudf::table_view const& probe,
nullable_join has_nulls = nullable_join::YES,
null_equality compare_nulls = null_equality::EQUAL,
rmm::cuda_stream_view stream = cudf::get_default_stream());

/**
* Returns the row indices that can be used to construct the result of performing
* an inner join between two tables. @see cudf::inner_join().
*
* @param stream CUDA stream used for device memory operations and kernel launches
* @param mr Device memory resource used to allocate the returned indices' device memory.
*
* @return A pair of columns [`build_indices`, `probe_indices`] that can be used to construct
* the result of performing an inner join between two tables with `build` and `probe`
* as the join keys.
*/
std::pair<std::unique_ptr<rmm::device_uvector<size_type>>,
std::unique_ptr<rmm::device_uvector<size_type>>>
inner_join(rmm::cuda_stream_view stream = cudf::get_default_stream(),
rmm::mr::device_memory_resource* mr = rmm::mr::get_current_device_resource()) const;

private:
using impl_type = typename cudf::detail::distinct_hash_join<HasNested>; ///< Implementation type

std::unique_ptr<impl_type> _impl; ///< Distinct hash join implementation
};

/**
* @brief Returns a pair of row index vectors corresponding to all pairs
* of rows between the specified tables where the predicate evaluates to true.
Expand Down
Loading

0 comments on commit 71c9909

Please sign in to comment.