Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

refactor: Wallet owns its update stream #51

Merged
merged 1 commit into from
Oct 3, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@ nostr-sdk = "0.35.0"
palette = "0.7.6"
secp256k1 = { version = "0.29.1", features = ["global-context"] }
tokio = "1.40.0"
tokio-stream = "0.1.16"
tracing-subscriber = "0.3.18"

[dev-dependencies]
Expand Down
4 changes: 1 addition & 3 deletions src/app.rs
Original file line number Diff line number Diff line change
Expand Up @@ -185,9 +185,7 @@ impl App {
// outer `stream!` is created on every update, but will only be polled if the subscription
// ID is new.
async_stream::stream! {
let mut stream = wallet
.get_update_stream()
.map(Message::UpdateWalletView);
let mut stream = wallet.get_update_stream().map(Message::UpdateWalletView);

while let Some(msg) = stream.next().await {
yield msg;
Expand Down
131 changes: 100 additions & 31 deletions src/fedimint.rs
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
use std::fmt::Display;
use std::pin::Pin;
use std::{
collections::{BTreeMap, HashMap},
fmt::Display,
path::PathBuf,
sync::Arc,
time::Duration,
};

use directories::ProjectDirs;
Expand All @@ -15,10 +15,6 @@ use fedimint_core::{config::FederationId, db::Database, invite_code::InviteCode,
use fedimint_ln_client::{LightningClientModule, LnReceiveState};
use fedimint_ln_common::{LightningGateway, LightningGatewayAnnouncement};
use fedimint_rocksdb::RocksDb;
use iced::futures::{
lock::{Mutex, MutexGuard},
StreamExt,
};
use lightning_invoice::{Bolt11Invoice, Bolt11InvoiceDescription, Description};
use nostr_sdk::{
bip39::Mnemonic,
Expand All @@ -29,15 +25,20 @@ use nostr_sdk::{
},
};
use secp256k1::rand::{seq::SliceRandom, thread_rng};
use tokio::sync::{mpsc, oneshot, watch, Mutex, MutexGuard};
use tokio_stream::StreamExt;

use crate::util::format_amount;

const FEDIMINT_CLIENTS_DATA_DIR_NAME: &str = "fedimint_clients";

// TODO: Figure out if we even want this. If we do, it probably shouldn't live here.
// It'd make more sense for it to live wherever the key is maintained elsewhere, and
// have `Wallet::new()` assume that the key is already derived.
const FEDIMINT_DERIVATION_NUMBER: u32 = 1;

const WALLET_VIEW_UPDATE_INTERVAL: Duration = Duration::from_secs(5);

pub enum LightningReceiveCompletion {
Success,
Failure,
Expand Down Expand Up @@ -73,45 +74,103 @@ pub struct Wallet {
derivable_secret: DerivableSecret,
clients: Arc<Mutex<HashMap<FederationId, ClientHandle>>>,
fedimint_clients_data_dir: PathBuf,
view_update_receiver: watch::Receiver<WalletView>,
// Used to tell `Self.view_update_task` to immediately update the view.
// If the view has changed, the task will yield a new view message.
// Then the oneshot sender is used to tell the caller that the view
// is now up to date (even if no new value was yielded).
force_update_view_sender: mpsc::Sender<oneshot::Sender<()>>,
view_update_task: tokio::task::JoinHandle<()>,
}

impl Drop for Wallet {
fn drop(&mut self) {
// TODO: We should properly shut down the task rather than aborting it.
self.view_update_task.abort();
}
}

impl Wallet {
pub fn new(xprivkey: Xpriv, network: Network, project_dirs: &ProjectDirs) -> Self {
Self {
derivable_secret: get_derivable_secret(&xprivkey, network),
clients: Arc::new(Mutex::new(HashMap::new())),
fedimint_clients_data_dir: project_dirs.data_dir().join(FEDIMINT_CLIENTS_DATA_DIR_NAME),
}
}
let (view_update_sender, view_update_receiver) = watch::channel(WalletView {
federations: BTreeMap::new(),
});

// TODO: Optimize this. Repeated polling is not ideal.
pub fn get_update_stream(
&self,
) -> Pin<Box<dyn iced::futures::Stream<Item = WalletView> + Send>> {
let clients = self.clients.clone();
Box::pin(async_stream::stream! {
let (force_update_view_sender, mut force_update_view_receiver) =
mpsc::channel::<oneshot::Sender<()>>(100);

let clients = Arc::new(Mutex::new(HashMap::new()));

let clients_clone = clients.clone();
let view_update_task = tokio::spawn(async move {
let mut last_state_or = None;

// TODO: Optimize this. Repeated polling is not ideal.
loop {
let current_state = Self::get_current_state(clients.lock().await).await;
// Wait either for a force update or for a timeout. If a force update
// occurs, then `force_update_completed_oneshot_or` will be `Some`.
// If a timeout occurs, then `force_update_completed_oneshot_or` will be `None`.
let force_update_completed_oneshot_or = tokio::select! {
Some(force_update_completed_oneshot) = force_update_view_receiver.recv() => Some(force_update_completed_oneshot),
() = tokio::time::sleep(WALLET_VIEW_UPDATE_INTERVAL) => None,
};

let current_state = Self::get_current_state(clients_clone.lock().await).await;

// Ignoring clippy lint here since the `match` provides better clarity.
#[allow(clippy::option_if_let_else)]
let has_changed = match &last_state_or {
Some(last_state) => {
&current_state != last_state
}
Some(last_state) => &current_state != last_state,
// If there was no last state, the state has changed.
None => true,
};

if has_changed {
last_state_or = Some(current_state.clone());
yield current_state;

// If all receivers have been dropped, stop the task.
if view_update_sender.send(current_state).is_err() {
break;
}
}

tokio::time::sleep(std::time::Duration::from_secs(1)).await;
// If this iteration was triggered by a force update, then send a message
// back to the caller to indicate that the view is now up to date.
if let Some(force_update_completed_oneshot) = force_update_completed_oneshot_or {
let _ = force_update_completed_oneshot.send(());
}
}
})
});

Self {
derivable_secret: get_derivable_secret(&xprivkey, network),
clients,
fedimint_clients_data_dir: project_dirs.data_dir().join(FEDIMINT_CLIENTS_DATA_DIR_NAME),
view_update_receiver,
force_update_view_sender,
view_update_task,
}
}

pub fn get_update_stream(&self) -> tokio_stream::wrappers::WatchStream<WalletView> {
tokio_stream::wrappers::WatchStream::new(self.view_update_receiver.clone())
}

/// Tell `view_update_task` to update the view, and wait for it to complete.
/// This ensures any streams opened by `get_update_stream` have yielded the
/// latest view. This function should be called at the end of any function
/// that modifies the view.
///
/// Note: This function takes a `MutexGuard` to ensure that the lock isn't
/// held while waiting for the view to update, which could cause a deadlock.
async fn force_update_view(
&self,
clients: MutexGuard<'_, HashMap<FederationId, ClientHandle>>,
) {
drop(clients);
let (sender, receiver) = oneshot::channel();
let _ = self.force_update_view_sender.send(sender).await;
let _ = receiver.await;
}

pub async fn connect_to_joined_federations(&self) -> anyhow::Result<()> {
Expand Down Expand Up @@ -151,6 +210,8 @@ impl Wallet {
clients.insert(federation_id, client);
}

self.force_update_view(clients).await;

Ok(())
}

Expand All @@ -176,9 +237,17 @@ impl Wallet {

clients.insert(federation_id, client);

self.force_update_view(clients).await;

Ok(())
}

/// Constructs the current view of the wallet.
/// SHOULD ONLY BE CALLED FROM THE `view_update_task`.
/// This way, `view_update_task` can only yield values
/// when the view is changed, with the guarantee that
/// the view hasn't been updated elsewhere in a way that
/// could de-sync the view.
async fn get_current_state(
clients: MutexGuard<'_, HashMap<FederationId, ClientHandle>>,
) -> WalletView {
Expand Down Expand Up @@ -230,6 +299,8 @@ impl Wallet {
.wait_for_ln_payment(payment_info.payment_type, payment_info.contract_id, false)
.await?;

self.force_update_view(clients).await;

Ok(())
}

Expand All @@ -238,10 +309,7 @@ impl Wallet {
federation_id: FederationId,
amount: Amount,
description: String,
) -> anyhow::Result<(
Bolt11Invoice,
iced::futures::channel::oneshot::Receiver<LightningReceiveCompletion>,
)> {
) -> anyhow::Result<(Bolt11Invoice, oneshot::Receiver<LightningReceiveCompletion>)> {
let clients = self.clients.lock().await;

let client = clients
Expand All @@ -267,8 +335,7 @@ impl Wallet {
.await?
.into_stream();

let (payment_completion_sender, payment_completion_receiver) =
iced::futures::channel::oneshot::channel();
let (payment_completion_sender, payment_completion_receiver) = oneshot::channel();

tokio::spawn(async move {
while let Some(update) = update_stream.next().await {
Expand All @@ -288,6 +355,8 @@ impl Wallet {
}
});

self.force_update_view(clients).await;

Ok((invoice, payment_completion_receiver))
}

Expand Down
Loading