Skip to content

Commit

Permalink
fix: allow polling only when worker is ready (#472)
Browse files Browse the repository at this point in the history
* fix: allow checking if service is ready

* fix: handle worker readiness before polling next

* lint: cargo clippy

* fix: get tests working

* fix: set start to true instead of false

* fix: get integration tests passing
  • Loading branch information
geofmureithi authored Dec 3, 2024
1 parent ad78ef3 commit b597cf4
Show file tree
Hide file tree
Showing 8 changed files with 367 additions and 102 deletions.
4 changes: 0 additions & 4 deletions .github/workflows/redis.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,3 @@ jobs:
working-directory: packages/apalis-redis
env:
REDIS_URL: redis://127.0.0.1/
- run: cargo test -- --test-threads=1
working-directory: packages/apalis-redis
env:
REDIS_URL: redis://127.0.0.1/
1 change: 1 addition & 0 deletions packages/apalis-core/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -227,6 +227,7 @@ pub mod test_utils {
{
let worker_id = WorkerId::new("test-worker");
let worker = Worker::new(worker_id, crate::worker::Context::default());
worker.start();
let b = backend.clone();
let mut poller = b.poll::<S>(&worker);
let (stop_tx, mut stop_rx) = channel::<()>(1);
Expand Down
176 changes: 176 additions & 0 deletions packages/apalis-core/src/worker/call_all.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,176 @@
use futures::{ready, stream::FuturesUnordered, Stream};
use pin_project_lite::pin_project;
use std::{
fmt,
future::Future,
pin::Pin,
task::{Context, Poll},
};
use tower::Service;

pin_project! {
/// A stream of responses received from the inner service in received order.
#[derive(Debug)]
pub(super) struct CallAllUnordered<Svc, S>
where
Svc: Service<S::Item>,
S: Stream,
{
#[pin]
inner: CallAll<Svc, S, FuturesUnordered<Svc::Future>>,
}
}

impl<Svc, S> CallAllUnordered<Svc, S>
where
Svc: Service<S::Item>,
S: Stream,
{
/// Create new [`CallAllUnordered`] combinator.
///
/// [`Stream`]: https://docs.rs/futures/latest/futures/stream/trait.Stream.html
pub(super) fn new(service: Svc, stream: S) -> CallAllUnordered<Svc, S> {
CallAllUnordered {
inner: CallAll::new(service, stream, FuturesUnordered::new()),
}
}
}

impl<Svc, S> Stream for CallAllUnordered<Svc, S>
where
Svc: Service<S::Item>,
S: Stream,
{
type Item = Result<Svc::Response, Svc::Error>;

fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
self.project().inner.poll_next(cx)
}
}

impl<F: Future> Drive<F> for FuturesUnordered<F> {
fn is_empty(&self) -> bool {
FuturesUnordered::is_empty(self)
}

fn push(&mut self, future: F) {
FuturesUnordered::push(self, future)
}

fn poll(&mut self, cx: &mut Context<'_>) -> Poll<Option<F::Output>> {
Stream::poll_next(Pin::new(self), cx)
}
}

pin_project! {
/// The [`Future`] returned by the [`ServiceExt::call_all`] combinator.
pub(crate) struct CallAll<Svc, S, Q>
where
S: Stream,
{
service: Option<Svc>,
#[pin]
stream: S,
queue: Q,
eof: bool,
curr_req: Option<S::Item>
}
}

impl<Svc, S, Q> fmt::Debug for CallAll<Svc, S, Q>
where
Svc: fmt::Debug,
S: Stream + fmt::Debug,
{
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_struct("CallAll")
.field("service", &self.service)
.field("stream", &self.stream)
.field("eof", &self.eof)
.finish()
}
}

pub(crate) trait Drive<F: Future> {
fn is_empty(&self) -> bool;

fn push(&mut self, future: F);

fn poll(&mut self, cx: &mut Context<'_>) -> Poll<Option<F::Output>>;
}

impl<Svc, S, Q> CallAll<Svc, S, Q>
where
Svc: Service<S::Item>,
S: Stream,
Q: Drive<Svc::Future>,
{
pub(crate) const fn new(service: Svc, stream: S, queue: Q) -> CallAll<Svc, S, Q> {
CallAll {
service: Some(service),
stream,
queue,
eof: false,
curr_req: None,
}
}
}

impl<Svc, S, Q> Stream for CallAll<Svc, S, Q>
where
Svc: Service<S::Item>,
S: Stream,
Q: Drive<Svc::Future>,
{
type Item = Result<Svc::Response, Svc::Error>;

fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
let mut this = self.project();

loop {
// First, see if we have any responses to yield
if let Poll::Ready(r) = this.queue.poll(cx) {
if let Some(rsp) = r.transpose()? {
return Poll::Ready(Some(Ok(rsp)));
}
}

// If there are no more requests coming, check if we're done
if *this.eof {
if this.queue.is_empty() {
return Poll::Ready(None);
} else {
return Poll::Pending;
}
}

// Then, see that the service is ready for another request
let svc = this
.service
.as_mut()
.expect("Using CallAll after extracting inner Service");

if let Err(e) = ready!(svc.poll_ready(cx)) {
// Set eof to prevent the service from being called again after a `poll_ready` error
*this.eof = true;
return Poll::Ready(Some(Err(e)));
}

// If not done, and we don't have a stored request, gather the next request from the
// stream (if there is one), or return `Pending` if the stream is not ready.
if this.curr_req.is_none() {
*this.curr_req = match ready!(this.stream.as_mut().poll_next(cx)) {
Some(next_req) => Some(next_req),
None => {
// Mark that there will be no more requests.
*this.eof = true;
continue;
}
};
}

// Unwrap: The check above always sets `this.curr_req` if none.
this.queue.push(svc.call(this.curr_req.take().unwrap()));
}
}
}
75 changes: 71 additions & 4 deletions packages/apalis-core/src/worker/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ use crate::monitor::shutdown::Shutdown;
use crate::request::Request;
use crate::service_fn::FromRequest;
use crate::task::task_id::TaskId;
use call_all::CallAllUnordered;
use futures::future::{join, select, BoxFuture};
use futures::stream::BoxStream;
use futures::{Future, FutureExt, Stream, StreamExt};
Expand All @@ -19,9 +20,10 @@ use std::sync::atomic::{AtomicBool, AtomicUsize, Ordering};
use std::sync::{Arc, Mutex, RwLock};
use std::task::{Context as TaskCtx, Poll, Waker};
use thiserror::Error;
use tower::util::CallAllUnordered;
use tower::{Layer, Service, ServiceBuilder};

mod call_all;

/// A worker name wrapper usually used by Worker builder
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
pub struct WorkerId {
Expand Down Expand Up @@ -208,6 +210,12 @@ impl Worker<Context> {
}
false
}
/// Start running the worker
pub fn start(&self) {
self.state.running.store(true, Ordering::Relaxed);
self.state.is_ready.store(true, Ordering::Release);
self.emit(Event::Start);
}
}

