Skip to content

Commit

Permalink
trying wrapping iterators alone
Browse files Browse the repository at this point in the history
  • Loading branch information
elstehle committed Jan 12, 2025
1 parent cc7c1bb commit c6430d1
Showing 1 changed file with 58 additions and 16 deletions.
74 changes: 58 additions & 16 deletions cub/cub/device/device_segmented_sort.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@
#pragma once

#include <cub/config.cuh>
#include "thrust/iterator/constant_iterator.h"

#if defined(_CCCL_IMPLICIT_SYSTEM_HEADER_GCC)
# pragma GCC system_header
Expand All @@ -47,6 +48,38 @@

CUB_NAMESPACE_BEGIN


template <typename Iterator, typename OffsetItT>
class OffsetIteratorT : public THRUST_NS_QUALIFIER::iterator_adaptor<OffsetIteratorT<Iterator, OffsetItT>, Iterator>
{
public:
using super_t = THRUST_NS_QUALIFIER::iterator_adaptor<OffsetIteratorT<Iterator, OffsetItT>, Iterator>;

OffsetIteratorT() = default;

_CCCL_HOST_DEVICE OffsetIteratorT(const Iterator& it, OffsetItT offset_it)
: super_t(it)
, offset_it(offset_it)
{}

// befriend thrust::iterator_core_access to allow it access to the private interface below
friend class THRUST_NS_QUALIFIER::iterator_core_access;

private:
OffsetItT offset_it;

_CCCL_HOST_DEVICE typename super_t::reference dereference() const
{
return *(this->base() + (*offset_it));
}
};

template <typename Iterator, typename OffsetItT>
_CCCL_HOST_DEVICE OffsetIteratorT<Iterator, OffsetItT> make_offset_iterator(const Iterator& it, OffsetItT offset_it)
{
return OffsetIteratorT<Iterator, OffsetItT>{it, offset_it};
}

