Skip to content

Commit

Permalink
final fixes
Browse files Browse the repository at this point in the history
  • Loading branch information
bcdurak committed Dec 16, 2024
1 parent 1db4a83 commit 37d2f0c
Show file tree
Hide file tree
Showing 5 changed files with 33 additions and 21 deletions.
4 changes: 4 additions & 0 deletions src/zenml/models/v2/base/scoped.py
Original file line number Diff line number Diff line change
Expand Up @@ -245,6 +245,8 @@ def apply_sorting(
UserSchema, getattr(table, "user_id") == UserSchema.id
)

query = query.add_columns(UserSchema.name)

if operand == SorterOps.ASCENDING:
query = query.order_by(asc(column))
else:
Expand Down Expand Up @@ -449,6 +451,8 @@ def apply_sorting(
getattr(table, "workspace_id") == WorkspaceSchema.id,
)

query = query.add_columns(WorkspaceSchema.name)

if operand == SorterOps.ASCENDING:
query = query.order_by(asc(column))
else:
Expand Down
16 changes: 9 additions & 7 deletions src/zenml/models/v2/core/artifact.py
Original file line number Diff line number Diff line change
Expand Up @@ -208,7 +208,7 @@ def apply_sorting(
Returns:
The query with sorting applied.
"""
from sqlmodel import asc, case, desc, func, select
from sqlmodel import asc, case, col, desc, func, select

from zenml.enums import SorterOps
from zenml.zen_stores.schemas import (
Expand All @@ -222,7 +222,7 @@ def apply_sorting(
# Subquery to find the latest version per artifact
latest_version_subquery = (
select(
ArtifactVersionSchema.artifact_id,
ArtifactSchema.id,
case(
(
func.max(ArtifactVersionSchema.created).is_(None),
Expand All @@ -231,14 +231,16 @@ def apply_sorting(
else_=func.max(ArtifactVersionSchema.created),
).label("latest_version_created"),
)
.group_by(ArtifactVersionSchema.artifact_id)
.outerjoin(
ArtifactVersionSchema,
ArtifactSchema.id == ArtifactVersionSchema.artifact_id, # type: ignore[arg-type]
)
.group_by(col(ArtifactSchema.id))
.subquery()
)

# Join the subquery with the main artifacts query
query = query.outerjoin(
latest_version_subquery,
ArtifactSchema.id == latest_version_subquery.c.artifact_id,
query = query.add_columns(
latest_version_subquery.c.latest_version_created,
)

# Apply sorting based on the operand
Expand Down
16 changes: 9 additions & 7 deletions src/zenml/models/v2/core/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -349,7 +349,7 @@ def apply_sorting(
Returns:
The query with sorting applied.
"""
from sqlmodel import asc, case, desc, func, select
from sqlmodel import asc, case, col, desc, func, select

from zenml.enums import SorterOps
from zenml.zen_stores.schemas import (
Expand All @@ -363,7 +363,7 @@ def apply_sorting(
# Subquery to find the latest version per model
latest_version_subquery = (
select(
ModelVersionSchema.model_id,
ModelSchema.id,
case(
(
func.max(ModelVersionSchema.created).is_(None),
Expand All @@ -372,14 +372,16 @@ def apply_sorting(
else_=func.max(ModelVersionSchema.created),
).label("latest_version_created"),
)
.group_by(ModelVersionSchema.model_id)
.outerjoin(
ModelVersionSchema,
ModelSchema.id == ModelVersionSchema.model_id, # type: ignore[arg-type]
)
.group_by(col(ModelSchema.id))
.subquery()
)

# Join the subquery with the main artifacts query
query = query.outerjoin(
latest_version_subquery,
ModelSchema.id == latest_version_subquery.c.model_id,
query = query.add_columns(
latest_version_subquery.c.latest_version_created,
)

# Apply sorting based on the operand
Expand Down
16 changes: 9 additions & 7 deletions src/zenml/models/v2/core/pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -346,7 +346,7 @@ def apply_sorting(
Returns:
The query with sorting applied.
"""
from sqlmodel import asc, case, desc, func, select
from sqlmodel import asc, case, col, desc, func, select

from zenml.enums import SorterOps
from zenml.zen_stores.schemas import PipelineRunSchema, PipelineSchema
Expand All @@ -357,7 +357,7 @@ def apply_sorting(
# Subquery to find the latest run per pipeline
latest_run_subquery = (
select(
PipelineRunSchema.pipeline_id,
PipelineSchema.id,
case(
(
func.max(PipelineRunSchema.created).is_(None),
Expand All @@ -366,14 +366,16 @@ def apply_sorting(
else_=func.max(PipelineRunSchema.created),
).label("latest_run"),
)
.group_by(PipelineRunSchema.pipeline_id)
.outerjoin(
PipelineRunSchema,
PipelineSchema.id == PipelineRunSchema.pipeline_id, # type: ignore[arg-type]
)
.group_by(col(PipelineSchema.id))
.subquery()
)

# Join the subquery with the pipelines
query = query.outerjoin(
latest_run_subquery,
PipelineSchema.id == latest_run_subquery.c.pipeline_id,
query = query.add_columns(
latest_run_subquery.c.latest_run,
)

if operand == SorterOps.ASCENDING:
Expand Down
2 changes: 2 additions & 0 deletions src/zenml/models/v2/core/pipeline_run.py
Original file line number Diff line number Diff line change
Expand Up @@ -982,6 +982,8 @@ def apply_sorting(
else:
return super().apply_sorting(query=query, table=table)

query = query.add_columns(column)

if operand == SorterOps.ASCENDING:
query = query.order_by(asc(column))
else:
Expand Down

0 comments on commit 37d2f0c

Please sign in to comment.