Skip to content

Commit

Permalink
refactor(db): Convert job run DAO to be usable with query_paged_as (#…
Browse files Browse the repository at this point in the history
…1670)

* refactor(db): Convert job run DAO to be usable with query_paged_as

* Rename DbOption to DbNullable

* [autofix.ci] apply automated fixes

* Use lowercase as in rust

* Add custom trait constraining which types can be used for DbNullable

* Apply suggestions

---------

Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com>
  • Loading branch information
boxbeam and autofix-ci[bot] authored Mar 21, 2024
1 parent 47fd912 commit 57a9187
Show file tree
Hide file tree
Showing 4 changed files with 131 additions and 60 deletions.
66 changes: 31 additions & 35 deletions ee/tabby-db-macros/src/lib.rs
Original file line number Diff line number Diff line change
@@ -1,41 +1,44 @@
use proc_macro::TokenStream;
use quote::quote;
use syn::{bracketed, parse::Parse, parse_macro_input, Ident, LitStr, Token, Type};
use syn::{bracketed, parse::Parse, parse_macro_input, Expr, Ident, LitStr, Token, Type};

#[derive(Clone)]
struct Column {
name: LitStr,
non_null: bool,
rename: LitStr,
}

impl Parse for Column {
fn parse(input: syn::parse::ParseStream) -> syn::Result<Self> {
let name = input.parse()?;
let name: LitStr = input.parse()?;
let non_null = input.peek(Token![!]);
if non_null {
input.parse::<Token![!]>()?;
}
Ok(Column { name, non_null })
let mut rename = None;
if input.peek(Token![as]) {
input.parse::<Token![as]>()?;
rename = Some(input.parse()?);
}
Ok(Column {
rename: rename.unwrap_or(name.clone()),
name,
non_null,
})
}
}

struct PaginationQueryInput {
pub typ: Type,
pub table_name: LitStr,
pub columns: Vec<Column>,
pub condition: Option<LitStr>,
pub condition: Option<Expr>,
pub limit: Ident,
pub skip_id: Ident,
pub backwards: Ident,
}

mod kw {
use syn::custom_keyword;

custom_keyword!(FROM);
custom_keyword!(WHERE);
}

