Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: allow polling only when worker is ready #472

Merged
merged 6 commits into from
Dec 3, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 0 additions & 4 deletions .github/workflows/redis.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,3 @@ jobs:
working-directory: packages/apalis-redis
env:
REDIS_URL: redis://127.0.0.1/
- run: cargo test -- --test-threads=1
working-directory: packages/apalis-redis
env:
REDIS_URL: redis://127.0.0.1/
1 change: 1 addition & 0 deletions packages/apalis-core/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -227,6 +227,7 @@ pub mod test_utils {
{
let worker_id = WorkerId::new("test-worker");
let worker = Worker::new(worker_id, crate::worker::Context::default());
worker.start();
let b = backend.clone();
let mut poller = b.poll::<S>(&worker);
let (stop_tx, mut stop_rx) = channel::<()>(1);
Expand Down
176 changes: 176 additions & 0 deletions packages/apalis-core/src/worker/call_all.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,176 @@
use futures::{ready, stream::FuturesUnordered, Stream};
use pin_project_lite::pin_project;
use std::{
fmt,
future::Future,
pin::Pin,
task::{Context, Poll},
};
use tower::Service;

pin_project! {
/// A stream of responses received from the inner service in received order.
#[derive(Debug)]
pub(super) struct CallAllUnordered<Svc, S>
where
Svc: Service<S::Item>,
S: Stream,
{
#[pin]
inner: CallAll<Svc, S, FuturesUnordered<Svc::Future>>,
}
}

impl<Svc, S> CallAllUnordered<Svc, S>
where
Svc: Service<S::Item>,
S: Stream,
{
/// Create new [`CallAllUnordered`] combinator.
///
/// [`Stream`]: https://docs.rs/futures/latest/futures/stream/trait.Stream.html
pub(super) fn new(service: Svc, stream: S) -> CallAllUnordered<Svc, S> {
CallAllUnordered {
inner: CallAll::new(service, stream, FuturesUnordered::new()),
}
}
}

impl<Svc, S> Stream for CallAllUnordered<Svc, S>
where
Svc: Service<S::Item>,
S: Stream,
{
type Item = Result<Svc::Response, Svc::Error>;

fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
self.project().inner.poll_next(cx)
}
}

impl<F: Future> Drive<F> for FuturesUnordered<F> {
fn is_empty(&self) -> bool {
FuturesUnordered::is_empty(self)
}

fn push(&mut self, future: F) {
FuturesUnordered::push(self, future)
}

fn poll(&mut self, cx: &mut Context<'_>) -> Poll<Option<F::Output>> {
Stream::poll_next(Pin::new(self), cx)
}
}

pin_project! {
/// The [`Future`] returned by the [`ServiceExt::call_all`] combinator.
pub(crate) struct CallAll<Svc, S, Q>
where
S: Stream,
{
service: Option<Svc>,
#[pin]
stream: S,
queue: Q,
eof: bool,
curr_req: Option<S::Item>
}
}

impl<Svc, S, Q> fmt::Debug for CallAll<Svc, S, Q>
where
Svc: fmt::Debug,
S: Stream + fmt::Debug,
{
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_struct("CallAll")
.field("service", &self.service)
.field("stream", &self.stream)
.field("eof", &self.eof)
.finish()
}
}

pub(crate) trait Drive<F: Future> {
fn is_empty(&self) -> bool;

fn push(&mut self, future: F);

fn poll(&mut self, cx: &mut Context<'_>) -> Poll<Option<F::Output>>;
}

impl<Svc, S, Q> CallAll<Svc, S, Q>
where
Svc: Service<S::Item>,
S: Stream,
Q: Drive<Svc::Future>,
{
pub(crate) const fn new(service: Svc, stream: S, queue: Q) -> CallAll<Svc, S, Q> {
CallAll {
service: Some(service),
stream,
queue,
eof: false,
curr_req: None,
}
}
}

impl<Svc, S, Q> Stream for CallAll<Svc, S, Q>
where
Svc: Service<S::Item>,
S: Stream,
Q: Drive<Svc::Future>,
{
type Item = Result<Svc::Response, Svc::Error>;

fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
let mut this = self.project();

loop {
// First, see if we have any responses to yield
if let Poll::Ready(r) = this.queue.poll(cx) {
if let Some(rsp) = r.transpose()? {
return Poll::Ready(Some(Ok(rsp)));
}
}

// If there are no more requests coming, check if we're done
if *this.eof {
if this.queue.is_empty() {
return Poll::Ready(None);
} else {
return Poll::Pending;
}
}

// Then, see that the service is ready for another request
let svc = this
.service
.as_mut()
.expect("Using CallAll after extracting inner Service");

if let Err(e) = ready!(svc.poll_ready(cx)) {
// Set eof to prevent the service from being called again after a `poll_ready` error
*this.eof = true;
return Poll::Ready(Some(Err(e)));
}

// If not done, and we don't have a stored request, gather the next request from the
// stream (if there is one), or return `Pending` if the stream is not ready.
if this.curr_req.is_none() {
*this.curr_req = match ready!(this.stream.as_mut().poll_next(cx)) {
Some(next_req) => Some(next_req),
None => {
// Mark that there will be no more requests.
*this.eof = true;
continue;
}
};
}

// Unwrap: The check above always sets `this.curr_req` if none.
this.queue.push(svc.call(this.curr_req.take().unwrap()));
}
}
}
75 changes: 71 additions & 4 deletions packages/apalis-core/src/worker/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ use crate::monitor::shutdown::Shutdown;
use crate::request::Request;
use crate::service_fn::FromRequest;
use crate::task::task_id::TaskId;
use call_all::CallAllUnordered;
use futures::future::{join, select, BoxFuture};
use futures::stream::BoxStream;
use futures::{Future, FutureExt, Stream, StreamExt};
Expand All @@ -19,9 +20,10 @@ use std::sync::atomic::{AtomicBool, AtomicUsize, Ordering};
use std::sync::{Arc, Mutex, RwLock};
use std::task::{Context as TaskCtx, Poll, Waker};
use thiserror::Error;
use tower::util::CallAllUnordered;
use tower::{Layer, Service, ServiceBuilder};

