Skip to content

Commit

Permalink
Add remove subscriptions messages
Browse files Browse the repository at this point in the history
  • Loading branch information
Rigidity committed Nov 12, 2024
1 parent c9f0c19 commit 2cf81a1
Show file tree
Hide file tree
Showing 6 changed files with 124 additions and 16 deletions.
2 changes: 1 addition & 1 deletion Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 0 additions & 1 deletion Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -122,7 +122,6 @@ tungstenite = "0.21.0"
native-tls = "0.2.11"
rustls = "0.22.0"
rustls-pemfile = "2.1.3"
log = "0.4.21"
flate2 = "1.0.30"
once_cell = "1.19.0"
num-bigint = "0.4.6"
Expand Down
2 changes: 1 addition & 1 deletion crates/chia-sdk-test/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ futures-channel = { workspace = true, features = ["sink"] }
futures-util = { workspace = true }
indexmap = { workspace = true }
thiserror = { workspace = true }
log = { workspace = true }
tracing = { workspace = true }
itertools = { workspace = true }
rand = { workspace = true }
rand_chacha = { workspace = true }
Expand Down
8 changes: 4 additions & 4 deletions crates/chia-sdk-test/src/peer_simulator.rs
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ impl PeerSimulator {
}

pub async fn with_config(config: SimulatorConfig) -> Result<Self, PeerSimulatorError> {
log::info!("starting simulator");
tracing::info!("starting simulator");

let addr = "127.0.0.1:0";
let peer_map = PeerMap::default();
Expand All @@ -60,7 +60,7 @@ impl PeerSimulator {
let stream = match tokio_tungstenite::accept_async(stream).await {
Ok(stream) => stream,
Err(error) => {
log::error!("error accepting websocket connection: {}", error);
tracing::error!("error accepting websocket connection: {}", error);
continue;
}
};
Expand Down Expand Up @@ -89,7 +89,7 @@ impl PeerSimulator {
}

pub async fn connect_raw(&self) -> Result<(Peer, mpsc::Receiver<Message>), PeerSimulatorError> {
log::info!("connecting new peer to simulator");
tracing::info!("connecting new peer to simulator");
let (ws, _) = connect_async(format!("ws://{}", self.addr)).await?;
Ok(Peer::from_websocket(
ws,
Expand All @@ -115,7 +115,7 @@ impl PeerSimulator {

tokio::spawn(async move {
while let Some(message) = receiver.recv().await {
log::debug!("received message: {message:?}");
tracing::debug!("received message: {message:?}");
}
});

Expand Down
58 changes: 58 additions & 0 deletions crates/chia-sdk-test/src/peer_simulator/subscriptions.rs
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,64 @@ impl Subscriptions {
.extend(puzzle_hashes);
}

pub(crate) fn remove_coin_subscriptions(
&mut self,
peer: SocketAddr,
coin_ids: &[Bytes32],
) -> Vec<Bytes32> {
let mut removed = Vec::new();

if let Some(subscriptions) = self.coin_subscriptions.get_mut(&peer) {
for coin_id in coin_ids {
if subscriptions.swap_remove(coin_id) {
removed.push(*coin_id);
}
}
if subscriptions.is_empty() {
self.coin_subscriptions.swap_remove(&peer);
}
}

removed
}

pub(crate) fn remove_puzzle_subscriptions(
&mut self,
peer: SocketAddr,
puzzle_hashes: &[Bytes32],
) -> Vec<Bytes32> {
let mut removed = Vec::new();

if let Some(subscriptions) = self.puzzle_subscriptions.get_mut(&peer) {
for puzzle_hash in puzzle_hashes {
if subscriptions.swap_remove(puzzle_hash) {
removed.push(*puzzle_hash);
}
}
if subscriptions.is_empty() {
self.puzzle_subscriptions.swap_remove(&peer);
}
}

removed
}

pub(crate) fn remove_all_coin_subscriptions(&mut self, peer: SocketAddr) -> Vec<Bytes32> {
self.coin_subscriptions
.swap_remove(&peer)
.unwrap_or_default()
.into_iter()
.collect()
}

pub(crate) fn remove_all_puzzle_subscriptions(&mut self, peer: SocketAddr) -> Vec<Bytes32> {
self.puzzle_subscriptions
.swap_remove(&peer)
.unwrap_or_default()
.into_iter()
.collect()
}

pub(crate) fn subscription_count(&self, peer: SocketAddr) -> usize {
self.coin_subscriptions.get(&peer).map_or(0, IndexSet::len)
+ self
Expand Down
69 changes: 60 additions & 9 deletions crates/chia-sdk-test/src/peer_simulator/ws_connection.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,10 @@ use chia_protocol::{
Bytes, Bytes32, CoinState, CoinStateUpdate, Message, NewPeakWallet, ProtocolMessageTypes,
PuzzleSolutionResponse, RegisterForCoinUpdates, RegisterForPhUpdates, RejectCoinState,
RejectPuzzleSolution, RejectPuzzleState, RejectStateReason, RequestChildren, RequestCoinState,
RequestPuzzleSolution, RequestPuzzleState, RespondChildren, RespondCoinState,
RespondPuzzleSolution, RespondPuzzleState, RespondToCoinUpdates, RespondToPhUpdates,
SendTransaction, SpendBundle, TransactionAck,
RequestPuzzleSolution, RequestPuzzleState, RequestRemoveCoinSubscriptions,
RequestRemovePuzzleSubscriptions, RespondChildren, RespondCoinState, RespondPuzzleSolution,
RespondPuzzleState, RespondRemoveCoinSubscriptions, RespondRemovePuzzleSubscriptions,
RespondToCoinUpdates, RespondToPhUpdates, SendTransaction, SpendBundle, TransactionAck,
};
use chia_traits::Streamable;
use clvmr::NodePtr;
Expand Down Expand Up @@ -42,7 +43,7 @@ pub(crate) async fn ws_connection(
let (mut tx, mut rx) = mpsc::unbounded();

if let Err(error) = handle_initial_peak(&mut tx, &simulator).await {
log::error!("error sending initial peak: {}", error);
tracing::error!("error sending initial peak: {}", error);
return;
}

Expand All @@ -53,7 +54,7 @@ pub(crate) async fn ws_connection(
tokio::spawn(async move {
while let Some(message) = rx.next().await {
if let Err(error) = sink.send(message).await {
log::error!("error sending message to peer: {}", error);
tracing::error!("error sending message to peer: {}", error);
continue;
}
}
Expand All @@ -63,7 +64,7 @@ pub(crate) async fn ws_connection(
let message = match message {
Ok(message) => message,
Err(error) => {
log::info!("received error from stream: {:?}", error);
tracing::info!("received error from stream: {:?}", error);
break;
}
};
Expand All @@ -79,7 +80,7 @@ pub(crate) async fn ws_connection(
)
.await
{
log::error!("error handling message: {}", error);
tracing::error!("error handling message: {}", error);
break;
}
}
Expand Down Expand Up @@ -166,6 +167,24 @@ async fn handle_message(
let response = request_puzzle_state(addr, request, config, &simulator, subscriptions)?;
(ProtocolMessageTypes::RespondPuzzleState, response)
}
ProtocolMessageTypes::RequestRemoveCoinSubscriptions => {
let request = RequestRemoveCoinSubscriptions::from_bytes(&request.data)?;
let mut subscriptions = subscriptions.lock().await;
let response = request_remove_coin_subscriptions(addr, request, &mut subscriptions)?;
(
ProtocolMessageTypes::RespondRemoveCoinSubscriptions,
response,
)
}
ProtocolMessageTypes::RequestRemovePuzzleSubscriptions => {
let request = RequestRemovePuzzleSubscriptions::from_bytes(&request.data)?;
let mut subscriptions = subscriptions.lock().await;
let response = request_remove_puzzle_subscriptions(addr, request, &mut subscriptions)?;
(
ProtocolMessageTypes::RespondRemovePuzzleSubscriptions,
response,
)
}
message_type => {
return Err(PeerSimulatorError::UnsupportedMessage(message_type));
}
Expand Down Expand Up @@ -250,7 +269,7 @@ async fn send_transaction(
let updates = match new_transaction(&mut simulator, &mut subscriptions, request.transaction) {
Ok(updates) => updates,
Err(error) => {
log::error!("error processing transaction: {:?}", &error);
tracing::error!("error processing transaction: {:?}", &error);

let error_code = match error {
PeerSimulatorError::Simulator(SimulatorError::Validation(error_code)) => error_code,
Expand Down Expand Up @@ -281,7 +300,7 @@ async fn send_transaction(

// Send updates to peers.
for (addr, mut peer) in peer_map.peers().await {
peer.send(new_peak.clone().into()).await.unwrap();
peer.send(new_peak.clone().into()).await?;

let Some(peer_updates) = updates.get(&addr).cloned() else {
continue;
Expand Down Expand Up @@ -558,3 +577,35 @@ fn request_puzzle_state(
.to_bytes()?
.into())
}

fn request_remove_coin_subscriptions(
peer: SocketAddr,
request: RequestRemoveCoinSubscriptions,
subscriptions: &mut MutexGuard<'_, Subscriptions>,
) -> Result<Bytes, PeerSimulatorError> {
let coin_ids = if let Some(coin_ids) = request.coin_ids {
subscriptions.remove_coin_subscriptions(peer, &coin_ids)
} else {
subscriptions.remove_all_coin_subscriptions(peer)
};

Ok(RespondRemoveCoinSubscriptions { coin_ids }
.to_bytes()?
.into())
}

fn request_remove_puzzle_subscriptions(
peer: SocketAddr,
request: RequestRemovePuzzleSubscriptions,
subscriptions: &mut MutexGuard<'_, Subscriptions>,
) -> Result<Bytes, PeerSimulatorError> {
let puzzle_hashes = if let Some(puzzle_hashes) = request.puzzle_hashes {
subscriptions.remove_puzzle_subscriptions(peer, &puzzle_hashes)
} else {
subscriptions.remove_all_puzzle_subscriptions(peer)
};

Ok(RespondRemovePuzzleSubscriptions { puzzle_hashes }
.to_bytes()?
.into())
}

0 comments on commit 2cf81a1

Please sign in to comment.