From 182ee1bead7f4ab331869b6267142f708ca8e7b0 Mon Sep 17 00:00:00 2001 From: Eric Tu Date: Fri, 13 Sep 2024 14:51:25 -0400 Subject: [PATCH] chunk as we pull off stream --- src/wallet.rs | 80 +++++++++++++++++++++++--------------------------- tests/tests.rs | 1 - 2 files changed, 37 insertions(+), 44 deletions(-) diff --git a/src/wallet.rs b/src/wallet.rs index 4b56572..dbd872f 100644 --- a/src/wallet.rs +++ b/src/wallet.rs @@ -7,9 +7,14 @@ use secrecy::{ExposeSecret, SecretVec, Zeroize}; use tonic::{ client::GrpcService, codegen::{Body, Bytes, StdError}, - Status, + Status, Streaming, }; +use crate::error::Error; +use crate::BlockRange; +use rayon::iter::IntoParallelIterator; +use rayon::iter::IntoParallelRefIterator; +use rayon::iter::ParallelIterator; use zcash_address::ZcashAddress; use zcash_client_backend::data_api::wallet::{ create_proposed_transactions, input_selection::GreedyInputSelector, propose_transfer, @@ -34,9 +39,6 @@ use zcash_primitives::transaction::components::amount::NonNegativeAmount; use zcash_primitives::transaction::fees::zip317::FeeRule; use zcash_primitives::transaction::TxId; use zcash_proofs::prover::LocalTxProver; - -use crate::error::Error; -use crate::BlockRange; const BATCH_SIZE: u32 = 10000; /// The maximum number of checkpoints to store in each shard-tree @@ -45,6 +47,8 @@ const PRUNING_DEPTH: usize = 100; type Proposal = zcash_client_backend::proposal::Proposal; +fn is_sync() {} +fn is_send() {} /// # A Zcash wallet /// /// A wallet is a set of accounts that can be synchronized together with the blockchain. @@ -171,25 +175,7 @@ where // TODO: Ensure wallet's view of the chain tip as of the previous wallet session is valid. // See https://github.com/Electric-Coin-Company/zec-sqlite-cli/blob/8c2e49f6d3067ec6cc85248488915278c3cb1c5a/src/commands/sync.rs#L157 - // Download and process all blocks in the requested ranges - // Split each range into BATCH_SIZE chunks to avoid requesting too many blocks at once - for scan_range in scan_ranges.into_iter().flat_map(|r| { - // Limit the number of blocks we download and scan at any one time. - (0..).scan(r, |acc, _| { - if acc.is_empty() { - None - } else if let Some((cur, next)) = acc.split_at(acc.block_range().start + BATCH_SIZE) - { - *acc = next; - Some(cur) - } else { - let cur = acc.clone(); - let end = acc.block_range().end; - *acc = ScanRange::from_parts(end..end, acc.priority()); - Some(cur) - } - }) - }) { + for scan_range in scan_ranges { self.fetch_and_scan_range( scan_range.block_range().start.into(), scan_range.block_range().end.into(), @@ -205,7 +191,7 @@ where &mut self, start: u32, end: u32, - ) -> Result>, Error> { + ) -> Result, Error> { let range = service::BlockRange { start: Some(service::BlockId { height: start.into(), @@ -217,9 +203,7 @@ where }), }; - let blocks = self.client.get_block_range(range).await?.into_inner(); - - Ok(blocks) + Ok(self.client.get_block_range(range).await?.into_inner()) } /// Download and process all blocks in the given range @@ -244,24 +228,34 @@ where self.db.get_orchard_nullifiers(NullifierQuery::Unspent)?, ); - tracing::info!("Scanning block range: {:?} to {:?}", start, end); - - let scanned_blocks = self + let mut chunked_block_stream = self .fetch_blocks(start, end) .await? - .map(|compact_block| { - scan_block( - &self.network, - compact_block.unwrap(), - &scanning_keys, - &nullifiers, - None, - ) - }) - .try_collect() - .await?; - - self.db.put_blocks(&chainstate, scanned_blocks)?; + .try_chunks(BATCH_SIZE as usize); + + while let Ok(Some(blocks)) = chunked_block_stream.try_next().await { + tracing::info!( + "Scanning block range: {:?} to {:?}", + blocks.first().unwrap().height, + blocks.last().unwrap().height + ); + + self.db.put_blocks( + &chainstate, + blocks + .into_iter() + .map(|compact_block| { + scan_block( + &self.network, + compact_block, + &scanning_keys, + &nullifiers, + None, + ) + }) + .collect::, _>>()?, + )?; + } Ok(()) } diff --git a/tests/tests.rs b/tests/tests.rs index 54aa372..4765a8f 100644 --- a/tests/tests.rs +++ b/tests/tests.rs @@ -79,7 +79,6 @@ async fn test_get_and_scan_range() { #[tokio::test] async fn test_get_and_scan_range_native() { initialize(); - rayon::spawn(|| { let num_parallel = rayon::current_num_threads(); tracing::info!("Native rayon has {} threads", num_parallel);