Skip to content

Commit

Permalink
refactor: Wallet owns its update stream
Browse files Browse the repository at this point in the history
  • Loading branch information
tvolk131 committed Oct 3, 2024
1 parent 02875cc commit 2121409
Show file tree
Hide file tree
Showing 4 changed files with 103 additions and 34 deletions.
1 change: 1 addition & 0 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@ nostr-sdk = "0.35.0"
palette = "0.7.6"
secp256k1 = { version = "0.29.1", features = ["global-context"] }
tokio = "1.40.0"
tokio-stream = "0.1.16"
tracing-subscriber = "0.3.18"

[dev-dependencies]
Expand Down
4 changes: 1 addition & 3 deletions src/app.rs
Original file line number Diff line number Diff line change
Expand Up @@ -185,9 +185,7 @@ impl App {
// outer `stream!` is created on every update, but will only be polled if the subscription
// ID is new.
async_stream::stream! {
let mut stream = wallet
.get_update_stream()
.map(Message::UpdateWalletView);
let mut stream = wallet.get_update_stream().map(Message::UpdateWalletView);

while let Some(msg) = stream.next().await {
yield msg;
Expand Down
131 changes: 100 additions & 31 deletions src/fedimint.rs
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
use std::fmt::Display;
use std::pin::Pin;
use std::{
collections::{BTreeMap, HashMap},
fmt::Display,
path::PathBuf,
sync::Arc,
time::Duration,
};

use directories::ProjectDirs;
Expand All @@ -15,10 +15,6 @@ use fedimint_core::{config::FederationId, db::Database, invite_code::InviteCode,
use fedimint_ln_client::{LightningClientModule, LnReceiveState};
use fedimint_ln_common::{LightningGateway, LightningGatewayAnnouncement};
use fedimint_rocksdb::RocksDb;
use iced::futures::{
lock::{Mutex, MutexGuard},
StreamExt,
};
use lightning_invoice::{Bolt11Invoice, Bolt11InvoiceDescription, Description};
use nostr_sdk::{
bip39::Mnemonic,
Expand All @@ -29,15 +25,20 @@ use nostr_sdk::{
},
};
use secp256k1::rand::{seq::SliceRandom, thread_rng};
use tokio::sync::{mpsc, oneshot, watch, Mutex, MutexGuard};
use tokio_stream::StreamExt;

use crate::util::format_amount;

const FEDIMINT_CLIENTS_DATA_DIR_NAME: &str = "fedimint_clients";

// TODO: Figure out if we even want this. If we do, it probably shouldn't live here.
// It'd make more sense for it to live wherever the key is maintained elsewhere, and
// have `Wallet::new()` assume that the key is already derived.
const FEDIMINT_DERIVATION_NUMBER: u32 = 1;

const WALLET_VIEW_UPDATE_INTERVAL: Duration = Duration::from_secs(5);

pub enum LightningReceiveCompletion {
Success,
Failure,
Expand Down Expand Up @@ -73,45 +74,103 @@ pub struct Wallet {
derivable_secret: DerivableSecret,
clients: Arc<Mutex<HashMap<FederationId, ClientHandle>>>,
fedimint_clients_data_dir: PathBuf,
view_update_receiver: watch::Receiver<WalletView>,
// Used to tell `Self.view_update_task` to immediately update the view.
// If the view has changed, the task will yield a new view message.
// Then the oneshot sender is used to tell the caller that the view
// is now up to date (even if no new value was yielded).
force_update_view_sender: mpsc::Sender<oneshot::Sender<()>>,
view_update_task: tokio::task::JoinHandle<()>,
}

impl Drop for Wallet {
fn drop(&mut self) {
// TODO: We should properly shut down the task rather than aborting it.
self.view_update_task.abort();
}
}

impl Wallet {
pub fn new(xprivkey: Xpriv, network: Network, project_dirs: &ProjectDirs) -> Self {
Self {
derivable_secret: get_derivable_secret(&xprivkey, network),
clients: Arc::new(Mutex::new(HashMap::new())),
fedimint_clients_data_dir: project_dirs.data_dir().join(FEDIMINT_CLIENTS_DATA_DIR_NAME),
}
}
let (view_update_sender, view_update_receiver) = watch::channel(WalletView {
federations: BTreeMap::new(),
});

// TODO: Optimize this. Repeated polling is not ideal.
pub fn get_update_stream(
&self,
) -> Pin<Box<dyn iced::futures::Stream<Item = WalletView> + Send>> {
let clients = self.clients.clone();
Box::pin(async_stream::stream! {
let (force_update_view_sender, mut force_update_view_receiver) =
mpsc::channel::<oneshot::Sender<()>>(100);

let clients = Arc::new(Mutex::new(HashMap::new()));

let clients_clone = clients.clone();
let view_update_task = tokio::spawn(async move {
let mut last_state_or = None;

// TODO: Optimize this. Repeated polling is not ideal.
loop {
let current_state = Self::get_current_state(clients.lock().await).await;
// Wait either for a force update or for a timeout. If a force update
// occurs, then `force_update_completed_oneshot_or` will be `Some`.
// If a timeout occurs, then `force_update_completed_oneshot_or` will be `None`.
let force_update_completed_oneshot_or = tokio::select! {
Some(force_update_completed_oneshot) = force_update_view_receiver.recv() => Some(force_update_completed_oneshot),
() = tokio::time::sleep(WALLET_VIEW_UPDATE_INTERVAL) => None,
};

let current_state = Self::get_current_state(clients_clone.lock().await).await;

// Ignoring clippy lint here since the `match` provides better clarity.
#[allow(clippy::option_if_let_else)]
let has_changed = match &last_state_or {
Some(last_state) => {
&current_state != last_state
}
Some(last_state) => &current_state != last_state,
// If there was no last state, the state has changed.
None => true,
};

if has_changed {
last_state_or = Some(current_state.clone());
yield current_state;

// If all receivers have been dropped, stop the task.
if view_update_sender.send(current_state).is_err() {
break;
}
}

tokio::time::sleep(std::time::Duration::from_secs(1)).await;
// If this iteration was triggered by a force update, then send a message
// back to the caller to indicate that the view is now up to date.
if let Some(force_update_completed_oneshot) = force_update_completed_oneshot_or {
let _ = force_update_completed_oneshot.send(());
}
}
})
});

Self {
derivable_secret: get_derivable_secret(&xprivkey, network),
clients,
fedimint_clients_data_dir: project_dirs.data_dir().join(FEDIMINT_CLIENTS_DATA_DIR_NAME),
view_update_receiver,
force_update_view_sender,
view_update_task,
}
}

pub fn get_update_stream(&self) -> tokio_stream::wrappers::WatchStream<WalletView> {
tokio_stream::wrappers::WatchStream::new(self.view_update_receiver.clone())
}

/// Tell `view_update_task` to update the view, and wait for it to complete.
/// This ensures any streams opened by `get_update_stream` have yielded the
/// latest view. This function should be called at the end of any function
/// that modifies the view.
///
/// Note: This function takes a `MutexGuard` to ensure that the lock isn't
/// held while waiting for the view to update, which could cause a deadlock.
async fn force_update_view(
&self,
clients: MutexGuard<'_, HashMap<FederationId, ClientHandle>>,
) {
drop(clients);
let (sender, receiver) = oneshot::channel();
let _ = self.force_update_view_sender.send(sender).await;
let _ = receiver.await;
}

pub async fn connect_to_joined_federations(&self) -> anyhow::Result<()> {
Expand Down Expand Up @@ -151,6 +210,8 @@ impl Wallet {
clients.insert(federation_id, client);
}

self.force_update_view(clients).await;

Ok(())
}

Expand All @@ -176,9 +237,17 @@ impl Wallet {

clients.insert(federation_id, client);

self.force_update_view(clients).await;

Ok(())
}

/// Constructs the current view of the wallet.
/// SHOULD ONLY BE CALLED FROM THE `view_update_task`.
/// This way, `view_update_task` can only yield values
/// when the view is changed, with the guarantee that
/// the view hasn't been updated elsewhere in a way that
/// could de-sync the view.
async fn get_current_state(
clients: MutexGuard<'_, HashMap<FederationId, ClientHandle>>,
) -> WalletView {
Expand Down Expand Up @@ -230,6 +299,8 @@ impl Wallet {
.wait_for_ln_payment(payment_info.payment_type, payment_info.contract_id, false)
.await?;

self.force_update_view(clients).await;

Ok(())
}

Expand All @@ -238,10 +309,7 @@ impl Wallet {
federation_id: FederationId,
amount: Amount,
description: String,
) -> anyhow::Result<(
Bolt11Invoice,
iced::futures::channel::oneshot::Receiver<LightningReceiveCompletion>,
)> {
) -> anyhow::Result<(Bolt11Invoice, oneshot::Receiver<LightningReceiveCompletion>)> {
let clients = self.clients.lock().await;

let client = clients
Expand All @@ -267,8 +335,7 @@ impl Wallet {
.await?
.into_stream();

let (payment_completion_sender, payment_completion_receiver) =
iced::futures::channel::oneshot::channel();
let (payment_completion_sender, payment_completion_receiver) = oneshot::channel();

tokio::spawn(async move {
while let Some(update) = update_stream.next().await {
Expand All @@ -288,6 +355,8 @@ impl Wallet {
}
});

self.force_update_view(clients).await;

Ok((invoice, payment_completion_receiver))
}

Expand Down

0 comments on commit 2121409

Please sign in to comment.