impl<Req, Ctx> FromRequest<Request<Req, Ctx>> for Worker<Context> {
Expand Down Expand Up @@ -290,13 +298,14 @@ impl<S, P> Worker<Ready<S, P>> {
Ctx: Send + 'static + Sync,
Res: 'static,
{
let worker_id = self.id().clone();
let worker_id = self.id;
let ctx = Context {
running: Arc::default(),
task_count: Arc::default(),
wakers: Arc::default(),
shutdown: self.state.shutdown,
event_handler: self.state.event_handler.clone(),
is_ready: Arc::default(),
};
let worker = Worker {
id: worker_id.clone(),
Expand All @@ -310,6 +319,7 @@ impl<S, P> Worker<Ready<S, P>> {
let layer = poller.layer;
let service = ServiceBuilder::new()
.layer(TrackerLayer::new(worker.state.clone()))
.layer(ReadinessLayer::new(worker.state.is_ready.clone()))
.layer(Data::new(worker.clone()))
.layer(layer)
.service(service);
Expand Down Expand Up @@ -366,9 +376,8 @@ impl Future for Runnable {
let poller_future = async { while (poller.next().await).is_some() {} };

if !this.running {
worker.running.store(true, Ordering::Relaxed);
worker.start();
this.running = true;
worker.emit(Event::Start);
}
let combined = Box::pin(join(poller_future, heartbeat.as_mut()));

Expand All @@ -395,6 +404,7 @@ pub struct Context {
running: Arc<AtomicBool>,
shutdown: Option<Shutdown>,
event_handler: EventHandler,
is_ready: Arc<AtomicBool>,
}

impl fmt::Debug for Context {
Expand Down Expand Up @@ -497,6 +507,11 @@ impl Context {
}
}
}

/// Returns if the worker is ready to consume new tasks
pub fn is_ready(&self) -> bool {
self.is_ready.load(Ordering::Acquire) && !self.is_shutting_down()
}
}

impl Future for Context {
Expand Down Expand Up @@ -557,6 +572,58 @@ where
}
}

#[derive(Clone)]
struct ReadinessLayer {
is_ready: Arc<AtomicBool>,
}

impl ReadinessLayer {
fn new(is_ready: Arc<AtomicBool>) -> Self {
Self { is_ready }
}
}

impl<S> Layer<S> for ReadinessLayer {
type Service = ReadinessService<S>;

fn layer(&self, inner: S) -> Self::Service {
ReadinessService {
inner,
is_ready: self.is_ready.clone(),
}
}
}

struct ReadinessService<S> {
inner: S,
is_ready: Arc<AtomicBool>,
}

impl<S, Request> Service<Request> for ReadinessService<S>
where
S: Service<Request>,
{
type Response = S::Response;
type Error = S::Error;
type Future = S::Future;

fn poll_ready(&mut self, cx: &mut std::task::Context<'_>) -> Poll<Result<(), Self::Error>> {
// Delegate poll_ready to the inner service
let result = self.inner.poll_ready(cx);
// Update the readiness state based on the result
match &result {
Poll::Ready(Ok(_)) => self.is_ready.store(true, Ordering::Release),
Poll::Pending | Poll::Ready(Err(_)) => self.is_ready.store(false, Ordering::Release),
}

result
}

fn call(&mut self, req: Request) -> Self::Future {
self.inner.call(req)
}
}

#[cfg(test)]
mod tests {
use std::{ops::Deref, sync::atomic::AtomicUsize};
Expand Down
Loading

0 comments on commit b597cf4

Please sign in to comment.