//! @rst
//! DeviceSegmentedSort provides device-wide, parallel operations for
//! computing a batched sort across multiple, non-overlapping sequences of
Expand Down Expand Up @@ -146,10 +179,12 @@ private:
EndOffsetIteratorT d_end_offsets,
cudaStream_t stream = 0)
{
using offset_it_t = OffsetIteratorT<EndOffsetIteratorT, thrust::constant_iterator<int>>;

constexpr bool is_descending = false;
constexpr bool is_overwrite_okay = false;
using DispatchT =
DispatchSegmentedSort<is_descending, KeyT, cub::NullType, int, BeginOffsetIteratorT, EndOffsetIteratorT>;
DispatchSegmentedSort<is_descending, KeyT, cub::NullType, int, BeginOffsetIteratorT, offset_it_t>;

DoubleBuffer<KeyT> d_keys(const_cast<KeyT*>(d_keys_in), d_keys_out);
DoubleBuffer<NullType> d_values;
Expand All @@ -162,7 +197,7 @@ private:
num_items,
num_segments,
d_begin_offsets,
d_end_offsets,
{d_end_offsets, thrust::make_constant_iterator(0)},
is_overwrite_okay,
stream);
}
Expand Down Expand Up @@ -321,8 +356,9 @@ private:
{
constexpr bool is_descending = true;
constexpr bool is_overwrite_okay = false;
using offset_it_t = OffsetIteratorT<EndOffsetIteratorT, thrust::constant_iterator<int>>;
using DispatchT =
DispatchSegmentedSort<is_descending, KeyT, cub::NullType, int, BeginOffsetIteratorT, EndOffsetIteratorT>;
DispatchSegmentedSort<is_descending, KeyT, cub::NullType, int, BeginOffsetIteratorT, offset_it_t>;

DoubleBuffer<KeyT> d_keys(const_cast<KeyT*>(d_keys_in), d_keys_out);
DoubleBuffer<NullType> d_values;
Expand All @@ -335,7 +371,7 @@ private:
num_items,
num_segments,
d_begin_offsets,
d_end_offsets,
{d_end_offsets, thrust::make_constant_iterator(0)},
is_overwrite_okay,
stream);
}
Expand Down Expand Up @@ -488,9 +524,10 @@ private:
{
constexpr bool is_descending = false;
constexpr bool is_overwrite_okay = true;
using offset_it_t = OffsetIteratorT<EndOffsetIteratorT, thrust::constant_iterator<int>>;

using DispatchT =
DispatchSegmentedSort<is_descending, KeyT, cub::NullType, int, BeginOffsetIteratorT, EndOffsetIteratorT>;
DispatchSegmentedSort<is_descending, KeyT, cub::NullType, int, BeginOffsetIteratorT, offset_it_t>;

DoubleBuffer<NullType> d_values;

Expand All @@ -502,7 +539,7 @@ private:
num_items,
num_segments,
d_begin_offsets,
d_end_offsets,
{d_end_offsets, thrust::make_constant_iterator(0)},
is_overwrite_okay,
stream);
}
Expand Down Expand Up @@ -658,9 +695,10 @@ private:
{
constexpr bool is_descending = true;
constexpr bool is_overwrite_okay = true;
using offset_it_t = OffsetIteratorT<EndOffsetIteratorT, thrust::constant_iterator<int>>;

using DispatchT =
DispatchSegmentedSort<is_descending, KeyT, cub::NullType, int, BeginOffsetIteratorT, EndOffsetIteratorT>;
DispatchSegmentedSort<is_descending, KeyT, cub::NullType, int, BeginOffsetIteratorT, offset_it_t>;

DoubleBuffer<NullType> d_values;

Expand All @@ -672,7 +710,7 @@ private:
num_items,
num_segments,
d_begin_offsets,
d_end_offsets,
{d_end_offsets, thrust::make_constant_iterator(0)},
is_overwrite_okay,
stream);
}
Expand Down Expand Up @@ -1379,7 +1417,8 @@ private:
{
constexpr bool is_descending = false;
constexpr bool is_overwrite_okay = false;
using DispatchT = DispatchSegmentedSort<is_descending, KeyT, ValueT, int, BeginOffsetIteratorT, EndOffsetIteratorT>;
using offset_it_t = OffsetIteratorT<EndOffsetIteratorT, thrust::constant_iterator<int>>;
using DispatchT = DispatchSegmentedSort<is_descending, KeyT, ValueT, int, BeginOffsetIteratorT, offset_it_t>;

DoubleBuffer<KeyT> d_keys(const_cast<KeyT*>(d_keys_in), d_keys_out);
DoubleBuffer<ValueT> d_values(const_cast<ValueT*>(d_values_in), d_values_out);
Expand All @@ -1392,7 +1431,7 @@ private:
num_items,
num_segments,
d_begin_offsets,
d_end_offsets,
{d_end_offsets, thrust::make_constant_iterator(0)},
is_overwrite_okay,
stream);
}
Expand Down Expand Up @@ -1578,7 +1617,8 @@ private:
{
constexpr bool is_descending = true;
constexpr bool is_overwrite_okay = false;
using DispatchT = DispatchSegmentedSort<is_descending, KeyT, ValueT, int, BeginOffsetIteratorT, EndOffsetIteratorT>;
using offset_it_t = OffsetIteratorT<EndOffsetIteratorT, thrust::constant_iterator<int>>;
using DispatchT = DispatchSegmentedSort<is_descending, KeyT, ValueT, int, BeginOffsetIteratorT, offset_it_t>;

DoubleBuffer<KeyT> d_keys(const_cast<KeyT*>(d_keys_in), d_keys_out);
DoubleBuffer<ValueT> d_values(const_cast<ValueT*>(d_values_in), d_values_out);
Expand All @@ -1591,7 +1631,7 @@ private:
num_items,
num_segments,
d_begin_offsets,
d_end_offsets,
{d_end_offsets, thrust::make_constant_iterator(0)},
is_overwrite_okay,
stream);
}
Expand Down Expand Up @@ -1771,7 +1811,8 @@ private:
{
constexpr bool is_descending = false;
constexpr bool is_overwrite_okay = true;
using DispatchT = DispatchSegmentedSort<is_descending, KeyT, ValueT, int, BeginOffsetIteratorT, EndOffsetIteratorT>;
using offset_it_t = OffsetIteratorT<EndOffsetIteratorT, thrust::constant_iterator<int>>;
using DispatchT = DispatchSegmentedSort<is_descending, KeyT, ValueT, int, BeginOffsetIteratorT, offset_it_t>;

return DispatchT::Dispatch(
d_temp_storage,
Expand All @@ -1781,7 +1822,7 @@ private:
num_items,
num_segments,
d_begin_offsets,
d_end_offsets,
{d_end_offsets, thrust::make_constant_iterator(0)},
is_overwrite_okay,
stream);
}
Expand Down Expand Up @@ -1966,7 +2007,8 @@ private:
{
constexpr bool is_descending = true;
constexpr bool is_overwrite_okay = true;
using DispatchT = DispatchSegmentedSort<is_descending, KeyT, ValueT, int, BeginOffsetIteratorT, EndOffsetIteratorT>;
using offset_it_t = OffsetIteratorT<EndOffsetIteratorT, thrust::constant_iterator<int>>;
using DispatchT = DispatchSegmentedSort<is_descending, KeyT, ValueT, int, BeginOffsetIteratorT, offset_it_t>;

return DispatchT::Dispatch(
d_temp_storage,
Expand All @@ -1976,7 +2018,7 @@ private:
num_items,
num_segments,
d_begin_offsets,
d_end_offsets,
{d_end_offsets, thrust::make_constant_iterator(0)},
is_overwrite_okay,
stream);
}
Expand Down

0 comments on commit c6430d1

Please sign in to comment.