mod call_all;

/// A worker name wrapper usually used by Worker builder
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
pub struct WorkerId {
Expand Down Expand Up @@ -208,6 +210,12 @@ impl Worker<Context> {
}
false
}
/// Start running the worker
pub fn start(&self) {
self.state.running.store(true, Ordering::Relaxed);
self.state.is_ready.store(true, Ordering::Release);
self.emit(Event::Start);
}
}

impl<Req, Ctx> FromRequest<Request<Req, Ctx>> for Worker<Context> {
Expand Down Expand Up @@ -290,13 +298,14 @@ impl<S, P> Worker<Ready<S, P>> {
Ctx: Send + 'static + Sync,
Res: 'static,
{
let worker_id = self.id().clone();
let worker_id = self.id;
let ctx = Context {
running: Arc::default(),
task_count: Arc::default(),
wakers: Arc::default(),
shutdown: self.state.shutdown,
event_handler: self.state.event_handler.clone(),
is_ready: Arc::default(),
};
let worker = Worker {
id: worker_id.clone(),
Expand All @@ -310,6 +319,7 @@ impl<S, P> Worker<Ready<S, P>> {
let layer = poller.layer;
let service = ServiceBuilder::new()
.layer(TrackerLayer::new(worker.state.clone()))
.layer(ReadinessLayer::new(worker.state.is_ready.clone()))
.layer(Data::new(worker.clone()))
.layer(layer)
.service(service);
Expand Down Expand Up @@ -366,9 +376,8 @@ impl Future for Runnable {
let poller_future = async { while (poller.next().await).is_some() {} };

if !this.running {
worker.running.store(true, Ordering::Relaxed);
worker.start();
this.running = true;
worker.emit(Event::Start);
}
let combined = Box::pin(join(poller_future, heartbeat.as_mut()));

Expand All @@ -395,6 +404,7 @@ pub struct Context {
running: Arc<AtomicBool>,
shutdown: Option<Shutdown>,
event_handler: EventHandler,
is_ready: Arc<AtomicBool>,
}

impl fmt::Debug for Context {
Expand Down Expand Up @@ -497,6 +507,11 @@ impl Context {
}
}
}

/// Returns if the worker is ready to consume new tasks
pub fn is_ready(&self) -> bool {
self.is_ready.load(Ordering::Acquire) && !self.is_shutting_down()
}
}

impl Future for Context {
Expand Down Expand Up @@ -557,6 +572,58 @@ where
}
}

#[derive(Clone)]
struct ReadinessLayer {
is_ready: Arc<AtomicBool>,
}

impl ReadinessLayer {
fn new(is_ready: Arc<AtomicBool>) -> Self {
Self { is_ready }
}
}

impl<S> Layer<S> for ReadinessLayer {
type Service = ReadinessService<S>;

fn layer(&self, inner: S) -> Self::Service {
ReadinessService {
inner,
is_ready: self.is_ready.clone(),
}
}
}

struct ReadinessService<S> {
inner: S,
is_ready: Arc<AtomicBool>,
}

impl<S, Request> Service<Request> for ReadinessService<S>
where
S: Service<Request>,
{
type Response = S::Response;
type Error = S::Error;
type Future = S::Future;

fn poll_ready(&mut self, cx: &mut std::task::Context<'_>) -> Poll<Result<(), Self::Error>> {
// Delegate poll_ready to the inner service
let result = self.inner.poll_ready(cx);
// Update the readiness state based on the result
match &result {
Poll::Ready(Ok(_)) => self.is_ready.store(true, Ordering::Release),
Poll::Pending | Poll::Ready(Err(_)) => self.is_ready.store(false, Ordering::Release),
}

result
}

fn call(&mut self, req: Request) -> Self::Future {
self.inner.call(req)
}
}

#[cfg(test)]
mod tests {
use std::{ops::Deref, sync::atomic::AtomicUsize};
Expand Down
Loading
Loading