diff --git a/core/Cargo.toml b/core/Cargo.toml index 7b136f3cab..3566807b1a 100644 --- a/core/Cargo.toml +++ b/core/Cargo.toml @@ -24,7 +24,6 @@ serde_json = { version = "1", features = ["raw_value"] } tracing = "0.1.34" # optional deps -async-lock = { version = "3.0", optional = true } futures-util = { version = "0.3.14", default-features = false, optional = true } hyper = { version = "0.14.10", default-features = false, features = ["stream"], optional = true } rustc-hash = { version = "1", optional = true } @@ -51,7 +50,6 @@ server = [ ] client = ["futures-util/sink", "tokio/sync"] async-client = [ - "async-lock", "client", "futures-util/alloc", "rustc-hash", @@ -63,7 +61,6 @@ async-client = [ "pin-project", ] async-wasm-client = [ - "async-lock", "client", "futures-util/alloc", "wasm-bindgen-futures", diff --git a/core/src/client/async_client/mod.rs b/core/src/client/async_client/mod.rs index a8c7f3d746..1eb7cc88bd 100644 --- a/core/src/client/async_client/mod.rs +++ b/core/src/client/async_client/mod.rs @@ -52,7 +52,6 @@ use jsonrpsee_types::{InvalidRequestId, ResponseSuccess, TwoPointZero}; use manager::RequestManager; use std::sync::Arc; -use async_lock::RwLock as AsyncRwLock; use async_trait::async_trait; use futures_timer::Delay; use futures_util::future::{self, Either}; @@ -69,6 +68,7 @@ use self::utils::{InactivityCheck, IntervalStream}; use super::{generate_batch_id_range, FrontToBack, IdKind, RequestIdManager}; const LOG_TARGET: &str = "jsonrpsee-client"; +const NOT_POISONED: &str = "Not poisoned; qed"; /// Configuration for WebSocket ping/pong mechanism and it may be used to disconnect /// an inactive connection. @@ -142,67 +142,39 @@ impl ThreadSafeRequestManager { } pub(crate) fn lock(&self) -> std::sync::MutexGuard { - self.0.lock().expect("Not poisoned; qed") + self.0.lock().expect(NOT_POISONED) } } + +pub(crate) type SharedDisconnectReason = Arc>>>; + /// If the background thread is terminated, this type /// can be used to read the error cause. /// // NOTE: This is an AsyncRwLock to be &self. #[derive(Debug)] -struct ErrorFromBack(AsyncRwLock>); +struct ErrorFromBack { + conn: mpsc::Sender, + disconnect_reason: SharedDisconnectReason, +} impl ErrorFromBack { - fn new(unread: oneshot::Receiver) -> Self { - Self(AsyncRwLock::new(Some(ReadErrorOnce::Unread(unread)))) + fn new(conn: mpsc::Sender, disconnect_reason: SharedDisconnectReason) -> Self { + Self { conn, disconnect_reason } } async fn read_error(&self) -> Error { - const PROOF: &str = "Option is only is used to workaround ownership issue and is always Some; qed"; + // When the background task is closed the error is written to `disconnect_reason`. + self.conn.closed().await; - if let ReadErrorOnce::Read(ref err) = self.0.read().await.as_ref().expect(PROOF) { - return Error::RestartNeeded(err.clone()); - }; - - let mut write = self.0.write().await; - let state = write.take(); - - let err = match state.expect(PROOF) { - ReadErrorOnce::Unread(rx) => { - let arc_err = Arc::new(match rx.await { - Ok(err) => err, - // This should never happen because the receiving end is still alive. - // Before shutting down the background task a error message should - // be emitted. - Err(_) => Error::Custom( - "Error reason could not be found. This is a bug. Please open an issue.".to_string(), - ), - }); - *write = Some(ReadErrorOnce::Read(arc_err.clone())); - arc_err - } - ReadErrorOnce::Read(arc_err) => { - *write = Some(ReadErrorOnce::Read(arc_err.clone())); - arc_err - } - }; - - Error::RestartNeeded(err) + if let Some(err) = self.disconnect_reason.read().expect(NOT_POISONED).as_ref() { + Error::RestartNeeded(err.clone()) + } else { + Error::Custom("Error reason could not be found. This is a bug. Please open an issue.".to_string()) + } } } -/// Wrapper over a [`oneshot::Receiver`] that reads -/// the underlying channel once and then stores the result in String. -/// It is possible that the error is read more than once if several calls are made -/// when the background thread has been terminated. -#[derive(Debug)] -enum ReadErrorOnce { - /// Error message is already read. - Read(Arc), - /// Error message is unread. - Unread(oneshot::Receiver), -} - /// Builder for [`Client`]. #[derive(Debug, Copy, Clone)] pub struct ClientBuilder { @@ -318,7 +290,7 @@ impl ClientBuilder { R: TransportReceiverT + Send, { let (to_back, from_front) = mpsc::channel(self.max_concurrent_requests); - let (err_to_front, err_from_back) = oneshot::channel::(); + let disconnect_reason = SharedDisconnectReason::default(); let max_buffer_capacity_per_subscription = self.max_buffer_capacity_per_subscription; let (client_dropped_tx, client_dropped_rx) = oneshot::channel(); let (send_receive_task_sync_tx, send_receive_task_sync_rx) = mpsc::channel(1); @@ -366,12 +338,12 @@ impl ClientBuilder { inactivity_stream, })); - tokio::spawn(wait_for_shutdown(send_receive_task_sync_rx, client_dropped_rx, err_to_front)); + tokio::spawn(wait_for_shutdown(send_receive_task_sync_rx, client_dropped_rx, disconnect_reason.clone())); Client { - to_back, + to_back: to_back.clone(), request_timeout: self.request_timeout, - error: ErrorFromBack::new(err_from_back), + error: ErrorFromBack::new(to_back, disconnect_reason), id_manager: RequestIdManager::new(self.max_concurrent_requests, self.id_kind), max_log_length: self.max_log_length, on_exit: Some(client_dropped_tx), @@ -391,7 +363,7 @@ impl ClientBuilder { type PendingIntervalStream = IntervalStream>; let (to_back, from_front) = mpsc::channel(self.max_concurrent_requests); - let (err_to_front, err_from_back) = oneshot::channel::(); + let disconnect_reason = SharedDisconnectReason::default(); let max_buffer_capacity_per_subscription = self.max_buffer_capacity_per_subscription; let (client_dropped_tx, client_dropped_rx) = oneshot::channel(); let (send_receive_task_sync_tx, send_receive_task_sync_rx) = mpsc::channel(1); @@ -423,13 +395,13 @@ impl ClientBuilder { wasm_bindgen_futures::spawn_local(wait_for_shutdown( send_receive_task_sync_rx, client_dropped_rx, - err_to_front, + disconnect_reason.clone(), )); Client { - to_back, + to_back: to_back.clone(), request_timeout: self.request_timeout, - error: ErrorFromBack::new(err_from_back), + error: ErrorFromBack::new(to_back, disconnect_reason), id_manager: RequestIdManager::new(self.max_concurrent_requests, self.id_kind), max_log_length: self.max_log_length, on_exit: Some(client_dropped_tx), @@ -474,7 +446,7 @@ impl Client { /// /// # Cancel-safety /// - /// This method is not cancel-safe + /// This method is cancel-safe pub async fn disconnect_reason(&self) -> Error { self.error.read_error().await } @@ -1070,7 +1042,7 @@ where async fn wait_for_shutdown( mut close_rx: mpsc::Receiver>, client_dropped: oneshot::Receiver<()>, - err_to_front: oneshot::Sender, + err_to_front: SharedDisconnectReason, ) { let rx_item = close_rx.recv(); @@ -1078,6 +1050,6 @@ async fn wait_for_shutdown( // Send an error to the frontend if the send or receive task completed with an error. if let Either::Left((Some(Err(err)), _)) = future::select(rx_item, client_dropped).await { - let _ = err_to_front.send(err); + *err_to_front.write().expect(NOT_POISONED) = Some(Arc::new(err)); } }