From ec66dbcb6c72df1261b059ec9bb6f4b962a25a61 Mon Sep 17 00:00:00 2001
From: Pat Hickey
Date: Mon, 13 Jan 2025 16:53:38 -0800
Subject: [PATCH] code motion: AbortOnDropJoinHandle lives in a mod just for
now
at some point soon this will get some other abstraction so we aren't
tied directly to tokio for tasks
---
crates/wasi/src/runtime/task.rs | 53 +++++++++++++++++++++++++++++++++
1 file changed, 53 insertions(+)
create mode 100644 crates/wasi/src/runtime/task.rs
diff --git a/crates/wasi/src/runtime/task.rs b/crates/wasi/src/runtime/task.rs
new file mode 100644
index 000000000000..389998b2f6ce
--- /dev/null
+++ b/crates/wasi/src/runtime/task.rs
@@ -0,0 +1,53 @@
+use std::future::Future;
+use std::pin::Pin;
+use std::task::{Context, Poll};
+/// Exactly like a [`tokio::task::JoinHandle`], except that it aborts the task when
+/// the handle is dropped.
+///
+/// This behavior makes it easier to tie a worker task to the lifetime of a Resource
+/// by keeping this handle owned by the Resource.
+#[derive(Debug)]
+pub struct AbortOnDropJoinHandle(tokio::task::JoinHandle);
+impl AbortOnDropJoinHandle {
+ /// Abort the task and wait for it to finish. Optionally returns the result
+ /// of the task if it ran to completion prior to being aborted.
+ pub(crate) async fn cancel(mut self) -> Option {
+ self.0.abort();
+
+ match (&mut self.0).await {
+ Ok(value) => Some(value),
+ Err(err) if err.is_cancelled() => None,
+ Err(err) => std::panic::resume_unwind(err.into_panic()),
+ }
+ }
+}
+impl Drop for AbortOnDropJoinHandle {
+ fn drop(&mut self) {
+ self.0.abort()
+ }
+}
+impl std::ops::Deref for AbortOnDropJoinHandle {
+ type Target = tokio::task::JoinHandle;
+ fn deref(&self) -> &Self::Target {
+ &self.0
+ }
+}
+impl std::ops::DerefMut for AbortOnDropJoinHandle {
+ fn deref_mut(&mut self) -> &mut tokio::task::JoinHandle {
+ &mut self.0
+ }
+}
+impl From> for AbortOnDropJoinHandle {
+ fn from(jh: tokio::task::JoinHandle) -> Self {
+ AbortOnDropJoinHandle(jh)
+ }
+}
+impl Future for AbortOnDropJoinHandle {
+ type Output = T;
+ fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll {
+ match Pin::new(&mut self.as_mut().0).poll(cx) {
+ Poll::Pending => Poll::Pending,
+ Poll::Ready(r) => Poll::Ready(r.expect("child task panicked")),
+ }
+ }
+}