From 2121409fba0a1ae7ebfc007f1249f548e5af36e2 Mon Sep 17 00:00:00 2001 From: Tommy Volk Date: Thu, 3 Oct 2024 08:18:20 -0500 Subject: [PATCH] refactor: Wallet owns its update stream --- Cargo.lock | 1 + Cargo.toml | 1 + src/app.rs | 4 +- src/fedimint.rs | 131 ++++++++++++++++++++++++++++++++++++------------ 4 files changed, 103 insertions(+), 34 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index 938ead6..a39733e 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -3270,6 +3270,7 @@ dependencies = [ "secp256k1 0.29.1", "tempfile", "tokio", + "tokio-stream", "tracing-subscriber", ] diff --git a/Cargo.toml b/Cargo.toml index c0a80a4..6ead152 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -42,6 +42,7 @@ nostr-sdk = "0.35.0" palette = "0.7.6" secp256k1 = { version = "0.29.1", features = ["global-context"] } tokio = "1.40.0" +tokio-stream = "0.1.16" tracing-subscriber = "0.3.18" [dev-dependencies] diff --git a/src/app.rs b/src/app.rs index b298143..5378d6f 100644 --- a/src/app.rs +++ b/src/app.rs @@ -185,9 +185,7 @@ impl App { // outer `stream!` is created on every update, but will only be polled if the subscription // ID is new. async_stream::stream! { - let mut stream = wallet - .get_update_stream() - .map(Message::UpdateWalletView); + let mut stream = wallet.get_update_stream().map(Message::UpdateWalletView); while let Some(msg) = stream.next().await { yield msg; diff --git a/src/fedimint.rs b/src/fedimint.rs index 0f2e3dd..7064488 100644 --- a/src/fedimint.rs +++ b/src/fedimint.rs @@ -1,9 +1,9 @@ -use std::fmt::Display; -use std::pin::Pin; use std::{ collections::{BTreeMap, HashMap}, + fmt::Display, path::PathBuf, sync::Arc, + time::Duration, }; use directories::ProjectDirs; @@ -15,10 +15,6 @@ use fedimint_core::{config::FederationId, db::Database, invite_code::InviteCode, use fedimint_ln_client::{LightningClientModule, LnReceiveState}; use fedimint_ln_common::{LightningGateway, LightningGatewayAnnouncement}; use fedimint_rocksdb::RocksDb; -use iced::futures::{ - lock::{Mutex, MutexGuard}, - StreamExt, -}; use lightning_invoice::{Bolt11Invoice, Bolt11InvoiceDescription, Description}; use nostr_sdk::{ bip39::Mnemonic, @@ -29,15 +25,20 @@ use nostr_sdk::{ }, }; use secp256k1::rand::{seq::SliceRandom, thread_rng}; +use tokio::sync::{mpsc, oneshot, watch, Mutex, MutexGuard}; +use tokio_stream::StreamExt; use crate::util::format_amount; const FEDIMINT_CLIENTS_DATA_DIR_NAME: &str = "fedimint_clients"; + // TODO: Figure out if we even want this. If we do, it probably shouldn't live here. // It'd make more sense for it to live wherever the key is maintained elsewhere, and // have `Wallet::new()` assume that the key is already derived. const FEDIMINT_DERIVATION_NUMBER: u32 = 1; +const WALLET_VIEW_UPDATE_INTERVAL: Duration = Duration::from_secs(5); + pub enum LightningReceiveCompletion { Success, Failure, @@ -73,45 +74,103 @@ pub struct Wallet { derivable_secret: DerivableSecret, clients: Arc>>, fedimint_clients_data_dir: PathBuf, + view_update_receiver: watch::Receiver, + // Used to tell `Self.view_update_task` to immediately update the view. + // If the view has changed, the task will yield a new view message. + // Then the oneshot sender is used to tell the caller that the view + // is now up to date (even if no new value was yielded). + force_update_view_sender: mpsc::Sender>, + view_update_task: tokio::task::JoinHandle<()>, +} + +impl Drop for Wallet { + fn drop(&mut self) { + // TODO: We should properly shut down the task rather than aborting it. + self.view_update_task.abort(); + } } impl Wallet { pub fn new(xprivkey: Xpriv, network: Network, project_dirs: &ProjectDirs) -> Self { - Self { - derivable_secret: get_derivable_secret(&xprivkey, network), - clients: Arc::new(Mutex::new(HashMap::new())), - fedimint_clients_data_dir: project_dirs.data_dir().join(FEDIMINT_CLIENTS_DATA_DIR_NAME), - } - } + let (view_update_sender, view_update_receiver) = watch::channel(WalletView { + federations: BTreeMap::new(), + }); - // TODO: Optimize this. Repeated polling is not ideal. - pub fn get_update_stream( - &self, - ) -> Pin + Send>> { - let clients = self.clients.clone(); - Box::pin(async_stream::stream! { + let (force_update_view_sender, mut force_update_view_receiver) = + mpsc::channel::>(100); + + let clients = Arc::new(Mutex::new(HashMap::new())); + + let clients_clone = clients.clone(); + let view_update_task = tokio::spawn(async move { let mut last_state_or = None; + + // TODO: Optimize this. Repeated polling is not ideal. loop { - let current_state = Self::get_current_state(clients.lock().await).await; + // Wait either for a force update or for a timeout. If a force update + // occurs, then `force_update_completed_oneshot_or` will be `Some`. + // If a timeout occurs, then `force_update_completed_oneshot_or` will be `None`. + let force_update_completed_oneshot_or = tokio::select! { + Some(force_update_completed_oneshot) = force_update_view_receiver.recv() => Some(force_update_completed_oneshot), + () = tokio::time::sleep(WALLET_VIEW_UPDATE_INTERVAL) => None, + }; + + let current_state = Self::get_current_state(clients_clone.lock().await).await; // Ignoring clippy lint here since the `match` provides better clarity. #[allow(clippy::option_if_let_else)] let has_changed = match &last_state_or { - Some(last_state) => { - ¤t_state != last_state - } + Some(last_state) => ¤t_state != last_state, // If there was no last state, the state has changed. None => true, }; if has_changed { last_state_or = Some(current_state.clone()); - yield current_state; + + // If all receivers have been dropped, stop the task. + if view_update_sender.send(current_state).is_err() { + break; + } } - tokio::time::sleep(std::time::Duration::from_secs(1)).await; + // If this iteration was triggered by a force update, then send a message + // back to the caller to indicate that the view is now up to date. + if let Some(force_update_completed_oneshot) = force_update_completed_oneshot_or { + let _ = force_update_completed_oneshot.send(()); + } } - }) + }); + + Self { + derivable_secret: get_derivable_secret(&xprivkey, network), + clients, + fedimint_clients_data_dir: project_dirs.data_dir().join(FEDIMINT_CLIENTS_DATA_DIR_NAME), + view_update_receiver, + force_update_view_sender, + view_update_task, + } + } + + pub fn get_update_stream(&self) -> tokio_stream::wrappers::WatchStream { + tokio_stream::wrappers::WatchStream::new(self.view_update_receiver.clone()) + } + + /// Tell `view_update_task` to update the view, and wait for it to complete. + /// This ensures any streams opened by `get_update_stream` have yielded the + /// latest view. This function should be called at the end of any function + /// that modifies the view. + /// + /// Note: This function takes a `MutexGuard` to ensure that the lock isn't + /// held while waiting for the view to update, which could cause a deadlock. + async fn force_update_view( + &self, + clients: MutexGuard<'_, HashMap>, + ) { + drop(clients); + let (sender, receiver) = oneshot::channel(); + let _ = self.force_update_view_sender.send(sender).await; + let _ = receiver.await; } pub async fn connect_to_joined_federations(&self) -> anyhow::Result<()> { @@ -151,6 +210,8 @@ impl Wallet { clients.insert(federation_id, client); } + self.force_update_view(clients).await; + Ok(()) } @@ -176,9 +237,17 @@ impl Wallet { clients.insert(federation_id, client); + self.force_update_view(clients).await; + Ok(()) } + /// Constructs the current view of the wallet. + /// SHOULD ONLY BE CALLED FROM THE `view_update_task`. + /// This way, `view_update_task` can only yield values + /// when the view is changed, with the guarantee that + /// the view hasn't been updated elsewhere in a way that + /// could de-sync the view. async fn get_current_state( clients: MutexGuard<'_, HashMap>, ) -> WalletView { @@ -230,6 +299,8 @@ impl Wallet { .wait_for_ln_payment(payment_info.payment_type, payment_info.contract_id, false) .await?; + self.force_update_view(clients).await; + Ok(()) } @@ -238,10 +309,7 @@ impl Wallet { federation_id: FederationId, amount: Amount, description: String, - ) -> anyhow::Result<( - Bolt11Invoice, - iced::futures::channel::oneshot::Receiver, - )> { + ) -> anyhow::Result<(Bolt11Invoice, oneshot::Receiver)> { let clients = self.clients.lock().await; let client = clients @@ -267,8 +335,7 @@ impl Wallet { .await? .into_stream(); - let (payment_completion_sender, payment_completion_receiver) = - iced::futures::channel::oneshot::channel(); + let (payment_completion_sender, payment_completion_receiver) = oneshot::channel(); tokio::spawn(async move { while let Some(update) = update_stream.next().await { @@ -288,6 +355,8 @@ impl Wallet { } }); + self.force_update_view(clients).await; + Ok((invoice, payment_completion_receiver)) }