Skip to content

Commit

Permalink
first fixes
Browse files Browse the repository at this point in the history
  • Loading branch information
bcdurak committed Dec 15, 2024
1 parent d6fae4e commit 1db4a83
Show file tree
Hide file tree
Showing 6 changed files with 305 additions and 97 deletions.
1 change: 1 addition & 0 deletions src/zenml/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -429,6 +429,7 @@ def handle_int_env_var(var: str, default: int = 0) -> int:
)
FILTERING_DATETIME_FORMAT: str = "%Y-%m-%d %H:%M:%S"
SORT_PIPELINES_BY_LATEST_RUN_KEY = "latest_run"
SORT_BY_LATEST_VERSION_KEY = "latest_version"

# Metadata constants
METADATA_ORCHESTRATOR_URL = "orchestrator_url"
Expand Down
86 changes: 0 additions & 86 deletions src/zenml/models/v2/base/scoped.py
Original file line number Diff line number Diff line change
Expand Up @@ -466,16 +466,6 @@ class WorkspaceScopedTaggableFilter(WorkspaceScopedFilter):
description="Tag to apply to the filter query.", default=None
)

FILTER_EXCLUDE_FIELDS: ClassVar[List[str]] = [
*WorkspaceScopedFilter.FILTER_EXCLUDE_FIELDS,
"tag",
]

CUSTOM_SORTING_OPTIONS: ClassVar[List[str]] = [
*WorkspaceScopedFilter.CUSTOM_SORTING_OPTIONS,
"tag",
]

