diff --git a/coins/monero/src/rpc/http.rs b/coins/monero/src/rpc/http.rs index 256efc57f..270886380 100644 --- a/coins/monero/src/rpc/http.rs +++ b/coins/monero/src/rpc/http.rs @@ -117,12 +117,50 @@ impl HttpRpc { .map_err(|e| RpcError::ConnectionError(format!("couldn't make request: {e:?}"))) }; + async fn body_from_response(response: Response<'_>) -> Result, RpcError> { + /* + let length = usize::try_from( + response + .headers() + .get("content-length") + .ok_or(RpcError::InvalidNode("no content-length header"))? + .to_str() + .map_err(|_| RpcError::InvalidNode("non-ascii content-length value"))? + .parse::() + .map_err(|_| RpcError::InvalidNode("non-u32 content-length value"))?, + ) + .unwrap(); + // Only pre-allocate 1 MB so a malicious node which claims a content-length of 1 GB actually + // has to send 1 GB of data to cause a 1 GB allocation + let mut res = Vec::with_capacity(length.max(1024 * 1024)); + let mut body = response.into_body(); + while res.len() < length { + let Some(data) = body.data().await else { break }; + res.extend(data.map_err(|e| RpcError::ConnectionError(format!("{e:?}")))?.as_ref()); + } + */ + + let mut res = Vec::with_capacity(128); + response + .body() + .await + .map_err(|e| RpcError::ConnectionError(format!("{e:?}")))? + .read_to_end(&mut res) + .unwrap(); + Ok(res) + } + for attempt in 0 .. 2 { - let response = match &self.authentication { - Authentication::Unauthenticated(client) => client - .request(request_fn(self.url.clone() + "/" + route)?) - .await - .map_err(|e| RpcError::ConnectionError(format!("{e:?}")))?, + return Ok(match &self.authentication { + Authentication::Unauthenticated(client) => { + body_from_response( + client + .request(request_fn(self.url.clone() + "/" + route)?) + .await + .map_err(|e| RpcError::ConnectionError(format!("{e:?}")))?, + ) + .await? + } Authentication::Authenticated { username, password, connection } => { let mut connection_lock = connection.lock().await; @@ -168,26 +206,16 @@ impl HttpRpc { ); } - let response_result = connection_lock + let response = connection_lock .1 .request(request) .await .map_err(|e| RpcError::ConnectionError(format!("{e:?}"))); - // If the connection entered an error state, drop the cached challenge as challenges are - // per-connection - // We don't need to create a new connection as simple-request will for us - if response_result.is_err() { - connection_lock.0 = None; - } - - // If we're not already on our second attempt and: - // A) We had a connection error - // B) We need to re-auth due to this token being stale - // Move to the next loop iteration (retrying all of this) - if (attempt == 0) && - (response_result.is_err() || { - let response = response_result.as_ref().unwrap(); + let (error, is_stale) = match &response { + Err(e) => (Some(e.clone()), false), + Ok(response) => ( + None, if response.status() == StatusCode::UNAUTHORIZED { if let Some(header) = response.headers().get("www-authenticate") { header @@ -201,49 +229,33 @@ impl HttpRpc { } } else { false - } - }) - { - // Drop the cached authentication before we do + }, + ), + }; + + // If the connection entered an error state, drop the cached challenge as challenges are + // per-connection + // We don't need to create a new connection as simple-request will for us + if error.is_some() || is_stale { connection_lock.0 = None; - continue; + // If we're not already on our second attempt, move to the next loop iteration + // (retrying all of this once) + if attempt == 0 { + continue; + } + if let Some(e) = error { + Err(e)? + } else { + debug_assert!(is_stale); + Err(RpcError::InvalidNode( + "node claimed fresh connection had stale authentication".to_string(), + ))? + } + } else { + body_from_response(response.unwrap()).await? } - - response_result? } - }; - - /* - let length = usize::try_from( - response - .headers() - .get("content-length") - .ok_or(RpcError::InvalidNode("no content-length header"))? - .to_str() - .map_err(|_| RpcError::InvalidNode("non-ascii content-length value"))? - .parse::() - .map_err(|_| RpcError::InvalidNode("non-u32 content-length value"))?, - ) - .unwrap(); - // Only pre-allocate 1 MB so a malicious node which claims a content-length of 1 GB actually - // has to send 1 GB of data to cause a 1 GB allocation - let mut res = Vec::with_capacity(length.max(1024 * 1024)); - let mut body = response.into_body(); - while res.len() < length { - let Some(data) = body.data().await else { break }; - res.extend(data.map_err(|e| RpcError::ConnectionError(format!("{e:?}")))?.as_ref()); - } - */ - - let mut res = Vec::with_capacity(128); - response - .body() - .await - .map_err(|e| RpcError::ConnectionError(format!("{e:?}")))? - .read_to_end(&mut res) - .unwrap(); - - return Ok(res); + }); } unreachable!() diff --git a/common/request/src/lib.rs b/common/request/src/lib.rs index 1764ece27..4c738e2ef 100644 --- a/common/request/src/lib.rs +++ b/common/request/src/lib.rs @@ -79,7 +79,7 @@ impl Client { }) } - pub async fn request>(&self, request: R) -> Result { + pub async fn request>(&self, request: R) -> Result, Error> { let request: Request = request.into(); let mut request = request.0; if let Some(header_host) = request.headers().get(hyper::header::HOST) { @@ -111,7 +111,7 @@ impl Client { .insert(hyper::header::HOST, HeaderValue::from_str(&host).map_err(|_| Error::InvalidUri)?); } - Ok(Response(match &self.connection { + let response = match &self.connection { Connection::ConnectionPool(client) => client.request(request).await.map_err(Error::Hyper)?, Connection::Connection { connector, host, connection } => { let mut connection_lock = connection.lock().await; @@ -125,8 +125,8 @@ impl Client { let call_res = call_res.map_err(Error::ConnectionError); let (requester, connection) = hyper::client::conn::http1::handshake(call_res?).await.map_err(Error::Hyper)?; - // This will die when we drop the requester, so we don't need to track an AbortHandle for - // it + // This will die when we drop the requester, so we don't need to track an AbortHandle + // for it tokio::spawn(connection); *connection_lock = Some(requester); } @@ -137,7 +137,7 @@ impl Client { // Send the request let res = connection.send_request(request).await; if let Ok(res) = res { - return Ok(Response(res)); + return Ok(Response(res, self)); } err = res.err(); } @@ -145,6 +145,8 @@ impl Client { *connection_lock = None; Err(Error::Hyper(err.unwrap()))? } - })) + }; + + Ok(Response(response, self)) } } diff --git a/common/request/src/response.rs b/common/request/src/response.rs index 4611324a4..04c8472b8 100644 --- a/common/request/src/response.rs +++ b/common/request/src/response.rs @@ -4,11 +4,12 @@ use hyper::{ body::{Buf, Body}, }; -use crate::Error; +use crate::{Client, Error}; +// Borrows the client so its async task lives as long as this response exists. #[derive(Debug)] -pub struct Response(pub(crate) hyper::Response); -impl Response { +pub struct Response<'a>(pub(crate) hyper::Response, pub(crate) &'a Client); +impl<'a> Response<'a> { pub fn status(&self) -> StatusCode { self.0.status() }