Skip to content

Commit

Permalink
Use Notify to coordinate waiters
Browse files Browse the repository at this point in the history
  • Loading branch information
djc committed Jan 26, 2024
1 parent 96e09c5 commit a3f0b2e
Show file tree
Hide file tree
Showing 4 changed files with 29 additions and 86 deletions.
2 changes: 1 addition & 1 deletion bb8/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ async-trait = "0.1"
futures-channel = "0.3.2"
futures-util = { version = "0.3.2", default-features = false, features = ["channel"] }
parking_lot = { version = "0.12", optional = true }
tokio = { version = "1.0", features = ["rt", "time"] }
tokio = { version = "1.0", features = ["rt", "sync", "time"] }

[dev-dependencies]
tokio = { version = "1.0", features = ["macros"] }
Expand Down
28 changes: 11 additions & 17 deletions bb8/src/inner.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@ use std::future::Future;
use std::sync::{Arc, Weak};
use std::time::{Duration, Instant};

use futures_channel::oneshot;
use futures_util::stream::{FuturesUnordered, StreamExt};
use futures_util::TryFutureExt;
use tokio::spawn;
Expand Down Expand Up @@ -111,12 +110,20 @@ where
make_pooled_conn: impl Fn(&'a Self, Conn<M::Connection>) -> PooledConnection<'b, M>,
) -> Result<PooledConnection<'b, M>, RunError<M::Error>> {
loop {
let mut conn = match self.inner.pop() {
Some((conn, approvals)) => {
let (conn, approvals) = self.inner.pop();
let mut conn = match conn {
Some(conn) => {
self.spawn_replenishing_approvals(approvals);
make_pooled_conn(self, conn)
}
None => break,
None => {
let notified = self.inner.notify.notified();
self.spawn_replenishing_approvals(approvals);
match timeout(self.inner.statics.connection_timeout, notified).await {
Ok(()) => continue,
Err(_) => return Err(RunError::TimedOut),
}
}
};

if !self.inner.statics.test_on_check_out {
Expand All @@ -132,19 +139,6 @@ where
}
}
}

let (tx, rx) = oneshot::channel();
{
let mut locked = self.inner.internals.lock();
let approvals = locked.push_waiter(tx, &self.inner.statics);
self.spawn_replenishing_approvals(approvals);
};

match rx.await {
Ok(Ok(mut guard)) => Ok(make_pooled_conn(self, guard.extract())),
Ok(Err(e)) => Err(RunError::User(e)),
Err(_) => Err(RunError::TimedOut),
}
}

pub(crate) async fn connect(&self) -> Result<M::Connection, M::Error> {
Expand Down
83 changes: 16 additions & 67 deletions bb8/src/internals.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ use std::sync::Arc;
use std::time::Instant;

use crate::{api::QueueStrategy, lock::Mutex};
use futures_channel::oneshot;
use tokio::sync::Notify;

use crate::api::{Builder, ManageConnection};
use std::collections::VecDeque;
Expand All @@ -17,6 +17,7 @@ where
pub(crate) statics: Builder<M>,
pub(crate) manager: M,
pub(crate) internals: Mutex<PoolInternals<M>>,
pub(crate) notify: Arc<Notify>,
}

impl<M> SharedPool<M>
Expand All @@ -28,24 +29,22 @@ where
statics,
manager,
internals: Mutex::new(PoolInternals::default()),
notify: Arc::new(Notify::new()),
}
}