def apply_filter(
self,
query: AnyQuery,
Expand Down Expand Up @@ -524,79 +514,3 @@ def get_custom_filters(
)

return custom_filters

def apply_sorting(
self,
query: AnyQuery,
table: Type["AnySchema"],
) -> AnyQuery:
"""Apply sorting to the query.
Args:
query: The query to which to apply the sorting.
table: The query table.
Returns:
The query with sorting applied.
"""
sort_by, operand = self.sorting_params

if sort_by == "tag":
from sqlmodel import and_, asc, desc, func

from zenml.enums import SorterOps, TaggableResourceTypes
from zenml.zen_stores.schemas import (
ArtifactSchema,
ArtifactVersionSchema,
ModelSchema,
ModelVersionSchema,
PipelineRunSchema,
PipelineSchema,
RunTemplateSchema,
TagResourceSchema,
TagSchema,
)

resource_type_mapping = {
ArtifactSchema: TaggableResourceTypes.ARTIFACT,
ArtifactVersionSchema: TaggableResourceTypes.ARTIFACT_VERSION,
ModelSchema: TaggableResourceTypes.MODEL,
ModelVersionSchema: TaggableResourceTypes.MODEL_VERSION,
PipelineSchema: TaggableResourceTypes.PIPELINE,
PipelineRunSchema: TaggableResourceTypes.PIPELINE_RUN,
RunTemplateSchema: TaggableResourceTypes.RUN_TEMPLATE,
}

query = (
query.outerjoin(
TagResourceSchema,
and_(
table.id == TagResourceSchema.resource_id,
TagResourceSchema.resource_type
== resource_type_mapping[table],
),
)
.outerjoin(TagSchema, TagResourceSchema.tag_id == TagSchema.id)
.group_by(table.id)
)

if operand == SorterOps.ASCENDING:
query = query.order_by(
asc(
func.group_concat(TagSchema.name, ",").label(
"tags_list"
)
)
)
else:
query = query.order_by(
desc(
func.group_concat(TagSchema.name, ",").label(
"tags_list"
)
)
)

return query

return super().apply_sorting(query=query, table=table)
86 changes: 84 additions & 2 deletions src/zenml/models/v2/core/artifact.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,12 +13,21 @@
# permissions and limitations under the License.
"""Models representing artifacts."""

from typing import TYPE_CHECKING, Dict, List, Optional
from typing import (
TYPE_CHECKING,
Any,
ClassVar,
Dict,
List,
Optional,
Type,
TypeVar,
)
from uuid import UUID

from pydantic import BaseModel, Field

from zenml.constants import STR_FIELD_MAX_LENGTH
from zenml.constants import SORT_BY_LATEST_VERSION_KEY, STR_FIELD_MAX_LENGTH
from zenml.models.v2.base.base import (
BaseDatedResponseBody,
BaseIdentifiedResponse,
Expand All @@ -31,6 +40,11 @@

if TYPE_CHECKING:
from zenml.models.v2.core.artifact_version import ArtifactVersionResponse
from zenml.zen_stores.schemas import BaseSchema

AnySchema = TypeVar("AnySchema", bound=BaseSchema)

AnyQuery = TypeVar("AnyQuery", bound=Any)

# ------------------ Request Model ------------------

Expand Down Expand Up @@ -174,3 +188,71 @@ class ArtifactFilter(WorkspaceScopedTaggableFilter):

name: Optional[str] = None
has_custom_name: Optional[bool] = None

CUSTOM_SORTING_OPTIONS: ClassVar[List[str]] = [
*WorkspaceScopedTaggableFilter.CUSTOM_SORTING_OPTIONS,
SORT_BY_LATEST_VERSION_KEY,
]

def apply_sorting(
self,
query: AnyQuery,
table: Type["AnySchema"],
) -> AnyQuery:
"""Apply sorting to the query for Artifacts.
Args:
query: The query to which to apply the sorting.
table: The query table.
Returns:
The query with sorting applied.
"""
from sqlmodel import asc, case, desc, func, select

from zenml.enums import SorterOps
from zenml.zen_stores.schemas import (
ArtifactSchema,
ArtifactVersionSchema,
)

sort_by, operand = self.sorting_params

if sort_by == SORT_BY_LATEST_VERSION_KEY:
# Subquery to find the latest version per artifact
latest_version_subquery = (
select(
ArtifactVersionSchema.artifact_id,
case(
(
func.max(ArtifactVersionSchema.created).is_(None),
ArtifactSchema.created,
),
else_=func.max(ArtifactVersionSchema.created),
).label("latest_version_created"),
)
.group_by(ArtifactVersionSchema.artifact_id)
.subquery()
)

# Join the subquery with the main artifacts query
query = query.outerjoin(
latest_version_subquery,
ArtifactSchema.id == latest_version_subquery.c.artifact_id,
)

# Apply sorting based on the operand
if operand == SorterOps.ASCENDING:
query = query.order_by(
asc(latest_version_subquery.c.latest_version_created),
asc(ArtifactSchema.id),
)
else:
query = query.order_by(
desc(latest_version_subquery.c.latest_version_created),
desc(ArtifactSchema.id),
)
return query

# For other sorting cases, delegate to the parent class
return super().apply_sorting(query=query, table=table)
81 changes: 79 additions & 2 deletions src/zenml/models/v2/core/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,12 +13,16 @@
# permissions and limitations under the License.
"""Models representing models."""

from typing import TYPE_CHECKING, List, Optional
from typing import TYPE_CHECKING, Any, ClassVar, List, Optional, Type, TypeVar
from uuid import UUID

from pydantic import BaseModel, Field

from zenml.constants import STR_FIELD_MAX_LENGTH, TEXT_FIELD_MAX_LENGTH
from zenml.constants import (
SORT_BY_LATEST_VERSION_KEY,
STR_FIELD_MAX_LENGTH,
TEXT_FIELD_MAX_LENGTH,
)
from zenml.models.v2.base.scoped import (
WorkspaceScopedRequest,
WorkspaceScopedResponse,
Expand All @@ -32,6 +36,11 @@
if TYPE_CHECKING:
from zenml.model.model import Model
from zenml.models.v2.core.tag import TagResponse
from zenml.zen_stores.schemas import BaseSchema

AnySchema = TypeVar("AnySchema", bound=BaseSchema)

AnyQuery = TypeVar("AnyQuery", bound=Any)

# ------------------ Request Model ------------------

Expand Down Expand Up @@ -320,3 +329,71 @@ class ModelFilter(WorkspaceScopedTaggableFilter):
default=None,
description="Name of the Model",
)

CUSTOM_SORTING_OPTIONS: ClassVar[List[str]] = [
*WorkspaceScopedTaggableFilter.CUSTOM_SORTING_OPTIONS,
SORT_BY_LATEST_VERSION_KEY,
]

def apply_sorting(
self,
query: AnyQuery,
table: Type["AnySchema"],
) -> AnyQuery:
"""Apply sorting to the query for Models.
Args:
query: The query to which to apply the sorting.
table: The query table.
Returns:
The query with sorting applied.
"""
from sqlmodel import asc, case, desc, func, select

from zenml.enums import SorterOps
from zenml.zen_stores.schemas import (
ModelSchema,
ModelVersionSchema,
)

sort_by, operand = self.sorting_params

if sort_by == SORT_BY_LATEST_VERSION_KEY:
# Subquery to find the latest version per model
latest_version_subquery = (
select(
ModelVersionSchema.model_id,
case(
(
func.max(ModelVersionSchema.created).is_(None),
ModelSchema.created,
),
else_=func.max(ModelVersionSchema.created),
).label("latest_version_created"),
)
.group_by(ModelVersionSchema.model_id)
.subquery()
)

# Join the subquery with the main artifacts query
query = query.outerjoin(
latest_version_subquery,
ModelSchema.id == latest_version_subquery.c.model_id,
)

# Apply sorting based on the operand
if operand == SorterOps.ASCENDING:
query = query.order_by(
asc(latest_version_subquery.c.latest_version_created),
asc(ModelSchema.id),
)
else:
query = query.order_by(
desc(latest_version_subquery.c.latest_version_created),
desc(ModelSchema.id),
)
return query

# For other sorting cases, delegate to the parent class
return super().apply_sorting(query=query, table=table)
15 changes: 8 additions & 7 deletions src/zenml/models/v2/core/pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -346,7 +346,7 @@ def apply_sorting(
Returns:
The query with sorting applied.
"""
from sqlmodel import asc, case, col, desc, func, select
from sqlmodel import asc, case, desc, func, select

from zenml.enums import SorterOps
from zenml.zen_stores.schemas import PipelineRunSchema, PipelineSchema
Expand All @@ -366,7 +366,7 @@ def apply_sorting(
else_=func.max(PipelineRunSchema.created),
).label("latest_run"),
)
.group_by(col(PipelineRunSchema.pipeline_id))
.group_by(PipelineRunSchema.pipeline_id)
.subquery()
)

Expand All @@ -378,13 +378,14 @@ def apply_sorting(

if operand == SorterOps.ASCENDING:
query = query.order_by(
asc(latest_run_subquery.c.latest_run)
).order_by(col(PipelineSchema.id))
asc(latest_run_subquery.c.latest_run),
asc(PipelineSchema.id),
)
else:
query = query.order_by(
desc(latest_run_subquery.c.latest_run)
).order_by(col(PipelineSchema.id))

desc(latest_run_subquery.c.latest_run),
desc(PipelineSchema.id),
)
return query
else:
return super().apply_sorting(query=query, table=table)
Loading

0 comments on commit 1db4a83

Please sign in to comment.