diff --git a/cpp/src/prims/detail/nbr_intersection.cuh b/cpp/src/prims/detail/nbr_intersection.cuh index 8261ec747f9..26b87f21dbb 100644 --- a/cpp/src/prims/detail/nbr_intersection.cuh +++ b/cpp/src/prims/detail/nbr_intersection.cuh @@ -319,7 +319,8 @@ struct pick_min_degree_t { } }; -template (nullptr); - auto intersection_size = set_intersection_by_key_with_mask( - indices0, - indices1, - edge_property_values0, - edge_property_values1, - mask_first, - nbr_intersection_indices.begin(), - nbr_intersection_e_property_values0, - nbr_intersection_e_property_values1, - local_edge_offset0, - local_degree0, - (std::is_same_v && edge_partition_e_mask), - local_edge_offset1, - local_degree1, - (std::is_same_v && edge_partition_e_mask), - nbr_intersection_offsets[i]); + edge_t intersection_size{}; + if (edge_partition_e_mask) { + intersection_size = + set_intersection_by_key_with_mask(indices0, + indices1, + edge_property_values0, + edge_property_values1, + (*edge_partition_e_mask).value_first(), + nbr_intersection_indices.begin(), + nbr_intersection_e_property_values0, + nbr_intersection_e_property_values1, + local_edge_offset0, + local_degree0, + std::is_same_v, + local_edge_offset1, + local_degree1, + std::is_same_v, + nbr_intersection_offsets[i]); + } else { + intersection_size = + set_intersection_by_key_with_mask(indices0, + indices1, + edge_property_values0, + edge_property_values1, + static_cast(nullptr), + nbr_intersection_indices.begin(), + nbr_intersection_e_property_values0, + nbr_intersection_e_property_values1, + local_edge_offset0, + local_degree0, + false, + local_edge_offset1, + local_degree1, + false, + nbr_intersection_offsets[i]); + } thrust::fill( thrust::seq, @@ -714,7 +737,7 @@ nbr_intersection(raft::handle_t const& handle, auto edge_mask_view = graph_view.edge_mask_view(); std::optional>> major_to_idx_map_ptr{ - std::nullopt}; + std::nullopt}; // idx to major_nbr_offsets std::optional> major_nbr_offsets{std::nullopt}; std::optional> major_nbr_indices{std::nullopt}; @@ -1041,7 +1064,7 @@ nbr_intersection(raft::handle_t const& handle, // 3. Collect neighbor list for minors (for the neighbors within the minor range for this GPU) std::optional>> minor_to_idx_map_ptr{ - std::nullopt}; + std::nullopt}; // idx to minor_nbr_offsets std::optional> minor_nbr_offsets{std::nullopt}; std::optional> minor_nbr_indices{std::nullopt}; diff --git a/cpp/src/prims/transform_e.cuh b/cpp/src/prims/transform_e.cuh index 93a2d040b60..5c83e0f1b71 100644 --- a/cpp/src/prims/transform_e.cuh +++ b/cpp/src/prims/transform_e.cuh @@ -42,7 +42,8 @@ namespace detail { int32_t constexpr transform_e_kernel_block_size = 512; -template edge_partition_e_mask, + EdgePartitionEdgeMaskWrapper edge_partition_e_mask, EdgePartitionEdgeValueOutputWrapper edge_partition_e_value_output, EdgeOp e_op) { @@ -72,35 +73,44 @@ __global__ void transform_e_packed_bool( auto num_edges = edge_partition.number_of_edges(); while (idx < static_cast(packed_bool_size(num_edges))) { - auto edge_mask = packed_bool_full_mask(); - if (edge_partition_e_mask) { edge_mask = *((*edge_partition_e_mask).value_first() + idx); } + [[maybe_unused]] auto edge_mask = + packed_bool_full_mask(); // relevant only when check_edge_mask is true + if constexpr (check_edge_mask) { edge_mask = *(edge_partition_e_mask.value_first() + idx); } auto local_edge_idx = idx * static_cast(packed_bools_per_word()) + static_cast(lane_id); int predicate{0}; - if ((local_edge_idx < num_edges) && (edge_mask & packed_bool_mask(lane_id))) { - auto major_idx = edge_partition.major_idx_from_local_edge_idx_nocheck(local_edge_idx); - auto major = edge_partition.major_from_major_idx_nocheck(major_idx); - auto major_offset = edge_partition.major_offset_from_major_nocheck(major); - auto minor = *(edge_partition.indices() + local_edge_idx); - auto minor_offset = edge_partition.minor_offset_from_minor_nocheck(minor); - - auto src = GraphViewType::is_storage_transposed ? minor : major; - auto dst = GraphViewType::is_storage_transposed ? major : minor; - auto src_offset = GraphViewType::is_storage_transposed ? minor_offset : major_offset; - auto dst_offset = GraphViewType::is_storage_transposed ? major_offset : minor_offset; - predicate = e_op(src, - dst, - edge_partition_src_value_input.get(src_offset), - edge_partition_dst_value_input.get(dst_offset), - edge_partition_e_value_input.get(local_edge_idx)) - ? int{1} - : int{0}; + if (local_edge_idx < num_edges) { + bool compute_predicate = true; + if constexpr (check_edge_mask) { + compute_predicate = (edge_mask & packed_bool_mask(lane_id) != packed_bool_empty_mask()); + } + + if (compute_predicate) { + auto major_idx = edge_partition.major_idx_from_local_edge_idx_nocheck(local_edge_idx); + auto major = edge_partition.major_from_major_idx_nocheck(major_idx); + auto major_offset = edge_partition.major_offset_from_major_nocheck(major); + auto minor = *(edge_partition.indices() + local_edge_idx); + auto minor_offset = edge_partition.minor_offset_from_minor_nocheck(minor); + + auto src = GraphViewType::is_storage_transposed ? minor : major; + auto dst = GraphViewType::is_storage_transposed ? major : minor; + auto src_offset = GraphViewType::is_storage_transposed ? minor_offset : major_offset; + auto dst_offset = GraphViewType::is_storage_transposed ? major_offset : minor_offset; + predicate = e_op(src, + dst, + edge_partition_src_value_input.get(src_offset), + edge_partition_dst_value_input.get(dst_offset), + edge_partition_e_value_input.get(local_edge_idx)) + ? int{1} + : int{0}; + } } + uint32_t new_val = __ballot_sync(uint32_t{0xffffffff}, predicate); if (lane_id == 0) { - if (edge_mask == packed_bool_full_mask()) { + if constexpr (check_edge_mask) { *(edge_partition_e_value_output.value_first() + idx) = new_val; } else { auto old_val = *(edge_partition_e_value_output.value_first() + idx); @@ -112,6 +122,99 @@ __global__ void transform_e_packed_bool( } } +template +struct update_e_value_t { + edge_partition_device_view_t + edge_partition{}; + EdgePartitionSrcValueInputWrapper edge_partition_src_value_input{}; + EdgePartitionDstValueInputWrapper edge_partition_dst_value_input{}; + EdgePartitionEdgeValueInputWrapper edge_partition_e_value_input{}; + EdgePartitionEdgeMaskWrapper edge_partition_e_mask{}; + EdgeOp e_op{}; + EdgeValueOutputWrapper edge_partition_e_value_output{}; + + __device__ void operator()(thrust::tuple edge) const + { + using vertex_t = typename GraphViewType::vertex_type; + using edge_t = typename GraphViewType::edge_type; + + auto major = thrust::get<0>(edge); + auto minor = thrust::get<1>(edge); + + auto major_offset = edge_partition.major_offset_from_major_nocheck(major); + auto major_idx = edge_partition.major_idx_from_major_nocheck(major); + assert(major_idx); + + auto minor_offset = edge_partition.minor_offset_from_minor_nocheck(minor); + + vertex_t const* indices{nullptr}; + edge_t edge_offset{}; + edge_t local_degree{}; + thrust::tie(indices, edge_offset, local_degree) = edge_partition.local_edges(*major_idx); + auto lower_it = thrust::lower_bound(thrust::seq, indices, indices + local_degree, minor); + auto upper_it = thrust::upper_bound(thrust::seq, lower_it, indices + local_degree, minor); + + auto src = GraphViewType::is_storage_transposed ? minor : major; + auto dst = GraphViewType::is_storage_transposed ? major : minor; + auto src_offset = GraphViewType::is_storage_transposed ? minor_offset : major_offset; + auto dst_offset = GraphViewType::is_storage_transposed ? major_offset : minor_offset; + + for (auto it = lower_it; it != upper_it; ++it) { + assert(*it == minor); + if constexpr (check_edge_mask) { + if (edge_partition_e_mask.get(edge_offset + thrust::distance(indices, it))) { + auto e_op_result = + e_op(src, + dst, + edge_partition_src_value_input.get(src_offset), + edge_partition_dst_value_input.get(dst_offset), + edge_partition_e_value_input.get(edge_offset + thrust::distance(indices, it))); + edge_partition_e_value_output.set(edge_offset + thrust::distance(indices, it), + e_op_result); + } + } else { + auto e_op_result = + e_op(src, + dst, + edge_partition_src_value_input.get(src_offset), + edge_partition_dst_value_input.get(dst_offset), + edge_partition_e_value_input.get(edge_offset + thrust::distance(indices, it))); + edge_partition_e_value_output.set(edge_offset + thrust::distance(indices, it), e_op_result); + } + } + } + + __device__ void operator()(typename GraphViewType::edge_type i) const + { + auto major_idx = edge_partition.major_idx_from_local_edge_idx_nocheck(i); + auto major = edge_partition.major_from_major_idx_nocheck(major_idx); + auto major_offset = edge_partition.major_offset_from_major_nocheck(major); + auto minor = *(edge_partition.indices() + i); + auto minor_offset = edge_partition.minor_offset_from_minor_nocheck(minor); + + auto src = GraphViewType::is_storage_transposed ? minor : major; + auto dst = GraphViewType::is_storage_transposed ? major : minor; + auto src_offset = GraphViewType::is_storage_transposed ? minor_offset : major_offset; + auto dst_offset = GraphViewType::is_storage_transposed ? major_offset : minor_offset; + auto e_op_result = e_op(src, + dst, + edge_partition_src_value_input.get(src_offset), + edge_partition_dst_value_input.get(dst_offset), + edge_partition_e_value_input.get(i)); + edge_partition_e_value_output.set(i, e_op_result); + } +}; + } // namespace detail /** @@ -228,47 +331,68 @@ void transform_e(raft::handle_t const& handle, raft::grid_1d_thread_t update_grid(num_edges, detail::transform_e_kernel_block_size, handle.get_device_properties().maxGridSize[0]); - detail::transform_e_packed_bool - <<>>( - edge_partition, - edge_partition_src_value_input, - edge_partition_dst_value_input, - edge_partition_e_value_input, - edge_partition_e_mask, - edge_partition_e_value_output, - e_op); + if (edge_partition_e_mask) { + detail::transform_e_packed_bool + <<>>( + edge_partition, + edge_partition_src_value_input, + edge_partition_dst_value_input, + edge_partition_e_value_input, + *edge_partition_e_mask, + edge_partition_e_value_output, + e_op); + } else { + detail::transform_e_packed_bool + <<>>( + edge_partition, + edge_partition_src_value_input, + edge_partition_dst_value_input, + edge_partition_e_value_input, + std::byte{}, // dummy + edge_partition_e_value_output, + e_op); + } } } else { - thrust::for_each( - handle.get_thrust_policy(), - thrust::make_counting_iterator(edge_t{0}), - thrust::make_counting_iterator(num_edges), - [e_op, - edge_partition, - edge_partition_src_value_input, - edge_partition_dst_value_input, - edge_partition_e_value_input, - edge_partition_e_mask, - edge_partition_e_value_output] __device__(edge_t i) { - if (!edge_partition_e_mask || (*edge_partition_e_mask).get(i)) { - auto major_idx = edge_partition.major_idx_from_local_edge_idx_nocheck(i); - auto major = edge_partition.major_from_major_idx_nocheck(major_idx); - auto major_offset = edge_partition.major_offset_from_major_nocheck(major); - auto minor = *(edge_partition.indices() + i); - auto minor_offset = edge_partition.minor_offset_from_minor_nocheck(minor); - - auto src = GraphViewType::is_storage_transposed ? minor : major; - auto dst = GraphViewType::is_storage_transposed ? major : minor; - auto src_offset = GraphViewType::is_storage_transposed ? minor_offset : major_offset; - auto dst_offset = GraphViewType::is_storage_transposed ? major_offset : minor_offset; - auto e_op_result = e_op(src, - dst, - edge_partition_src_value_input.get(src_offset), - edge_partition_dst_value_input.get(dst_offset), - edge_partition_e_value_input.get(i)); - edge_partition_e_value_output.set(i, e_op_result); - } - }); + if (edge_partition_e_mask) { + thrust::for_each(handle.get_thrust_policy(), + thrust::make_counting_iterator(edge_t{0}), + thrust::make_counting_iterator(num_edges), + detail::update_e_value_t{ + edge_partition, + edge_partition_src_value_input, + edge_partition_dst_value_input, + edge_partition_e_value_input, + *edge_partition_e_mask, + e_op, + edge_partition_e_value_output}); + } else { + thrust::for_each(handle.get_thrust_policy(), + thrust::make_counting_iterator(edge_t{0}), + thrust::make_counting_iterator(num_edges), + detail::update_e_value_t{ + edge_partition, + edge_partition_src_value_input, + edge_partition_dst_value_input, + edge_partition_e_value_input, + std::byte{}, // dummy + e_op, + edge_partition_e_value_output}); + } } } } @@ -467,53 +591,45 @@ void transform_e(raft::handle_t const& handle, auto edge_partition_e_value_output = edge_partition_e_output_device_view_t(edge_value_output, i); - thrust::for_each( - handle.get_thrust_policy(), - edge_first + edge_partition_offsets[i], - edge_first + edge_partition_offsets[i + 1], - [e_op, - edge_partition, - edge_partition_src_value_input, - edge_partition_dst_value_input, - edge_partition_e_value_input, - edge_partition_e_mask, - edge_partition_e_value_output] __device__(thrust::tuple edge) { - auto major = thrust::get<0>(edge); - auto minor = thrust::get<1>(edge); - - auto major_offset = edge_partition.major_offset_from_major_nocheck(major); - auto major_idx = edge_partition.major_idx_from_major_nocheck(major); - assert(major_idx); - - auto minor_offset = edge_partition.minor_offset_from_minor_nocheck(minor); - - vertex_t const* indices{nullptr}; - edge_t edge_offset{}; - edge_t local_degree{}; - thrust::tie(indices, edge_offset, local_degree) = edge_partition.local_edges(*major_idx); - auto lower_it = thrust::lower_bound(thrust::seq, indices, indices + local_degree, minor); - auto upper_it = thrust::upper_bound(thrust::seq, lower_it, indices + local_degree, minor); - - auto src = GraphViewType::is_storage_transposed ? minor : major; - auto dst = GraphViewType::is_storage_transposed ? major : minor; - auto src_offset = GraphViewType::is_storage_transposed ? minor_offset : major_offset; - auto dst_offset = GraphViewType::is_storage_transposed ? major_offset : minor_offset; - - for (auto it = lower_it; it != upper_it; ++it) { - assert(*it == minor); - if (!edge_partition_e_mask || - ((*edge_partition_e_mask).get(edge_offset + thrust::distance(indices, it)))) { - auto e_op_result = - e_op(src, - dst, - edge_partition_src_value_input.get(src_offset), - edge_partition_dst_value_input.get(dst_offset), - edge_partition_e_value_input.get(edge_offset + thrust::distance(indices, it))); - edge_partition_e_value_output.set(edge_offset + thrust::distance(indices, it), - e_op_result); - } - } - }); + if (edge_partition_e_mask) { + thrust::for_each(handle.get_thrust_policy(), + edge_first + edge_partition_offsets[i], + edge_first + edge_partition_offsets[i + 1], + detail::update_e_value_t{ + edge_partition, + edge_partition_src_value_input, + edge_partition_dst_value_input, + edge_partition_e_value_input, + *edge_partition_e_mask, + e_op, + edge_partition_e_value_output}); + } else { + thrust::for_each(handle.get_thrust_policy(), + edge_first + edge_partition_offsets[i], + edge_first + edge_partition_offsets[i + 1], + detail::update_e_value_t{ + edge_partition, + edge_partition_src_value_input, + edge_partition_dst_value_input, + edge_partition_e_value_input, + std::byte{}, // dummy + e_op, + edge_partition_e_value_output}); + } } } diff --git a/cpp/src/prims/transform_reduce_e.cuh b/cpp/src/prims/transform_reduce_e.cuh index 483ab64dcd9..7acc7461268 100644 --- a/cpp/src/prims/transform_reduce_e.cuh +++ b/cpp/src/prims/transform_reduce_e.cuh @@ -1,5 +1,5 @@ /* - * Copyright (c) 2020-2023, NVIDIA CORPORATION. + * Copyright (c) 2020-2024, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -15,6 +15,7 @@ */ #pragma once +#include #include #include @@ -89,48 +90,51 @@ __global__ void transform_reduce_e_hypersparse( while (idx < static_cast(dcs_nzd_vertex_count)) { auto major = *(edge_partition.major_from_major_hypersparse_idx_nocheck(static_cast(idx))); + auto major_offset = edge_partition.major_offset_from_major_nocheck(major); auto major_idx = major_start_offset + idx; // major_offset != major_idx in the hypersparse region vertex_t const* indices{nullptr}; edge_t edge_offset{}; edge_t local_degree{}; thrust::tie(indices, edge_offset, local_degree) = edge_partition.local_edges(major_idx); - auto sum = thrust::transform_reduce( - thrust::seq, - thrust::make_counting_iterator(edge_t{0}), - thrust::make_counting_iterator(local_degree), - [&edge_partition, - &edge_partition_src_value_input, - &edge_partition_dst_value_input, - &edge_partition_e_value_input, - &edge_partition_e_mask, - &e_op, - major, - indices, - edge_offset] __device__(auto i) { - if (!edge_partition_e_mask || (*edge_partition_e_mask).get(edge_offset + i)) { - auto major_offset = edge_partition.major_offset_from_major_nocheck(major); - auto minor = indices[i]; - auto minor_offset = edge_partition.minor_offset_from_minor_nocheck(minor); - auto src = GraphViewType::is_storage_transposed ? minor : major; - auto dst = GraphViewType::is_storage_transposed ? major : minor; - auto src_offset = GraphViewType::is_storage_transposed - ? minor_offset - : static_cast(major_offset); - auto dst_offset = GraphViewType::is_storage_transposed - ? static_cast(major_offset) - : minor_offset; - return e_op(src, - dst, - edge_partition_src_value_input.get(src_offset), - edge_partition_dst_value_input.get(dst_offset), - edge_partition_e_value_input.get(edge_offset + i)); - } else { - return e_op_result_t{}; - } - }, - e_op_result_t{}, - edge_property_add); + + auto call_e_op = call_e_op_t{edge_partition, + edge_partition_src_value_input, + edge_partition_dst_value_input, + edge_partition_e_value_input, + e_op, + major, + major_offset, + indices, + edge_offset}; + + e_op_result_t sum{}; + if (edge_partition_e_mask) { + sum = thrust::transform_reduce( + thrust::seq, + thrust::make_counting_iterator(edge_t{0}), + thrust::make_counting_iterator(local_degree), + [&edge_partition_e_mask, &call_e_op, edge_offset] __device__(auto i) { + if ((*edge_partition_e_mask).get(edge_offset + i)) { + return call_e_op(i); + } else { + return e_op_result_t{}; + } + }, + e_op_result_t{}, + edge_property_add); + } else { + sum = thrust::transform_reduce(thrust::seq, + thrust::make_counting_iterator(edge_t{0}), + thrust::make_counting_iterator(local_degree), + call_e_op, + e_op_result_t{}, + edge_property_add); + } e_op_result_sum = edge_property_add(e_op_result_sum, sum); idx += gridDim.x * blockDim.x; @@ -175,50 +179,50 @@ __global__ void transform_reduce_e_low_degree( property_op edge_property_add{}; e_op_result_t e_op_result_sum{}; while (idx < static_cast(major_range_last - major_range_first)) { - auto major_offset = major_start_offset + idx; + auto major_offset = static_cast(major_start_offset + idx); + auto major = edge_partition.major_from_major_offset_nocheck(major_offset); vertex_t const* indices{nullptr}; edge_t edge_offset{}; edge_t local_degree{}; thrust::tie(indices, edge_offset, local_degree) = edge_partition.local_edges(major_offset); - auto sum = thrust::transform_reduce( - thrust::seq, - thrust::make_counting_iterator(edge_t{0}), - thrust::make_counting_iterator(local_degree), - [&edge_partition, - &edge_partition_src_value_input, - &edge_partition_dst_value_input, - &edge_partition_e_value_input, - &edge_partition_e_mask, - &e_op, - major_offset, - indices, - edge_offset] __device__(auto i) { - if (!edge_partition_e_mask || (*edge_partition_e_mask).get(edge_offset + i)) { - auto minor = indices[i]; - auto minor_offset = edge_partition.minor_offset_from_minor_nocheck(minor); - auto src = GraphViewType::is_storage_transposed - ? minor - : edge_partition.major_from_major_offset_nocheck(major_offset); - auto dst = GraphViewType::is_storage_transposed - ? edge_partition.major_from_major_offset_nocheck(major_offset) - : minor; - auto src_offset = GraphViewType::is_storage_transposed - ? minor_offset - : static_cast(major_offset); - auto dst_offset = GraphViewType::is_storage_transposed - ? static_cast(major_offset) - : minor_offset; - return e_op(src, - dst, - edge_partition_src_value_input.get(src_offset), - edge_partition_dst_value_input.get(dst_offset), - edge_partition_e_value_input.get(edge_offset + i)); - } else { - return e_op_result_t{}; - } - }, - e_op_result_t{}, - edge_property_add); + + auto call_e_op = call_e_op_t{edge_partition, + edge_partition_src_value_input, + edge_partition_dst_value_input, + edge_partition_e_value_input, + e_op, + major, + major_offset, + indices, + edge_offset}; + + e_op_result_t sum{}; + if (edge_partition_e_mask) { + sum = thrust::transform_reduce( + thrust::seq, + thrust::make_counting_iterator(edge_t{0}), + thrust::make_counting_iterator(local_degree), + [&edge_partition_e_mask, &call_e_op, edge_offset] __device__(auto i) { + if ((*edge_partition_e_mask).get(edge_offset + i)) { + return call_e_op(i); + } else { + return e_op_result_t{}; + } + }, + e_op_result_t{}, + edge_property_add); + } else { + sum = thrust::transform_reduce(thrust::seq, + thrust::make_counting_iterator(edge_t{0}), + thrust::make_counting_iterator(local_degree), + call_e_op, + e_op_result_t{}, + edge_property_add); + } e_op_result_sum = edge_property_add(e_op_result_sum, sum); idx += gridDim.x * blockDim.x; @@ -264,30 +268,37 @@ __global__ void transform_reduce_e_mid_degree( property_op edge_property_add{}; e_op_result_t e_op_result_sum{}; while (idx < static_cast(major_range_last - major_range_first)) { - auto major_offset = major_start_offset + idx; + auto major_offset = static_cast(major_start_offset + idx); + auto major = edge_partition.major_from_major_offset_nocheck(major_offset); vertex_t const* indices{nullptr}; edge_t edge_offset{}; edge_t local_degree{}; thrust::tie(indices, edge_offset, local_degree) = edge_partition.local_edges(major_offset); - for (edge_t i = lane_id; i < local_degree; i += raft::warp_size()) { - if (!edge_partition_e_mask || (*edge_partition_e_mask).get(edge_offset + i)) { - auto minor = indices[i]; - auto minor_offset = edge_partition.minor_offset_from_minor_nocheck(minor); - auto src = GraphViewType::is_storage_transposed - ? minor - : edge_partition.major_from_major_offset_nocheck(major_offset); - auto dst = GraphViewType::is_storage_transposed - ? edge_partition.major_from_major_offset_nocheck(major_offset) - : minor; - auto src_offset = - GraphViewType::is_storage_transposed ? minor_offset : static_cast(major_offset); - auto dst_offset = - GraphViewType::is_storage_transposed ? static_cast(major_offset) : minor_offset; - auto e_op_result = e_op(src, - dst, - edge_partition_src_value_input.get(src_offset), - edge_partition_dst_value_input.get(dst_offset), - edge_partition_e_value_input.get(edge_offset + i)); + + auto call_e_op = call_e_op_t{edge_partition, + edge_partition_src_value_input, + edge_partition_dst_value_input, + edge_partition_e_value_input, + e_op, + major, + major_offset, + indices, + edge_offset}; + + if (edge_partition_e_mask) { + for (edge_t i = lane_id; i < local_degree; i += raft::warp_size()) { + if ((*edge_partition_e_mask).get(edge_offset + i)) { + auto e_op_result = call_e_op(i); + e_op_result_sum = edge_property_add(e_op_result_sum, e_op_result); + } + } + } else { + for (edge_t i = lane_id; i < local_degree; i += raft::warp_size()) { + auto e_op_result = call_e_op(i); e_op_result_sum = edge_property_add(e_op_result_sum, e_op_result); } } @@ -331,30 +342,37 @@ __global__ void transform_reduce_e_high_degree( property_op edge_property_add{}; e_op_result_t e_op_result_sum{}; while (idx < static_cast(major_range_last - major_range_first)) { - auto major_offset = major_start_offset + idx; + auto major_offset = static_cast(major_start_offset + idx); + auto major = edge_partition.major_from_major_offset_nocheck(major_offset); vertex_t const* indices{nullptr}; edge_t edge_offset{}; edge_t local_degree{}; thrust::tie(indices, edge_offset, local_degree) = edge_partition.local_edges(major_offset); - for (edge_t i = threadIdx.x; i < local_degree; i += blockDim.x) { - if (!edge_partition_e_mask || (*edge_partition_e_mask).get(edge_offset + i)) { - auto minor = indices[i]; - auto minor_offset = edge_partition.minor_offset_from_minor_nocheck(minor); - auto src = GraphViewType::is_storage_transposed - ? minor - : edge_partition.major_from_major_offset_nocheck(major_offset); - auto dst = GraphViewType::is_storage_transposed - ? edge_partition.major_from_major_offset_nocheck(major_offset) - : minor; - auto src_offset = - GraphViewType::is_storage_transposed ? minor_offset : static_cast(major_offset); - auto dst_offset = - GraphViewType::is_storage_transposed ? static_cast(major_offset) : minor_offset; - auto e_op_result = e_op(src, - dst, - edge_partition_src_value_input.get(src_offset), - edge_partition_dst_value_input.get(dst_offset), - edge_partition_e_value_input.get(edge_offset + i)); + + auto call_e_op = call_e_op_t{edge_partition, + edge_partition_src_value_input, + edge_partition_dst_value_input, + edge_partition_e_value_input, + e_op, + major, + major_offset, + indices, + edge_offset}; + + if (edge_partition_e_mask) { + for (edge_t i = threadIdx.x; i < local_degree; i += blockDim.x) { + if ((*edge_partition_e_mask).get(edge_offset + i)) { + auto e_op_result = call_e_op(i); + e_op_result_sum = edge_property_add(e_op_result_sum, e_op_result); + } + } + } else { + for (edge_t i = threadIdx.x; i < local_degree; i += blockDim.x) { + auto e_op_result = call_e_op(i); e_op_result_sum = edge_property_add(e_op_result_sum, e_op_result); } }