diff --git a/rexecutor-sqlx/src/backend.rs b/rexecutor-sqlx/src/backend.rs index 0bbe224..a304020 100644 --- a/rexecutor-sqlx/src/backend.rs +++ b/rexecutor-sqlx/src/backend.rs @@ -13,7 +13,7 @@ use rexecutor::{ use tokio::sync::mpsc; use tracing::instrument; -use crate::{stream::ReadyJobStream, RexecutorPgBackend}; +use crate::{stream::ReadyJobStream, RexecutorPgBackend, map_err}; impl RexecutorPgBackend { fn handle_update(result: sqlx::Result, job_id: JobId) -> Result<(), BackendError> { @@ -21,8 +21,7 @@ impl RexecutorPgBackend { Ok(0) => Err(BackendError::JobNotFound(job_id)), Ok(1) => Ok(()), Ok(_) => Err(BackendError::BadState), - // TODO fix this - Err(_error) => Err(BackendError::BadState), + Err(error) => Err(map_err(error)), } } } @@ -59,8 +58,7 @@ impl Backend for RexecutorPgBackend { } else { self.insert_job(job).await } - // TODO error handling - .map_err(|_| BackendError::BadState) + .map_err(map_err) } async fn mark_job_complete(&self, id: JobId) -> Result<(), BackendError> { let result = self._mark_job_complete(id).await; @@ -102,7 +100,7 @@ impl Backend for RexecutorPgBackend { async fn prune_jobs(&self, spec: &PruneSpec) -> Result<(), BackendError> { self.delete_from_spec(spec) .await - .map_err(|_| BackendError::BadState) + .map_err(map_err) } async fn rerun_job(&self, id: JobId) -> Result<(), BackendError> { let result = self.rerun(id).await; @@ -116,7 +114,7 @@ impl Backend for RexecutorPgBackend { async fn query<'a>(&self, query: Query<'a>) -> Result, BackendError> { self.run_query(query) .await - .map_err(|_| BackendError::BadState)? + .map_err(map_err)? .into_iter() .map(TryFrom::try_from) .collect() @@ -216,8 +214,6 @@ mod test { backend: RexecutorPgBackend::from_pool(pool).await.unwrap() ); - // TODO: add tests for ignoring running, cancelled, complete, and discarded jobs - #[sqlx::test] async fn load_job_mark_as_executing_for_executor_returns_none_when_db_empty(pool: PgPool) { let backend: RexecutorPgBackend = pool.into(); diff --git a/rexecutor-sqlx/src/lib.rs b/rexecutor-sqlx/src/lib.rs index 75a501e..3cf8642 100644 --- a/rexecutor-sqlx/src/lib.rs +++ b/rexecutor-sqlx/src/lib.rs @@ -42,13 +42,25 @@ struct Notification { use types::*; +fn map_err(error: sqlx::Error) -> BackendError { + match error { + sqlx::Error::Io(err) => BackendError::Io(err), + sqlx::Error::Tls(err) => BackendError::Io(std::io::Error::other(err)), + sqlx::Error::Protocol(err) => BackendError::Io(std::io::Error::other(err)), + sqlx::Error::AnyDriverError(err) => BackendError::Io(std::io::Error::other(err)), + sqlx::Error::PoolTimedOut => BackendError::Io(std::io::Error::other(error)), + sqlx::Error::PoolClosed => BackendError::Io(std::io::Error::other(error)), + _ => BackendError::BadState, + } +} + impl RexecutorPgBackend { /// Creates a new [`RexecutorPgBackend`] from a db connection string. pub async fn from_db_url(db_url: &str) -> Result { let pool = PgPoolOptions::new() .connect(db_url) .await - .map_err(|_| BackendError::BadState)?; + .map_err(map_err)?; Self::from_pool(pool).await } /// Create a new [`RexecutorPgBackend`] from an existing [`PgPool`]. @@ -59,11 +71,11 @@ impl RexecutorPgBackend { }; let mut listener = PgListener::connect_with(&this.pool) .await - .map_err(|_| BackendError::BadState)?; + .map_err(map_err)?; listener .listen("public.rexecutor_scheduled") .await - .map_err(|_| BackendError::BadState)?; + .map_err(map_err)?; tokio::spawn({ let subscribers = this.subscribers.clone(); @@ -97,7 +109,7 @@ impl RexecutorPgBackend { sqlx::migrate!() .run(&self.pool) .await - .map_err(|_| BackendError::BadState) + .map_err(|err| BackendError::Io(std::io::Error::other(err))) } async fn load_job_mark_as_executing_for_executor( diff --git a/rexecutor/src/backend.rs b/rexecutor/src/backend.rs index 364aa58..b9a6863 100644 --- a/rexecutor/src/backend.rs +++ b/rexecutor/src/backend.rs @@ -406,7 +406,10 @@ pub enum BackendError { /// No jobs was found matching the criteria provided. #[error("Job not found: {0}")] JobNotFound(JobId), - // TODO do we need some sort of IO error here + /// There was an error doing IO with the backend + #[error(transparent)] + Io(std::io::Error), + } #[cfg(test)]