From 37d2f0c761cee88e77c2f20494057927bcbefa86 Mon Sep 17 00:00:00 2001 From: Baris Can Durak Date: Mon, 16 Dec 2024 01:12:57 +0100 Subject: [PATCH] final fixes --- src/zenml/models/v2/base/scoped.py | 4 ++++ src/zenml/models/v2/core/artifact.py | 16 +++++++++------- src/zenml/models/v2/core/model.py | 16 +++++++++------- src/zenml/models/v2/core/pipeline.py | 16 +++++++++------- src/zenml/models/v2/core/pipeline_run.py | 2 ++ 5 files changed, 33 insertions(+), 21 deletions(-) diff --git a/src/zenml/models/v2/base/scoped.py b/src/zenml/models/v2/base/scoped.py index 573116a19e..00c4176988 100644 --- a/src/zenml/models/v2/base/scoped.py +++ b/src/zenml/models/v2/base/scoped.py @@ -245,6 +245,8 @@ def apply_sorting( UserSchema, getattr(table, "user_id") == UserSchema.id ) + query = query.add_columns(UserSchema.name) + if operand == SorterOps.ASCENDING: query = query.order_by(asc(column)) else: @@ -449,6 +451,8 @@ def apply_sorting( getattr(table, "workspace_id") == WorkspaceSchema.id, ) + query = query.add_columns(WorkspaceSchema.name) + if operand == SorterOps.ASCENDING: query = query.order_by(asc(column)) else: diff --git a/src/zenml/models/v2/core/artifact.py b/src/zenml/models/v2/core/artifact.py index 8cebcba6e4..bd87395b17 100644 --- a/src/zenml/models/v2/core/artifact.py +++ b/src/zenml/models/v2/core/artifact.py @@ -208,7 +208,7 @@ def apply_sorting( Returns: The query with sorting applied. """ - from sqlmodel import asc, case, desc, func, select + from sqlmodel import asc, case, col, desc, func, select from zenml.enums import SorterOps from zenml.zen_stores.schemas import ( @@ -222,7 +222,7 @@ def apply_sorting( # Subquery to find the latest version per artifact latest_version_subquery = ( select( - ArtifactVersionSchema.artifact_id, + ArtifactSchema.id, case( ( func.max(ArtifactVersionSchema.created).is_(None), @@ -231,14 +231,16 @@ def apply_sorting( else_=func.max(ArtifactVersionSchema.created), ).label("latest_version_created"), ) - .group_by(ArtifactVersionSchema.artifact_id) + .outerjoin( + ArtifactVersionSchema, + ArtifactSchema.id == ArtifactVersionSchema.artifact_id, # type: ignore[arg-type] + ) + .group_by(col(ArtifactSchema.id)) .subquery() ) - # Join the subquery with the main artifacts query - query = query.outerjoin( - latest_version_subquery, - ArtifactSchema.id == latest_version_subquery.c.artifact_id, + query = query.add_columns( + latest_version_subquery.c.latest_version_created, ) # Apply sorting based on the operand diff --git a/src/zenml/models/v2/core/model.py b/src/zenml/models/v2/core/model.py index e8b781ecd2..5669a9d223 100644 --- a/src/zenml/models/v2/core/model.py +++ b/src/zenml/models/v2/core/model.py @@ -349,7 +349,7 @@ def apply_sorting( Returns: The query with sorting applied. """ - from sqlmodel import asc, case, desc, func, select + from sqlmodel import asc, case, col, desc, func, select from zenml.enums import SorterOps from zenml.zen_stores.schemas import ( @@ -363,7 +363,7 @@ def apply_sorting( # Subquery to find the latest version per model latest_version_subquery = ( select( - ModelVersionSchema.model_id, + ModelSchema.id, case( ( func.max(ModelVersionSchema.created).is_(None), @@ -372,14 +372,16 @@ def apply_sorting( else_=func.max(ModelVersionSchema.created), ).label("latest_version_created"), ) - .group_by(ModelVersionSchema.model_id) + .outerjoin( + ModelVersionSchema, + ModelSchema.id == ModelVersionSchema.model_id, # type: ignore[arg-type] + ) + .group_by(col(ModelSchema.id)) .subquery() ) - # Join the subquery with the main artifacts query - query = query.outerjoin( - latest_version_subquery, - ModelSchema.id == latest_version_subquery.c.model_id, + query = query.add_columns( + latest_version_subquery.c.latest_version_created, ) # Apply sorting based on the operand diff --git a/src/zenml/models/v2/core/pipeline.py b/src/zenml/models/v2/core/pipeline.py index 707e9a8610..7cfd4109a9 100644 --- a/src/zenml/models/v2/core/pipeline.py +++ b/src/zenml/models/v2/core/pipeline.py @@ -346,7 +346,7 @@ def apply_sorting( Returns: The query with sorting applied. """ - from sqlmodel import asc, case, desc, func, select + from sqlmodel import asc, case, col, desc, func, select from zenml.enums import SorterOps from zenml.zen_stores.schemas import PipelineRunSchema, PipelineSchema @@ -357,7 +357,7 @@ def apply_sorting( # Subquery to find the latest run per pipeline latest_run_subquery = ( select( - PipelineRunSchema.pipeline_id, + PipelineSchema.id, case( ( func.max(PipelineRunSchema.created).is_(None), @@ -366,14 +366,16 @@ def apply_sorting( else_=func.max(PipelineRunSchema.created), ).label("latest_run"), ) - .group_by(PipelineRunSchema.pipeline_id) + .outerjoin( + PipelineRunSchema, + PipelineSchema.id == PipelineRunSchema.pipeline_id, # type: ignore[arg-type] + ) + .group_by(col(PipelineSchema.id)) .subquery() ) - # Join the subquery with the pipelines - query = query.outerjoin( - latest_run_subquery, - PipelineSchema.id == latest_run_subquery.c.pipeline_id, + query = query.add_columns( + latest_run_subquery.c.latest_run, ) if operand == SorterOps.ASCENDING: diff --git a/src/zenml/models/v2/core/pipeline_run.py b/src/zenml/models/v2/core/pipeline_run.py index 3a22f64295..740fbc6711 100644 --- a/src/zenml/models/v2/core/pipeline_run.py +++ b/src/zenml/models/v2/core/pipeline_run.py @@ -982,6 +982,8 @@ def apply_sorting( else: return super().apply_sorting(query=query, table=table) + query = query.add_columns(column) + if operand == SorterOps.ASCENDING: query = query.order_by(asc(column)) else: