From bcb652b144552d57386ce5cb9c8c95acffce1f28 Mon Sep 17 00:00:00 2001 From: Dirkjan Ochtman Date: Fri, 26 Jan 2024 12:42:27 +0100 Subject: [PATCH] Use Notify to coordinate waiters --- bb8/Cargo.toml | 2 +- bb8/src/inner.rs | 26 ++++---------- bb8/src/internals.rs | 83 +++++++++----------------------------------- bb8/tests/test.rs | 2 +- 4 files changed, 25 insertions(+), 88 deletions(-) diff --git a/bb8/Cargo.toml b/bb8/Cargo.toml index 34b1964..2ad049b 100644 --- a/bb8/Cargo.toml +++ b/bb8/Cargo.toml @@ -14,7 +14,7 @@ async-trait = "0.1" futures-channel = "0.3.2" futures-util = { version = "0.3.2", default-features = false, features = ["channel"] } parking_lot = { version = "0.12", optional = true } -tokio = { version = "1.0", features = ["rt", "time"] } +tokio = { version = "1.0", features = ["rt", "sync", "time"] } [dev-dependencies] tokio = { version = "1.0", features = ["macros"] } diff --git a/bb8/src/inner.rs b/bb8/src/inner.rs index ad93cf4..5842324 100644 --- a/bb8/src/inner.rs +++ b/bb8/src/inner.rs @@ -4,7 +4,6 @@ use std::future::Future; use std::sync::{Arc, Weak}; use std::time::{Duration, Instant}; -use futures_channel::oneshot; use futures_util::stream::{FuturesUnordered, StreamExt}; use futures_util::TryFutureExt; use tokio::spawn; @@ -111,12 +110,14 @@ where make_pooled_conn: impl Fn(&'a Self, Conn) -> PooledConnection<'b, M>, ) -> Result, RunError> { loop { - let mut conn = match self.inner.pop() { - Some((conn, approvals)) => { - self.spawn_replenishing_approvals(approvals); - make_pooled_conn(self, conn) + let (conn, approvals) = self.inner.pop(); + self.spawn_replenishing_approvals(approvals); + let mut conn = match conn { + Some(conn) => make_pooled_conn(self, conn), + None => { + self.inner.notify.notified().await; + continue; } - None => break, }; if !self.inner.statics.test_on_check_out { @@ -132,19 +133,6 @@ where } } } - - let (tx, rx) = oneshot::channel(); - { - let mut locked = self.inner.internals.lock(); - let approvals = locked.push_waiter(tx, &self.inner.statics); - self.spawn_replenishing_approvals(approvals); - }; - - match rx.await { - Ok(Ok(mut guard)) => Ok(make_pooled_conn(self, guard.extract())), - Ok(Err(e)) => Err(RunError::User(e)), - Err(_) => Err(RunError::TimedOut), - } } pub(crate) async fn connect(&self) -> Result { diff --git a/bb8/src/internals.rs b/bb8/src/internals.rs index 918ab65..8d6585c 100644 --- a/bb8/src/internals.rs +++ b/bb8/src/internals.rs @@ -3,7 +3,7 @@ use std::sync::Arc; use std::time::Instant; use crate::{api::QueueStrategy, lock::Mutex}; -use futures_channel::oneshot; +use tokio::sync::Notify; use crate::api::{Builder, ManageConnection}; use std::collections::VecDeque; @@ -17,6 +17,7 @@ where pub(crate) statics: Builder, pub(crate) manager: M, pub(crate) internals: Mutex>, + pub(crate) notify: Arc, } impl SharedPool @@ -28,24 +29,22 @@ where statics, manager, internals: Mutex::new(PoolInternals::default()), + notify: Arc::new(Notify::new()), } } - pub(crate) fn pop(&self) -> Option<(Conn, ApprovalIter)> { + pub(crate) fn pop(&self) -> (Option>, ApprovalIter) { let mut locked = self.internals.lock(); - let idle = locked.conns.pop_front()?; - Some((idle.conn, locked.wanted(&self.statics))) + let conn = locked.conns.pop_front().map(|idle| idle.conn); + let approvals = match &conn { + Some(_) => locked.wanted(&self.statics), + None => locked.approvals(&self.statics, 1), + }; + + (conn, approvals) } - pub(crate) fn forward_error(&self, mut err: M::Error) { - let mut locked = self.internals.lock(); - while let Some(waiter) = locked.waiters.pop_front() { - match waiter.send(Err(err)) { - Ok(_) => return, - Err(Err(e)) => err = e, - Err(Ok(_)) => unreachable!(), - } - } + pub(crate) fn forward_error(&self, err: M::Error) { self.statics.error_sink.sink(err); } } @@ -56,7 +55,6 @@ pub(crate) struct PoolInternals where M: ManageConnection, { - waiters: VecDeque, M::Error>>>, conns: VecDeque>, num_conns: u32, pending_conns: u32, @@ -77,26 +75,14 @@ where self.num_conns += 1; } - let queue_strategy = pool.statics.queue_strategy; - - let mut guard = InternalsGuard::new(conn, pool); - while let Some(waiter) = self.waiters.pop_front() { - // This connection is no longer idle, send it back out - match waiter.send(Ok(guard)) { - Ok(()) => return, - Err(Ok(g)) => { - guard = g; - } - Err(Err(_)) => unreachable!(), - } - } - // Queue it in the idle queue - let conn = IdleConn::from(guard.conn.take().unwrap()); - match queue_strategy { + let conn = IdleConn::from(conn); + match pool.statics.queue_strategy { QueueStrategy::Fifo => self.conns.push_back(conn), QueueStrategy::Lifo => self.conns.push_front(conn), } + + pool.notify.notify_one(); } pub(crate) fn connect_failed(&mut self, _: Approval) { @@ -120,15 +106,6 @@ where self.approvals(config, wanted) } - pub(crate) fn push_waiter( - &mut self, - waiter: oneshot::Sender, M::Error>>, - config: &Builder, - ) -> ApprovalIter { - self.waiters.push_back(waiter); - self.approvals(config, 1) - } - fn approvals(&mut self, config: &Builder, num: u32) -> ApprovalIter { let current = self.num_conns + self.pending_conns; let allowed = if current < config.max_size { @@ -174,7 +151,6 @@ where { fn default() -> Self { Self { - waiters: VecDeque::new(), conns: VecDeque::new(), num_conns: 0, pending_conns: 0, @@ -182,33 +158,6 @@ where } } -pub(crate) struct InternalsGuard { - conn: Option>, - pool: Arc>, -} - -impl InternalsGuard { - fn new(conn: Conn, pool: Arc>) -> Self { - Self { - conn: Some(conn), - pool, - } - } - - pub(crate) fn extract(&mut self) -> Conn { - self.conn.take().unwrap() // safe: can only be `None` after `Drop` - } -} - -impl Drop for InternalsGuard { - fn drop(&mut self) { - if let Some(conn) = self.conn.take() { - let mut locked = self.pool.internals.lock(); - locked.put(conn, None, self.pool.clone()); - } - } -} - #[must_use] pub(crate) struct ApprovalIter { num: usize, diff --git a/bb8/tests/test.rs b/bb8/tests/test.rs index 70710a2..69853ed 100644 --- a/bb8/tests/test.rs +++ b/bb8/tests/test.rs @@ -282,7 +282,7 @@ async fn test_lazy_initialization_failure_no_retry() { .build_unchecked(manager); let res = pool.get().await; - assert_eq!(res.unwrap_err(), RunError::User(Error)); + assert_eq!(res.unwrap_err(), RunError::TimedOut); } #[tokio::test]