pub(crate) fn pop(&self) -> Option<(Conn<M::Connection>, ApprovalIter)> {
pub(crate) fn pop(&self) -> (Option<Conn<M::Connection>>, ApprovalIter) {
let mut locked = self.internals.lock();
let idle = locked.conns.pop_front()?;
Some((idle.conn, locked.wanted(&self.statics)))
let conn = locked.conns.pop_front().map(|idle| idle.conn);
let approvals = match &conn {
Some(_) => locked.wanted(&self.statics),
None => locked.approvals(&self.statics, 1),
};

(conn, approvals)
}

pub(crate) fn forward_error(&self, mut err: M::Error) {
let mut locked = self.internals.lock();
while let Some(waiter) = locked.waiters.pop_front() {
match waiter.send(Err(err)) {
Ok(_) => return,
Err(Err(e)) => err = e,
Err(Ok(_)) => unreachable!(),
}
}
pub(crate) fn forward_error(&self, err: M::Error) {
self.statics.error_sink.sink(err);
}
}
Expand All @@ -56,7 +55,6 @@ pub(crate) struct PoolInternals<M>
where
M: ManageConnection,
{
waiters: VecDeque<oneshot::Sender<Result<InternalsGuard<M>, M::Error>>>,
conns: VecDeque<IdleConn<M::Connection>>,
num_conns: u32,
pending_conns: u32,
Expand All @@ -77,26 +75,14 @@ where
self.num_conns += 1;
}

let queue_strategy = pool.statics.queue_strategy;

let mut guard = InternalsGuard::new(conn, pool);
while let Some(waiter) = self.waiters.pop_front() {
// This connection is no longer idle, send it back out
match waiter.send(Ok(guard)) {
Ok(()) => return,
Err(Ok(g)) => {
guard = g;
}
Err(Err(_)) => unreachable!(),
}
}

// Queue it in the idle queue
let conn = IdleConn::from(guard.conn.take().unwrap());
match queue_strategy {
let conn = IdleConn::from(conn);
match pool.statics.queue_strategy {
QueueStrategy::Fifo => self.conns.push_back(conn),
QueueStrategy::Lifo => self.conns.push_front(conn),
}

pool.notify.notify_one();
}

pub(crate) fn connect_failed(&mut self, _: Approval) {
Expand All @@ -120,15 +106,6 @@ where
self.approvals(config, wanted)
}

pub(crate) fn push_waiter(
&mut self,
waiter: oneshot::Sender<Result<InternalsGuard<M>, M::Error>>,
config: &Builder<M>,
) -> ApprovalIter {
self.waiters.push_back(waiter);
self.approvals(config, 1)
}

fn approvals(&mut self, config: &Builder<M>, num: u32) -> ApprovalIter {
let current = self.num_conns + self.pending_conns;
let allowed = if current < config.max_size {
Expand Down Expand Up @@ -174,41 +151,13 @@ where
{
fn default() -> Self {
Self {
waiters: VecDeque::new(),
conns: VecDeque::new(),
num_conns: 0,
pending_conns: 0,
}
}
}

pub(crate) struct InternalsGuard<M: ManageConnection> {
conn: Option<Conn<M::Connection>>,
pool: Arc<SharedPool<M>>,
}

impl<M: ManageConnection> InternalsGuard<M> {
fn new(conn: Conn<M::Connection>, pool: Arc<SharedPool<M>>) -> Self {
Self {
conn: Some(conn),
pool,
}
}

pub(crate) fn extract(&mut self) -> Conn<M::Connection> {
self.conn.take().unwrap() // safe: can only be `None` after `Drop`
}
}

impl<M: ManageConnection> Drop for InternalsGuard<M> {
fn drop(&mut self) {
if let Some(conn) = self.conn.take() {
let mut locked = self.pool.internals.lock();
locked.put(conn, None, self.pool.clone());
}
}
}

#[must_use]
pub(crate) struct ApprovalIter {
num: usize,
Expand Down
2 changes: 1 addition & 1 deletion bb8/tests/test.rs
Original file line number Diff line number Diff line change
Expand Up @@ -282,7 +282,7 @@ async fn test_lazy_initialization_failure_no_retry() {
.build_unchecked(manager);

let res = pool.get().await;
assert_eq!(res.unwrap_err(), RunError::User(Error));
assert_eq!(res.unwrap_err(), RunError::TimedOut);
}

#[tokio::test]
Expand Down

0 comments on commit a3f0b2e

Please sign in to comment.