Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix access violation in resume_after and resume_on_signal #1342

Closed
Closed
Show file tree
Hide file tree
Changes from 1 commit
Commits
Show all changes
16 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
58 changes: 35 additions & 23 deletions strings/base_coroutine_threadpool.h
Original file line number Diff line number Diff line change
Expand Up @@ -381,10 +381,23 @@ namespace winrt::impl
template <typename T>
void await_suspend(impl::coroutine_handle<T> handle)
{
set_cancellable_promise_from_handle(handle);
handle_type<timer_traits> new_timer;
new_timer.attach(check_pointer(WINRT_IMPL_CreateThreadpoolTimer(callback, this, nullptr)));

m_handle = handle;
create_threadpool_timer();
state expected = state::idle;
if (m_state.compare_exchange_strong(expected, state::pending, std::memory_order_release))
{
set_cancellable_promise_from_handle(handle);
sylveon marked this conversation as resolved.
Show resolved Hide resolved

m_handle = handle;
m_timer = std::move(new_timer);

set_threadpool_timer();
}
else
{
throw hresult_illegal_method_call();
}
}

void await_resume()
Expand All @@ -396,17 +409,10 @@ namespace winrt::impl
}

private:
void create_threadpool_timer()
void set_threadpool_timer()
{
m_timer.attach(check_pointer(WINRT_IMPL_CreateThreadpoolTimer(callback, this, nullptr)));
int64_t relative_count = -m_duration.count();
WINRT_IMPL_SetThreadpoolTimer(m_timer.get(), &relative_count, 0, 0);

state expected = state::idle;
if (!m_state.compare_exchange_strong(expected, state::pending, std::memory_order_release))
{
fire_immediately();
}
}

static int32_t __stdcall fallback_SetThreadpoolTimerEx(winrt::impl::ptp_timer, void*, uint32_t, uint32_t) noexcept
kennykerr marked this conversation as resolved.
Show resolved Hide resolved
Expand Down Expand Up @@ -495,12 +501,25 @@ namespace winrt::impl
}

template <typename T>
void await_suspend(impl::coroutine_handle<T> resume)
void await_suspend(impl::coroutine_handle<T> handle)
{
set_cancellable_promise_from_handle(resume);
handle_type<wait_traits> new_wait;
new_wait.attach(check_pointer(WINRT_IMPL_CreateThreadpoolWait(callback, this, nullptr)));

m_resume = resume;
create_threadpool_wait();
state expected = state::idle;
if (m_state.compare_exchange_strong(expected, state::pending, std::memory_order_release))
{
set_cancellable_promise_from_handle(handle);

m_resume = handle;
m_wait = std::move(new_wait);

set_threadpool_wait();
}
else
{
throw hresult_illegal_method_call();
}
}

bool await_resume()
Expand All @@ -518,18 +537,11 @@ namespace winrt::impl
return 0; // pretend wait has already triggered and a callback is on its way
}

void create_threadpool_wait()
void set_threadpool_wait()
{
m_wait.attach(check_pointer(WINRT_IMPL_CreateThreadpoolWait(callback, this, nullptr)));
int64_t relative_count = -m_timeout.count();
int64_t* file_time = relative_count != 0 ? &relative_count : nullptr;
WINRT_IMPL_SetThreadpoolWait(m_wait.get(), m_handle, file_time);

state expected = state::idle;
if (!m_state.compare_exchange_strong(expected, state::pending, std::memory_order_release))
{
fire_immediately();
}
}

void fire_immediately() noexcept
Expand Down
42 changes: 42 additions & 0 deletions test/old_tests/UnitTests/async.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1553,6 +1553,26 @@ TEST_CASE("async, resume_after")
REQUIRE(after != GetCurrentThreadId());
}

namespace
{
IAsyncAction test_resume_after_illegal_state(winrt::impl::timespan_awaiter &awaiter)
{
co_await awaiter;
}
}

TEST_CASE("async, resume_after, illegal_state")
{
auto awaiter = resume_after(1s);

IAsyncAction first = test_resume_after_illegal_state(awaiter);
IAsyncAction second = test_resume_after_illegal_state(awaiter);

REQUIRE_THROWS_AS(second.get(), hresult_illegal_method_call);

first.get(); // allow first coroutine to succeed
}

//
// Other tests already excercise resume_on_signal so here we focus on testing the timeout.
//
Expand Down Expand Up @@ -1584,3 +1604,25 @@ TEST_CASE("async, resume_on_signal")
SetEvent(event.get()); // allow final resume_on_signal to succeed
async.get();
}

namespace
{
IAsyncAction test_resume_on_signal_illegal_state(winrt::impl::signal_awaiter &awaiter)
{
co_await awaiter;
}
}

TEST_CASE("async, resume_on_signal, illegal_state")
{
handle event { CreateEvent(nullptr, false, false, nullptr) };
auto awaiter = resume_on_signal(event.get());

IAsyncAction first = test_resume_on_signal_illegal_state(awaiter);
IAsyncAction second = test_resume_on_signal_illegal_state(awaiter);

REQUIRE_THROWS_AS(second.get(), hresult_illegal_method_call);

SetEvent(event.get()); // allow first coroutine to succeed
first.get();
}