From 7c1f634d5bc48635848f080ff4109e5c33c61b2a Mon Sep 17 00:00:00 2001 From: Baris Can Durak Date: Fri, 29 Nov 2024 16:58:08 +0100 Subject: [PATCH 01/16] first checkpoint --- src/zenml/models/v2/base/filter.py | 9 +- src/zenml/models/v2/base/scoped.py | 165 +++++++++++++++++- src/zenml/models/v2/core/artifact_version.py | 26 +-- src/zenml/models/v2/core/code_repository.py | 13 +- src/zenml/models/v2/core/component.py | 44 ----- src/zenml/models/v2/core/flavor.py | 12 +- src/zenml/models/v2/core/model.py | 58 +----- src/zenml/models/v2/core/model_version.py | 27 --- src/zenml/models/v2/core/pipeline.py | 51 +----- src/zenml/models/v2/core/pipeline_build.py | 10 -- .../models/v2/core/pipeline_deployment.py | 10 -- src/zenml/models/v2/core/pipeline_run.py | 104 ++++++++--- src/zenml/models/v2/core/run_template.py | 26 --- src/zenml/models/v2/core/schedule.py | 10 -- src/zenml/models/v2/core/secret.py | 14 -- src/zenml/models/v2/core/service.py | 15 +- src/zenml/models/v2/core/service_connector.py | 11 -- src/zenml/models/v2/core/stack.py | 27 --- src/zenml/models/v2/core/step_run.py | 11 -- 19 files changed, 249 insertions(+), 394 deletions(-) diff --git a/src/zenml/models/v2/base/filter.py b/src/zenml/models/v2/base/filter.py index 1c4d2cccfb5..bb15aa16f59 100644 --- a/src/zenml/models/v2/base/filter.py +++ b/src/zenml/models/v2/base/filter.py @@ -436,7 +436,6 @@ class BaseFilter(BaseModel): le=PAGE_SIZE_MAXIMUM, description="Page size", ) - id: Optional[Union[UUID, str]] = Field( default=None, description="Id for this resource", @@ -491,13 +490,13 @@ def validate_sort_by(cls, value: Any) -> Any: ) value = column - if column in cls.FILTER_EXCLUDE_FIELDS: + if column in cls.CUSTOM_SORTING_OPTIONS: + return value + elif column in cls.FILTER_EXCLUDE_FIELDS: raise ValueError( f"This resource can not be sorted by this field: '{value}'" ) - elif column in cls.model_fields: - return value - elif column in cls.CUSTOM_SORTING_OPTIONS: + if column in cls.model_fields: return value else: raise ValueError( diff --git a/src/zenml/models/v2/base/scoped.py b/src/zenml/models/v2/base/scoped.py index f563b6dc81c..dacee2b7cd1 100644 --- a/src/zenml/models/v2/base/scoped.py +++ b/src/zenml/models/v2/base/scoped.py @@ -23,6 +23,7 @@ Optional, Type, TypeVar, + Union, ) from uuid import UUID @@ -151,16 +152,32 @@ class UserScopedFilter(BaseFilter): FILTER_EXCLUDE_FIELDS: ClassVar[List[str]] = [ *BaseFilter.FILTER_EXCLUDE_FIELDS, + "user", "scope_user", ] CLI_EXCLUDE_FIELDS: ClassVar[List[str]] = [ *BaseFilter.CLI_EXCLUDE_FIELDS, + "user_id", "scope_user", ] + CUSTOM_SORTING_OPTIONS: ClassVar[List[str]] = [ + *BaseFilter.CUSTOM_SORTING_OPTIONS, + "user", + ] + scope_user: Optional[UUID] = Field( default=None, description="The user to scope this query to.", ) + user_id: Optional[Union[UUID, str]] = Field( + default=None, + description="UUID of the user that created the entity.", + union_mode="left_to_right", + ) + user: Optional[Union[UUID, str]] = Field( + default=None, + description="Name/ID of the user that created the entity.", + ) def set_scope_user(self, user_id: UUID) -> None: """Set the user that is performing the filtering to scope the response. @@ -170,6 +187,69 @@ def set_scope_user(self, user_id: UUID) -> None: """ self.scope_user = user_id + def get_custom_filters( + self, + ) -> List["ColumnElement[bool]"]: + """Get custom filters. + + Returns: + A list of custom filters. + """ + custom_filters = super().get_custom_filters() + + from sqlmodel import and_ + + from zenml.zen_stores.schemas import ( + PipelineSchema, + UserSchema, + ) + + if self.user: + user_filter = and_( + PipelineSchema.user_id == UserSchema.id, + self.generate_name_or_id_query_conditions( + value=self.user, + table=UserSchema, + additional_columns=["full_name"], + ), + ) + custom_filters.append(user_filter) + + return custom_filters + + def apply_sorting( + self, + query: AnyQuery, + table: Type["AnySchema"], + ) -> AnyQuery: + """Apply sorting to the query. + + Args: + query: The query to which to apply the sorting. + table: The query table. + + Returns: + The query with sorting applied. + """ + from sqlmodel import asc, desc + + from zenml.enums import SorterOps + from zenml.zen_stores.schemas import UserSchema + + sort_by, operand = self.sorting_params + + if sort_by == "user": + column = UserSchema.name + + query = query.join(UserSchema, table.user_id == UserSchema.id) + + if operand == SorterOps.ASCENDING: + return query.order_by(asc(column)) + else: + return query.order_by(desc(column)) + + return super().apply_sorting(query=query, table=table) + def apply_filter( self, query: AnyQuery, @@ -240,21 +320,36 @@ def workspace(self) -> "WorkspaceResponse": return self.get_metadata().workspace -class WorkspaceScopedFilter(BaseFilter): +class WorkspaceScopedFilter(UserScopedFilter): """Model to enable advanced scoping with workspace.""" FILTER_EXCLUDE_FIELDS: ClassVar[List[str]] = [ - *BaseFilter.FILTER_EXCLUDE_FIELDS, + *UserScopedFilter.FILTER_EXCLUDE_FIELDS, + "workspace", "scope_workspace", ] CLI_EXCLUDE_FIELDS: ClassVar[List[str]] = [ - *BaseFilter.CLI_EXCLUDE_FIELDS, + *UserScopedFilter.CLI_EXCLUDE_FIELDS, + "workspace_id", "scope_workspace", ] + CUSTOM_SORTING_OPTIONS: ClassVar[List[str]] = [ + *UserScopedFilter.CUSTOM_SORTING_OPTIONS, + "workspace", + ] scope_workspace: Optional[UUID] = Field( default=None, description="The workspace to scope this query to.", ) + workspace_id: Optional[Union[UUID, str]] = Field( + default=None, + description="UUID of the workspace that this entity belongs to.", + union_mode="left_to_right", + ) + workspace: Optional[Union[UUID, str]] = Field( + default=None, + description="Name/ID of the workspace that this entity belongs to.", + ) def set_scope_workspace(self, workspace_id: UUID) -> None: """Set the workspace to scope this response. @@ -264,6 +359,35 @@ def set_scope_workspace(self, workspace_id: UUID) -> None: """ self.scope_workspace = workspace_id + def get_custom_filters( + self, + ) -> List["ColumnElement[bool]"]: + """Get custom filters. + + Returns: + A list of custom filters. + """ + custom_filters = super().get_custom_filters() + + from sqlmodel import and_ + + from zenml.zen_stores.schemas import ( + PipelineSchema, + WorkspaceSchema, + ) + + if self.workspace: + workspace_filter = and_( + PipelineSchema.user_id == WorkspaceSchema.id, + self.generate_name_or_id_query_conditions( + value=self.workspace, + table=WorkspaceSchema, + ), + ) + custom_filters.append(workspace_filter) + + return custom_filters + def apply_filter( self, query: AnyQuery, @@ -291,6 +415,41 @@ def apply_filter( return query + def apply_sorting( + self, + query: AnyQuery, + table: Type["AnySchema"], + ) -> AnyQuery: + """Apply sorting to the query. + + Args: + query: The query to which to apply the sorting. + table: The query table. + + Returns: + The query with sorting applied. + """ + from sqlmodel import asc, desc + + from zenml.enums import SorterOps + from zenml.zen_stores.schemas import WorkspaceSchema + + sort_by, operand = self.sorting_params + + if sort_by == "workspace": + column = WorkspaceSchema.name + + query = query.join( + WorkspaceSchema, table.workspace_id == WorkspaceSchema.id + ) + + if operand == SorterOps.ASCENDING: + return query.order_by(asc(column)) + else: + return query.order_by(desc(column)) + + return super().apply_sorting(query=query, table=table) + class WorkspaceScopedTaggableFilter(WorkspaceScopedFilter): """Model to enable advanced scoping with workspace and tagging.""" diff --git a/src/zenml/models/v2/core/artifact_version.py b/src/zenml/models/v2/core/artifact_version.py index 646e5c16ce1..4f3ef2a8052 100644 --- a/src/zenml/models/v2/core/artifact_version.py +++ b/src/zenml/models/v2/core/artifact_version.py @@ -471,7 +471,6 @@ class ArtifactVersionFilter(WorkspaceScopedTaggableFilter): "name", "only_unused", "has_custom_name", - "user", "model", "pipeline_run", "model_version_id", @@ -516,19 +515,10 @@ class ArtifactVersionFilter(WorkspaceScopedTaggableFilter): description="Artifact store for this artifact", union_mode="left_to_right", ) - workspace_id: Optional[Union[UUID, str]] = Field( - default=None, - description="Workspace for this artifact", - union_mode="left_to_right", - ) - user_id: Optional[Union[UUID, str]] = Field( - default=None, - description="User that produced this artifact", - union_mode="left_to_right", - ) model_version_id: Optional[Union[UUID, str]] = Field( default=None, - description="ID of the model version that is associated with this artifact version.", + description="ID of the model version that is associated with this " + "artifact version.", union_mode="left_to_right", ) only_unused: Optional[bool] = Field( @@ -580,7 +570,6 @@ def get_custom_filters(self) -> List[Union["ColumnElement[bool]"]]: StepRunInputArtifactSchema, StepRunOutputArtifactSchema, StepRunSchema, - UserSchema, ) if self.name: @@ -628,17 +617,6 @@ def get_custom_filters(self) -> List[Union["ColumnElement[bool]"]]: ) custom_filters.append(custom_name_filter) - if self.user: - user_filter = and_( - ArtifactVersionSchema.user_id == UserSchema.id, - self.generate_name_or_id_query_conditions( - value=self.user, - table=UserSchema, - additional_columns=["full_name"], - ), - ) - custom_filters.append(user_filter) - if self.model: model_filter = and_( ArtifactVersionSchema.id diff --git a/src/zenml/models/v2/core/code_repository.py b/src/zenml/models/v2/core/code_repository.py index c0a5430468b..485f710b7de 100644 --- a/src/zenml/models/v2/core/code_repository.py +++ b/src/zenml/models/v2/core/code_repository.py @@ -13,8 +13,7 @@ # permissions and limitations under the License. """Models representing code repositories.""" -from typing import Any, Dict, Optional, Union -from uuid import UUID +from typing import Any, Dict, Optional from pydantic import Field @@ -189,13 +188,3 @@ class CodeRepositoryFilter(WorkspaceScopedFilter): description="Name of the code repository.", default=None, ) - workspace_id: Optional[Union[UUID, str]] = Field( - description="Workspace of the code repository.", - default=None, - union_mode="left_to_right", - ) - user_id: Optional[Union[UUID, str]] = Field( - description="User that created the code repository.", - default=None, - union_mode="left_to_right", - ) diff --git a/src/zenml/models/v2/core/component.py b/src/zenml/models/v2/core/component.py index a4f52be884c..ac73aa5e8dd 100644 --- a/src/zenml/models/v2/core/component.py +++ b/src/zenml/models/v2/core/component.py @@ -356,7 +356,6 @@ class ComponentFilter(WorkspaceScopedFilter): *WorkspaceScopedFilter.FILTER_EXCLUDE_FIELDS, "scope_type", "stack_id", - "user", ] CLI_EXCLUDE_FIELDS: ClassVar[List[str]] = [ *WorkspaceScopedFilter.CLI_EXCLUDE_FIELDS, @@ -366,7 +365,6 @@ class ComponentFilter(WorkspaceScopedFilter): default=None, description="The type to scope this query to.", ) - name: Optional[str] = Field( default=None, description="Name of the stack component", @@ -379,16 +377,6 @@ class ComponentFilter(WorkspaceScopedFilter): default=None, description="Type of the stack component", ) - workspace_id: Optional[Union[UUID, str]] = Field( - default=None, - description="Workspace of the stack component", - union_mode="left_to_right", - ) - user_id: Optional[Union[UUID, str]] = Field( - default=None, - description="User of the stack component", - union_mode="left_to_right", - ) connector_id: Optional[Union[UUID, str]] = Field( default=None, description="Connector linked to the stack component", @@ -399,10 +387,6 @@ class ComponentFilter(WorkspaceScopedFilter): description="Stack of the stack component", union_mode="left_to_right", ) - user: Optional[Union[UUID, str]] = Field( - default=None, - description="Name/ID of the user that created the component.", - ) def set_scope_type(self, component_type: str) -> None: """Set the type of component on which to perform the filtering to scope the response. @@ -449,31 +433,3 @@ def generate_filter( base_filter = operator(base_filter, stack_filter) return base_filter - - def get_custom_filters(self) -> List["ColumnElement[bool]"]: - """Get custom filters. - - Returns: - A list of custom filters. - """ - from sqlmodel import and_ - - from zenml.zen_stores.schemas import ( - StackComponentSchema, - UserSchema, - ) - - custom_filters = super().get_custom_filters() - - if self.user: - user_filter = and_( - StackComponentSchema.user_id == UserSchema.id, - self.generate_name_or_id_query_conditions( - value=self.user, - table=UserSchema, - additional_columns=["full_name"], - ), - ) - custom_filters.append(user_filter) - - return custom_filters diff --git a/src/zenml/models/v2/core/flavor.py b/src/zenml/models/v2/core/flavor.py index fd4110300c3..77fe774c073 100644 --- a/src/zenml/models/v2/core/flavor.py +++ b/src/zenml/models/v2/core/flavor.py @@ -13,7 +13,7 @@ # permissions and limitations under the License. """Models representing flavors.""" -from typing import TYPE_CHECKING, Any, ClassVar, Dict, List, Optional, Union +from typing import TYPE_CHECKING, Any, ClassVar, Dict, List, Optional from uuid import UUID from pydantic import Field @@ -428,13 +428,3 @@ class FlavorFilter(WorkspaceScopedFilter): default=None, description="Integration associated with the flavor", ) - workspace_id: Optional[Union[UUID, str]] = Field( - default=None, - description="Workspace of the stack", - union_mode="left_to_right", - ) - user_id: Optional[Union[UUID, str]] = Field( - default=None, - description="User of the stack", - union_mode="left_to_right", - ) diff --git a/src/zenml/models/v2/core/model.py b/src/zenml/models/v2/core/model.py index 0eb3b749c88..0b5272ab7e6 100644 --- a/src/zenml/models/v2/core/model.py +++ b/src/zenml/models/v2/core/model.py @@ -13,7 +13,7 @@ # permissions and limitations under the License. """Models representing models.""" -from typing import TYPE_CHECKING, ClassVar, List, Optional, Union +from typing import TYPE_CHECKING, List, Optional from uuid import UUID from pydantic import BaseModel, Field @@ -30,8 +30,6 @@ from zenml.utils.pagination_utils import depaginate if TYPE_CHECKING: - from sqlalchemy.sql.elements import ColumnElement - from zenml.model.model import Model from zenml.models.v2.core.tag import TagResponse @@ -318,61 +316,7 @@ def versions(self) -> List["Model"]: class ModelFilter(WorkspaceScopedTaggableFilter): """Model to enable advanced filtering of all Workspaces.""" - CLI_EXCLUDE_FIELDS: ClassVar[List[str]] = [ - *WorkspaceScopedTaggableFilter.CLI_EXCLUDE_FIELDS, - "workspace_id", - "user_id", - ] - FILTER_EXCLUDE_FIELDS: ClassVar[List[str]] = [ - *WorkspaceScopedTaggableFilter.FILTER_EXCLUDE_FIELDS, - "user", - ] - name: Optional[str] = Field( default=None, description="Name of the Model", ) - workspace_id: Optional[Union[UUID, str]] = Field( - default=None, - description="Workspace of the Model", - union_mode="left_to_right", - ) - user_id: Optional[Union[UUID, str]] = Field( - default=None, - description="User of the Model", - union_mode="left_to_right", - ) - user: Optional[Union[UUID, str]] = Field( - default=None, - description="Name/ID of the user that created the model.", - ) - - def get_custom_filters( - self, - ) -> List["ColumnElement[bool]"]: - """Get custom filters. - - Returns: - A list of custom filters. - """ - custom_filters = super().get_custom_filters() - - from sqlmodel import and_ - - from zenml.zen_stores.schemas import ( - ModelSchema, - UserSchema, - ) - - if self.user: - user_filter = and_( - ModelSchema.user_id == UserSchema.id, - self.generate_name_or_id_query_conditions( - value=self.user, - table=UserSchema, - additional_columns=["full_name"], - ), - ) - custom_filters.append(user_filter) - - return custom_filters diff --git a/src/zenml/models/v2/core/model_version.py b/src/zenml/models/v2/core/model_version.py index 9ce9b1692b5..c5f834c031e 100644 --- a/src/zenml/models/v2/core/model_version.py +++ b/src/zenml/models/v2/core/model_version.py @@ -585,7 +585,6 @@ class ModelVersionFilter(WorkspaceScopedTaggableFilter): FILTER_EXCLUDE_FIELDS: ClassVar[List[str]] = [ *WorkspaceScopedTaggableFilter.FILTER_EXCLUDE_FIELDS, - "user", "run_metadata", ] @@ -597,25 +596,11 @@ class ModelVersionFilter(WorkspaceScopedTaggableFilter): default=None, description="The number of the Model Version", ) - workspace_id: Optional[Union[UUID, str]] = Field( - default=None, - description="The workspace of the Model Version", - union_mode="left_to_right", - ) - user_id: Optional[Union[UUID, str]] = Field( - default=None, - description="The user of the Model Version", - union_mode="left_to_right", - ) stage: Optional[Union[str, ModelStages]] = Field( description="The model version stage", default=None, union_mode="left_to_right", ) - user: Optional[Union[UUID, str]] = Field( - default=None, - description="Name/ID of the user that created the model version.", - ) run_metadata: Optional[Dict[str, str]] = Field( default=None, description="The run_metadata to filter the model versions by.", @@ -653,20 +638,8 @@ def get_custom_filters( from zenml.zen_stores.schemas import ( ModelVersionSchema, RunMetadataSchema, - UserSchema, ) - if self.user: - user_filter = and_( - ModelVersionSchema.user_id == UserSchema.id, - self.generate_name_or_id_query_conditions( - value=self.user, - table=UserSchema, - additional_columns=["full_name"], - ), - ) - custom_filters.append(user_filter) - if self.run_metadata is not None: from zenml.enums import MetadataResourceTypes diff --git a/src/zenml/models/v2/core/pipeline.py b/src/zenml/models/v2/core/pipeline.py index 5166e0abb9c..457b4989dd3 100644 --- a/src/zenml/models/v2/core/pipeline.py +++ b/src/zenml/models/v2/core/pipeline.py @@ -21,7 +21,6 @@ Optional, Type, TypeVar, - Union, ) from uuid import UUID @@ -45,8 +44,6 @@ from zenml.models.v2.core.tag import TagResponse if TYPE_CHECKING: - from sqlalchemy.sql.elements import ColumnElement - from zenml.models.v2.core.pipeline_run import PipelineRunResponse from zenml.zen_stores.schemas import BaseSchema @@ -261,10 +258,8 @@ class PipelineFilter(WorkspaceScopedTaggableFilter): CUSTOM_SORTING_OPTIONS = [SORT_PIPELINES_BY_LATEST_RUN_KEY] FILTER_EXCLUDE_FIELDS: ClassVar[List[str]] = [ *WorkspaceScopedTaggableFilter.FILTER_EXCLUDE_FIELDS, - "user", "latest_run_status", ] - name: Optional[str] = Field( default=None, description="Name of the Pipeline", @@ -274,20 +269,6 @@ class PipelineFilter(WorkspaceScopedTaggableFilter): description="Filter by the status of the latest run of a pipeline. " "This will always be applied as an `AND` filter for now.", ) - workspace_id: Optional[Union[UUID, str]] = Field( - default=None, - description="Workspace of the Pipeline", - union_mode="left_to_right", - ) - user_id: Optional[Union[UUID, str]] = Field( - default=None, - description="User of the Pipeline", - union_mode="left_to_right", - ) - user: Optional[Union[UUID, str]] = Field( - default=None, - description="Name/ID of the user that created the pipeline.", - ) def apply_filter( self, query: AnyQuery, table: Type["AnySchema"] @@ -343,36 +324,6 @@ def apply_filter( return query - def get_custom_filters( - self, - ) -> List["ColumnElement[bool]"]: - """Get custom filters. - - Returns: - A list of custom filters. - """ - custom_filters = super().get_custom_filters() - - from sqlmodel import and_ - - from zenml.zen_stores.schemas import ( - PipelineSchema, - UserSchema, - ) - - if self.user: - user_filter = and_( - PipelineSchema.user_id == UserSchema.id, - self.generate_name_or_id_query_conditions( - value=self.user, - table=UserSchema, - additional_columns=["full_name"], - ), - ) - custom_filters.append(user_filter) - - return custom_filters - def apply_sorting( self, query: AnyQuery, @@ -392,7 +343,7 @@ def apply_sorting( if column == SORT_PIPELINES_BY_LATEST_RUN_KEY: # If sorting by the latest run, the sorting is already done in the # base query in `SqlZenStore.list_pipelines(...)` and we don't need - # to to anything here + # to do anything here return query else: return super().apply_sorting(query=query, table=table) diff --git a/src/zenml/models/v2/core/pipeline_build.py b/src/zenml/models/v2/core/pipeline_build.py index 3cb6dcb4e47..d110a8d9cd1 100644 --- a/src/zenml/models/v2/core/pipeline_build.py +++ b/src/zenml/models/v2/core/pipeline_build.py @@ -446,16 +446,6 @@ def contains_code(self) -> bool: class PipelineBuildFilter(WorkspaceScopedFilter): """Model to enable advanced filtering of all pipeline builds.""" - workspace_id: Optional[Union[UUID, str]] = Field( - description="Workspace for this pipeline build.", - default=None, - union_mode="left_to_right", - ) - user_id: Optional[Union[UUID, str]] = Field( - description="User that produced this pipeline build.", - default=None, - union_mode="left_to_right", - ) pipeline_id: Optional[Union[UUID, str]] = Field( description="Pipeline associated with the pipeline build.", default=None, diff --git a/src/zenml/models/v2/core/pipeline_deployment.py b/src/zenml/models/v2/core/pipeline_deployment.py index 760f65f1a35..94dbc431507 100644 --- a/src/zenml/models/v2/core/pipeline_deployment.py +++ b/src/zenml/models/v2/core/pipeline_deployment.py @@ -358,16 +358,6 @@ def template_id(self) -> Optional[UUID]: class PipelineDeploymentFilter(WorkspaceScopedFilter): """Model to enable advanced filtering of all pipeline deployments.""" - workspace_id: Optional[Union[UUID, str]] = Field( - default=None, - description="Workspace for this deployment.", - union_mode="left_to_right", - ) - user_id: Optional[Union[UUID, str]] = Field( - default=None, - description="User that created this deployment.", - union_mode="left_to_right", - ) pipeline_id: Optional[Union[UUID, str]] = Field( default=None, description="Pipeline associated with the deployment.", diff --git a/src/zenml/models/v2/core/pipeline_run.py b/src/zenml/models/v2/core/pipeline_run.py index a9cace6bc12..b4e6723a7d9 100644 --- a/src/zenml/models/v2/core/pipeline_run.py +++ b/src/zenml/models/v2/core/pipeline_run.py @@ -16,10 +16,13 @@ from datetime import datetime from typing import ( TYPE_CHECKING, + Any, ClassVar, Dict, List, Optional, + Type, + TypeVar, Union, cast, ) @@ -55,6 +58,11 @@ from zenml.models.v2.core.schedule import ScheduleResponse from zenml.models.v2.core.stack import StackResponse from zenml.models.v2.core.step_run import StepRunResponse + from zenml.zen_stores.schemas.base_schemas import BaseSchema + + AnySchema = TypeVar("AnySchema", bound=BaseSchema) + +AnyQuery = TypeVar("AnyQuery", bound=Any) # ------------------ Request Model ------------------ @@ -575,6 +583,7 @@ def tags(self) -> List[TagResponse]: class PipelineRunFilter(WorkspaceScopedTaggableFilter): """Model to enable advanced filtering of all Workspaces.""" + CUSTOM_SORTING_OPTIONS = ["stack"] FILTER_EXCLUDE_FIELDS: ClassVar[List[str]] = [ *WorkspaceScopedTaggableFilter.FILTER_EXCLUDE_FIELDS, "unlisted", @@ -583,7 +592,6 @@ class PipelineRunFilter(WorkspaceScopedTaggableFilter): "schedule_id", "stack_id", "template_id", - "user", "pipeline", "stack", "code_repository", @@ -606,16 +614,6 @@ class PipelineRunFilter(WorkspaceScopedTaggableFilter): description="Pipeline associated with the Pipeline Run", union_mode="left_to_right", ) - workspace_id: Optional[Union[UUID, str]] = Field( - default=None, - description="Workspace of the Pipeline Run", - union_mode="left_to_right", - ) - user_id: Optional[Union[UUID, str]] = Field( - default=None, - description="User that created the Pipeline Run", - union_mode="left_to_right", - ) stack_id: Optional[Union[UUID, str]] = Field( default=None, description="Stack used for the Pipeline Run", @@ -666,16 +664,12 @@ class PipelineRunFilter(WorkspaceScopedTaggableFilter): union_mode="left_to_right", ) unlisted: Optional[bool] = None - user: Optional[Union[UUID, str]] = Field( - default=None, - description="Name/ID of the user that created the run.", - ) run_metadata: Optional[Dict[str, str]] = Field( default=None, description="The run_metadata to filter the pipeline runs by.", ) # TODO: Remove once frontend is ready for it. This is replaced by the more - # generic `pipeline` filter below. + # generic `pipeline` filter below. pipeline_name: Optional[str] = Field( default=None, description="Name of the pipeline associated with the run", @@ -731,7 +725,6 @@ def get_custom_filters( StackComponentSchema, StackCompositionSchema, StackSchema, - UserSchema, ) if self.unlisted is not None: @@ -782,17 +775,6 @@ def get_custom_filters( ) custom_filters.append(run_template_filter) - if self.user: - user_filter = and_( - PipelineRunSchema.user_id == UserSchema.id, - self.generate_name_or_id_query_conditions( - value=self.user, - table=UserSchema, - additional_columns=["full_name"], - ), - ) - custom_filters.append(user_filter) - if self.pipeline: pipeline_filter = and_( PipelineRunSchema.pipeline_id == PipelineSchema.id, @@ -914,3 +896,69 @@ def get_custom_filters( custom_filters.append(additional_filter) return custom_filters + + def apply_sorting( + self, + query: AnyQuery, + table: Type["AnySchema"], + ) -> AnyQuery: + """Apply sorting to the query. + + Args: + query: The query to which to apply the sorting. + table: The query table. + + Returns: + The query with sorting applied. + """ + from sqlmodel import asc, desc + + from zenml.enums import SorterOps + from zenml.zen_stores.schemas import ( + ModelSchema, + ModelVersionSchema, + PipelineDeploymentSchema, + PipelineRunSchema, + PipelineSchema, + StackSchema, + ) + + sort_by, operand = self.sorting_params + + if sort_by == "pipeline": + query = query.join( + PipelineSchema, + PipelineRunSchema.pipeline_id == PipelineSchema.id, + ) + column = PipelineSchema.name + elif sort_by == "stack": + query = query.join( + PipelineDeploymentSchema, + PipelineRunSchema.deployment_id == PipelineDeploymentSchema.id, + ).join( + StackSchema, + PipelineDeploymentSchema.stack_id == StackSchema.id, + ) + column = StackSchema.name + elif sort_by == "model": + query = query.join( + ModelVersionSchema, + PipelineRunSchema.model_version_id == ModelVersionSchema.id, + ).join( + ModelSchema, + ModelVersionSchema.model_id == ModelSchema.id, + ) + column = ModelSchema.name + elif sort_by == "model_version": + query = query.join( + ModelVersionSchema, + PipelineRunSchema.model_version_id == ModelVersionSchema.id, + ) + column = ModelVersionSchema.name + else: + return super().apply_sorting(query=query, table=table) + + if operand == SorterOps.ASCENDING: + return query.order_by(asc(column)) + else: + return query.order_by(desc(column)) diff --git a/src/zenml/models/v2/core/run_template.py b/src/zenml/models/v2/core/run_template.py index b1aae8a325a..9a4fefc6671 100644 --- a/src/zenml/models/v2/core/run_template.py +++ b/src/zenml/models/v2/core/run_template.py @@ -310,16 +310,6 @@ class RunTemplateFilter(WorkspaceScopedTaggableFilter): default=None, description="Name of the run template.", ) - workspace_id: Optional[Union[UUID, str]] = Field( - default=None, - description="Workspace associated with the template.", - union_mode="left_to_right", - ) - user_id: Optional[Union[UUID, str]] = Field( - default=None, - description="User that created the template.", - union_mode="left_to_right", - ) pipeline_id: Optional[Union[UUID, str]] = Field( default=None, description="Pipeline associated with the template.", @@ -340,10 +330,6 @@ class RunTemplateFilter(WorkspaceScopedTaggableFilter): description="Code repository associated with the template.", union_mode="left_to_right", ) - user: Optional[Union[UUID, str]] = Field( - default=None, - description="Name/ID of the user that created the template.", - ) pipeline: Optional[Union[UUID, str]] = Field( default=None, description="Name/ID of the pipeline associated with the template.", @@ -371,7 +357,6 @@ def get_custom_filters( PipelineSchema, RunTemplateSchema, StackSchema, - UserSchema, ) if self.code_repository_id: @@ -409,17 +394,6 @@ def get_custom_filters( ) custom_filters.append(pipeline_filter) - if self.user: - user_filter = and_( - RunTemplateSchema.user_id == UserSchema.id, - self.generate_name_or_id_query_conditions( - value=self.user, - table=UserSchema, - additional_columns=["full_name"], - ), - ) - custom_filters.append(user_filter) - if self.pipeline: pipeline_filter = and_( RunTemplateSchema.source_deployment_id diff --git a/src/zenml/models/v2/core/schedule.py b/src/zenml/models/v2/core/schedule.py index af838f17ccc..0e7dc01c421 100644 --- a/src/zenml/models/v2/core/schedule.py +++ b/src/zenml/models/v2/core/schedule.py @@ -279,16 +279,6 @@ def pipeline_id(self) -> Optional[UUID]: class ScheduleFilter(WorkspaceScopedFilter): """Model to enable advanced filtering of all Users.""" - workspace_id: Optional[Union[UUID, str]] = Field( - default=None, - description="Workspace scope of the schedule.", - union_mode="left_to_right", - ) - user_id: Optional[Union[UUID, str]] = Field( - default=None, - description="User that created the schedule", - union_mode="left_to_right", - ) pipeline_id: Optional[Union[UUID, str]] = Field( default=None, description="Pipeline that the schedule is attached to.", diff --git a/src/zenml/models/v2/core/secret.py b/src/zenml/models/v2/core/secret.py index 79e50cd1841..3f29b57de22 100644 --- a/src/zenml/models/v2/core/secret.py +++ b/src/zenml/models/v2/core/secret.py @@ -15,7 +15,6 @@ from datetime import datetime from typing import Any, ClassVar, Dict, List, Optional, Union -from uuid import UUID from pydantic import Field, SecretStr @@ -253,25 +252,12 @@ class SecretFilter(WorkspaceScopedFilter): default=None, description="Name of the secret", ) - scope: Optional[Union[SecretScope, str]] = Field( default=None, description="Scope in which to filter secrets", union_mode="left_to_right", ) - workspace_id: Optional[Union[UUID, str]] = Field( - default=None, - description="Workspace of the Secret", - union_mode="left_to_right", - ) - - user_id: Optional[Union[UUID, str]] = Field( - default=None, - description="User that created the Secret", - union_mode="left_to_right", - ) - @staticmethod def _get_filtering_value(value: Optional[Any]) -> str: """Convert the value to a string that can be used for lexicographical filtering and sorting. diff --git a/src/zenml/models/v2/core/service.py b/src/zenml/models/v2/core/service.py index c3dcbd7cfc8..35ef181ce08 100644 --- a/src/zenml/models/v2/core/service.py +++ b/src/zenml/models/v2/core/service.py @@ -37,7 +37,6 @@ WorkspaceScopedResponseBody, WorkspaceScopedResponseMetadata, WorkspaceScopedResponseResources, - WorkspaceScopedTaggableFilter, ) from zenml.services.service_status import ServiceState from zenml.services.service_type import ServiceType @@ -376,16 +375,6 @@ class ServiceFilter(WorkspaceScopedFilter): description="Name of the service. Use this to filter services by " "their name.", ) - workspace_id: Optional[Union[UUID, str]] = Field( - default=None, - description="Workspace of the service", - union_mode="left_to_right", - ) - user_id: Optional[Union[UUID, str]] = Field( - default=None, - description="User of the service", - union_mode="left_to_right", - ) type: Optional[str] = Field( default=None, description="Type of the service. Filter services by their type.", @@ -457,9 +446,7 @@ def set_flavor(self, flavor: str) -> None: "config", ] CLI_EXCLUDE_FIELDS: ClassVar[List[str]] = [ - *WorkspaceScopedTaggableFilter.CLI_EXCLUDE_FIELDS, - "workspace_id", - "user_id", + *WorkspaceScopedFilter.CLI_EXCLUDE_FIELDS, "flavor", "type", "pipeline_step_name", diff --git a/src/zenml/models/v2/core/service_connector.py b/src/zenml/models/v2/core/service_connector.py index 806e6100072..8c71106ae22 100644 --- a/src/zenml/models/v2/core/service_connector.py +++ b/src/zenml/models/v2/core/service_connector.py @@ -801,7 +801,6 @@ class ServiceConnectorFilter(WorkspaceScopedFilter): default=None, description="The type to scope this query to.", ) - name: Optional[str] = Field( default=None, description="The name to filter by", @@ -810,16 +809,6 @@ class ServiceConnectorFilter(WorkspaceScopedFilter): default=None, description="The type of service connector to filter by", ) - workspace_id: Optional[Union[UUID, str]] = Field( - default=None, - description="Workspace to filter by", - union_mode="left_to_right", - ) - user_id: Optional[Union[UUID, str]] = Field( - default=None, - description="User to filter by", - union_mode="left_to_right", - ) auth_method: Optional[str] = Field( default=None, title="Filter by the authentication method configured for the " diff --git a/src/zenml/models/v2/core/stack.py b/src/zenml/models/v2/core/stack.py index 3d8ad20a2c1..bb446d8886d 100644 --- a/src/zenml/models/v2/core/stack.py +++ b/src/zenml/models/v2/core/stack.py @@ -323,7 +323,6 @@ class StackFilter(WorkspaceScopedFilter): FILTER_EXCLUDE_FIELDS: ClassVar[List[str]] = [ *WorkspaceScopedFilter.FILTER_EXCLUDE_FIELDS, "component_id", - "user", "component", ] @@ -334,25 +333,11 @@ class StackFilter(WorkspaceScopedFilter): description: Optional[str] = Field( default=None, description="Description of the stack" ) - workspace_id: Optional[Union[UUID, str]] = Field( - default=None, - description="Workspace of the stack", - union_mode="left_to_right", - ) - user_id: Optional[Union[UUID, str]] = Field( - default=None, - description="User of the stack", - union_mode="left_to_right", - ) component_id: Optional[Union[UUID, str]] = Field( default=None, description="Component in the stack", union_mode="left_to_right", ) - user: Optional[Union[UUID, str]] = Field( - default=None, - description="Name/ID of the user that created the stack.", - ) component: Optional[Union[UUID, str]] = Field( default=None, description="Name/ID of a component in the stack." ) @@ -369,7 +354,6 @@ def get_custom_filters(self) -> List["ColumnElement[bool]"]: StackComponentSchema, StackCompositionSchema, StackSchema, - UserSchema, ) if self.component_id: @@ -379,17 +363,6 @@ def get_custom_filters(self) -> List["ColumnElement[bool]"]: ) custom_filters.append(component_id_filter) - if self.user: - user_filter = and_( - StackSchema.user_id == UserSchema.id, - self.generate_name_or_id_query_conditions( - value=self.user, - table=UserSchema, - additional_columns=["full_name"], - ), - ) - custom_filters.append(user_filter) - if self.component: component_filter = and_( StackCompositionSchema.stack_id == StackSchema.id, diff --git a/src/zenml/models/v2/core/step_run.py b/src/zenml/models/v2/core/step_run.py index 2916d9236ce..3762a12a9cb 100644 --- a/src/zenml/models/v2/core/step_run.py +++ b/src/zenml/models/v2/core/step_run.py @@ -553,16 +553,6 @@ class StepRunFilter(WorkspaceScopedFilter): description="Original id for this step run", union_mode="left_to_right", ) - user_id: Optional[Union[UUID, str]] = Field( - default=None, - description="User that produced this step run", - union_mode="left_to_right", - ) - workspace_id: Optional[Union[UUID, str]] = Field( - default=None, - description="Workspace of this step run", - union_mode="left_to_right", - ) model_version_id: Optional[Union[UUID, str]] = Field( default=None, description="Model version associated with the step run.", @@ -576,7 +566,6 @@ class StepRunFilter(WorkspaceScopedFilter): default=None, description="The run_metadata to filter the step runs by.", ) - model_config = ConfigDict(protected_namespaces=()) def get_custom_filters( From ffdf975bd9dbe24bac4a7ad397d387a3f7117da3 Mon Sep 17 00:00:00 2001 From: Baris Can Durak Date: Fri, 29 Nov 2024 17:47:06 +0100 Subject: [PATCH 02/16] second checkpoint --- src/zenml/models/v2/base/scoped.py | 79 ++++++++++++++++++++++++ src/zenml/models/v2/core/pipeline.py | 50 ++++++++++++--- src/zenml/models/v2/core/pipeline_run.py | 10 ++- src/zenml/zen_stores/sql_zen_store.py | 61 +----------------- 4 files changed, 133 insertions(+), 67 deletions(-) diff --git a/src/zenml/models/v2/base/scoped.py b/src/zenml/models/v2/base/scoped.py index dacee2b7cd1..f5383c0159b 100644 --- a/src/zenml/models/v2/base/scoped.py +++ b/src/zenml/models/v2/base/scoped.py @@ -463,6 +463,11 @@ class WorkspaceScopedTaggableFilter(WorkspaceScopedFilter): "tag", ] + CUSTOM_SORTING_OPTIONS: ClassVar[List[str]] = [ + *WorkspaceScopedFilter.CUSTOM_SORTING_OPTIONS, + "tag", + ] + def apply_filter( self, query: AnyQuery, @@ -506,3 +511,77 @@ def get_custom_filters(self) -> List["ColumnElement[bool]"]: ) return custom_filters + + def apply_sorting( + self, + query: AnyQuery, + table: Type["AnySchema"], + ) -> AnyQuery: + """Apply sorting to the query. + + Args: + query: The query to which to apply the sorting. + table: The query table. + + Returns: + The query with sorting applied. + """ + from sqlmodel import asc, desc, func + + from zenml.enums import SorterOps, TaggableResourceTypes + from zenml.zen_stores.schemas import ( + ArtifactSchema, + ArtifactVersionSchema, + ModelSchema, + ModelVersionSchema, + PipelineRunSchema, + PipelineSchema, + RunTemplateSchema, + TagResourceSchema, + TagSchema, + ) + + sort_by, operand = self.sorting_params + + if sort_by == "tag": + resource_type_mapping = { + ArtifactSchema: TaggableResourceTypes.ARTIFACT, + ArtifactVersionSchema: TaggableResourceTypes.ARTIFACT_VERSION, + ModelSchema: TaggableResourceTypes.MODEL, + ModelVersionSchema: TaggableResourceTypes.MODEL_VERSION, + PipelineSchema: TaggableResourceTypes.PIPELINE, + PipelineRunSchema: TaggableResourceTypes.PIPELINE_RUN, + RunTemplateSchema: TaggableResourceTypes.RUN_TEMPLATE, + } + + query = ( + query.outerjoin( + TagResourceSchema, + (table.id == TagResourceSchema.resource_id) + & ( + TagResourceSchema.resource_type + == resource_type_mapping[table] + ), + ) + .outerjoin(TagSchema, TagResourceSchema.tag_id == TagSchema.id) + .group_by(table.id) + ) + + if operand == SorterOps.ASCENDING: + return query.order_by( + asc( + func.group_concat(TagSchema.name, ",").label( + "tags_list" + ) + ) + ) + else: + return query.order_by( + desc( + func.group_concat(TagSchema.name, ",").label( + "tags_list" + ) + ) + ) + + return super().apply_sorting(query=query, table=table) diff --git a/src/zenml/models/v2/core/pipeline.py b/src/zenml/models/v2/core/pipeline.py index 457b4989dd3..960106489f1 100644 --- a/src/zenml/models/v2/core/pipeline.py +++ b/src/zenml/models/v2/core/pipeline.py @@ -255,11 +255,15 @@ def tags(self) -> List[TagResponse]: class PipelineFilter(WorkspaceScopedTaggableFilter): """Pipeline filter model.""" - CUSTOM_SORTING_OPTIONS = [SORT_PIPELINES_BY_LATEST_RUN_KEY] + CUSTOM_SORTING_OPTIONS: ClassVar[List[str]] = [ + *WorkspaceScopedTaggableFilter.CUSTOM_SORTING_OPTIONS, + SORT_PIPELINES_BY_LATEST_RUN_KEY, + ] FILTER_EXCLUDE_FIELDS: ClassVar[List[str]] = [ *WorkspaceScopedTaggableFilter.FILTER_EXCLUDE_FIELDS, "latest_run_status", ] + name: Optional[str] = Field( default=None, description="Name of the Pipeline", @@ -338,12 +342,44 @@ def apply_sorting( Returns: The query with sorting applied. """ - column, _ = self.sorting_params + from sqlmodel import asc, case, col, desc, func, select + + from zenml.enums import SorterOps + from zenml.zen_stores.schemas import PipelineRunSchema, PipelineSchema + + sort_by, operand = self.sorting_params + + if sort_by == SORT_PIPELINES_BY_LATEST_RUN_KEY: + # Subquery to find the latest run per pipeline + latest_run_subquery = ( + select( + PipelineRunSchema.pipeline_id, + case( + ( + func.max(PipelineRunSchema.created).is_(None), + PipelineSchema.created, + ), + else_=func.max(PipelineRunSchema.created), + ).label("latest_run"), + ) + .group_by(col(PipelineRunSchema.pipeline_id)) + .subquery() + ) + + # Join the subquery with the pipelines + query = query.outerjoin( + latest_run_subquery, + PipelineSchema.id == latest_run_subquery.c.pipeline_id, + ) + + if operand == SorterOps.ASCENDING: + return query.order_by( + asc(latest_run_subquery.c.latest_run) + ).order_by(col(PipelineSchema.id)) + else: + return query.order_by( + desc(latest_run_subquery.c.latest_run) + ).order_by(col(PipelineSchema.id)) - if column == SORT_PIPELINES_BY_LATEST_RUN_KEY: - # If sorting by the latest run, the sorting is already done in the - # base query in `SqlZenStore.list_pipelines(...)` and we don't need - # to do anything here - return query else: return super().apply_sorting(query=query, table=table) diff --git a/src/zenml/models/v2/core/pipeline_run.py b/src/zenml/models/v2/core/pipeline_run.py index b4e6723a7d9..869f4839728 100644 --- a/src/zenml/models/v2/core/pipeline_run.py +++ b/src/zenml/models/v2/core/pipeline_run.py @@ -583,7 +583,15 @@ def tags(self) -> List[TagResponse]: class PipelineRunFilter(WorkspaceScopedTaggableFilter): """Model to enable advanced filtering of all Workspaces.""" - CUSTOM_SORTING_OPTIONS = ["stack"] + CUSTOM_SORTING_OPTIONS: ClassVar[List[str]] = [ + *WorkspaceScopedTaggableFilter.CUSTOM_SORTING_OPTIONS, + "tag", + "stack", + "pipeline", + "model", + "model_version", + ] + FILTER_EXCLUDE_FIELDS: ClassVar[List[str]] = [ *WorkspaceScopedTaggableFilter.FILTER_EXCLUDE_FIELDS, "unlisted", diff --git a/src/zenml/zen_stores/sql_zen_store.py b/src/zenml/zen_stores/sql_zen_store.py index 98ea30fb8e9..f6a06e77371 100644 --- a/src/zenml/zen_stores/sql_zen_store.py +++ b/src/zenml/zen_stores/sql_zen_store.py @@ -54,7 +54,7 @@ field_validator, model_validator, ) -from sqlalchemy import asc, case, desc, func +from sqlalchemy import func from sqlalchemy.engine import URL, Engine, make_url from sqlalchemy.exc import ( ArgumentError, @@ -99,7 +99,6 @@ ENV_ZENML_SERVER, FINISHED_ONBOARDING_SURVEY_KEY, MAX_RETRIES_FOR_VERSIONED_ENTITY_CREATION, - SORT_PIPELINES_BY_LATEST_RUN_KEY, SQL_STORE_BACKUP_DIRECTORY_NAME, TEXT_FIELD_MAX_LENGTH, handle_bool_env_var, @@ -116,7 +115,6 @@ OnboardingStep, SecretScope, SecretsStoreType, - SorterOps, StackComponentType, StackDeploymentProvider, StepRunInputArtifactType, @@ -4328,69 +4326,14 @@ def list_pipelines( Returns: A list of all pipelines matching the filter criteria. """ - query: Union[Select[Any], SelectOfScalar[Any]] = select(PipelineSchema) - _custom_conversion: Optional[Callable[[Any], PipelineResponse]] = None - - column, operand = pipeline_filter_model.sorting_params - if column == SORT_PIPELINES_BY_LATEST_RUN_KEY: - with Session(self.engine) as session: - max_date_subquery = ( - # If no run exists for the pipeline yet, we use the pipeline - # creation date as a fallback, otherwise newly created - # pipeline would always be at the top/bottom - select( - PipelineSchema.id, - case( - ( - func.max(PipelineRunSchema.created).is_(None), - PipelineSchema.created, - ), - else_=func.max(PipelineRunSchema.created), - ).label("run_or_created"), - ) - .outerjoin( - PipelineRunSchema, - PipelineSchema.id == PipelineRunSchema.pipeline_id, # type: ignore[arg-type] - ) - .group_by(col(PipelineSchema.id)) - .subquery() - ) - - if operand == SorterOps.DESCENDING: - sort_clause = desc - else: - sort_clause = asc - - query = ( - # We need to include the subquery in the select here to - # make this query work with the distinct statement. This - # result will be removed in the custom conversion function - # applied later - select(PipelineSchema, max_date_subquery.c.run_or_created) - .where(PipelineSchema.id == max_date_subquery.c.id) - .order_by(sort_clause(max_date_subquery.c.run_or_created)) - # We always add the `id` column as a tiebreaker to ensure a - # stable, repeatable order of items, otherwise subsequent - # pages might contain the same items. - .order_by(col(PipelineSchema.id)) - ) - - def _custom_conversion(row: Any) -> PipelineResponse: - return cast( - PipelineResponse, - row[0].to_model( - include_metadata=hydrate, include_resources=True - ), - ) - with Session(self.engine) as session: + query = select(PipelineSchema) return self.filter_and_paginate( session=session, query=query, table=PipelineSchema, filter_model=pipeline_filter_model, hydrate=hydrate, - custom_schema_to_model_conversion=_custom_conversion, ) def count_pipelines(self, filter_model: Optional[PipelineFilter]) -> int: From 155bbe2d11d9f8f72bdca50b8a603609b11590c8 Mon Sep 17 00:00:00 2001 From: Baris Can Durak Date: Fri, 29 Nov 2024 18:07:36 +0100 Subject: [PATCH 03/16] formatting and linting --- src/zenml/models/v2/base/scoped.py | 25 ++++++++++++++++-------- src/zenml/models/v2/core/pipeline.py | 5 +++-- src/zenml/models/v2/core/pipeline_run.py | 6 ++++-- 3 files changed, 24 insertions(+), 12 deletions(-) diff --git a/src/zenml/models/v2/base/scoped.py b/src/zenml/models/v2/base/scoped.py index f5383c0159b..75e3934b533 100644 --- a/src/zenml/models/v2/base/scoped.py +++ b/src/zenml/models/v2/base/scoped.py @@ -241,12 +241,16 @@ def apply_sorting( if sort_by == "user": column = UserSchema.name - query = query.join(UserSchema, table.user_id == UserSchema.id) + query = query.join( + UserSchema, getattr(table, "user_id") == UserSchema.id + ) if operand == SorterOps.ASCENDING: - return query.order_by(asc(column)) + query = query.order_by(asc(column)) else: - return query.order_by(desc(column)) + query = query.order_by(desc(column)) + + return query return super().apply_sorting(query=query, table=table) @@ -440,13 +444,16 @@ def apply_sorting( column = WorkspaceSchema.name query = query.join( - WorkspaceSchema, table.workspace_id == WorkspaceSchema.id + WorkspaceSchema, + getattr(table, "workspace_id") == WorkspaceSchema.id, ) if operand == SorterOps.ASCENDING: - return query.order_by(asc(column)) + query = query.order_by(asc(column)) else: - return query.order_by(desc(column)) + query = query.order_by(desc(column)) + + return query return super().apply_sorting(query=query, table=table) @@ -568,7 +575,7 @@ def apply_sorting( ) if operand == SorterOps.ASCENDING: - return query.order_by( + query = query.order_by( asc( func.group_concat(TagSchema.name, ",").label( "tags_list" @@ -576,7 +583,7 @@ def apply_sorting( ) ) else: - return query.order_by( + query = query.order_by( desc( func.group_concat(TagSchema.name, ",").label( "tags_list" @@ -584,4 +591,6 @@ def apply_sorting( ) ) + return query + return super().apply_sorting(query=query, table=table) diff --git a/src/zenml/models/v2/core/pipeline.py b/src/zenml/models/v2/core/pipeline.py index 960106489f1..03a81fbb23c 100644 --- a/src/zenml/models/v2/core/pipeline.py +++ b/src/zenml/models/v2/core/pipeline.py @@ -373,13 +373,14 @@ def apply_sorting( ) if operand == SorterOps.ASCENDING: - return query.order_by( + query = query.order_by( asc(latest_run_subquery.c.latest_run) ).order_by(col(PipelineSchema.id)) else: - return query.order_by( + query = query.order_by( desc(latest_run_subquery.c.latest_run) ).order_by(col(PipelineSchema.id)) + return query else: return super().apply_sorting(query=query, table=table) diff --git a/src/zenml/models/v2/core/pipeline_run.py b/src/zenml/models/v2/core/pipeline_run.py index 869f4839728..348ab39e493 100644 --- a/src/zenml/models/v2/core/pipeline_run.py +++ b/src/zenml/models/v2/core/pipeline_run.py @@ -967,6 +967,8 @@ def apply_sorting( return super().apply_sorting(query=query, table=table) if operand == SorterOps.ASCENDING: - return query.order_by(asc(column)) + query = query.order_by(asc(column)) else: - return query.order_by(desc(column)) + query = query.order_by(desc(column)) + + return query From fcf898ae390748caf5a6c673b5d0bec0c144d6d1 Mon Sep 17 00:00:00 2001 From: GitHub Actions Date: Fri, 29 Nov 2024 17:10:34 +0000 Subject: [PATCH 04/16] Auto-update of LLM Finetuning template --- examples/llm_finetuning/.copier-answers.yml | 2 +- examples/llm_finetuning/steps/log_metadata.py | 9 +++++++-- examples/llm_finetuning/steps/prepare_datasets.py | 9 +++++---- 3 files changed, 13 insertions(+), 7 deletions(-) diff --git a/examples/llm_finetuning/.copier-answers.yml b/examples/llm_finetuning/.copier-answers.yml index 2c547f98d61..7deecebb1d2 100644 --- a/examples/llm_finetuning/.copier-answers.yml +++ b/examples/llm_finetuning/.copier-answers.yml @@ -1,5 +1,5 @@ # Changes here will be overwritten by Copier -_commit: 2024.11.08 +_commit: 2024.11.08-2-gece1d46 _src_path: gh:zenml-io/template-llm-finetuning bf16: true cuda_version: cuda11.8 diff --git a/examples/llm_finetuning/steps/log_metadata.py b/examples/llm_finetuning/steps/log_metadata.py index 645f98cc8ea..90109fdf3c4 100644 --- a/examples/llm_finetuning/steps/log_metadata.py +++ b/examples/llm_finetuning/steps/log_metadata.py @@ -17,7 +17,7 @@ from typing import Any, Dict -from zenml import get_step_context, log_model_metadata, step +from zenml import get_step_context, log_metadata, step @step(enable_cache=False) @@ -39,4 +39,9 @@ def log_metadata_from_step_artifact( metadata = {artifact_name: metadata_dict} - log_model_metadata(metadata) + if context.model: + log_metadata( + metadata=metadata, + model_name=context.model.name, + model_version=context.model.version, + ) diff --git a/examples/llm_finetuning/steps/prepare_datasets.py b/examples/llm_finetuning/steps/prepare_datasets.py index fe98126369d..b9cc13c2261 100644 --- a/examples/llm_finetuning/steps/prepare_datasets.py +++ b/examples/llm_finetuning/steps/prepare_datasets.py @@ -22,7 +22,7 @@ from typing_extensions import Annotated from utils.tokenizer import generate_and_tokenize_prompt, load_tokenizer -from zenml import log_model_metadata, step +from zenml import log_metadata, step from zenml.materializers import BuiltInMaterializer from zenml.utils.cuda_utils import cleanup_gpu_memory @@ -49,11 +49,12 @@ def prepare_data( cleanup_gpu_memory(force=True) - log_model_metadata( - { + log_metadata( + metadata={ "system_prompt": system_prompt, "base_model_id": base_model_id, - } + }, + infer_model=True, ) tokenizer = load_tokenizer(base_model_id, False, use_fast) From f10b93de36a6ea6dda5fa08599732961fd835159 Mon Sep 17 00:00:00 2001 From: GitHub Actions Date: Fri, 29 Nov 2024 17:11:04 +0000 Subject: [PATCH 05/16] Auto-update of Starter template --- examples/mlops_starter/.copier-answers.yml | 2 +- .../mlops_starter/steps/data_preprocessor.py | 7 ++++--- .../mlops_starter/steps/model_evaluator.py | 19 ++++++++++++------- 3 files changed, 17 insertions(+), 11 deletions(-) diff --git a/examples/mlops_starter/.copier-answers.yml b/examples/mlops_starter/.copier-answers.yml index fd6b937c7c9..364bccaa9d0 100644 --- a/examples/mlops_starter/.copier-answers.yml +++ b/examples/mlops_starter/.copier-answers.yml @@ -1,5 +1,5 @@ # Changes here will be overwritten by Copier -_commit: 2024.10.30 +_commit: 2024.10.30-7-gb60e441 _src_path: gh:zenml-io/template-starter email: info@zenml.io full_name: ZenML GmbH diff --git a/examples/mlops_starter/steps/data_preprocessor.py b/examples/mlops_starter/steps/data_preprocessor.py index 0cf9d3ab521..f94d1e85f6d 100644 --- a/examples/mlops_starter/steps/data_preprocessor.py +++ b/examples/mlops_starter/steps/data_preprocessor.py @@ -23,7 +23,7 @@ from typing_extensions import Annotated from utils.preprocess import ColumnsDropper, DataFrameCaster, NADropper -from zenml import log_artifact_metadata, step +from zenml import log_metadata, step @step @@ -87,8 +87,9 @@ def data_preprocessor( dataset_tst = preprocess_pipeline.transform(dataset_tst) # Log metadata so we can load it in the inference pipeline - log_artifact_metadata( - artifact_name="preprocess_pipeline", + log_metadata( metadata={"random_state": random_state, "target": target}, + artifact_name="preprocess_pipeline", + infer_artifact=True, ) return dataset_trn, dataset_tst, preprocess_pipeline diff --git a/examples/mlops_starter/steps/model_evaluator.py b/examples/mlops_starter/steps/model_evaluator.py index 2a9b6ee9e75..c63c53109f4 100644 --- a/examples/mlops_starter/steps/model_evaluator.py +++ b/examples/mlops_starter/steps/model_evaluator.py @@ -20,7 +20,8 @@ import pandas as pd from sklearn.base import ClassifierMixin -from zenml import log_artifact_metadata, step +from zenml import log_metadata, step +from zenml.client import Client from zenml.logger import get_logger logger = get_logger(__name__) @@ -79,27 +80,31 @@ def model_evaluator( dataset_tst.drop(columns=[target]), dataset_tst[target], ) - logger.info(f"Train accuracy={trn_acc*100:.2f}%") - logger.info(f"Test accuracy={tst_acc*100:.2f}%") + logger.info(f"Train accuracy={trn_acc * 100:.2f}%") + logger.info(f"Test accuracy={tst_acc * 100:.2f}%") messages = [] if trn_acc < min_train_accuracy: messages.append( - f"Train accuracy {trn_acc*100:.2f}% is below {min_train_accuracy*100:.2f}% !" + f"Train accuracy {trn_acc * 100:.2f}% is below {min_train_accuracy * 100:.2f}% !" ) if tst_acc < min_test_accuracy: messages.append( - f"Test accuracy {tst_acc*100:.2f}% is below {min_test_accuracy*100:.2f}% !" + f"Test accuracy {tst_acc * 100:.2f}% is below {min_test_accuracy * 100:.2f}% !" ) else: for message in messages: logger.warning(message) - log_artifact_metadata( + client = Client() + latest_classifier = client.get_artifact_version("sklearn_classifier") + + log_metadata( metadata={ "train_accuracy": float(trn_acc), "test_accuracy": float(tst_acc), }, - artifact_name="sklearn_classifier", + artifact_version_id=latest_classifier.id, ) + return float(tst_acc) From 799eb5b77d80a689f2e11abae9577d8faa6ee4b9 Mon Sep 17 00:00:00 2001 From: GitHub Actions Date: Fri, 29 Nov 2024 17:17:38 +0000 Subject: [PATCH 06/16] Auto-update of E2E template --- examples/e2e/.copier-answers.yml | 2 +- examples/e2e/steps/hp_tuning/hp_tuning_single_search.py | 5 +++-- 2 files changed, 4 insertions(+), 3 deletions(-) diff --git a/examples/e2e/.copier-answers.yml b/examples/e2e/.copier-answers.yml index 0a2f40d5a92..e6fb1292beb 100644 --- a/examples/e2e/.copier-answers.yml +++ b/examples/e2e/.copier-answers.yml @@ -1,5 +1,5 @@ # Changes here will be overwritten by Copier -_commit: 2024.11.20 +_commit: 2024.11.20-2-g760142f _src_path: gh:zenml-io/template-e2e-batch data_quality_checks: true email: info@zenml.io diff --git a/examples/e2e/steps/hp_tuning/hp_tuning_single_search.py b/examples/e2e/steps/hp_tuning/hp_tuning_single_search.py index f2f39969a6f..7b55eebae7a 100644 --- a/examples/e2e/steps/hp_tuning/hp_tuning_single_search.py +++ b/examples/e2e/steps/hp_tuning/hp_tuning_single_search.py @@ -25,7 +25,7 @@ from typing_extensions import Annotated from utils import get_model_from_config -from zenml import log_artifact_metadata, step +from zenml import log_metadata, step from zenml.logger import get_logger logger = get_logger(__name__) @@ -95,9 +95,10 @@ def hp_tuning_single_search( y_pred = cv.predict(X_tst) score = accuracy_score(y_tst, y_pred) # log score along with output artifact as metadata - log_artifact_metadata( + log_metadata( metadata={"metric": float(score)}, artifact_name="hp_result", + infer_artifact=True, ) ### YOUR CODE ENDS HERE ### return cv.best_estimator_ From 2cf9e817dbb7b382967a9e43761cdad270890855 Mon Sep 17 00:00:00 2001 From: GitHub Actions Date: Fri, 29 Nov 2024 17:20:31 +0000 Subject: [PATCH 07/16] Auto-update of NLP template --- examples/e2e_nlp/.copier-answers.yml | 2 +- examples/e2e_nlp/steps/training/model_trainer.py | 5 +++-- 2 files changed, 4 insertions(+), 3 deletions(-) diff --git a/examples/e2e_nlp/.copier-answers.yml b/examples/e2e_nlp/.copier-answers.yml index e13858e7da1..274927e3ce5 100644 --- a/examples/e2e_nlp/.copier-answers.yml +++ b/examples/e2e_nlp/.copier-answers.yml @@ -1,5 +1,5 @@ # Changes here will be overwritten by Copier -_commit: 2024.10.30 +_commit: 2024.10.30-2-g1ae14e3 _src_path: gh:zenml-io/template-nlp accelerator: cpu cloud_of_choice: aws diff --git a/examples/e2e_nlp/steps/training/model_trainer.py b/examples/e2e_nlp/steps/training/model_trainer.py index edb9ab23ba5..0a3de574c09 100644 --- a/examples/e2e_nlp/steps/training/model_trainer.py +++ b/examples/e2e_nlp/steps/training/model_trainer.py @@ -30,7 +30,7 @@ from typing_extensions import Annotated from utils.misc import compute_metrics -from zenml import ArtifactConfig, log_artifact_metadata, step +from zenml import ArtifactConfig, log_metadata, step from zenml.client import Client from zenml.integrations.mlflow.experiment_trackers import ( MLFlowExperimentTracker, @@ -157,9 +157,10 @@ def model_trainer( eval_results = trainer.evaluate(metric_key_prefix="") # Log the evaluation results in model control plane - log_artifact_metadata( + log_metadata( metadata={"metrics": eval_results}, artifact_name="model", + infer_artifact=True, ) ### YOUR CODE ENDS HERE ### From 0599bf3689840ea92aa84c870ff0859e98f12b1d Mon Sep 17 00:00:00 2001 From: GitHub Actions Date: Mon, 2 Dec 2024 12:51:22 +0000 Subject: [PATCH 08/16] Auto-update of LLM Finetuning template --- examples/llm_finetuning/.copier-answers.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/llm_finetuning/.copier-answers.yml b/examples/llm_finetuning/.copier-answers.yml index 7deecebb1d2..47bfa4cf2af 100644 --- a/examples/llm_finetuning/.copier-answers.yml +++ b/examples/llm_finetuning/.copier-answers.yml @@ -1,5 +1,5 @@ # Changes here will be overwritten by Copier -_commit: 2024.11.08-2-gece1d46 +_commit: 2024.11.28 _src_path: gh:zenml-io/template-llm-finetuning bf16: true cuda_version: cuda11.8 From 59d515bd95be6107ea39d91df0a5b1dbfc9d7ecb Mon Sep 17 00:00:00 2001 From: GitHub Actions Date: Mon, 2 Dec 2024 12:51:51 +0000 Subject: [PATCH 09/16] Auto-update of Starter template --- examples/mlops_starter/.copier-answers.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/mlops_starter/.copier-answers.yml b/examples/mlops_starter/.copier-answers.yml index 364bccaa9d0..ec87b32240d 100644 --- a/examples/mlops_starter/.copier-answers.yml +++ b/examples/mlops_starter/.copier-answers.yml @@ -1,5 +1,5 @@ # Changes here will be overwritten by Copier -_commit: 2024.10.30-7-gb60e441 +_commit: 2024.11.28 _src_path: gh:zenml-io/template-starter email: info@zenml.io full_name: ZenML GmbH From 8b94066e575948f15d4cfe22b235ec5d2d8a3037 Mon Sep 17 00:00:00 2001 From: GitHub Actions Date: Mon, 2 Dec 2024 12:58:32 +0000 Subject: [PATCH 10/16] Auto-update of E2E template --- examples/e2e/.copier-answers.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/e2e/.copier-answers.yml b/examples/e2e/.copier-answers.yml index e6fb1292beb..a9c5acbf8c3 100644 --- a/examples/e2e/.copier-answers.yml +++ b/examples/e2e/.copier-answers.yml @@ -1,5 +1,5 @@ # Changes here will be overwritten by Copier -_commit: 2024.11.20-2-g760142f +_commit: 2024.11.28 _src_path: gh:zenml-io/template-e2e-batch data_quality_checks: true email: info@zenml.io From b70a5637fa3aebf163311cb831d39cf69d457ece Mon Sep 17 00:00:00 2001 From: GitHub Actions Date: Mon, 2 Dec 2024 13:01:43 +0000 Subject: [PATCH 11/16] Auto-update of NLP template --- examples/e2e_nlp/.copier-answers.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/e2e_nlp/.copier-answers.yml b/examples/e2e_nlp/.copier-answers.yml index 274927e3ce5..b12ebdd786b 100644 --- a/examples/e2e_nlp/.copier-answers.yml +++ b/examples/e2e_nlp/.copier-answers.yml @@ -1,5 +1,5 @@ # Changes here will be overwritten by Copier -_commit: 2024.10.30-2-g1ae14e3 +_commit: 2024.11.28 _src_path: gh:zenml-io/template-nlp accelerator: cpu cloud_of_choice: aws From 01d752c3272305b8135f0778cfeebe8bfe202635 Mon Sep 17 00:00:00 2001 From: Baris Can Durak Date: Mon, 2 Dec 2024 15:30:24 +0100 Subject: [PATCH 12/16] formatting --- .../service_connectors/service_connector_utils.py | 12 +++--------- 1 file changed, 3 insertions(+), 9 deletions(-) diff --git a/src/zenml/service_connectors/service_connector_utils.py b/src/zenml/service_connectors/service_connector_utils.py index d97f097faf5..a20f4847ee3 100644 --- a/src/zenml/service_connectors/service_connector_utils.py +++ b/src/zenml/service_connectors/service_connector_utils.py @@ -60,15 +60,9 @@ def _raise_specific_cloud_exception_if_needed( orchestrators: List[ResourcesInfo], container_registries: List[ResourcesInfo], ) -> None: - AWS_DOCS = ( - "https://docs.zenml.io/how-to/infrastructure-deployment/auth-management/aws-service-connector" - ) - GCP_DOCS = ( - "https://docs.zenml.io/how-to/infrastructure-deployment/auth-management/gcp-service-connector" - ) - AZURE_DOCS = ( - "https://docs.zenml.io/how-to/infrastructure-deployment/auth-management/azure-service-connector" - ) + AWS_DOCS = "https://docs.zenml.io/how-to/infrastructure-deployment/auth-management/aws-service-connector" + GCP_DOCS = "https://docs.zenml.io/how-to/infrastructure-deployment/auth-management/gcp-service-connector" + AZURE_DOCS = "https://docs.zenml.io/how-to/infrastructure-deployment/auth-management/azure-service-connector" if not artifact_stores: error_msg = ( From e60bd688a664d0edb927e47a78762bd2ec8bc330 Mon Sep 17 00:00:00 2001 From: Baris Can Durak Date: Mon, 2 Dec 2024 17:21:10 +0100 Subject: [PATCH 13/16] fixing the failing cli integration tests --- src/zenml/client.py | 46 ++++++++++++++++++++++++++++-- src/zenml/models/v2/base/scoped.py | 1 + 2 files changed, 45 insertions(+), 2 deletions(-) diff --git a/src/zenml/client.py b/src/zenml/client.py index 995f2d8bdb3..1341cc3c764 100644 --- a/src/zenml/client.py +++ b/src/zenml/client.py @@ -1702,6 +1702,7 @@ def list_services( updated: Optional[datetime] = None, type: Optional[str] = None, flavor: Optional[str] = None, + user: Optional[Union[UUID, str]] = None, workspace_id: Optional[Union[str, UUID]] = None, user_id: Optional[Union[str, UUID]] = None, hydrate: bool = False, @@ -1727,6 +1728,7 @@ def list_services( flavor: Use the service flavor for filtering workspace_id: The id of the workspace to filter by. user_id: The id of the user to filter by. + user: Filter by user name/ID. hydrate: Flag deciding whether to hydrate the output model(s) by including metadata fields in the response. running: Use the running status for filtering @@ -1753,6 +1755,7 @@ def list_services( flavor=flavor, workspace_id=workspace_id, user_id=user_id, + user=user, running=running, name=service_name, pipeline_name=pipeline_name, @@ -2249,6 +2252,7 @@ def list_flavors( type: Optional[str] = None, integration: Optional[str] = None, user_id: Optional[Union[str, UUID]] = None, + user: Optional[Union[UUID, str]] = None, hydrate: bool = False, ) -> Page[FlavorResponse]: """Fetches all the flavor models. @@ -2262,6 +2266,7 @@ def list_flavors( created: Use to flavors by time of creation updated: Use the last updated date for filtering user_id: The id of the user to filter by. + user: Filter by user name/ID. name: The name of the flavor to filter by. type: The type of the flavor to filter by. integration: The integration of the flavor to filter by. @@ -2277,6 +2282,7 @@ def list_flavors( sort_by=sort_by, logical_operator=logical_operator, user_id=user_id, + user=user, name=name, type=type, integration=integration, @@ -2661,6 +2667,7 @@ def list_builds( updated: Optional[Union[datetime, str]] = None, workspace_id: Optional[Union[str, UUID]] = None, user_id: Optional[Union[str, UUID]] = None, + user: Optional[Union[UUID, str]] = None, pipeline_id: Optional[Union[str, UUID]] = None, stack_id: Optional[Union[str, UUID]] = None, is_local: Optional[bool] = None, @@ -2682,6 +2689,7 @@ def list_builds( updated: Use the last updated date for filtering workspace_id: The id of the workspace to filter by. user_id: The id of the user to filter by. + user: Filter by user name/ID. pipeline_id: The id of the pipeline to filter by. stack_id: The id of the stack to filter by. is_local: Use to filter local builds. @@ -2705,6 +2713,7 @@ def list_builds( updated=updated, workspace_id=workspace_id, user_id=user_id, + user=user, pipeline_id=pipeline_id, stack_id=stack_id, is_local=is_local, @@ -2771,7 +2780,7 @@ def get_event_source( allow_name_prefix_match: bool = True, hydrate: bool = True, ) -> EventSourceResponse: - """Get a event source by name, ID or prefix. + """Get an event source by name, ID or prefix. Args: name_id_or_prefix: The name, ID or prefix of the stack. @@ -2804,6 +2813,7 @@ def list_event_sources( event_source_type: Optional[str] = None, workspace_id: Optional[Union[str, UUID]] = None, user_id: Optional[Union[str, UUID]] = None, + user: Optional[Union[UUID, str]] = None, hydrate: bool = False, ) -> Page[EventSourceResponse]: """Lists all event_sources. @@ -2818,6 +2828,7 @@ def list_event_sources( updated: Use the last updated date for filtering workspace_id: The id of the workspace to filter by. user_id: The id of the user to filter by. + user: Filter by user name/ID. name: The name of the event_source to filter by. flavor: The flavor of the event_source to filter by. event_source_type: The subtype of the event_source to filter by. @@ -2834,6 +2845,7 @@ def list_event_sources( logical_operator=logical_operator, workspace_id=workspace_id, user_id=user_id, + user=user, name=name, flavor=flavor, plugin_subtype=event_source_type, @@ -3001,6 +3013,7 @@ def list_actions( action_type: Optional[str] = None, workspace_id: Optional[Union[str, UUID]] = None, user_id: Optional[Union[str, UUID]] = None, + user: Optional[Union[UUID, str]] = None, hydrate: bool = False, ) -> Page[ActionResponse]: """List actions. @@ -3015,6 +3028,7 @@ def list_actions( updated: Use the last updated date for filtering workspace_id: The id of the workspace to filter by. user_id: The id of the user to filter by. + user: Filter by user name/ID. name: The name of the action to filter by. flavor: The flavor of the action to filter by. action_type: The type of the action to filter by. @@ -3031,6 +3045,7 @@ def list_actions( logical_operator=logical_operator, workspace_id=workspace_id, user_id=user_id, + user=user, name=name, id=id, flavor=flavor, @@ -3179,6 +3194,7 @@ def list_triggers( action_subtype: Optional[str] = None, workspace_id: Optional[Union[str, UUID]] = None, user_id: Optional[Union[str, UUID]] = None, + user: Optional[Union[UUID, str]] = None, hydrate: bool = False, ) -> Page[TriggerResponse]: """Lists all triggers. @@ -3193,6 +3209,7 @@ def list_triggers( updated: Use the last updated date for filtering workspace_id: The id of the workspace to filter by. user_id: The id of the user to filter by. + user: Filter by user name/ID. name: The name of the trigger to filter by. event_source_id: The event source associated with the trigger. action_id: The action associated with the trigger. @@ -3215,6 +3232,7 @@ def list_triggers( logical_operator=logical_operator, workspace_id=workspace_id, user_id=user_id, + user=user, name=name, event_source_id=event_source_id, action_id=action_id, @@ -3365,6 +3383,7 @@ def list_deployments( updated: Optional[Union[datetime, str]] = None, workspace_id: Optional[Union[str, UUID]] = None, user_id: Optional[Union[str, UUID]] = None, + user: Optional[Union[UUID, str]] = None, pipeline_id: Optional[Union[str, UUID]] = None, stack_id: Optional[Union[str, UUID]] = None, build_id: Optional[Union[str, UUID]] = None, @@ -3383,6 +3402,7 @@ def list_deployments( updated: Use the last updated date for filtering workspace_id: The id of the workspace to filter by. user_id: The id of the user to filter by. + user: Filter by user name/ID. pipeline_id: The id of the pipeline to filter by. stack_id: The id of the stack to filter by. build_id: The id of the build to filter by. @@ -3403,6 +3423,7 @@ def list_deployments( updated=updated, workspace_id=workspace_id, user_id=user_id, + user=user, pipeline_id=pipeline_id, stack_id=stack_id, build_id=build_id, @@ -3650,6 +3671,7 @@ def list_schedules( name: Optional[str] = None, workspace_id: Optional[Union[str, UUID]] = None, user_id: Optional[Union[str, UUID]] = None, + user: Optional[Union[UUID, str]] = None, pipeline_id: Optional[Union[str, UUID]] = None, orchestrator_id: Optional[Union[str, UUID]] = None, active: Optional[Union[str, bool]] = None, @@ -3674,6 +3696,7 @@ def list_schedules( name: The name of the stack to filter by. workspace_id: The id of the workspace to filter by. user_id: The id of the user to filter by. + user: Filter by user name/ID. pipeline_id: The id of the pipeline to filter by. orchestrator_id: The id of the orchestrator to filter by. active: Use to filter by active status. @@ -3700,6 +3723,7 @@ def list_schedules( name=name, workspace_id=workspace_id, user_id=user_id, + user=user, pipeline_id=pipeline_id, orchestrator_id=orchestrator_id, active=active, @@ -3940,6 +3964,7 @@ def list_run_steps( original_step_run_id: Optional[Union[str, UUID]] = None, workspace_id: Optional[Union[str, UUID]] = None, user_id: Optional[Union[str, UUID]] = None, + user: Optional[Union[UUID, str]] = None, model_version_id: Optional[Union[str, UUID]] = None, model: Optional[Union[UUID, str]] = None, hydrate: bool = False, @@ -3958,6 +3983,7 @@ def list_run_steps( end_time: Use to filter by the time when the step finished running workspace_id: The id of the workspace to filter by. user_id: The id of the user to filter by. + user: Filter by user name/ID. pipeline_run_id: The id of the pipeline run to filter by. deployment_id: The id of the deployment to filter by. original_step_run_id: The id of the original step run to filter by. @@ -3992,6 +4018,7 @@ def list_run_steps( name=name, workspace_id=workspace_id, user_id=user_id, + user=user, model_version_id=model_version_id, model=model, ) @@ -4664,6 +4691,7 @@ def list_secrets( scope: Optional[SecretScope] = None, workspace_id: Optional[Union[str, UUID]] = None, user_id: Optional[Union[str, UUID]] = None, + user: Optional[Union[UUID, str]] = None, hydrate: bool = False, ) -> Page[SecretResponse]: """Fetches all the secret models. @@ -4683,6 +4711,7 @@ def list_secrets( scope: The scope of the secret to filter by. workspace_id: The id of the workspace to filter by. user_id: The id of the user to filter by. + user: Filter by user name/ID. hydrate: Flag deciding whether to hydrate the output model(s) by including metadata fields in the response. @@ -4699,6 +4728,7 @@ def list_secrets( sort_by=sort_by, logical_operator=logical_operator, user_id=user_id, + user=user, workspace_id=workspace_id, name=name, scope=scope, @@ -5013,6 +5043,7 @@ def list_code_repositories( name: Optional[str] = None, workspace_id: Optional[Union[str, UUID]] = None, user_id: Optional[Union[str, UUID]] = None, + user: Optional[Union[UUID, str]] = None, hydrate: bool = False, ) -> Page[CodeRepositoryResponse]: """List all code repositories. @@ -5028,6 +5059,7 @@ def list_code_repositories( name: The name of the code repository to filter by. workspace_id: The id of the workspace to filter by. user_id: The id of the user to filter by. + user: Filter by user name/ID. hydrate: Flag deciding whether to hydrate the output model(s) by including metadata fields in the response. @@ -5045,6 +5077,7 @@ def list_code_repositories( name=name, workspace_id=workspace_id, user_id=user_id, + user=user, ) filter_model.set_scope_workspace(self.active_workspace.id) return self.zen_store.list_code_repositories( @@ -5405,6 +5438,7 @@ def list_service_connectors( resource_id: Optional[str] = None, workspace_id: Optional[Union[str, UUID]] = None, user_id: Optional[Union[str, UUID]] = None, + user: Optional[Union[UUID, str]] = None, labels: Optional[Dict[str, Optional[str]]] = None, secret_id: Optional[Union[str, UUID]] = None, hydrate: bool = False, @@ -5427,6 +5461,7 @@ def list_service_connectors( they can give access to. workspace_id: The id of the workspace to filter by. user_id: The id of the user to filter by. + user: Filter by user name/ID. name: The name of the service connector to filter by. labels: The labels of the service connector to filter by. secret_id: Filter by the id of the secret that is referenced by the @@ -5444,6 +5479,7 @@ def list_service_connectors( logical_operator=logical_operator, workspace_id=workspace_id or self.active_workspace.id, user_id=user_id, + user=user, name=name, connector_type=connector_type, auth_method=auth_method, @@ -6596,6 +6632,7 @@ def list_authorized_devices( client_id: Union[UUID, str, None] = None, status: Union[OAuthDeviceStatus, str, None] = None, trusted_device: Union[bool, str, None] = None, + user: Optional[Union[UUID, str]] = None, failed_auth_attempts: Union[int, str, None] = None, last_login: Optional[Union[datetime, str, None]] = None, hydrate: bool = False, @@ -6613,6 +6650,7 @@ def list_authorized_devices( expires: Use the expiration date for filtering. client_id: Use the client id for filtering. status: Use the status for filtering. + user: Filter by user name/ID. trusted_device: Use the trusted device flag for filtering. failed_auth_attempts: Use the failed auth attempts for filtering. last_login: Use the last login date for filtering. @@ -6632,6 +6670,7 @@ def list_authorized_devices( updated=updated, expires=expires, client_id=client_id, + user=user, status=status, trusted_device=trusted_device, failed_auth_attempts=failed_auth_attempts, @@ -6730,7 +6769,7 @@ def get_trigger_execution( trigger_execution_id: UUID, hydrate: bool = True, ) -> TriggerExecutionResponse: - """Get an trigger execution by ID. + """Get a trigger execution by ID. Args: trigger_execution_id: The ID of the trigger execution to get. @@ -6751,6 +6790,7 @@ def list_trigger_executions( size: int = PAGE_SIZE_DEFAULT, logical_operator: LogicalOperators = LogicalOperators.AND, trigger_id: Optional[UUID] = None, + user: Optional[Union[UUID, str]] = None, hydrate: bool = False, ) -> Page[TriggerExecutionResponse]: """List all trigger executions matching the given filter criteria. @@ -6761,6 +6801,7 @@ def list_trigger_executions( size: The maximum size of all pages. logical_operator: Which logical operator to use [and, or]. trigger_id: ID of the trigger to filter by. + user: Filter by user name/ID. hydrate: Flag deciding whether to hydrate the output model(s) by including metadata fields in the response. @@ -6772,6 +6813,7 @@ def list_trigger_executions( sort_by=sort_by, page=page, size=size, + user=user, logical_operator=logical_operator, ) filter_model.set_scope_workspace(self.active_workspace.id) diff --git a/src/zenml/models/v2/base/scoped.py b/src/zenml/models/v2/base/scoped.py index 75e3934b533..19d3cf7b081 100644 --- a/src/zenml/models/v2/base/scoped.py +++ b/src/zenml/models/v2/base/scoped.py @@ -335,6 +335,7 @@ class WorkspaceScopedFilter(UserScopedFilter): CLI_EXCLUDE_FIELDS: ClassVar[List[str]] = [ *UserScopedFilter.CLI_EXCLUDE_FIELDS, "workspace_id", + "workspace", "scope_workspace", ] CUSTOM_SORTING_OPTIONS: ClassVar[List[str]] = [ From 8b7cc2167e754696d8bcd550c6450625c93f8ab9 Mon Sep 17 00:00:00 2001 From: Baris Can Durak Date: Wed, 11 Dec 2024 14:27:45 +0100 Subject: [PATCH 14/16] adjusting the models --- .../kubernetes_step_operator.py | 1 - src/zenml/models/v2/base/filter.py | 11 ++- src/zenml/models/v2/base/scoped.py | 67 ++++++++++--------- src/zenml/models/v2/core/artifact_version.py | 15 ++++- src/zenml/models/v2/core/component.py | 7 +- src/zenml/models/v2/core/model_version.py | 7 +- .../models/v2/core/model_version_artifact.py | 14 +++- .../v2/core/model_version_pipeline_run.py | 17 ++++- src/zenml/models/v2/core/pipeline_run.py | 6 +- src/zenml/models/v2/core/run_template.py | 24 ++++++- src/zenml/models/v2/core/service.py | 10 ++- src/zenml/models/v2/core/stack.py | 24 ++++++- src/zenml/models/v2/core/step_run.py | 21 +++++- src/zenml/models/v2/core/trigger.py | 22 +++++- 14 files changed, 184 insertions(+), 62 deletions(-) diff --git a/src/zenml/integrations/kubernetes/step_operators/kubernetes_step_operator.py b/src/zenml/integrations/kubernetes/step_operators/kubernetes_step_operator.py index 0b7b01b546d..52b19af2afe 100644 --- a/src/zenml/integrations/kubernetes/step_operators/kubernetes_step_operator.py +++ b/src/zenml/integrations/kubernetes/step_operators/kubernetes_step_operator.py @@ -33,7 +33,6 @@ from zenml.step_operators import BaseStepOperator if TYPE_CHECKING: - from zenml.config.base_settings import BaseSettings from zenml.config.step_run_info import StepRunInfo from zenml.models import PipelineDeploymentBase diff --git a/src/zenml/models/v2/base/filter.py b/src/zenml/models/v2/base/filter.py index bb15aa16f59..1b79696134a 100644 --- a/src/zenml/models/v2/base/filter.py +++ b/src/zenml/models/v2/base/filter.py @@ -758,7 +758,7 @@ def offset(self) -> int: return self.size * (self.page - 1) def generate_filter( - self, table: Type[SQLModel] + self, table: Type["AnySchema"] ) -> Union["ColumnElement[bool]"]: """Generate the filter for the query. @@ -778,7 +778,7 @@ def generate_filter( filters.append( column_filter.generate_query_conditions(table=table) ) - for custom_filter in self.get_custom_filters(): + for custom_filter in self.get_custom_filters(table): filters.append(custom_filter) if self.logical_operator == LogicalOperators.OR: return or_(False, *filters) @@ -787,12 +787,17 @@ def generate_filter( else: raise RuntimeError("No valid logical operator was supplied.") - def get_custom_filters(self) -> List["ColumnElement[bool]"]: + def get_custom_filters( + self, table: Type["AnySchema"] + ) -> List["ColumnElement[bool]"]: """Get custom filters. This can be overridden by subclasses to define custom filters that are not based on the columns of the underlying table. + Args: + table: The query table. + Returns: A list of custom filters. """ diff --git a/src/zenml/models/v2/base/scoped.py b/src/zenml/models/v2/base/scoped.py index 19d3cf7b081..830ad18d86e 100644 --- a/src/zenml/models/v2/base/scoped.py +++ b/src/zenml/models/v2/base/scoped.py @@ -188,25 +188,25 @@ def set_scope_user(self, user_id: UUID) -> None: self.scope_user = user_id def get_custom_filters( - self, + self, table: Type["AnySchema"] ) -> List["ColumnElement[bool]"]: """Get custom filters. + Args: + table: The query table. + Returns: A list of custom filters. """ - custom_filters = super().get_custom_filters() + custom_filters = super().get_custom_filters(table) from sqlmodel import and_ - from zenml.zen_stores.schemas import ( - PipelineSchema, - UserSchema, - ) + from zenml.zen_stores.schemas import UserSchema if self.user: user_filter = and_( - PipelineSchema.user_id == UserSchema.id, + getattr(table, "user_id") == UserSchema.id, self.generate_name_or_id_query_conditions( value=self.user, table=UserSchema, @@ -365,25 +365,25 @@ def set_scope_workspace(self, workspace_id: UUID) -> None: self.scope_workspace = workspace_id def get_custom_filters( - self, + self, table: Type["AnySchema"] ) -> List["ColumnElement[bool]"]: """Get custom filters. + Args: + table: The query table. + Returns: A list of custom filters. """ - custom_filters = super().get_custom_filters() + custom_filters = super().get_custom_filters(table) from sqlmodel import and_ - from zenml.zen_stores.schemas import ( - PipelineSchema, - WorkspaceSchema, - ) + from zenml.zen_stores.schemas import WorkspaceSchema if self.workspace: workspace_filter = and_( - PipelineSchema.user_id == WorkspaceSchema.id, + getattr(table, "workspace_id") == WorkspaceSchema.id, self.generate_name_or_id_query_conditions( value=self.workspace, table=WorkspaceSchema, @@ -502,15 +502,20 @@ def apply_filter( return query - def get_custom_filters(self) -> List["ColumnElement[bool]"]: + def get_custom_filters( + self, table: Type["AnySchema"] + ) -> List["ColumnElement[bool]"]: """Get custom tag filters. + Args: + table: The query table. + Returns: A list of custom filters. """ from zenml.zen_stores.schemas import TagSchema - custom_filters = super().get_custom_filters() + custom_filters = super().get_custom_filters(table) if self.tag: custom_filters.append( self.generate_custom_query_conditions_for_column( @@ -534,24 +539,24 @@ def apply_sorting( Returns: The query with sorting applied. """ - from sqlmodel import asc, desc, func - - from zenml.enums import SorterOps, TaggableResourceTypes - from zenml.zen_stores.schemas import ( - ArtifactSchema, - ArtifactVersionSchema, - ModelSchema, - ModelVersionSchema, - PipelineRunSchema, - PipelineSchema, - RunTemplateSchema, - TagResourceSchema, - TagSchema, - ) - sort_by, operand = self.sorting_params if sort_by == "tag": + from sqlmodel import asc, desc, func + + from zenml.enums import SorterOps, TaggableResourceTypes + from zenml.zen_stores.schemas import ( + ArtifactSchema, + ArtifactVersionSchema, + ModelSchema, + ModelVersionSchema, + PipelineRunSchema, + PipelineSchema, + RunTemplateSchema, + TagResourceSchema, + TagSchema, + ) + resource_type_mapping = { ArtifactSchema: TaggableResourceTypes.ARTIFACT, ArtifactVersionSchema: TaggableResourceTypes.ARTIFACT_VERSION, diff --git a/src/zenml/models/v2/core/artifact_version.py b/src/zenml/models/v2/core/artifact_version.py index 5805679b38a..a6998b92b3c 100644 --- a/src/zenml/models/v2/core/artifact_version.py +++ b/src/zenml/models/v2/core/artifact_version.py @@ -20,6 +20,8 @@ Dict, List, Optional, + Type, + TypeVar, Union, ) from uuid import UUID @@ -58,6 +60,10 @@ ) from zenml.models.v2.core.pipeline_run import PipelineRunResponse from zenml.models.v2.core.step_run import StepRunResponse + from zenml.zen_stores.schemas.base_schemas import BaseSchema + + AnySchema = TypeVar("AnySchema", bound=BaseSchema) + logger = get_logger(__name__) @@ -549,13 +555,18 @@ class ArtifactVersionFilter(WorkspaceScopedTaggableFilter): model_config = ConfigDict(protected_namespaces=()) - def get_custom_filters(self) -> List[Union["ColumnElement[bool]"]]: + def get_custom_filters( + self, table: Type["AnySchema"] + ) -> List[Union["ColumnElement[bool]"]]: """Get custom filters. + Args: + table: The query table. + Returns: A list of custom filters. """ - custom_filters = super().get_custom_filters() + custom_filters = super().get_custom_filters(table) from sqlmodel import and_, or_, select diff --git a/src/zenml/models/v2/core/component.py b/src/zenml/models/v2/core/component.py index ac73aa5e8dd..98418589222 100644 --- a/src/zenml/models/v2/core/component.py +++ b/src/zenml/models/v2/core/component.py @@ -21,6 +21,7 @@ List, Optional, Type, + TypeVar, Union, ) from uuid import UUID @@ -42,9 +43,11 @@ if TYPE_CHECKING: from sqlalchemy.sql.elements import ColumnElement - from sqlmodel import SQLModel from zenml.models import FlavorResponse, ServiceConnectorResponse + from zenml.zen_stores.schemas.base_schemas import BaseSchema + + AnySchema = TypeVar("AnySchema", bound=BaseSchema) # ------------------ Base Model ------------------ @@ -397,7 +400,7 @@ def set_scope_type(self, component_type: str) -> None: self.scope_type = component_type def generate_filter( - self, table: Type["SQLModel"] + self, table: Type["AnySchema"] ) -> Union["ColumnElement[bool]"]: """Generate the filter for the query. diff --git a/src/zenml/models/v2/core/model_version.py b/src/zenml/models/v2/core/model_version.py index 2d67069117a..949d9ce1d15 100644 --- a/src/zenml/models/v2/core/model_version.py +++ b/src/zenml/models/v2/core/model_version.py @@ -624,14 +624,17 @@ def set_scope_model(self, model_name_or_id: Union[str, UUID]) -> None: self._model_id = model_id def get_custom_filters( - self, + self, table: Type["AnySchema"] ) -> List["ColumnElement[bool]"]: """Get custom filters. + Args: + table: The query table. + Returns: A list of custom filters. """ - custom_filters = super().get_custom_filters() + custom_filters = super().get_custom_filters(table) from sqlmodel import and_ diff --git a/src/zenml/models/v2/core/model_version_artifact.py b/src/zenml/models/v2/core/model_version_artifact.py index f3a677a86e9..6c9514b9735 100644 --- a/src/zenml/models/v2/core/model_version_artifact.py +++ b/src/zenml/models/v2/core/model_version_artifact.py @@ -13,7 +13,7 @@ # permissions and limitations under the License. """Models representing the link between model versions and artifacts.""" -from typing import TYPE_CHECKING, List, Optional, Union +from typing import TYPE_CHECKING, List, Optional, Type, TypeVar, Union from uuid import UUID from pydantic import ConfigDict, Field @@ -32,6 +32,9 @@ from sqlalchemy.sql.elements import ColumnElement from zenml.models.v2.core.artifact_version import ArtifactVersionResponse + from zenml.zen_stores.schemas import BaseSchema + + AnySchema = TypeVar("AnySchema", bound=BaseSchema) # ------------------ Request Model ------------------ @@ -164,13 +167,18 @@ class ModelVersionArtifactFilter(BaseFilter): # careful we might overwrite some fields protected by pydantic. model_config = ConfigDict(protected_namespaces=()) - def get_custom_filters(self) -> List[Union["ColumnElement[bool]"]]: + def get_custom_filters( + self, table: Type["AnySchema"] + ) -> List[Union["ColumnElement[bool]"]]: """Get custom filters. + Args: + table: The query table. + Returns: A list of custom filters. """ - custom_filters = super().get_custom_filters() + custom_filters = super().get_custom_filters(table) from sqlmodel import and_, col diff --git a/src/zenml/models/v2/core/model_version_pipeline_run.py b/src/zenml/models/v2/core/model_version_pipeline_run.py index 6181c2ffbb1..40e7f823d9c 100644 --- a/src/zenml/models/v2/core/model_version_pipeline_run.py +++ b/src/zenml/models/v2/core/model_version_pipeline_run.py @@ -13,7 +13,7 @@ # permissions and limitations under the License. """Models representing the link between model versions and pipeline runs.""" -from typing import List, Optional, Union +from typing import TYPE_CHECKING, List, Optional, Type, TypeVar, Union from uuid import UUID from pydantic import ConfigDict, Field @@ -30,6 +30,12 @@ from zenml.models.v2.base.filter import BaseFilter, StrFilter from zenml.models.v2.core.pipeline_run import PipelineRunResponse +if TYPE_CHECKING: + from zenml.zen_stores.schemas import BaseSchema + + AnySchema = TypeVar("AnySchema", bound=BaseSchema) + + # ------------------ Request Model ------------------ @@ -147,13 +153,18 @@ class ModelVersionPipelineRunFilter(BaseFilter): # careful we might overwrite some fields protected by pydantic. model_config = ConfigDict(protected_namespaces=()) - def get_custom_filters(self) -> List["ColumnElement[bool]"]: + def get_custom_filters( + self, table: Type["AnySchema"] + ) -> List["ColumnElement[bool]"]: """Get custom filters. + Args: + table: The query table. + Returns: A list of custom filters. """ - custom_filters = super().get_custom_filters() + custom_filters = super().get_custom_filters(table) from sqlmodel import and_ diff --git a/src/zenml/models/v2/core/pipeline_run.py b/src/zenml/models/v2/core/pipeline_run.py index b507900a809..3a22f642953 100644 --- a/src/zenml/models/v2/core/pipeline_run.py +++ b/src/zenml/models/v2/core/pipeline_run.py @@ -718,13 +718,17 @@ class PipelineRunFilter(WorkspaceScopedTaggableFilter): def get_custom_filters( self, + table: Type["AnySchema"], ) -> List["ColumnElement[bool]"]: """Get custom filters. + Args: + table: The query table. + Returns: A list of custom filters. """ - custom_filters = super().get_custom_filters() + custom_filters = super().get_custom_filters(table) from sqlmodel import and_, col, or_ diff --git a/src/zenml/models/v2/core/run_template.py b/src/zenml/models/v2/core/run_template.py index 9a4fefc6671..2bc177c043e 100644 --- a/src/zenml/models/v2/core/run_template.py +++ b/src/zenml/models/v2/core/run_template.py @@ -13,7 +13,17 @@ # permissions and limitations under the License. """Models representing pipeline templates.""" -from typing import TYPE_CHECKING, Any, ClassVar, Dict, List, Optional, Union +from typing import ( + TYPE_CHECKING, + Any, + ClassVar, + Dict, + List, + Optional, + Type, + TypeVar, + Union, +) from uuid import UUID from pydantic import Field @@ -45,6 +55,11 @@ if TYPE_CHECKING: from sqlalchemy.sql.elements import ColumnElement + from zenml.zen_stores.schemas.base_schemas import BaseSchema + + AnySchema = TypeVar("AnySchema", bound=BaseSchema) + + # ------------------ Request Model ------------------ @@ -340,14 +355,17 @@ class RunTemplateFilter(WorkspaceScopedTaggableFilter): ) def get_custom_filters( - self, + self, table: Type["AnySchema"] ) -> List["ColumnElement[bool]"]: """Get custom filters. + Args: + table: The query table. + Returns: A list of custom filters. """ - custom_filters = super().get_custom_filters() + custom_filters = super().get_custom_filters(table) from sqlmodel import and_ diff --git a/src/zenml/models/v2/core/service.py b/src/zenml/models/v2/core/service.py index 35ef181ce08..2ad9724b20a 100644 --- a/src/zenml/models/v2/core/service.py +++ b/src/zenml/models/v2/core/service.py @@ -15,19 +15,20 @@ from datetime import datetime from typing import ( + TYPE_CHECKING, Any, ClassVar, Dict, List, Optional, Type, + TypeVar, Union, ) from uuid import UUID from pydantic import BaseModel, ConfigDict, Field from sqlalchemy.sql.elements import ColumnElement -from sqlmodel import SQLModel from zenml.constants import STR_FIELD_MAX_LENGTH from zenml.models.v2.base.scoped import ( @@ -41,6 +42,11 @@ from zenml.services.service_status import ServiceState from zenml.services.service_type import ServiceType +if TYPE_CHECKING: + from zenml.zen_stores.schemas import BaseSchema + + AnySchema = TypeVar("AnySchema", bound=BaseSchema) + # ------------------ Request Model ------------------ @@ -455,7 +461,7 @@ def set_flavor(self, flavor: str) -> None: ] def generate_filter( - self, table: Type["SQLModel"] + self, table: Type["AnySchema"] ) -> Union["ColumnElement[bool]"]: """Generate the filter for the query. diff --git a/src/zenml/models/v2/core/stack.py b/src/zenml/models/v2/core/stack.py index bb446d8886d..1e49eb1544b 100644 --- a/src/zenml/models/v2/core/stack.py +++ b/src/zenml/models/v2/core/stack.py @@ -14,7 +14,17 @@ """Models representing stacks.""" import json -from typing import TYPE_CHECKING, Any, ClassVar, Dict, List, Optional, Union +from typing import ( + TYPE_CHECKING, + Any, + ClassVar, + Dict, + List, + Optional, + Type, + TypeVar, + Union, +) from uuid import UUID from pydantic import Field, model_validator @@ -39,6 +49,9 @@ from sqlalchemy.sql.elements import ColumnElement from zenml.models.v2.core.component import ComponentResponse + from zenml.zen_stores.schemas import BaseSchema + + AnySchema = TypeVar("AnySchema", bound=BaseSchema) # ------------------ Request Model ------------------ @@ -342,13 +355,18 @@ class StackFilter(WorkspaceScopedFilter): default=None, description="Name/ID of a component in the stack." ) - def get_custom_filters(self) -> List["ColumnElement[bool]"]: + def get_custom_filters( + self, table: Type["AnySchema"] + ) -> List["ColumnElement[bool]"]: """Get custom filters. + Args: + table: The query table. + Returns: A list of custom filters. """ - custom_filters = super().get_custom_filters() + custom_filters = super().get_custom_filters(table) from zenml.zen_stores.schemas import ( StackComponentSchema, diff --git a/src/zenml/models/v2/core/step_run.py b/src/zenml/models/v2/core/step_run.py index 504f9db09d3..0a505539d07 100644 --- a/src/zenml/models/v2/core/step_run.py +++ b/src/zenml/models/v2/core/step_run.py @@ -14,7 +14,16 @@ """Models representing steps runs.""" from datetime import datetime -from typing import TYPE_CHECKING, ClassVar, Dict, List, Optional, Union +from typing import ( + TYPE_CHECKING, + ClassVar, + Dict, + List, + Optional, + Type, + TypeVar, + Union, +) from uuid import UUID from pydantic import BaseModel, ConfigDict, Field @@ -41,6 +50,9 @@ LogsRequest, LogsResponse, ) + from zenml.zen_stores.schemas import BaseSchema + + AnySchema = TypeVar("AnySchema", bound=BaseSchema) class StepRunInputResponse(ArtifactVersionResponse): @@ -569,14 +581,17 @@ class StepRunFilter(WorkspaceScopedFilter): model_config = ConfigDict(protected_namespaces=()) def get_custom_filters( - self, + self, table: Type["AnySchema"] ) -> List["ColumnElement[bool]"]: """Get custom filters. + Args: + table: The query table. + Returns: A list of custom filters. """ - custom_filters = super().get_custom_filters() + custom_filters = super().get_custom_filters(table) from sqlmodel import and_ diff --git a/src/zenml/models/v2/core/trigger.py b/src/zenml/models/v2/core/trigger.py index daef211ed7b..45fc23a501c 100644 --- a/src/zenml/models/v2/core/trigger.py +++ b/src/zenml/models/v2/core/trigger.py @@ -13,7 +13,17 @@ # permissions and limitations under the License. """Collection of all models concerning triggers.""" -from typing import TYPE_CHECKING, Any, ClassVar, Dict, List, Optional, Union +from typing import ( + TYPE_CHECKING, + Any, + ClassVar, + Dict, + List, + Optional, + Type, + TypeVar, + Union, +) from uuid import UUID from pydantic import Field, model_validator @@ -39,6 +49,9 @@ ActionResponse, ) from zenml.models.v2.core.event_source import EventSourceResponse + from zenml.zen_stores.schemas import BaseSchema + + AnySchema = TypeVar("AnySchema", bound=BaseSchema) # ------------------ Request Model ------------------ @@ -358,10 +371,13 @@ class TriggerFilter(WorkspaceScopedFilter): ) def get_custom_filters( - self, + self, table: Type["AnySchema"] ) -> List["ColumnElement[bool]"]: """Get custom filters. + Args: + table: The query table. + Returns: A list of custom filters. """ @@ -373,7 +389,7 @@ def get_custom_filters( TriggerSchema, ) - custom_filters = super().get_custom_filters() + custom_filters = super().get_custom_filters(table) if self.event_source_flavor: event_source_flavor_filter = and_( From 43810d00ab78f41a46501172e63d93473094b734 Mon Sep 17 00:00:00 2001 From: Baris Can Durak Date: Wed, 11 Dec 2024 14:43:09 +0100 Subject: [PATCH 15/16] consistency check --- src/zenml/models/v2/base/scoped.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/src/zenml/models/v2/base/scoped.py b/src/zenml/models/v2/base/scoped.py index 830ad18d86e..f5267f4840d 100644 --- a/src/zenml/models/v2/base/scoped.py +++ b/src/zenml/models/v2/base/scoped.py @@ -542,7 +542,7 @@ def apply_sorting( sort_by, operand = self.sorting_params if sort_by == "tag": - from sqlmodel import asc, desc, func + from sqlmodel import and_, asc, desc, func from zenml.enums import SorterOps, TaggableResourceTypes from zenml.zen_stores.schemas import ( @@ -570,10 +570,10 @@ def apply_sorting( query = ( query.outerjoin( TagResourceSchema, - (table.id == TagResourceSchema.resource_id) - & ( + and_( + table.id == TagResourceSchema.resource_id, TagResourceSchema.resource_type - == resource_type_mapping[table] + == resource_type_mapping[table], ), ) .outerjoin(TagSchema, TagResourceSchema.tag_id == TagSchema.id) From 7b7f50a07e3d9c2e3f45b56ebd178b1026ab9518 Mon Sep 17 00:00:00 2001 From: Baris Can Durak Date: Thu, 12 Dec 2024 17:56:23 +0100 Subject: [PATCH 16/16] formatting --- src/zenml/models/v2/core/pipeline_build.py | 16 +++++++++++++--- 1 file changed, 13 insertions(+), 3 deletions(-) diff --git a/src/zenml/models/v2/core/pipeline_build.py b/src/zenml/models/v2/core/pipeline_build.py index c0dcff81e38..19dc89ccbf0 100644 --- a/src/zenml/models/v2/core/pipeline_build.py +++ b/src/zenml/models/v2/core/pipeline_build.py @@ -14,7 +14,17 @@ """Models representing pipeline builds.""" import json -from typing import TYPE_CHECKING, Any, ClassVar, Dict, List, Optional, Union, Type, TypeVar +from typing import ( + TYPE_CHECKING, + Any, + ClassVar, + Dict, + List, + Optional, + Type, + TypeVar, + Union, +) from uuid import UUID from pydantic import Field @@ -40,7 +50,6 @@ AnySchema = TypeVar("AnySchema", bound=BaseSchema) - # ------------------ Request Model ------------------ @@ -495,7 +504,8 @@ class PipelineBuildFilter(WorkspaceScopedFilter): ) def get_custom_filters( - self, table: Type["AnySchema"], + self, + table: Type["AnySchema"], ) -> List["ColumnElement[bool]"]: """Get custom filters.