From ec66dbcb6c72df1261b059ec9bb6f4b962a25a61 Mon Sep 17 00:00:00 2001 From: Pat Hickey Date: Mon, 13 Jan 2025 16:53:38 -0800 Subject: [PATCH] code motion: AbortOnDropJoinHandle lives in a mod just for now at some point soon this will get some other abstraction so we aren't tied directly to tokio for tasks --- crates/wasi/src/runtime/task.rs | 53 +++++++++++++++++++++++++++++++++ 1 file changed, 53 insertions(+) create mode 100644 crates/wasi/src/runtime/task.rs diff --git a/crates/wasi/src/runtime/task.rs b/crates/wasi/src/runtime/task.rs new file mode 100644 index 000000000000..389998b2f6ce --- /dev/null +++ b/crates/wasi/src/runtime/task.rs @@ -0,0 +1,53 @@ +use std::future::Future; +use std::pin::Pin; +use std::task::{Context, Poll}; +/// Exactly like a [`tokio::task::JoinHandle`], except that it aborts the task when +/// the handle is dropped. +/// +/// This behavior makes it easier to tie a worker task to the lifetime of a Resource +/// by keeping this handle owned by the Resource. +#[derive(Debug)] +pub struct AbortOnDropJoinHandle(tokio::task::JoinHandle); +impl AbortOnDropJoinHandle { + /// Abort the task and wait for it to finish. Optionally returns the result + /// of the task if it ran to completion prior to being aborted. + pub(crate) async fn cancel(mut self) -> Option { + self.0.abort(); + + match (&mut self.0).await { + Ok(value) => Some(value), + Err(err) if err.is_cancelled() => None, + Err(err) => std::panic::resume_unwind(err.into_panic()), + } + } +} +impl Drop for AbortOnDropJoinHandle { + fn drop(&mut self) { + self.0.abort() + } +} +impl std::ops::Deref for AbortOnDropJoinHandle { + type Target = tokio::task::JoinHandle; + fn deref(&self) -> &Self::Target { + &self.0 + } +} +impl std::ops::DerefMut for AbortOnDropJoinHandle { + fn deref_mut(&mut self) -> &mut tokio::task::JoinHandle { + &mut self.0 + } +} +impl From> for AbortOnDropJoinHandle { + fn from(jh: tokio::task::JoinHandle) -> Self { + AbortOnDropJoinHandle(jh) + } +} +impl Future for AbortOnDropJoinHandle { + type Output = T; + fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { + match Pin::new(&mut self.as_mut().0).poll(cx) { + Poll::Pending => Poll::Pending, + Poll::Ready(r) => Poll::Ready(r.expect("child task panicked")), + } + } +}