diff --git a/Cargo.lock b/Cargo.lock index f64a6534..b6a54500 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -3723,7 +3723,7 @@ dependencies = [ [[package]] name = "wrpc-runtime-wasmtime" -version = "0.17.3" +version = "0.17.4" dependencies = [ "anyhow", "bytes", diff --git a/crates/runtime-wasmtime/Cargo.toml b/crates/runtime-wasmtime/Cargo.toml index 36b42651..af27313a 100644 --- a/crates/runtime-wasmtime/Cargo.toml +++ b/crates/runtime-wasmtime/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "wrpc-runtime-wasmtime" -version = "0.17.3" +version = "0.17.4" description = "wRPC wasmtime integration" authors.workspace = true diff --git a/crates/runtime-wasmtime/src/lib.rs b/crates/runtime-wasmtime/src/lib.rs index 020efeb6..96a181ee 100644 --- a/crates/runtime-wasmtime/src/lib.rs +++ b/crates/runtime-wasmtime/src/lib.rs @@ -14,6 +14,7 @@ use futures::future::try_join_all; use futures::stream::FuturesUnordered; use futures::{Stream, TryStreamExt as _}; use tokio::io::{AsyncRead, AsyncReadExt as _, AsyncWrite, AsyncWriteExt as _}; +use tokio::sync::Mutex; use tokio::try_join; use tokio_util::codec::{Encoder, FramedRead}; use tokio_util::compat::FuturesAsyncReadCompatExt as _; @@ -991,7 +992,7 @@ where pub trait ServeExt: wrpc_transport::Serve { fn serve_function( &self, - store: impl AsContextMut + Send + Clone + 'static, + store: impl Into>>>, func: Func, instance: &str, name: &str, @@ -1001,26 +1002,32 @@ pub trait ServeExt: wrpc_transport::Serve { where T: WasiView + 'static, { + let store = store.into(); async move { - let params_ty: Arc<[_]> = Arc::from(func.params(&store)); - let results_ty: Arc<[_]> = Arc::from(func.results(&store)); + let (params_ty, results_ty): (Arc<[_]>, Arc<[_]>) = { + let store = store.lock().await; + ( + Arc::from(func.params(&*store)), + Arc::from(func.results(&*store)), + ) + }; // TODO: set paths let invocations = self.serve(instance, name, []).await?; - Ok(invocations.and_then(move |(cx, mut tx, rx)| { - let mut store = store.clone(); + let store = store.clone(); let params_ty = Arc::clone(¶ms_ty); let results_ty = Arc::clone(&results_ty); async move { + let mut store = store.lock().await; let mut params = vec![Val::Bool(false); params_ty.len()]; let mut rx = pin!(rx); for (i, (v, ty)) in zip(&mut params, params_ty.iter()).enumerate() { - read_value(&mut store, &mut rx, v, ty, &[i]) + read_value(&mut *store, &mut rx, v, ty, &[i]) .await .with_context(|| format!("failed to decode parameter value {i}"))?; } let mut results = vec![Val::Bool(false); results_ty.len()]; - func.call_async(&mut store, ¶ms, &mut results) + func.call_async(&mut *store, ¶ms, &mut results) .await .context("failed to call function")?; let mut buf = BytesMut::default(); @@ -1047,7 +1054,7 @@ pub trait ServeExt: wrpc_transport::Serve { }), ) .await?; - func.post_return_async(&mut store) + func.post_return_async(&mut *store) .await .context("failed to perform post-return cleanup")?; Ok(cx) @@ -1058,3 +1065,24 @@ pub trait ServeExt: wrpc_transport::Serve { } impl ServeExt for T {} + +#[cfg(test)] +mod tests { + use tokio::sync::Mutex; + + use super::*; + + #[allow(unused)] + async fn serve_function( + srv: impl wrpc_transport::Serve, + store: wasmtime::Store, + func: Func, + ) { + srv.serve_function(Mutex::new(store), func, "foo", "bar") + .await + .unwrap() + .try_for_each(|_| async { Ok(()) }) + .await + .unwrap() + } +}