Skip to content

Commit

Permalink
Merge pull request #1327 from NVIDIA/schedule-mem-fn-customization
Browse files Browse the repository at this point in the history
migrate `schedule` customizations from `tag_invoke` to member functions
  • Loading branch information
ericniebler authored May 14, 2024
2 parents 0a2666f + d5d1aac commit 4e573c3
Show file tree
Hide file tree
Showing 22 changed files with 147 additions and 194 deletions.
18 changes: 8 additions & 10 deletions examples/benchmark/static_thread_pool_old.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -148,8 +148,8 @@ namespace exec_old {
using __id = scheduler;
bool operator==(const scheduler&) const = default;

stdexec::forward_progress_guarantee
query(stdexec::get_forward_progress_guarantee_t) const noexcept {
auto query(stdexec::get_forward_progress_guarantee_t) const noexcept
-> stdexec::forward_progress_guarantee {
return stdexec::forward_progress_guarantee::parallel;
}

Expand All @@ -168,6 +168,7 @@ namespace exec_old {
using sender_concept = stdexec::sender_t;
using completion_signatures =
stdexec::completion_signatures<stdexec::set_value_t(), stdexec::set_stopped_t()>;

private:
template <typename Receiver>
auto make_operation_(Receiver r) const -> operation<stdexec::__id<Receiver>> {
Expand Down Expand Up @@ -203,21 +204,18 @@ namespace exec_old {
static_thread_pool& pool_;
};

sender make_sender_() const {
return sender{*pool_};
}

friend sender tag_invoke(stdexec::schedule_t, const scheduler& s) noexcept {
return s.make_sender_();
}

friend class static_thread_pool;

explicit scheduler(static_thread_pool& pool) noexcept
: pool_(&pool) {
}

static_thread_pool* pool_;

public:
sender schedule() const noexcept {
return sender{*pool_};
}
};

scheduler get_scheduler() noexcept {
Expand Down
45 changes: 21 additions & 24 deletions include/exec/any_sender_of.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -1058,6 +1058,11 @@ namespace exec {

using __sender_t = _ScheduleSender;

auto schedule() const noexcept -> __sender_t {
STDEXEC_ASSERT(__storage_.__get_vtable()->__schedule_);
return __storage_.__get_vtable()->__schedule_(__storage_.__get_object_pointer());
}

template <class _Tag, class... _As>
requires __callable<const __query_vtable<_SchedulerQueries, false>&, _Tag, void*, _As...>
auto query(_Tag, _As&&... __as) const //
Expand Down Expand Up @@ -1085,7 +1090,7 @@ namespace exec {
__mtype<__query_vtable<_SchedulerQueries, false>>{}, __mtype<_Scheduler>{})},
[](void* __object_pointer) noexcept -> __sender_t {
const _Scheduler& __scheduler = *static_cast<const _Scheduler*>(__object_pointer);
return __sender_t{schedule(__scheduler)};
return __sender_t{stdexec::schedule(__scheduler)};
},
[](const void* __self, const void* __other) noexcept -> bool {
static_assert(
Expand All @@ -1099,13 +1104,6 @@ namespace exec {
}
};

template <same_as<__scheduler> _Self>
STDEXEC_MEMFN_DECL(auto schedule)(this const _Self& __self) noexcept -> __sender_t {
STDEXEC_ASSERT(__self.__storage_.__get_vtable()->__schedule_);
return __self.__storage_.__get_vtable()->__schedule_(
__self.__storage_.__get_object_pointer());
}

friend auto
operator==(const __scheduler& __self, const __scheduler& __other) noexcept -> bool {
if (__self.__storage_.__get_vtable() != __other.__storage_.__get_vtable()) {
Expand Down Expand Up @@ -1245,31 +1243,30 @@ namespace exec {
__any::__scheduler<__schedule_sender, queries<_SchedulerQueries...>>;

__scheduler_base __scheduler_;

public:
using __t = any_scheduler;
using __id = any_scheduler;

template <class _Scheduler>
requires(
!stdexec::__decays_to<_Scheduler, any_scheduler> && stdexec::scheduler<_Scheduler>)
any_scheduler(_Scheduler&& __scheduler)
template <stdexec::__none_of<any_scheduler> _Scheduler>
requires stdexec::scheduler<_Scheduler>
any_scheduler(_Scheduler __scheduler)
: __scheduler_{static_cast<_Scheduler&&>(__scheduler)} {
}

private:
template <class _Tag, stdexec::__decays_to<any_scheduler> Self, class... _As>
requires stdexec::
tag_invocable<_Tag, stdexec::__copy_cvref_t<Self, __scheduler_base>, _As...>
friend auto tag_invoke(_Tag, Self&& __self, _As&&... __as) //
noexcept(
std::
is_nothrow_invocable_v<_Tag, stdexec::__copy_cvref_t<Self, __scheduler_base>, _As...>) {
return stdexec::tag_invoke(
_Tag{}, static_cast<Self&&>(__self).__scheduler_, static_cast<_As&&>(__as)...);
auto schedule() const noexcept -> __schedule_sender {
return __scheduler_.schedule();
}

template <class _Tag, class... _As>
requires stdexec::tag_invocable<_Tag, const __scheduler_base&, _As...>
auto query(_Tag, _As&&... __as) const
noexcept(stdexec::nothrow_tag_invocable<_Tag, const __scheduler_base&, _As...>)
-> stdexec::tag_invoke_result_t<_Tag, const __scheduler_base&, _As...> {
return stdexec::tag_invoke(_Tag(), __scheduler_, static_cast<_As&&>(__as)...);
}

friend auto operator==(const any_scheduler& __self, const any_scheduler& __other) noexcept
-> bool = default;
auto operator==(const any_scheduler&) const noexcept -> bool = default;
};
};
};
Expand Down
13 changes: 4 additions & 9 deletions include/exec/libdispatch_queue.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -194,21 +194,16 @@ namespace exec {
libdispatch_queue *queue;
};

sender make_sender() const {
auto schedule() const noexcept -> sender {
return sender{queue_};
}

STDEXEC_MEMFN_FRIEND(schedule);
STDEXEC_MEMFN_DECL(sender schedule)(this libdispatch_scheduler const &s) noexcept {
return s.make_sender();
}

domain query(stdexec::get_domain_t) const noexcept {
auto query(stdexec::get_domain_t) const noexcept -> domain {
return {};
}

stdexec::forward_progress_guarantee
query(stdexec::get_forward_progress_guarantee_t) const noexcept {
auto query(stdexec::get_forward_progress_guarantee_t) const noexcept
-> stdexec::forward_progress_guarantee {
return stdexec::forward_progress_guarantee::parallel;
}

Expand Down
7 changes: 2 additions & 5 deletions include/exec/linux/io_uring_context.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -1146,11 +1146,8 @@ namespace exec {
}
};

private:
STDEXEC_MEMFN_FRIEND(schedule);

STDEXEC_MEMFN_DECL(auto schedule)(this const __scheduler& __sched) -> __schedule_sender {
return __schedule_sender{__schedule_env{__sched.__context_}};
auto schedule() const -> __schedule_sender {
return __schedule_sender{__schedule_env{__context_}};
}

friend auto tag_invoke(exec::now_t, const __scheduler&) noexcept
Expand Down
2 changes: 1 addition & 1 deletion include/exec/reschedule.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,7 @@ namespace exec {
}
};

STDEXEC_MEMFN_DECL(auto schedule)(this __scheduler) noexcept -> __sender {
auto schedule() const noexcept -> __sender {
return {};
}

Expand Down
51 changes: 24 additions & 27 deletions include/exec/static_thread_pool.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -316,23 +316,11 @@ namespace exec {
~static_thread_pool_();

struct scheduler {
using __t = scheduler;
using __id = scheduler;
auto operator==(const scheduler&) const -> bool = default;

auto query(get_forward_progress_guarantee_t) const noexcept -> forward_progress_guarantee {
return forward_progress_guarantee::parallel;
}

auto query(get_domain_t) const noexcept -> domain {
return {};
}

private:
template <typename ReceiverId>
friend struct operation;

class sender {
class _sender {
struct env {
static_thread_pool_& pool_;
remote_queue* queue_;
Expand All @@ -345,8 +333,8 @@ namespace exec {
};

public:
using __t = sender;
using __id = sender;
using __t = _sender;
using __id = _sender;
using sender_concept = sender_t;
using completion_signatures =
stdexec::completion_signatures<set_value_t(), set_stopped_t()>;
Expand All @@ -365,13 +353,13 @@ namespace exec {
}

template <receiver Receiver>
STDEXEC_MEMFN_DECL(auto connect)(this sender sndr, Receiver rcvr) -> operation_t<Receiver> {
STDEXEC_MEMFN_DECL(auto connect)(this _sender sndr, Receiver rcvr) -> operation_t<Receiver> {
return sndr.make_operation_(static_cast<Receiver&&>(rcvr));
}

friend struct static_thread_pool_::scheduler;

explicit sender(
explicit _sender(
static_thread_pool_& pool,
remote_queue* queue,
std::size_t threadIndex,
Expand All @@ -388,15 +376,6 @@ namespace exec {
nodemask constraints_{};
};

[[nodiscard]]
auto make_sender_() const -> sender {
return sender{*pool_, queue_, thread_idx_, nodemask_};
}

STDEXEC_MEMFN_DECL(auto schedule)(this const scheduler& sch) noexcept -> sender {
return sch.make_sender_();
}

friend class static_thread_pool_;

explicit scheduler(
Expand Down Expand Up @@ -429,6 +408,24 @@ namespace exec {
remote_queue* queue_;
nodemask nodemask_;
std::size_t thread_idx_{std::numeric_limits<std::size_t>::max()};

public:
using __t = scheduler;
using __id = scheduler;
auto operator==(const scheduler&) const -> bool = default;

[[nodiscard]]
auto schedule() const noexcept -> _sender {
return _sender{*pool_, queue_, thread_idx_, nodemask_};
}

auto query(get_forward_progress_guarantee_t) const noexcept -> forward_progress_guarantee {
return forward_progress_guarantee::parallel;
}

auto query(get_domain_t) const noexcept -> domain {
return {};
}
};

auto get_scheduler() noexcept -> scheduler {
Expand Down Expand Up @@ -1004,7 +1001,7 @@ namespace exec {
template <typename ReceiverId>
class static_thread_pool_::operation<ReceiverId>::__t : public task_base {
using __id = operation;
friend static_thread_pool_::scheduler::sender;
friend static_thread_pool_::scheduler::_sender;

static_thread_pool_& pool_;
remote_queue* queue_;
Expand Down
12 changes: 4 additions & 8 deletions include/exec/timed_thread_scheduler.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -355,17 +355,13 @@ namespace exec {
return schedule_at{*self.context_, tp};
}

private:
STDEXEC_MEMFN_FRIEND(schedule);

STDEXEC_MEMFN_DECL(auto schedule)(this const timed_thread_scheduler& self) noexcept -> schedule_at {
return exec::schedule_at(self, time_point());
auto schedule() const noexcept -> schedule_at {
return exec::schedule_at(*this, time_point());
}

friend auto
operator==(const timed_thread_scheduler& sched1, const timed_thread_scheduler& sched2) noexcept
-> bool = default;
auto operator==(const timed_thread_scheduler&) const noexcept -> bool = default;

private:
timed_thread_context* context_;
};

Expand Down
6 changes: 3 additions & 3 deletions include/exec/trampoline_scheduler.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -160,11 +160,11 @@ namespace exec {
std::size_t __max_recursion_depth_;
};

STDEXEC_MEMFN_DECL(auto schedule)(this __scheduler __self) noexcept -> __schedule_sender {
return __schedule_sender{__self.__max_recursion_depth_};
public:
auto schedule() const noexcept -> __schedule_sender {
return __schedule_sender{__max_recursion_depth_};
}

public:
auto operator==(const __scheduler&) const noexcept -> bool = default;
};

Expand Down
4 changes: 2 additions & 2 deletions include/nvexec/multi_gpu_context.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -193,8 +193,8 @@ namespace nvexec {
return ensure_started_th<S>(static_cast<S&&>(sndr), sch.context_state_);
}

STDEXEC_MEMFN_DECL(sender_t schedule)(this const multi_gpu_stream_scheduler& self) noexcept {
return {self.num_devices_, self.context_state_};
sender_t schedule() const noexcept {
return {num_devices_, context_state_};
}

template <sender S>
Expand Down
5 changes: 3 additions & 2 deletions include/nvexec/stream_context.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -228,8 +228,9 @@ namespace nvexec {

STDEXEC_ATTRIBUTE((host, device))

inline STDEXEC_MEMFN_DECL(sender_t schedule)(this const stream_scheduler& self) noexcept {
return {self.context_state_};
inline auto
schedule() const noexcept -> sender_t {
return {context_state_};
}

STDEXEC_MEMFN_DECL(std::true_type __has_algorithm_customizations)(this const stream_scheduler& self) noexcept {
Expand Down
2 changes: 1 addition & 1 deletion include/stdexec/__detail/__let.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ namespace stdexec {
}
};

STDEXEC_MEMFN_DECL(auto schedule)(this __unknown_scheduler) noexcept {
auto schedule() const noexcept {
return __sender();
}

Expand Down
Loading

0 comments on commit 4e573c3

Please sign in to comment.