diff --git a/Cargo.lock b/Cargo.lock index e611a9b..7db0d7d 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -728,9 +728,9 @@ dependencies = [ [[package]] name = "async-trait" -version = "0.1.77" +version = "0.1.80" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c980ee35e870bd1a4d2c8294d4c04d0499e67bca1e4b5cefcc693c2fa00caea9" +checksum = "c6fa2087f2753a7da8cc1c0dbfcf89579dd57458e36769de5ac750b4671737ca" dependencies = [ "proc-macro2", "quote", @@ -928,6 +928,18 @@ version = "1.6.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "8c3c1a368f70d6cf7302d78f8f7093da241fb8e8807c05cc9e51a125895a6d5b" +[[package]] +name = "bb8" +version = "0.8.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b10cf871f3ff2ce56432fddc2615ac7acc3aa22ca321f8fea800846fbb32f188" +dependencies = [ + "async-trait", + "futures-util", + "parking_lot", + "tokio", +] + [[package]] name = "bevy" version = "0.13.2" @@ -6667,7 +6679,9 @@ checksum = "216080ab382b992234dda86873c18d4c48358f5cfcb70fd693d7f6f2131b628b" name = "replicate-client" version = "0.0.0" dependencies = [ + "async-trait", "base64 0.21.7", + "bb8", "bytes", "clap", "color-eyre", diff --git a/crates/replicate/client/Cargo.toml b/crates/replicate/client/Cargo.toml index 6e6595d..871f504 100644 --- a/crates/replicate/client/Cargo.toml +++ b/crates/replicate/client/Cargo.toml @@ -9,7 +9,9 @@ description = "A client api for state replication" publish = false [dependencies] +async-trait = "0.1.80" base64.workspace = true +bb8 = "0.8.5" bytes.workspace = true eyre.workspace = true futures.workspace = true diff --git a/crates/replicate/client/src/manager.rs b/crates/replicate/client/src/manager.rs index 475f5f3..7cedc9d 100644 --- a/crates/replicate/client/src/manager.rs +++ b/crates/replicate/client/src/manager.rs @@ -2,23 +2,20 @@ use std::fmt::Debug; -use eyre::Result; -use eyre::{bail, ensure, Context, OptionExt}; +use async_trait::async_trait; +use eyre::{bail, ensure, eyre, Context, OptionExt}; +use eyre::{ContextCompat, Result}; use futures::sink::SinkExt; use futures::stream::StreamExt; use replicate_common::{ messages::manager::{Clientbound as Cb, Serverbound as Sb}, InstanceId, }; -use tokio::sync::{mpsc, oneshot}; use url::Url; use crate::connect_to_url; use crate::Ascii; -/// The number of queued rpc calls allowed before we start erroring. -const RPC_CAPACITY: usize = 64; - type Framed = replicate_common::Framed; /// Manages instances on the instance server. Under the hood, this is all done @@ -29,10 +26,78 @@ type Framed = replicate_common::Framed; /// user IDs. #[derive(Debug)] pub struct Manager { - _conn: wtransport::Connection, + pool: bb8::Pool, url: Url, - task: tokio::task::JoinHandle>, - request_tx: mpsc::Sender<(Sb, oneshot::Sender)>, +} + +#[derive(Debug)] +struct StreamPoolManager { + conn: wtransport::Connection, +} + +impl StreamPoolManager { + fn new(conn: wtransport::Connection) -> Self { + Self { conn } + } +} + +// bb8 returns connections to the pool even if the drop is due to a panic. +// To avoid that, we drop the inner connection if the thread is panicking. +struct DropConnectionOnPanic<'a> { + pooled_connection: bb8::PooledConnection<'a, StreamPoolManager>, +} + +impl<'a> Drop for DropConnectionOnPanic<'a> { + fn drop(&mut self) { + if std::thread::panicking() { + (*self.pooled_connection).take(); + } + } +} + +#[async_trait] +impl bb8::ManageConnection for StreamPoolManager { + /// The connection type this manager deals with. + type Connection = Option; + /// The error type returned by `Connection`s. + type Error = eyre::Report; + /// Attempts to create a new connection. + async fn connect(&self) -> Result { + let bi = wtransport::stream::BiStream::join( + self.conn + .open_bi() + .await + .wrap_err("could not initiate bi stream")? + .await + .wrap_err("could not finish opening bi stream")?, + ); + + let framed = Framed::new(bi); + Ok(Some(framed)) + } + /// Determines if the connection is still connected to the database. + async fn is_valid(&self, framed: &mut Self::Connection) -> Result<(), Self::Error> { + let framed = framed + .as_mut() + .wrap_err("connection was dropped due to panic")?; + framed + .send(Sb::HandshakeRequest) + .await + .wrap_err("failed to send handshake request")?; + let Some(msg) = framed.next().await else { + bail!("Server disconnected before completing handshake"); + }; + let msg = msg.wrap_err("error while receiving handshake response")?; + ensure!( + msg == Cb::HandshakeResponse, + "invalid message during handshake" + ); + Ok(()) + } + /// Synchronously determine if the connection is no longer usable, if possible. + fn has_broken(&self, framed: &mut Self::Connection) -> bool { + framed.is_none() + } } impl Manager { @@ -49,41 +114,11 @@ impl Manager { let conn = connect_to_url(&url, bearer_token) .await .wrap_err("failed to connect to server")?; - let bi = wtransport::stream::BiStream::join( - conn.open_bi() - .await - .wrap_err("could not initiate bi stream")? - .await - .wrap_err("could not finish opening bi stream")?, - ); - - let mut framed = Framed::new(bi); - // Do handshake before anything else - { - framed - .send(Sb::HandshakeRequest) - .await - .wrap_err("failed to send handshake request")?; - let Some(msg) = framed.next().await else { - bail!("Server disconnected before completing handshake"); - }; - let msg = msg.wrap_err("error while receiving handshake response")?; - ensure!( - msg == Cb::HandshakeResponse, - "invalid message during handshake" - ); - } + let manager = StreamPoolManager::new(conn); + let pool = bb8::Pool::builder().build(manager).await.unwrap(); - let (request_tx, request_rx) = mpsc::channel(RPC_CAPACITY); - let task = tokio::spawn(manager_task(framed, request_rx)); - - Ok(Self { - _conn: conn, - url, - task, - request_tx, - }) + Ok(Self { pool, url }) } pub async fn instance_create(&self) -> Result { @@ -102,37 +137,22 @@ impl Manager { Ok(url) } - /// Panics if the connection is already dead - async fn request(&self, request: Sb) -> Result { - let (response_tx, response_rx) = oneshot::channel(); - self.request_tx - .send((request, response_tx)) - .await - .wrap_err("failed to send to manager task")?; - response_rx - .await - .wrap_err("failed to receive from manager task") - } - - /// Destroys the manager and reaps any errors from its networking task - pub async fn join(self) -> Result<()> { - self.task - .await - .wrap_err("panic in manager task, file a bug report on github uwu")? - .wrap_err("error in task") - } - - /// The url of this Manager. - pub fn url(&self) -> &Url { - &self.url + async fn get_framed(&self) -> Result> { + let pooled_connection = self.pool.get().await.map_err(|e| match e { + bb8::RunError::User(eyre) => { + eyre.wrap_err("get from connection pool failed") + } + bb8::RunError::TimedOut => eyre!("connection pool fetch timed out"), + })?; + Ok(DropConnectionOnPanic { pooled_connection }) } -} -async fn manager_task( - mut framed: Framed, - mut request_rx: mpsc::Receiver<(Sb, oneshot::Sender)>, -) -> Result<()> { - while let Some((request, response_tx)) = request_rx.recv().await { + async fn request(&self, request: Sb) -> Result { + let mut wrapper = self.get_framed().await?; + let framed = wrapper + .pooled_connection + .as_mut() + .expect("only emptied in Drop impl"); framed .send(request) .await @@ -142,8 +162,11 @@ async fn manager_task( .await .ok_or_eyre("expected a response from the server")? .wrap_err("error while receiving response")?; - let _ = response_tx.send(response); + Ok(response) + } + + /// The url of this Manager. + pub fn url(&self) -> &Url { + &self.url } - // We only return ok when the manager struct was dropped - Ok(()) }