impl Parse for PaginationQueryInput {
fn parse(input: syn::parse::ParseStream) -> syn::Result<Self> {
let typ = input.parse()?;
Expand All @@ -54,19 +57,19 @@ impl Parse for PaginationQueryInput {
columns.push(inner.parse()?);
}

let mut condition = None;
if input.peek(kw::WHERE) {
input.parse::<kw::WHERE>()?;
condition = Some(input.parse()?);
}

input.parse::<Token![,]>()?;
let limit = input.parse()?;
input.parse::<Token![,]>()?;
let skip_id = input.parse()?;
input.parse::<Token![,]>()?;
let backwards = input.parse()?;

let mut condition = None;
if input.peek(Token![,]) {
input.parse::<Token![,]>()?;
condition = Some(input.parse()?);
}

Ok(PaginationQueryInput {
typ,
table_name,
Expand All @@ -90,32 +93,25 @@ pub fn query_paged_as(input: TokenStream) -> TokenStream {
.iter()
.map(|col| {
let name = col.name.value();
if col.non_null {
format!("{name} as \"{name}!\"")
} else {
name
}
let rename = col.rename.value();
let non_null = col.non_null.then_some("!").unwrap_or_default();
format!("{name} AS '{rename}{non_null}'")
})
.collect::<Vec<_>>()
.join(", ");
let column_args: Vec<String> = input.columns.iter().map(|col| col.name.value()).collect();
let where_clause = input
.condition
.clone()
.map(|cond| format!("WHERE {}", cond.value()))
.unwrap_or_default();
let limit = input.limit;
let condition = match input.condition {
Some(cond) => quote! {Some(#cond.into())},
Some(cond) => quote! {#cond},
None => quote! {None},
};

let limit = input.limit;
let skip_id = input.skip_id;
let backwards = input.backwards;
quote! {
sqlx::query_as(&crate::make_pagination_query_with_condition({
let _ = sqlx::query_as!(#typ, "SELECT " + #columns + " FROM " + #table_name + #where_clause);
&#table_name
}, &[ #(#column_args),* ], #limit, #skip_id, #backwards, #condition))
}.into()
sqlx::query_as(&crate::make_pagination_query_with_condition({
let _ = sqlx::query_as!(#typ, "SELECT " + #columns + " FROM " + #table_name);
&#table_name
}, &[ #(#column_args),* ], #limit, #skip_id, #backwards, #condition))
}
.into()
}
36 changes: 19 additions & 17 deletions ee/tabby-db/src/job_runs.rs
Original file line number Diff line number Diff line change
@@ -1,23 +1,23 @@
use anyhow::Result;
use chrono::{DateTime, Utc};
use sqlx::{query, FromRow};
use tabby_db_macros::query_paged_as;

use super::DbConn;
use crate::make_pagination_query_with_condition;
use crate::{DateTimeUtc, DbOption};

#[derive(Default, Clone, FromRow)]
pub struct JobRunDAO {
pub id: i32,
pub id: i64,
#[sqlx(rename = "job")]
pub name: String,
pub exit_code: Option<i32>,
pub exit_code: Option<i64>,
pub stdout: String,
pub stderr: String,
pub created_at: DateTime<Utc>,
pub updated_at: DateTime<Utc>,
pub created_at: DateTimeUtc,
pub updated_at: DateTimeUtc,

#[sqlx(rename = "end_ts")]
pub finished_at: Option<DateTime<Utc>>,
pub finished_at: DbOption<DateTimeUtc>,
}

/// db read/write operations for `job_runs` table
Expand Down Expand Up @@ -72,26 +72,28 @@ impl DbConn {
} else {
None
};
let query = make_pagination_query_with_condition(
let job_runs: Vec<JobRunDAO> = query_paged_as!(
JobRunDAO,
"job_runs",
&[
[
"id",
"job",
"job" as "name",
"exit_code",
"stdout",
"stderr",
"created_at",
"updated_at",
"end_ts",
"created_at"!,
"updated_at"!,
"end_ts" as "finished_at"
],
limit,
skip_id,
backwards,
condition,
);
condition
)
.fetch_all(&self.pool)
.await?;

let runs = sqlx::query_as(&query).fetch_all(&self.pool).await?;
Ok(runs)
Ok(job_runs)
}

pub async fn cleanup_stale_job_runs(&self) -> Result<()> {
Expand Down
81 changes: 77 additions & 4 deletions ee/tabby-db/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,8 @@ pub use oauth_credential::OAuthCredentialDAO;
pub use repositories::RepositoryDAO;
pub use server_setting::ServerSettingDAO;
use sqlx::{
query, query_scalar, sqlite::SqliteQueryResult, Pool, Sqlite, SqlitePool, Type, Value, ValueRef,
database::HasValueRef, query, query_scalar, sqlite::SqliteQueryResult, Decode, Encode, Pool,
Sqlite, SqlitePool, Type, Value, ValueRef,
};
pub use users::UserDAO;

Expand Down Expand Up @@ -187,6 +188,72 @@ impl DbConn {
}
}

pub trait DbNullable:
for<'a> Decode<'a, Sqlite> + for<'a> Encode<'a, Sqlite> + Type<Sqlite>
{
}
impl DbNullable for DateTimeUtc {}

#[derive(Default)]
pub struct DbOption<T>(Option<T>)
where
T: DbNullable;

impl<T> Type<Sqlite> for DbOption<T>
where
T: Type<Sqlite> + DbNullable,
{
fn type_info() -> <Sqlite as sqlx::Database>::TypeInfo {
T::type_info()
}
}

impl<'a, T> Decode<'a, Sqlite> for DbOption<T>
where
T: DbNullable,
{
fn decode(
value: <Sqlite as HasValueRef<'a>>::ValueRef,
) -> std::prelude::v1::Result<Self, sqlx::error::BoxDynError> {
if value.is_null() {
Ok(Self(None))
} else {
Ok(Self(Some(T::decode(value)?)))
}
}
}

impl<T, F> From<Option<F>> for DbOption<T>
where
T: From<F> + DbNullable,
{
fn from(value: Option<F>) -> Self {
DbOption(value.map(|v| T::from(v)))
}
}

impl<T> DbOption<T>
where
T: DbNullable,
{
pub fn into_option<V>(self) -> Option<V>
where
T: Into<V>,
{
self.0.map(Into::into)
}
}

impl<T> Clone for DbOption<T>
where
T: Clone + DbNullable,
{
fn clone(&self) -> Self {
self.0.clone().into()
}
}

#[derive(Default, Clone)]
pub struct DateTimeUtc(DateTime<Utc>);

impl From<DateTime<Utc>> for DateTimeUtc {
Expand All @@ -195,9 +262,15 @@ impl From<DateTime<Utc>> for DateTimeUtc {
}
}

impl<'a> sqlx::Decode<'a, Sqlite> for DateTimeUtc {
impl From<DateTimeUtc> for DateTime<Utc> {
fn from(val: DateTimeUtc) -> Self {
*val
}
}

impl<'a> Decode<'a, Sqlite> for DateTimeUtc {
fn decode(
value: <Sqlite as sqlx::database::HasValueRef<'a>>::ValueRef,
value: <Sqlite as HasValueRef<'a>>::ValueRef,
) -> std::prelude::v1::Result<Self, sqlx::error::BoxDynError> {
let time: NaiveDateTime = value.to_owned().decode();
Ok(time.into())
Expand All @@ -210,7 +283,7 @@ impl Type<Sqlite> for DateTimeUtc {
}
}

impl<'a> sqlx::Encode<'a, Sqlite> for DateTimeUtc {
impl<'a> Encode<'a, Sqlite> for DateTimeUtc {
fn encode_by_ref(
&self,
buf: &mut <Sqlite as sqlx::database::HasArguments<'a>>::ArgumentBuffer,
Expand Down
8 changes: 4 additions & 4 deletions ee/tabby-webserver/src/service/dao.rs
Original file line number Diff line number Diff line change
Expand Up @@ -31,10 +31,10 @@ impl From<JobRunDAO> for job::JobRun {
Self {
id: run.id.as_id(),
job: run.name,
created_at: run.created_at,
updated_at: run.updated_at,
finished_at: run.finished_at,
exit_code: run.exit_code,
created_at: *run.created_at,
updated_at: *run.updated_at,
finished_at: run.finished_at.into_option(),
exit_code: run.exit_code.map(|i| i as i32),
stdout: run.stdout,
stderr: run.stderr,
}
Expand Down

0 comments on commit 57a9187

Please sign in to comment.