From 1db4a837407eaef050d6b82ba02072bed045dfd3 Mon Sep 17 00:00:00 2001 From: Baris Can Durak Date: Sun, 15 Dec 2024 22:46:23 +0100 Subject: [PATCH] first fixes --- src/zenml/constants.py | 1 + src/zenml/models/v2/base/scoped.py | 86 ----------- src/zenml/models/v2/core/artifact.py | 86 ++++++++++- src/zenml/models/v2/core/model.py | 81 ++++++++++- src/zenml/models/v2/core/pipeline.py | 15 +- .../functional/models/test_sorting.py | 133 ++++++++++++++++++ 6 files changed, 305 insertions(+), 97 deletions(-) create mode 100644 tests/integration/functional/models/test_sorting.py diff --git a/src/zenml/constants.py b/src/zenml/constants.py index 183b1acce16..a7a13edb614 100644 --- a/src/zenml/constants.py +++ b/src/zenml/constants.py @@ -429,6 +429,7 @@ def handle_int_env_var(var: str, default: int = 0) -> int: ) FILTERING_DATETIME_FORMAT: str = "%Y-%m-%d %H:%M:%S" SORT_PIPELINES_BY_LATEST_RUN_KEY = "latest_run" +SORT_BY_LATEST_VERSION_KEY = "latest_version" # Metadata constants METADATA_ORCHESTRATOR_URL = "orchestrator_url" diff --git a/src/zenml/models/v2/base/scoped.py b/src/zenml/models/v2/base/scoped.py index f5267f4840d..573116a19ea 100644 --- a/src/zenml/models/v2/base/scoped.py +++ b/src/zenml/models/v2/base/scoped.py @@ -466,16 +466,6 @@ class WorkspaceScopedTaggableFilter(WorkspaceScopedFilter): description="Tag to apply to the filter query.", default=None ) - FILTER_EXCLUDE_FIELDS: ClassVar[List[str]] = [ - *WorkspaceScopedFilter.FILTER_EXCLUDE_FIELDS, - "tag", - ] - - CUSTOM_SORTING_OPTIONS: ClassVar[List[str]] = [ - *WorkspaceScopedFilter.CUSTOM_SORTING_OPTIONS, - "tag", - ] - def apply_filter( self, query: AnyQuery, @@ -524,79 +514,3 @@ def get_custom_filters( ) return custom_filters - - def apply_sorting( - self, - query: AnyQuery, - table: Type["AnySchema"], - ) -> AnyQuery: - """Apply sorting to the query. - - Args: - query: The query to which to apply the sorting. - table: The query table. - - Returns: - The query with sorting applied. - """ - sort_by, operand = self.sorting_params - - if sort_by == "tag": - from sqlmodel import and_, asc, desc, func - - from zenml.enums import SorterOps, TaggableResourceTypes - from zenml.zen_stores.schemas import ( - ArtifactSchema, - ArtifactVersionSchema, - ModelSchema, - ModelVersionSchema, - PipelineRunSchema, - PipelineSchema, - RunTemplateSchema, - TagResourceSchema, - TagSchema, - ) - - resource_type_mapping = { - ArtifactSchema: TaggableResourceTypes.ARTIFACT, - ArtifactVersionSchema: TaggableResourceTypes.ARTIFACT_VERSION, - ModelSchema: TaggableResourceTypes.MODEL, - ModelVersionSchema: TaggableResourceTypes.MODEL_VERSION, - PipelineSchema: TaggableResourceTypes.PIPELINE, - PipelineRunSchema: TaggableResourceTypes.PIPELINE_RUN, - RunTemplateSchema: TaggableResourceTypes.RUN_TEMPLATE, - } - - query = ( - query.outerjoin( - TagResourceSchema, - and_( - table.id == TagResourceSchema.resource_id, - TagResourceSchema.resource_type - == resource_type_mapping[table], - ), - ) - .outerjoin(TagSchema, TagResourceSchema.tag_id == TagSchema.id) - .group_by(table.id) - ) - - if operand == SorterOps.ASCENDING: - query = query.order_by( - asc( - func.group_concat(TagSchema.name, ",").label( - "tags_list" - ) - ) - ) - else: - query = query.order_by( - desc( - func.group_concat(TagSchema.name, ",").label( - "tags_list" - ) - ) - ) - - return query - - return super().apply_sorting(query=query, table=table) diff --git a/src/zenml/models/v2/core/artifact.py b/src/zenml/models/v2/core/artifact.py index c62a7cee1a5..8cebcba6e42 100644 --- a/src/zenml/models/v2/core/artifact.py +++ b/src/zenml/models/v2/core/artifact.py @@ -13,12 +13,21 @@ # permissions and limitations under the License. """Models representing artifacts.""" -from typing import TYPE_CHECKING, Dict, List, Optional +from typing import ( + TYPE_CHECKING, + Any, + ClassVar, + Dict, + List, + Optional, + Type, + TypeVar, +) from uuid import UUID from pydantic import BaseModel, Field -from zenml.constants import STR_FIELD_MAX_LENGTH +from zenml.constants import SORT_BY_LATEST_VERSION_KEY, STR_FIELD_MAX_LENGTH from zenml.models.v2.base.base import ( BaseDatedResponseBody, BaseIdentifiedResponse, @@ -31,6 +40,11 @@ if TYPE_CHECKING: from zenml.models.v2.core.artifact_version import ArtifactVersionResponse + from zenml.zen_stores.schemas import BaseSchema + + AnySchema = TypeVar("AnySchema", bound=BaseSchema) + +AnyQuery = TypeVar("AnyQuery", bound=Any) # ------------------ Request Model ------------------ @@ -174,3 +188,71 @@ class ArtifactFilter(WorkspaceScopedTaggableFilter): name: Optional[str] = None has_custom_name: Optional[bool] = None + + CUSTOM_SORTING_OPTIONS: ClassVar[List[str]] = [ + *WorkspaceScopedTaggableFilter.CUSTOM_SORTING_OPTIONS, + SORT_BY_LATEST_VERSION_KEY, + ] + + def apply_sorting( + self, + query: AnyQuery, + table: Type["AnySchema"], + ) -> AnyQuery: + """Apply sorting to the query for Artifacts. + + Args: + query: The query to which to apply the sorting. + table: The query table. + + Returns: + The query with sorting applied. + """ + from sqlmodel import asc, case, desc, func, select + + from zenml.enums import SorterOps + from zenml.zen_stores.schemas import ( + ArtifactSchema, + ArtifactVersionSchema, + ) + + sort_by, operand = self.sorting_params + + if sort_by == SORT_BY_LATEST_VERSION_KEY: + # Subquery to find the latest version per artifact + latest_version_subquery = ( + select( + ArtifactVersionSchema.artifact_id, + case( + ( + func.max(ArtifactVersionSchema.created).is_(None), + ArtifactSchema.created, + ), + else_=func.max(ArtifactVersionSchema.created), + ).label("latest_version_created"), + ) + .group_by(ArtifactVersionSchema.artifact_id) + .subquery() + ) + + # Join the subquery with the main artifacts query + query = query.outerjoin( + latest_version_subquery, + ArtifactSchema.id == latest_version_subquery.c.artifact_id, + ) + + # Apply sorting based on the operand + if operand == SorterOps.ASCENDING: + query = query.order_by( + asc(latest_version_subquery.c.latest_version_created), + asc(ArtifactSchema.id), + ) + else: + query = query.order_by( + desc(latest_version_subquery.c.latest_version_created), + desc(ArtifactSchema.id), + ) + return query + + # For other sorting cases, delegate to the parent class + return super().apply_sorting(query=query, table=table) diff --git a/src/zenml/models/v2/core/model.py b/src/zenml/models/v2/core/model.py index 0b5272ab7e6..e8b781ecd2c 100644 --- a/src/zenml/models/v2/core/model.py +++ b/src/zenml/models/v2/core/model.py @@ -13,12 +13,16 @@ # permissions and limitations under the License. """Models representing models.""" -from typing import TYPE_CHECKING, List, Optional +from typing import TYPE_CHECKING, Any, ClassVar, List, Optional, Type, TypeVar from uuid import UUID from pydantic import BaseModel, Field -from zenml.constants import STR_FIELD_MAX_LENGTH, TEXT_FIELD_MAX_LENGTH +from zenml.constants import ( + SORT_BY_LATEST_VERSION_KEY, + STR_FIELD_MAX_LENGTH, + TEXT_FIELD_MAX_LENGTH, +) from zenml.models.v2.base.scoped import ( WorkspaceScopedRequest, WorkspaceScopedResponse, @@ -32,6 +36,11 @@ if TYPE_CHECKING: from zenml.model.model import Model from zenml.models.v2.core.tag import TagResponse + from zenml.zen_stores.schemas import BaseSchema + + AnySchema = TypeVar("AnySchema", bound=BaseSchema) + +AnyQuery = TypeVar("AnyQuery", bound=Any) # ------------------ Request Model ------------------ @@ -320,3 +329,71 @@ class ModelFilter(WorkspaceScopedTaggableFilter): default=None, description="Name of the Model", ) + + CUSTOM_SORTING_OPTIONS: ClassVar[List[str]] = [ + *WorkspaceScopedTaggableFilter.CUSTOM_SORTING_OPTIONS, + SORT_BY_LATEST_VERSION_KEY, + ] + + def apply_sorting( + self, + query: AnyQuery, + table: Type["AnySchema"], + ) -> AnyQuery: + """Apply sorting to the query for Models. + + Args: + query: The query to which to apply the sorting. + table: The query table. + + Returns: + The query with sorting applied. + """ + from sqlmodel import asc, case, desc, func, select + + from zenml.enums import SorterOps + from zenml.zen_stores.schemas import ( + ModelSchema, + ModelVersionSchema, + ) + + sort_by, operand = self.sorting_params + + if sort_by == SORT_BY_LATEST_VERSION_KEY: + # Subquery to find the latest version per model + latest_version_subquery = ( + select( + ModelVersionSchema.model_id, + case( + ( + func.max(ModelVersionSchema.created).is_(None), + ModelSchema.created, + ), + else_=func.max(ModelVersionSchema.created), + ).label("latest_version_created"), + ) + .group_by(ModelVersionSchema.model_id) + .subquery() + ) + + # Join the subquery with the main artifacts query + query = query.outerjoin( + latest_version_subquery, + ModelSchema.id == latest_version_subquery.c.model_id, + ) + + # Apply sorting based on the operand + if operand == SorterOps.ASCENDING: + query = query.order_by( + asc(latest_version_subquery.c.latest_version_created), + asc(ModelSchema.id), + ) + else: + query = query.order_by( + desc(latest_version_subquery.c.latest_version_created), + desc(ModelSchema.id), + ) + return query + + # For other sorting cases, delegate to the parent class + return super().apply_sorting(query=query, table=table) diff --git a/src/zenml/models/v2/core/pipeline.py b/src/zenml/models/v2/core/pipeline.py index 199e9cce959..707e9a86104 100644 --- a/src/zenml/models/v2/core/pipeline.py +++ b/src/zenml/models/v2/core/pipeline.py @@ -346,7 +346,7 @@ def apply_sorting( Returns: The query with sorting applied. """ - from sqlmodel import asc, case, col, desc, func, select + from sqlmodel import asc, case, desc, func, select from zenml.enums import SorterOps from zenml.zen_stores.schemas import PipelineRunSchema, PipelineSchema @@ -366,7 +366,7 @@ def apply_sorting( else_=func.max(PipelineRunSchema.created), ).label("latest_run"), ) - .group_by(col(PipelineRunSchema.pipeline_id)) + .group_by(PipelineRunSchema.pipeline_id) .subquery() ) @@ -378,13 +378,14 @@ def apply_sorting( if operand == SorterOps.ASCENDING: query = query.order_by( - asc(latest_run_subquery.c.latest_run) - ).order_by(col(PipelineSchema.id)) + asc(latest_run_subquery.c.latest_run), + asc(PipelineSchema.id), + ) else: query = query.order_by( - desc(latest_run_subquery.c.latest_run) - ).order_by(col(PipelineSchema.id)) - + desc(latest_run_subquery.c.latest_run), + desc(PipelineSchema.id), + ) return query else: return super().apply_sorting(query=query, table=table) diff --git a/tests/integration/functional/models/test_sorting.py b/tests/integration/functional/models/test_sorting.py new file mode 100644 index 00000000000..67f17ce9b39 --- /dev/null +++ b/tests/integration/functional/models/test_sorting.py @@ -0,0 +1,133 @@ +# Copyright (c) ZenML GmbH 2024. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at: +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express +# or implied. See the License for the specific language governing +# permissions and limitations under the License. + +from typing import Annotated + +from zenml import Model, pipeline, step +from zenml.constants import ( + SORT_BY_LATEST_VERSION_KEY, + SORT_PIPELINES_BY_LATEST_RUN_KEY, +) + + +@step +def first_step() -> Annotated[int, "int_artifact"]: + """Step to return an int.""" + return 3 + + +@pipeline(enable_cache=False) +def first_pipeline(): + """Pipeline definition to test the different sorting mechanisms.""" + _ = first_step() + + +@step +def second_step() -> Annotated[str, "str_artifact"]: + """Step to return a string.""" + return "3" + + +@pipeline(enable_cache=False) +def second_pipeline(): + """Pipeline definition to test the different sorting mechanisms.""" + _ = second_step() + + +def test_sorting_entities(clean_client): + """Testing different sorting functionalities.""" + first_pipeline_first_run = first_pipeline.with_options( + model=Model(name="Model2"), + )() + _ = first_pipeline.with_options( + model=Model(name="Model1", version="second"), + )() + _ = first_pipeline.with_options( + model=Model(name="Model1", version="first"), + )() + second_pipeline_first_run = second_pipeline() + + # Sorting runs by the name of the user + clean_client.list_pipeline_runs(sort_by="user") + clean_client.list_pipeline_runs(sort_by="asc:user") + clean_client.list_pipeline_runs(sort_by="desc:user") + + # Sorting runs by the name of the workspace + clean_client.list_pipeline_runs(sort_by="workspace") + clean_client.list_pipeline_runs(sort_by="asc:workspace") + clean_client.list_pipeline_runs(sort_by="desc:workspace") + + # Sorting pipelines by latest run + results = clean_client.list_pipelines( + sort_by=f"{SORT_PIPELINES_BY_LATEST_RUN_KEY}" + ) + assert results[0].id == first_pipeline_first_run.pipeline.id + assert results[1].id == second_pipeline_first_run.pipeline.id + + results = clean_client.list_pipelines( + sort_by=f"asc:{SORT_PIPELINES_BY_LATEST_RUN_KEY}" + ) + assert results[0].id == first_pipeline_first_run.pipeline.id + assert results[1].id == second_pipeline_first_run.pipeline.id + + results = clean_client.list_pipelines( + sort_by=f"desc:{SORT_PIPELINES_BY_LATEST_RUN_KEY}" + ) + assert results[0].id == second_pipeline_first_run.pipeline.id + assert results[1].id == first_pipeline_first_run.pipeline.id + + # Sorting runs by pipeline name + results = clean_client.list_pipeline_runs(sort_by="asc:name") + assert results[0].name.startswith("first_") + assert results[-1].name.startswith("second_") + + # Sorting runs by stack name + clean_client.list_pipeline_runs(sort_by="asc:stack") + clean_client.list_pipeline_runs(sort_by="desc:stack") + + # Sorting runs by model name + results = clean_client.list_pipeline_runs(sort_by="asc:model") + assert results[0].model_version.model.name == "Model1" + assert results[-1].model_version.model.name == "Model2" + clean_client.list_pipeline_runs(sort_by="desc:model") + + # Sorting runs by model version + results = clean_client.list_pipeline_runs(sort_by="asc:model_version") + + assert results[0].model_version.name == "1" + assert results[-1].model_version.name == "second" + + clean_client.list_pipeline_runs(sort_by="desc:model") + + # Sorting artifacts by latest version + results = clean_client.list_artifacts( + sort_by=f"asc:{SORT_BY_LATEST_VERSION_KEY}" + ) + assert results[0].name == "int_artifact" + + results = clean_client.list_artifacts( + sort_by=f"desc:{SORT_BY_LATEST_VERSION_KEY}" + ) + assert results[0].name == "str_artifact" + + # Sorting models by latest version + results = clean_client.list_models( + sort_by=f"asc:{SORT_BY_LATEST_VERSION_KEY}" + ) + assert results[0].name == "Model2" + + results = clean_client.list_models( + sort_by=f"desc:{SORT_BY_LATEST_VERSION_KEY}" + ) + assert results[0].name == "Model1"