From b597cf487ce47b348993b2f1d89d38e248742987 Mon Sep 17 00:00:00 2001 From: Geoffrey Mureithi <95377562+geofmureithi@users.noreply.github.com> Date: Tue, 3 Dec 2024 07:20:34 +0300 Subject: [PATCH] fix: allow polling only when worker is ready (#472) * fix: allow checking if service is ready * fix: handle worker readiness before polling next * lint: cargo clippy * fix: get tests working * fix: set start to true instead of false * fix: get integration tests passing --- .github/workflows/redis.yaml | 4 - packages/apalis-core/src/lib.rs | 1 + packages/apalis-core/src/worker/call_all.rs | 176 ++++++++++++++++++++ packages/apalis-core/src/worker/mod.rs | 75 ++++++++- packages/apalis-redis/src/storage.rs | 55 +++--- packages/apalis-sql/src/mysql.rs | 58 ++++--- packages/apalis-sql/src/postgres.rs | 41 ++--- packages/apalis-sql/src/sqlite.rs | 59 ++++--- 8 files changed, 367 insertions(+), 102 deletions(-) create mode 100644 packages/apalis-core/src/worker/call_all.rs diff --git a/.github/workflows/redis.yaml b/.github/workflows/redis.yaml index 66d3673..3c0a006 100644 --- a/.github/workflows/redis.yaml +++ b/.github/workflows/redis.yaml @@ -26,7 +26,3 @@ jobs: working-directory: packages/apalis-redis env: REDIS_URL: redis://127.0.0.1/ - - run: cargo test -- --test-threads=1 - working-directory: packages/apalis-redis - env: - REDIS_URL: redis://127.0.0.1/ diff --git a/packages/apalis-core/src/lib.rs b/packages/apalis-core/src/lib.rs index 6e356d4..f4793e9 100644 --- a/packages/apalis-core/src/lib.rs +++ b/packages/apalis-core/src/lib.rs @@ -227,6 +227,7 @@ pub mod test_utils { { let worker_id = WorkerId::new("test-worker"); let worker = Worker::new(worker_id, crate::worker::Context::default()); + worker.start(); let b = backend.clone(); let mut poller = b.poll::(&worker); let (stop_tx, mut stop_rx) = channel::<()>(1); diff --git a/packages/apalis-core/src/worker/call_all.rs b/packages/apalis-core/src/worker/call_all.rs new file mode 100644 index 0000000..711998a --- /dev/null +++ b/packages/apalis-core/src/worker/call_all.rs @@ -0,0 +1,176 @@ +use futures::{ready, stream::FuturesUnordered, Stream}; +use pin_project_lite::pin_project; +use std::{ + fmt, + future::Future, + pin::Pin, + task::{Context, Poll}, +}; +use tower::Service; + +pin_project! { + /// A stream of responses received from the inner service in received order. + #[derive(Debug)] + pub(super) struct CallAllUnordered + where + Svc: Service, + S: Stream, + { + #[pin] + inner: CallAll>, + } +} + +impl CallAllUnordered +where + Svc: Service, + S: Stream, +{ + /// Create new [`CallAllUnordered`] combinator. + /// + /// [`Stream`]: https://docs.rs/futures/latest/futures/stream/trait.Stream.html + pub(super) fn new(service: Svc, stream: S) -> CallAllUnordered { + CallAllUnordered { + inner: CallAll::new(service, stream, FuturesUnordered::new()), + } + } +} + +impl Stream for CallAllUnordered +where + Svc: Service, + S: Stream, +{ + type Item = Result; + + fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + self.project().inner.poll_next(cx) + } +} + +impl Drive for FuturesUnordered { + fn is_empty(&self) -> bool { + FuturesUnordered::is_empty(self) + } + + fn push(&mut self, future: F) { + FuturesUnordered::push(self, future) + } + + fn poll(&mut self, cx: &mut Context<'_>) -> Poll> { + Stream::poll_next(Pin::new(self), cx) + } +} + +pin_project! { + /// The [`Future`] returned by the [`ServiceExt::call_all`] combinator. + pub(crate) struct CallAll + where + S: Stream, + { + service: Option, + #[pin] + stream: S, + queue: Q, + eof: bool, + curr_req: Option + } +} + +impl fmt::Debug for CallAll +where + Svc: fmt::Debug, + S: Stream + fmt::Debug, +{ + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.debug_struct("CallAll") + .field("service", &self.service) + .field("stream", &self.stream) + .field("eof", &self.eof) + .finish() + } +} + +pub(crate) trait Drive { + fn is_empty(&self) -> bool; + + fn push(&mut self, future: F); + + fn poll(&mut self, cx: &mut Context<'_>) -> Poll>; +} + +impl CallAll +where + Svc: Service, + S: Stream, + Q: Drive, +{ + pub(crate) const fn new(service: Svc, stream: S, queue: Q) -> CallAll { + CallAll { + service: Some(service), + stream, + queue, + eof: false, + curr_req: None, + } + } +} + +impl Stream for CallAll +where + Svc: Service, + S: Stream, + Q: Drive, +{ + type Item = Result; + + fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + let mut this = self.project(); + + loop { + // First, see if we have any responses to yield + if let Poll::Ready(r) = this.queue.poll(cx) { + if let Some(rsp) = r.transpose()? { + return Poll::Ready(Some(Ok(rsp))); + } + } + + // If there are no more requests coming, check if we're done + if *this.eof { + if this.queue.is_empty() { + return Poll::Ready(None); + } else { + return Poll::Pending; + } + } + + // Then, see that the service is ready for another request + let svc = this + .service + .as_mut() + .expect("Using CallAll after extracting inner Service"); + + if let Err(e) = ready!(svc.poll_ready(cx)) { + // Set eof to prevent the service from being called again after a `poll_ready` error + *this.eof = true; + return Poll::Ready(Some(Err(e))); + } + + // If not done, and we don't have a stored request, gather the next request from the + // stream (if there is one), or return `Pending` if the stream is not ready. + if this.curr_req.is_none() { + *this.curr_req = match ready!(this.stream.as_mut().poll_next(cx)) { + Some(next_req) => Some(next_req), + None => { + // Mark that there will be no more requests. + *this.eof = true; + continue; + } + }; + } + + // Unwrap: The check above always sets `this.curr_req` if none. + this.queue.push(svc.call(this.curr_req.take().unwrap())); + } + } +} diff --git a/packages/apalis-core/src/worker/mod.rs b/packages/apalis-core/src/worker/mod.rs index 57261b4..bb07621 100644 --- a/packages/apalis-core/src/worker/mod.rs +++ b/packages/apalis-core/src/worker/mod.rs @@ -5,6 +5,7 @@ use crate::monitor::shutdown::Shutdown; use crate::request::Request; use crate::service_fn::FromRequest; use crate::task::task_id::TaskId; +use call_all::CallAllUnordered; use futures::future::{join, select, BoxFuture}; use futures::stream::BoxStream; use futures::{Future, FutureExt, Stream, StreamExt}; @@ -19,9 +20,10 @@ use std::sync::atomic::{AtomicBool, AtomicUsize, Ordering}; use std::sync::{Arc, Mutex, RwLock}; use std::task::{Context as TaskCtx, Poll, Waker}; use thiserror::Error; -use tower::util::CallAllUnordered; use tower::{Layer, Service, ServiceBuilder}; +mod call_all; + /// A worker name wrapper usually used by Worker builder #[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)] pub struct WorkerId { @@ -208,6 +210,12 @@ impl Worker { } false } + /// Start running the worker + pub fn start(&self) { + self.state.running.store(true, Ordering::Relaxed); + self.state.is_ready.store(true, Ordering::Release); + self.emit(Event::Start); + } } impl FromRequest> for Worker { @@ -290,13 +298,14 @@ impl Worker> { Ctx: Send + 'static + Sync, Res: 'static, { - let worker_id = self.id().clone(); + let worker_id = self.id; let ctx = Context { running: Arc::default(), task_count: Arc::default(), wakers: Arc::default(), shutdown: self.state.shutdown, event_handler: self.state.event_handler.clone(), + is_ready: Arc::default(), }; let worker = Worker { id: worker_id.clone(), @@ -310,6 +319,7 @@ impl Worker> { let layer = poller.layer; let service = ServiceBuilder::new() .layer(TrackerLayer::new(worker.state.clone())) + .layer(ReadinessLayer::new(worker.state.is_ready.clone())) .layer(Data::new(worker.clone())) .layer(layer) .service(service); @@ -366,9 +376,8 @@ impl Future for Runnable { let poller_future = async { while (poller.next().await).is_some() {} }; if !this.running { - worker.running.store(true, Ordering::Relaxed); + worker.start(); this.running = true; - worker.emit(Event::Start); } let combined = Box::pin(join(poller_future, heartbeat.as_mut())); @@ -395,6 +404,7 @@ pub struct Context { running: Arc, shutdown: Option, event_handler: EventHandler, + is_ready: Arc, } impl fmt::Debug for Context { @@ -497,6 +507,11 @@ impl Context { } } } + + /// Returns if the worker is ready to consume new tasks + pub fn is_ready(&self) -> bool { + self.is_ready.load(Ordering::Acquire) && !self.is_shutting_down() + } } impl Future for Context { @@ -557,6 +572,58 @@ where } } +#[derive(Clone)] +struct ReadinessLayer { + is_ready: Arc, +} + +impl ReadinessLayer { + fn new(is_ready: Arc) -> Self { + Self { is_ready } + } +} + +impl Layer for ReadinessLayer { + type Service = ReadinessService; + + fn layer(&self, inner: S) -> Self::Service { + ReadinessService { + inner, + is_ready: self.is_ready.clone(), + } + } +} + +struct ReadinessService { + inner: S, + is_ready: Arc, +} + +impl Service for ReadinessService +where + S: Service, +{ + type Response = S::Response; + type Error = S::Error; + type Future = S::Future; + + fn poll_ready(&mut self, cx: &mut std::task::Context<'_>) -> Poll> { + // Delegate poll_ready to the inner service + let result = self.inner.poll_ready(cx); + // Update the readiness state based on the result + match &result { + Poll::Ready(Ok(_)) => self.is_ready.store(true, Ordering::Release), + Poll::Pending | Poll::Ready(Err(_)) => self.is_ready.store(false, Ordering::Release), + } + + result + } + + fn call(&mut self, req: Request) -> Self::Future { + self.inner.call(req) + } +} + #[cfg(test)] mod tests { use std::{ops::Deref, sync::atomic::AtomicUsize}; diff --git a/packages/apalis-redis/src/storage.rs b/packages/apalis-redis/src/storage.rs index d29a080..e2e130e 100644 --- a/packages/apalis-redis/src/storage.rs +++ b/packages/apalis-redis/src/storage.rs @@ -479,18 +479,22 @@ where } } _ = poll_next_stm.next() => { - let res = self.fetch_next(worker.id()).await; - match res { - Err(e) => { - worker.emit(Event::Error(Box::new(RedisPollError::PollNextError(e)))); - } - Ok(res) => { - for job in res { - if let Err(e) = tx.send(Ok(Some(job))).await { - worker.emit(Event::Error(Box::new(RedisPollError::EnqueueError(e)))); + if worker.is_ready() { + let res = self.fetch_next(worker.id()).await; + match res { + Err(e) => { + worker.emit(Event::Error(Box::new(RedisPollError::PollNextError(e)))); + } + Ok(res) => { + for job in res { + if let Err(e) = tx.send(Ok(Some(job))).await { + worker.emit(Event::Error(Box::new(RedisPollError::EnqueueError(e)))); + } } } } + } else { + continue; } } @@ -966,6 +970,7 @@ where #[cfg(test)] mod tests { + use apalis_core::worker::Context; use apalis_core::{generic_storage_test, sleep}; use email_service::Email; @@ -1019,17 +1024,17 @@ mod tests { .clone() } - async fn register_worker_at(storage: &mut RedisStorage) -> WorkerId { - let worker = WorkerId::new("test-worker"); - + async fn register_worker_at(storage: &mut RedisStorage) -> Worker { + let worker = Worker::new(WorkerId::new("test-worker"), Context::default()); + worker.start(); storage - .keep_alive(&worker) + .keep_alive(&worker.id()) .await .expect("failed to register worker"); worker } - async fn register_worker(storage: &mut RedisStorage) -> WorkerId { + async fn register_worker(storage: &mut RedisStorage) -> Worker { register_worker_at(storage).await } @@ -1053,9 +1058,9 @@ mod tests { let mut storage = setup().await; push_email(&mut storage, example_email()).await; - let worker_id = register_worker(&mut storage).await; + let worker = register_worker(&mut storage).await; - let _job = consume_one(&mut storage, &worker_id).await; + let _job = consume_one(&mut storage, &worker.id()).await; } #[tokio::test] @@ -1063,9 +1068,9 @@ mod tests { let mut storage = setup().await; push_email(&mut storage, example_email()).await; - let worker_id = register_worker(&mut storage).await; + let worker = register_worker(&mut storage).await; - let job = consume_one(&mut storage, &worker_id).await; + let job = consume_one(&mut storage, &worker.id()).await; let ctx = &job.parts.context; let res = 42usize; storage @@ -1085,13 +1090,13 @@ mod tests { push_email(&mut storage, example_email()).await; - let worker_id = register_worker(&mut storage).await; + let worker = register_worker(&mut storage).await; - let job = consume_one(&mut storage, &worker_id).await; + let job = consume_one(&mut storage, &worker.id()).await; let job_id = &job.parts.task_id; storage - .kill(&worker_id, &job_id) + .kill(&worker.id(), &job_id) .await .expect("failed to kill job"); @@ -1104,9 +1109,9 @@ mod tests { push_email(&mut storage, example_email()).await; - let worker_id = register_worker_at(&mut storage).await; + let worker = register_worker_at(&mut storage).await; - let job = consume_one(&mut storage, &worker_id).await; + let job = consume_one(&mut storage, &worker.id()).await; sleep(Duration::from_millis(1000)).await; let dead_since = Utc::now() - chrono::Duration::from_std(Duration::from_secs(1)).unwrap(); let res = storage @@ -1132,9 +1137,9 @@ mod tests { push_email(&mut storage, example_email()).await; - let worker_id = register_worker_at(&mut storage).await; + let worker = register_worker_at(&mut storage).await; sleep(Duration::from_millis(1100)).await; - let job = consume_one(&mut storage, &worker_id).await; + let job = consume_one(&mut storage, &worker.id()).await; let dead_since = Utc::now() - chrono::Duration::from_std(Duration::from_secs(5)).unwrap(); let res = storage .reenqueue_orphaned(1, dead_since) diff --git a/packages/apalis-sql/src/mysql.rs b/packages/apalis-sql/src/mysql.rs index c925d0d..b391ec4 100644 --- a/packages/apalis-sql/src/mysql.rs +++ b/packages/apalis-sql/src/mysql.rs @@ -138,17 +138,22 @@ where { fn stream_jobs( self, - worker_id: &WorkerId, + worker: &Worker, interval: Duration, buffer_size: usize, ) -> impl Stream>, sqlx::Error>> { let pool = self.pool.clone(); - let worker_id = worker_id.to_string(); + let worker = worker.clone(); + let worker_id = worker.id().to_string(); + try_stream! { let buffer_size = u32::try_from(buffer_size) .map_err(|e| sqlx::Error::Io(io::Error::new(io::ErrorKind::InvalidInput, e)))?; loop { apalis_core::sleep(interval).await; + if !worker.is_ready() { + continue; + } let pool = pool.clone(); let job_type = self.config.namespace.clone(); let mut tx = pool.begin().await?; @@ -413,7 +418,7 @@ where let mut hb_storage = self.clone(); let requeue_storage = self.clone(); let stream = self - .stream_jobs(worker.id(), config.poll_interval, config.buffer_size) + .stream_jobs(worker, config.poll_interval, config.buffer_size) .map_err(|e| Error::SourceError(Arc::new(Box::new(e)))); let stream = BackendStream::new(stream.boxed(), controller); let w = worker.clone(); @@ -699,12 +704,11 @@ mod tests { async fn consume_one( storage: &mut MysqlStorage, - worker_id: &WorkerId, + worker: &Worker, ) -> Request { - let mut stream = - storage - .clone() - .stream_jobs(worker_id, std::time::Duration::from_secs(10), 1); + let mut stream = storage + .clone() + .stream_jobs(worker, std::time::Duration::from_secs(10), 1); stream .next() .await @@ -724,17 +728,18 @@ mod tests { async fn register_worker_at( storage: &mut MysqlStorage, last_seen: DateTime, - ) -> WorkerId { + ) -> Worker { let worker_id = WorkerId::new("test-worker"); - + let wrk = Worker::new(worker_id, Context::default()); + wrk.start(); storage - .keep_alive_at::(&worker_id, last_seen) + .keep_alive_at::(&wrk.id(), last_seen) .await .expect("failed to register worker"); - worker_id + wrk } - async fn register_worker(storage: &mut MysqlStorage) -> WorkerId { + async fn register_worker(storage: &mut MysqlStorage) -> Worker { let now = Utc::now(); register_worker_at(storage, now).await @@ -762,13 +767,13 @@ mod tests { let mut storage = setup().await; push_email(&mut storage, example_email()).await; - let worker_id = register_worker(&mut storage).await; + let worker = register_worker(&mut storage).await; - let job = consume_one(&mut storage, &worker_id).await; + let job = consume_one(&mut storage, &worker).await; let ctx = job.parts.context; // TODO: Fix assertions assert_eq!(*ctx.status(), State::Running); - assert_eq!(*ctx.lock_by(), Some(worker_id.clone())); + assert_eq!(*ctx.lock_by(), Some(worker.id().clone())); assert!(ctx.lock_at().is_some()); } @@ -778,14 +783,14 @@ mod tests { push_email(&mut storage, example_email()).await; - let worker_id = register_worker(&mut storage).await; + let worker = register_worker(&mut storage).await; - let job = consume_one(&mut storage, &worker_id).await; + let job = consume_one(&mut storage, &worker).await; let job_id = &job.parts.task_id; storage - .kill(&worker_id, job_id) + .kill(worker.id(), job_id) .await .expect("failed to kill job"); @@ -808,18 +813,19 @@ mod tests { // register a worker not responding since 6 minutes ago let worker_id = WorkerId::new("test-worker"); - + let worker = Worker::new(worker_id, Context::default()); + worker.start(); let five_minutes_ago = Utc::now() - Duration::from_secs(5 * 60); let six_minutes_ago = Utc::now() - Duration::from_secs(60 * 6); storage - .keep_alive_at::(&worker_id, six_minutes_ago) + .keep_alive_at::(worker.id(), six_minutes_ago) .await .unwrap(); // fetch job - let job = consume_one(&mut storage, &worker_id).await; + let job = consume_one(&mut storage, &worker).await; let ctx = job.parts.context; assert_eq!(*ctx.status(), State::Running); @@ -859,13 +865,15 @@ mod tests { let six_minutes_ago = Utc::now() - Duration::from_secs(6 * 60); let worker_id = WorkerId::new("test-worker"); + let worker = Worker::new(worker_id, Context::default()); + worker.start(); storage - .keep_alive_at::(&worker_id, four_minutes_ago) + .keep_alive_at::(&worker.id(), four_minutes_ago) .await .unwrap(); // fetch job - let job = consume_one(&mut storage, &worker_id).await; + let job = consume_one(&mut storage, &worker).await; let ctx = &job.parts.context; assert_eq!(*ctx.status(), State::Running); @@ -884,7 +892,7 @@ mod tests { .unwrap(); let ctx = job.parts.context; assert_eq!(*ctx.status(), State::Running); - assert_eq!(*ctx.lock_by(), Some(worker_id)); + assert_eq!(*ctx.lock_by(), Some(worker.id().clone())); assert!(ctx.lock_at().is_some()); assert_eq!(*ctx.last_error(), None); assert_eq!(job.parts.attempt.current(), 1); diff --git a/packages/apalis-sql/src/postgres.rs b/packages/apalis-sql/src/postgres.rs index 26664f2..77e96aa 100644 --- a/packages/apalis-sql/src/postgres.rs +++ b/packages/apalis-sql/src/postgres.rs @@ -253,9 +253,10 @@ where } } _ = poll_next_stm.next() => { - if let Err(e) = fetch_next_batch(&mut self, worker.id(), &mut tx).await { - worker.emit(Event::Error(Box::new(PgPollError::FetchNextError(e)))); - + if worker.is_ready() { + if let Err(e) = fetch_next_batch(&mut self, worker.id(), &mut tx).await { + worker.emit(Event::Error(Box::new(PgPollError::FetchNextError(e)))); + } } } _ = pg_notification.next() => { @@ -484,7 +485,7 @@ where .bind(args) .bind(req.parts.task_id.to_string()) .bind(&job_type) - .bind(&req.parts.context.max_attempts()) + .bind(req.parts.context.max_attempts()) .execute(&self.pool) .await?; Ok(req.parts) @@ -507,7 +508,7 @@ where .bind(job) .bind(task_id) .bind(job_type) - .bind(&parts.context.max_attempts()) + .bind(parts.context.max_attempts()) .bind(on) .execute(&self.pool) .await?; @@ -835,17 +836,19 @@ mod tests { async fn register_worker_at( storage: &mut PostgresStorage, last_seen: Timestamp, - ) -> WorkerId { + ) -> Worker { let worker_id = WorkerId::new("test-worker"); storage .keep_alive_at::(&worker_id, last_seen) .await .expect("failed to register worker"); - worker_id + let wrk = Worker::new(worker_id, Context::default()); + wrk.start(); + wrk } - async fn register_worker(storage: &mut PostgresStorage) -> WorkerId { + async fn register_worker(storage: &mut PostgresStorage) -> Worker { register_worker_at(storage, Utc::now().timestamp()).await } @@ -871,16 +874,16 @@ mod tests { let mut storage = setup().await; push_email(&mut storage, example_email()).await; - let worker_id = register_worker(&mut storage).await; + let worker = register_worker(&mut storage).await; - let job = consume_one(&mut storage, &worker_id).await; + let job = consume_one(&mut storage, &worker.id()).await; let job_id = &job.parts.task_id; // Refresh our job let job = get_job(&mut storage, job_id).await; let ctx = job.parts.context; assert_eq!(*ctx.status(), State::Running); - assert_eq!(*ctx.lock_by(), Some(worker_id.clone())); + assert_eq!(*ctx.lock_by(), Some(worker.id().clone())); assert!(ctx.lock_at().is_some()); } @@ -890,13 +893,13 @@ mod tests { push_email(&mut storage, example_email()).await; - let worker_id = register_worker(&mut storage).await; + let worker = register_worker(&mut storage).await; - let job = consume_one(&mut storage, &worker_id).await; + let job = consume_one(&mut storage, &worker.id()).await; let job_id = &job.parts.task_id; storage - .kill(&worker_id, job_id) + .kill(&worker.id(), job_id) .await .expect("failed to kill job"); @@ -914,9 +917,9 @@ mod tests { let six_minutes_ago = Utc::now() - Duration::from_secs(6 * 60); let five_minutes_ago = Utc::now() - Duration::from_secs(5 * 60); - let worker_id = register_worker_at(&mut storage, six_minutes_ago.timestamp()).await; + let worker = register_worker_at(&mut storage, six_minutes_ago.timestamp()).await; - let job = consume_one(&mut storage, &worker_id).await; + let job = consume_one(&mut storage, &worker.id()).await; storage .reenqueue_orphaned(1, five_minutes_ago) .await @@ -942,9 +945,9 @@ mod tests { let four_minutes_ago = Utc::now() - Duration::from_secs(4 * 60); let six_minutes_ago = Utc::now() - Duration::from_secs(6 * 60); - let worker_id = register_worker_at(&mut storage, four_minutes_ago.timestamp()).await; + let worker = register_worker_at(&mut storage, four_minutes_ago.timestamp()).await; - let job = consume_one(&mut storage, &worker_id).await; + let job = consume_one(&mut storage, &worker.id()).await; let ctx = &job.parts.context; assert_eq!(*ctx.status(), State::Running); @@ -957,7 +960,7 @@ mod tests { let job = get_job(&mut storage, job_id).await; let ctx = job.parts.context; assert_eq!(*ctx.status(), State::Running); - assert_eq!(*ctx.lock_by(), Some(worker_id)); + assert_eq!(*ctx.lock_by(), Some(worker.id().clone())); assert!(ctx.lock_at().is_some()); assert_eq!(*ctx.last_error(), None); assert_eq!(job.parts.attempt.current(), 0); diff --git a/packages/apalis-sql/src/sqlite.rs b/packages/apalis-sql/src/sqlite.rs index f562cc4..e794c71 100644 --- a/packages/apalis-sql/src/sqlite.rs +++ b/packages/apalis-sql/src/sqlite.rs @@ -178,16 +178,21 @@ where { fn stream_jobs( &self, - worker_id: &WorkerId, + worker: &Worker, interval: Duration, buffer_size: usize, ) -> impl Stream>, sqlx::Error>> { let pool = self.pool.clone(); - let worker_id = worker_id.clone(); + let worker = worker.clone(); let config = self.config.clone(); let namespace = Namespace(self.config.namespace.clone()); try_stream! { loop { + apalis_core::sleep(interval).await; + if !worker.is_ready() { + continue; + } + let worker_id = worker.id(); let tx = pool.clone(); let mut tx = tx.acquire().await?; let job_type = &config.namespace; @@ -201,7 +206,7 @@ where .fetch_all(&mut *tx) .await?; for id in ids { - let res = fetch_next(&pool, &worker_id, id.0, &config).await?; + let res = fetch_next(&pool, worker_id, id.0, &config).await?; yield match res { None => None::>, Some(job) => { @@ -214,7 +219,6 @@ where } } }; - apalis_core::sleep(interval).await; } } } @@ -244,7 +248,7 @@ where .bind(raw) .bind(parts.task_id.to_string()) .bind(job_type.to_string()) - .bind(&parts.context.max_attempts()) + .bind(parts.context.max_attempts()) .execute(&self.pool) .await?; Ok(parts) @@ -265,7 +269,7 @@ where .bind(job) .bind(id.to_string()) .bind(job_type) - .bind(&req.parts.context.max_attempts()) + .bind(req.parts.context.max_attempts()) .bind(on) .execute(&self.pool) .await?; @@ -472,7 +476,7 @@ impl let config = self.config.clone(); let controller = self.controller.clone(); let stream = self - .stream_jobs(worker.id(), config.poll_interval, config.buffer_size) + .stream_jobs(worker, config.poll_interval, config.buffer_size) .map_err(|e| Error::SourceError(Arc::new(Box::new(e)))); let stream = BackendStream::new(stream.boxed(), controller); let requeue_storage = self.clone(); @@ -661,10 +665,10 @@ mod tests { async fn consume_one( storage: &mut SqliteStorage, - worker_id: &WorkerId, + worker: &Worker, ) -> Request { let mut stream = storage - .stream_jobs(worker_id, std::time::Duration::from_secs(10), 1) + .stream_jobs(worker, std::time::Duration::from_secs(10), 1) .boxed(); stream .next() @@ -674,17 +678,22 @@ mod tests { .expect("no job is pending") } - async fn register_worker_at(storage: &mut SqliteStorage, last_seen: i64) -> WorkerId { + async fn register_worker_at( + storage: &mut SqliteStorage, + last_seen: i64, + ) -> Worker { let worker_id = WorkerId::new("test-worker"); storage .keep_alive_at::(&worker_id, last_seen) .await .expect("failed to register worker"); - worker_id + let wrk = Worker::new(worker_id, Context::default()); + wrk.start(); + wrk } - async fn register_worker(storage: &mut SqliteStorage) -> WorkerId { + async fn register_worker(storage: &mut SqliteStorage) -> Worker { register_worker_at(storage, Utc::now().timestamp()).await } @@ -706,26 +715,26 @@ mod tests { #[tokio::test] async fn test_consume_last_pushed_job() { let mut storage = setup().await; - let worker_id = register_worker(&mut storage).await; + let worker = register_worker(&mut storage).await; push_email(&mut storage, example_good_email()).await; let len = storage.len().await.expect("Could not fetch the jobs count"); assert_eq!(len, 1); - let job = consume_one(&mut storage, &worker_id).await; + let job = consume_one(&mut storage, &worker).await; let ctx = job.parts.context; assert_eq!(*ctx.status(), State::Running); - assert_eq!(*ctx.lock_by(), Some(worker_id.clone())); + assert_eq!(*ctx.lock_by(), Some(worker.id().clone())); assert!(ctx.lock_at().is_some()); } #[tokio::test] async fn test_acknowledge_job() { let mut storage = setup().await; - let worker_id = register_worker(&mut storage).await; + let worker = register_worker(&mut storage).await; push_email(&mut storage, example_good_email()).await; - let job = consume_one(&mut storage, &worker_id).await; + let job = consume_one(&mut storage, &worker).await; let job_id = &job.parts.task_id; let ctx = &job.parts.context; let res = 1usize; @@ -749,13 +758,13 @@ mod tests { push_email(&mut storage, example_good_email()).await; - let worker_id = register_worker(&mut storage).await; + let worker = register_worker(&mut storage).await; - let job = consume_one(&mut storage, &worker_id).await; + let job = consume_one(&mut storage, &worker).await; let job_id = &job.parts.task_id; storage - .kill(&worker_id, job_id) + .kill(&worker.id(), job_id) .await .expect("failed to kill job"); @@ -774,9 +783,9 @@ mod tests { let six_minutes_ago = Utc::now() - Duration::from_secs(6 * 60); let five_minutes_ago = Utc::now() - Duration::from_secs(5 * 60); - let worker_id = register_worker_at(&mut storage, six_minutes_ago.timestamp()).await; + let worker = register_worker_at(&mut storage, six_minutes_ago.timestamp()).await; - let job = consume_one(&mut storage, &worker_id).await; + let job = consume_one(&mut storage, &worker).await; let job_id = &job.parts.task_id; storage .reenqueue_orphaned(1, five_minutes_ago) @@ -800,9 +809,9 @@ mod tests { let six_minutes_ago = Utc::now() - Duration::from_secs(6 * 60); let four_minutes_ago = Utc::now() - Duration::from_secs(4 * 60); - let worker_id = register_worker_at(&mut storage, four_minutes_ago.timestamp()).await; + let worker = register_worker_at(&mut storage, four_minutes_ago.timestamp()).await; - let job = consume_one(&mut storage, &worker_id).await; + let job = consume_one(&mut storage, &worker).await; let job_id = &job.parts.task_id; storage .reenqueue_orphaned(1, six_minutes_ago) @@ -812,7 +821,7 @@ mod tests { let job = get_job(&mut storage, job_id).await; let ctx = &job.parts.context; assert_eq!(*ctx.status(), State::Running); - assert_eq!(*ctx.lock_by(), Some(worker_id)); + assert_eq!(*ctx.lock_by(), Some(worker.id().clone())); assert!(ctx.lock_at().is_some()); assert_eq!(*ctx.last_error(), None); assert_eq!(job.parts.attempt.current(), 1);