diff --git a/node/src/exchange.rs b/node/src/exchange.rs index 27ebacf3..68f417ef 100644 --- a/node/src/exchange.rs +++ b/node/src/exchange.rs @@ -36,9 +36,11 @@ const RESPONSE_SIZE_MAXIMUM: usize = 10 * 1024 * 1024; /// Maximum length of the protobuf length delimiter in bytes const PROTOBUF_MAX_LENGTH_DELIMITER_LEN: usize = 10; +type RequestType = HeaderRequest; +type ResponseType = Vec; type ReqRespBehaviour = request_response::Behaviour; -type ReqRespEvent = request_response::Event>; -type ReqRespMessage = request_response::Message>; +type ReqRespEvent = request_response::Event; +type ReqRespMessage = request_response::Message; pub(crate) struct ExchangeBehaviour where @@ -251,13 +253,20 @@ where cx: &mut Context<'_>, params: &mut impl PollParameters, ) -> Poll>> { - while let Poll::Ready(ev) = self.req_resp.poll(cx, params) { - if let Some(ev) = self.on_to_swarm(ev) { - return Poll::Ready(ev); + loop { + if let Poll::Ready(ev) = self.req_resp.poll(cx, params) { + if let Some(ev) = self.on_to_swarm(ev) { + return Poll::Ready(ev); + } + + continue; + } + if self.server_handler.poll(cx, &mut self.req_resp).is_ready() { + continue; } - } - Poll::Pending + return Poll::Pending; + } } } diff --git a/node/src/exchange/client.rs b/node/src/exchange/client.rs index 6f11f6d5..a56d806f 100644 --- a/node/src/exchange/client.rs +++ b/node/src/exchange/client.rs @@ -22,7 +22,7 @@ use crate::utils::{OneshotResultSender, OneshotResultSenderExt}; pub(super) struct ExchangeClientHandler where - S: Sender, + S: RequestSender, { reqs: HashMap, peer_tracker: Arc, @@ -33,13 +33,13 @@ struct State { respond_to: OneshotResultSender, P2pError>, } -pub(super) trait Sender { +pub(super) trait RequestSender { type RequestId: Hash + Eq + Debug; fn send_request(&mut self, peer: &PeerId, request: HeaderRequest) -> Self::RequestId; } -impl Sender for ReqRespBehaviour { +impl RequestSender for ReqRespBehaviour { type RequestId = RequestId; fn send_request(&mut self, peer: &PeerId, request: HeaderRequest) -> RequestId { @@ -49,7 +49,7 @@ impl Sender for ReqRespBehaviour { impl ExchangeClientHandler where - S: Sender, + S: RequestSender, { pub(super) fn new(peer_tracker: Arc) -> Self { ExchangeClientHandler { @@ -1004,7 +1004,7 @@ mod tests { peer: PeerId, } - impl Sender for MockReq { + impl RequestSender for MockReq { type RequestId = MockReqId; fn send_request(&mut self, peer: &PeerId, _request: HeaderRequest) -> Self::RequestId { diff --git a/node/src/exchange/server.rs b/node/src/exchange/server.rs index 8ee24424..c6c90146 100644 --- a/node/src/exchange/server.rs +++ b/node/src/exchange/server.rs @@ -1,52 +1,394 @@ +use std::fmt::{Debug, Display}; use std::sync::Arc; +use std::task::{Context, Poll}; -use celestia_proto::p2p::pb::{HeaderRequest, HeaderResponse}; +use celestia_proto::p2p::pb::{header_request, HeaderRequest, HeaderResponse}; +use celestia_types::hash::Hash; use libp2p::{ request_response::{InboundFailure, RequestId, ResponseChannel}, PeerId, }; -use tracing::instrument; +use tokio::sync::mpsc::{self, error::TrySendError}; +use tracing::{instrument, trace}; +use crate::exchange::utils::{ExtendedHeaderExt, HeaderRequestExt, HeaderResponseExt}; +use crate::exchange::{ReqRespBehaviour, ResponseType}; +use crate::executor::spawn; use crate::store::Store; -pub(super) struct ExchangeServerHandler +const MAX_HEADERS_AMOUNT_RESPONSE: u64 = 512; + +pub(super) struct ExchangeServerHandler where - S: Store + 'static, + S: Store, + R: ResponseSender, { - _store: Arc, + store: Arc, + + rx: mpsc::Receiver<(R::Channel, ResponseType)>, + tx: mpsc::Sender<(R::Channel, ResponseType)>, +} + +pub(super) trait ResponseSender { + type Channel: Send + 'static; + + fn send_response(&mut self, channel: Self::Channel, response: ResponseType); } -impl ExchangeServerHandler +impl ResponseSender for ReqRespBehaviour { + type Channel = ResponseChannel; + + fn send_response(&mut self, channel: Self::Channel, response: ResponseType) { + // response was prepared specifically for the request, we can drop it + // in case of error we'll get Event::InboundFailure + let _ = self.send_response(channel, response); + } +} + +impl ExchangeServerHandler where S: Store + 'static, + R: ResponseSender, { pub(super) fn new(store: Arc) -> Self { - ExchangeServerHandler { _store: store } + let (tx, rx) = mpsc::channel(32); + ExchangeServerHandler { store, rx, tx } } - #[instrument(level = "trace", skip(self, _respond_to))] - pub(super) fn on_request_received( + #[instrument(level = "trace", skip(self, response_channel))] + pub(super) fn on_request_received( &mut self, peer: PeerId, - request_id: RequestId, + request_id: Id, request: HeaderRequest, - _respond_to: ResponseChannel>, - ) { - // TODO + response_channel: R::Channel, + ) where + Id: Display + Debug, + { + let Some((amount, data)) = parse_request(request) else { + self.handle_invalid_request(response_channel); + return; + }; + + match data { + header_request::Data::Origin(0) => { + self.handle_request_current_head(response_channel); + } + header_request::Data::Origin(height) => { + self.handle_request_by_height(response_channel, height, amount); + } + header_request::Data::Hash(hash) => { + self.handle_request_by_hash(response_channel, hash); + } + }; } - #[instrument(level = "trace", skip(self))] pub(super) fn on_response_sent(&mut self, peer: PeerId, request_id: RequestId) { - // TODO + trace!("response_sent; request_id: {request_id}, peer: {peer}"); } - #[instrument(level = "trace", skip(self))] pub(super) fn on_failure( &mut self, peer: PeerId, request_id: RequestId, error: InboundFailure, ) { - // TODO + // TODO: cancel job if libp2p already failed it? + trace!("on_failure; request_id: {request_id}, peer: {peer}, error: {error:?}"); + } + + pub fn poll(&mut self, cx: &mut Context<'_>, sender: &mut R) -> Poll<()> { + loop { + if let Poll::Ready(Some((channel, response))) = self.rx.poll_recv(cx) { + sender.send_response(channel, response); + continue; + } + + return Poll::Pending; + } + } + + fn handle_request_current_head(&mut self, channel: R::Channel) { + let store = self.store.clone(); + let tx = self.tx.clone(); + + spawn(async move { + let response = store + .get_head() + .await + .map(|head| head.to_header_response()) + .unwrap_or_else(|_| HeaderResponse::not_found()); + + let _ = tx.send((channel, vec![response])).await; + }); + } + + fn handle_request_by_hash(&mut self, channel: R::Channel, hash: Vec) { + let Ok(hash) = hash.try_into().map(Hash::Sha256) else { + self.handle_invalid_request(channel); + return; + }; + + let store = self.store.clone(); + let tx = self.tx.clone(); + + spawn(async move { + let response = store + .get_by_hash(&hash) + .await + .map(|head| head.to_header_response()) + .unwrap_or_else(|_| HeaderResponse::not_found()); + + let _ = tx.send((channel, vec![response])).await; + }); + } + + fn handle_request_by_height(&mut self, channel: R::Channel, origin: u64, amount: u64) { + let store = self.store.clone(); + let tx = self.tx.clone(); + + spawn(async move { + let amount = amount.min(MAX_HEADERS_AMOUNT_RESPONSE); + let mut responses = vec![]; + + for i in origin..origin + amount { + match store.get_by_height(i).await { + Ok(h) => { + if responses.is_empty() { + responses.reserve_exact(amount as usize); + } + + responses.push(h.to_header_response()); + } + Err(_) => break, + } + } + + if responses.is_empty() { + responses.reserve_exact(1); + responses.push(HeaderResponse::not_found()); + } + + let _ = tx.send((channel, responses)).await; + }); + } + + fn handle_invalid_request(&self, channel: R::Channel) { + if let Err(TrySendError::Full(response)) = + self.tx.try_send((channel, vec![HeaderResponse::invalid()])) + { + let tx = self.tx.clone(); + + spawn(async move { + let _ = tx.send(response).await; + }); + } + } +} + +fn parse_request(request: HeaderRequest) -> Option<(u64, header_request::Data)> { + if !request.is_valid() { + return None; + } + + let HeaderRequest { + amount, + data: Some(data), + } = request + else { + return None; + }; + + Some((amount, data)) +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::exchange::utils::HeaderRequestExt; + use crate::store::tests::gen_filled_store; + use crate::store::InMemoryStore; + use celestia_proto::p2p::pb::header_request::Data; + use celestia_proto::p2p::pb::{HeaderRequest, StatusCode}; + use celestia_types::ExtendedHeader; + use libp2p::PeerId; + use std::future::poll_fn; + use std::sync::Arc; + use tendermint_proto::Protobuf; + use tokio::select; + use tokio::sync::oneshot; + + #[tokio::test] + async fn request_head_test() { + let (store, _) = gen_filled_store(4); + let expected_head = store.get_head().unwrap(); + let mut handler = ExchangeServerHandler::new(Arc::new(store)); + + handler.on_request_received(PeerId::random(), "test", HeaderRequest::head_request(), ()); + + let received = poll_handler_for_result(&mut handler).await; + + assert_eq!(received.len(), 1); + assert_eq!(received[0].status_code, i32::from(StatusCode::Ok)); + let decoded_header = ExtendedHeader::decode(&received[0].body[..]).unwrap(); + assert_eq!(decoded_header, expected_head); + } + + #[tokio::test] + async fn request_header_test() { + let (store, _) = gen_filled_store(3); + let expected_genesis = store.get_by_height(1).unwrap(); + let mut handler = ExchangeServerHandler::new(Arc::new(store)); + + handler.on_request_received( + PeerId::random(), + "test", + HeaderRequest::with_origin(1, 1), + (), + ); + + let received = poll_handler_for_result(&mut handler).await; + + assert_eq!(received.len(), 1); + assert_eq!(received[0].status_code, i32::from(StatusCode::Ok)); + let decoded_header = ExtendedHeader::decode(&received[0].body[..]).unwrap(); + assert_eq!(decoded_header, expected_genesis); + } + + #[tokio::test] + async fn invalid_amount_request_test() { + let (store, _) = gen_filled_store(1); + let mut handler = ExchangeServerHandler::new(Arc::new(store)); + + handler.on_request_received( + PeerId::random(), + "test", + HeaderRequest::with_origin(0, 0), + (), + ); + + let received = poll_handler_for_result(&mut handler).await; + + assert_eq!(received.len(), 1); + assert_eq!(received[0].status_code, i32::from(StatusCode::Invalid)); + } + + #[tokio::test] + async fn none_data_request_test() { + let (store, _) = gen_filled_store(1); + let mut handler = ExchangeServerHandler::new(Arc::new(store)); + + let request = HeaderRequest { + data: None, + amount: 1, + }; + handler.on_request_received(PeerId::random(), "test", request, ()); + + let received = poll_handler_for_result(&mut handler).await; + + assert_eq!(received.len(), 1); + assert_eq!(received[0].status_code, i32::from(StatusCode::Invalid)); + } + + #[tokio::test] + async fn request_hash_test() { + let (store, _) = gen_filled_store(1); + let stored_header = store.get_head().unwrap(); + let mut handler = ExchangeServerHandler::new(Arc::new(store)); + + handler.on_request_received( + PeerId::random(), + "test", + HeaderRequest::with_hash(stored_header.hash()), + (), + ); + + let received = poll_handler_for_result(&mut handler).await; + + assert_eq!(received.len(), 1); + assert_eq!(received[0].status_code, i32::from(StatusCode::Ok)); + let decoded_header = ExtendedHeader::decode(&received[0].body[..]).unwrap(); + assert_eq!(decoded_header, stored_header); + } + + #[tokio::test] + async fn request_range_test() { + let (store, _) = gen_filled_store(10); + let expected_headers = [ + store.get_by_height(5).unwrap(), + store.get_by_height(6).unwrap(), + store.get_by_height(7).unwrap(), + ]; + let mut handler = ExchangeServerHandler::new(Arc::new(store)); + + let request = HeaderRequest { + data: Some(Data::Origin(5)), + amount: u64::try_from(expected_headers.len()).unwrap(), + }; + handler.on_request_received(PeerId::random(), "test", request, ()); + + let received = poll_handler_for_result(&mut handler).await; + + for (rec, exp) in received.iter().zip(expected_headers.iter()) { + assert_eq!(rec.status_code, i32::from(StatusCode::Ok)); + let decoded_header = ExtendedHeader::decode(&rec.body[..]).unwrap(); + assert_eq!(&decoded_header, exp); + } + } + + #[tokio::test] + async fn request_range_beyond_head_test() { + let (store, _) = gen_filled_store(5); + let expected_hashes = [store.get_by_height(5).ok()]; + let expected_status_codes = [StatusCode::Ok]; + assert_eq!(expected_hashes.len(), expected_status_codes.len()); + + let mut handler = ExchangeServerHandler::new(Arc::new(store)); + + let request = HeaderRequest::with_origin(5, 10); + handler.on_request_received(PeerId::random(), "test", request, ()); + + let received = poll_handler_for_result(&mut handler).await; + + assert_eq!(received.len(), expected_hashes.len()); + for (rec, (exp_status, exp_header)) in received + .iter() + .zip(expected_status_codes.iter().zip(expected_hashes.iter())) + { + assert_eq!(rec.status_code, i32::from(*exp_status)); + if let Some(exp_header) = exp_header { + let decoded_header = ExtendedHeader::decode(&rec.body[..]).unwrap(); + assert_eq!(&decoded_header, exp_header); + } + } + } + + #[derive(Debug)] + struct TestResponseSender(pub Option>); + + impl ResponseSender for TestResponseSender { + type Channel = (); + + fn send_response(&mut self, _channel: Self::Channel, response: ResponseType) { + if let Some(sender) = self.0.take() { + let _ = sender.send(response); + } + } + } + + // helper which waits for result over the test channel, while continously polling the handler + // needed because `ExchangeServerHandler::poll` never returns `Ready` + async fn poll_handler_for_result( + handler: &mut ExchangeServerHandler, + ) -> Vec { + let (tx, receiver) = oneshot::channel(); + let mut sender = TestResponseSender(Some(tx)); + + let result = select! { + _ = poll_fn(move |cx| handler.poll(cx, &mut sender)) => panic!("shouldn't return"), + r = receiver => { r.unwrap() } + }; + + result } } diff --git a/node/src/exchange/utils.rs b/node/src/exchange/utils.rs index 05a35c3f..d13e80ac 100644 --- a/node/src/exchange/utils.rs +++ b/node/src/exchange/utils.rs @@ -10,6 +10,7 @@ use crate::exchange::ExchangeError; pub(super) trait HeaderRequestExt { fn with_origin(origin: u64, amount: u64) -> HeaderRequest; fn with_hash(hash: Hash) -> HeaderRequest; + fn head_request() -> HeaderRequest; fn is_valid(&self) -> bool; fn is_head_request(&self) -> bool; } @@ -29,6 +30,10 @@ impl HeaderRequestExt for HeaderRequest { } } + fn head_request() -> HeaderRequest { + HeaderRequest::with_origin(0, 1) + } + fn is_valid(&self) -> bool { match (&self.data, self.amount) { (None, _) | (_, 0) => false, @@ -45,6 +50,10 @@ impl HeaderRequestExt for HeaderRequest { pub(super) trait HeaderResponseExt { fn to_extended_header(&self) -> Result; + + fn not_found() -> HeaderResponse; + + fn invalid() -> HeaderResponse; } impl HeaderResponseExt for HeaderResponse { @@ -56,6 +65,20 @@ impl HeaderResponseExt for HeaderResponse { .map_err(|_| ExchangeError::InvalidResponse), } } + + fn not_found() -> HeaderResponse { + HeaderResponse { + status_code: StatusCode::NotFound.into(), + body: vec![], + } + } + + fn invalid() -> HeaderResponse { + HeaderResponse { + status_code: StatusCode::Invalid.into(), + body: vec![], + } + } } pub(super) trait ExtendedHeaderExt { diff --git a/node/src/store.rs b/node/src/store.rs index 7f6b1353..c8654434 100644 --- a/node/src/store.rs +++ b/node/src/store.rs @@ -391,7 +391,7 @@ pub mod tests { )); } - fn gen_filled_store(amount: u64) -> (InMemoryStore, ExtendedHeaderGenerator) { + pub fn gen_filled_store(amount: u64) -> (InMemoryStore, ExtendedHeaderGenerator) { let s = InMemoryStore::new(); let mut gen = ExtendedHeaderGenerator::new();