From d6fae4ef12536e1455c92a0f45882c0e4ad40a10 Mon Sep 17 00:00:00 2001 From: Michael Schuster Date: Fri, 13 Dec 2024 14:29:45 +0100 Subject: [PATCH] Create model versions server-side to avoid race conditions (#3254) * POC * Log model version if created * Fetch model version if it already exists * Test * Missing docstring * Refactor fetching model version into separate method * Remove FK cycle * Add DB migration * Cleanup * Move to computed property to solve alembic issue * mypy * Some cleanup * More tests and fixes * Typo * Move logs so they don't happen when creating the object * Always log model version * Mysql fix * Linting * More linting * Make columns non-nullable * Formatting * Fix some tests * Refactor * Fix comparison for mysql * Implement RBAC checks inside SQL ZenStore * Docstrings * Try to fix vertex test * Fix alembic order * Ignore generated columns in DB migration * Linting * Rework entire DB stuff to work with mariadb * Revert change to DB backup * Add missing null check * Try to fix vertex again --- src/zenml/model/model.py | 28 +- src/zenml/models/v2/core/model_version.py | 4 - src/zenml/orchestrators/step_launcher.py | 22 +- src/zenml/orchestrators/step_run_utils.py | 212 +---------- .../zen_server/rbac/rbac_sql_zen_store.py | 173 +++++++++ src/zenml/zen_stores/base_zen_store.py | 12 +- ..._add_model_version_producer_run_unique_.py | 68 ++++ src/zenml/zen_stores/schemas/model_schemas.py | 48 ++- .../schemas/pipeline_deployment_schemas.py | 14 +- src/zenml/zen_stores/sql_zen_store.py | 347 ++++++++++++++++-- .../functional/model/test_model_version.py | 12 +- .../pipelines/test_pipeline_context.py | 2 +- .../functional/steps/test_model_version.py | 148 +++++++- .../orchestrators/test_vertex_orchestrator.py | 36 +- tests/unit/model/test_model_version_init.py | 27 -- 15 files changed, 828 insertions(+), 325 deletions(-) create mode 100644 src/zenml/zen_server/rbac/rbac_sql_zen_store.py create mode 100644 src/zenml/zen_stores/migrations/versions/a1237ba94fd8_add_model_version_producer_run_unique_.py delete mode 100644 tests/unit/model/test_model_version_init.py diff --git a/src/zenml/model/model.py b/src/zenml/model/model.py index b7f3c591518..f7987fc1b58 100644 --- a/src/zenml/model/model.py +++ b/src/zenml/model/model.py @@ -509,22 +509,6 @@ def _root_validator(cls, data: Dict[str, Any]) -> Dict[str, Any]: raise ValueError( "`model_version_id` field is for internal use only" ) - - version = data.get("version", None) - - if ( - version in [stage.value for stage in ModelStages] - and not suppress_class_validation_warnings - ): - logger.info( - f"Version `{version}` matches one of the possible " - "`ModelStages` and will be fetched using stage." - ) - if str(version).isnumeric() and not suppress_class_validation_warnings: - logger.info( - f"`version` `{version}` is numeric and will be fetched " - "using version number." - ) data["suppress_class_validation_warnings"] = True return data @@ -603,6 +587,18 @@ def _get_model_version( hydrate=hydrate, ) else: + if self.version in ModelStages.values(): + logger.info( + f"Version `{self.version}` for model {self.name} matches " + "one of the possible `ModelStages` and will be fetched " + "using stage." + ) + if str(self.version).isnumeric(): + logger.info( + f"Version `{self.version}` for model {self.name} is " + "numeric and will be fetched using version number." + ) + mv = zenml_client.get_model_version( model_name_or_id=self.name, model_version_name_or_number_or_id=self.version, diff --git a/src/zenml/models/v2/core/model_version.py b/src/zenml/models/v2/core/model_version.py index 949d9ce1d15..80880f1e70e 100644 --- a/src/zenml/models/v2/core/model_version.py +++ b/src/zenml/models/v2/core/model_version.py @@ -77,10 +77,6 @@ class ModelVersionRequest(WorkspaceScopedRequest): default=None, ) - number: Optional[int] = Field( - description="The number of the model version", - default=None, - ) model: UUID = Field( description="The ID of the model containing version", ) diff --git a/src/zenml/orchestrators/step_launcher.py b/src/zenml/orchestrators/step_launcher.py index 6db9c085a89..1141172bf31 100644 --- a/src/zenml/orchestrators/step_launcher.py +++ b/src/zenml/orchestrators/step_launcher.py @@ -179,12 +179,10 @@ def launch(self) -> None: pipeline_run_id=pipeline_run.id, pipeline_run_metadata=pipeline_run_metadata, ) - - pipeline_model_version, pipeline_run = ( - step_run_utils.prepare_pipeline_run_model_version( - pipeline_run - ) - ) + if model_version := pipeline_run.model_version: + step_run_utils.log_model_version_dashboard_url( + model_version=model_version + ) request_factory = step_run_utils.StepRunRequestFactory( deployment=self._deployment, @@ -209,12 +207,10 @@ def launch(self) -> None: step_run = Client().zen_store.create_run_step( step_run_request ) - - step_model_version, step_run = ( - step_run_utils.prepare_step_run_model_version( - step_run=step_run, pipeline_run=pipeline_run + if model_version := step_run.model_version: + step_run_utils.log_model_version_dashboard_url( + model_version=model_version ) - ) if not step_run.status.is_finished: logger.info(f"Step `{self._step_name}` has started.") @@ -289,8 +285,8 @@ def _bypass() -> None: f"Using cached version of step `{self._step_name}`." ) if ( - model_version := step_model_version - or pipeline_model_version + model_version := step_run.model_version + or pipeline_run.model_version ): step_run_utils.link_output_artifacts_to_model_version( artifacts=step_run.outputs, diff --git a/src/zenml/orchestrators/step_run_utils.py b/src/zenml/orchestrators/step_run_utils.py index e371b4c509a..6451a4cc0a4 100644 --- a/src/zenml/orchestrators/step_run_utils.py +++ b/src/zenml/orchestrators/step_run_utils.py @@ -14,7 +14,7 @@ """Utilities for creating step runs.""" from datetime import datetime -from typing import TYPE_CHECKING, Dict, List, Optional, Set, Tuple +from typing import Dict, List, Optional, Set, Tuple from zenml.client import Client from zenml.config.step_configurations import Step @@ -24,21 +24,13 @@ from zenml.model.utils import link_artifact_version_to_model_version from zenml.models import ( ArtifactVersionResponse, - ModelVersionPipelineRunRequest, ModelVersionResponse, PipelineDeploymentResponse, PipelineRunResponse, - PipelineRunUpdate, StepRunRequest, - StepRunResponse, - StepRunUpdate, ) from zenml.orchestrators import cache_utils, input_utils, utils from zenml.stack import Stack -from zenml.utils import pagination_utils, string_utils - -if TYPE_CHECKING: - from zenml.model.model import Model logger = get_logger(__name__) @@ -293,10 +285,6 @@ def create_cached_step_runs( deployment=deployment, pipeline_run=pipeline_run, stack=stack ) - pipeline_model_version, pipeline_run = prepare_pipeline_run_model_version( - pipeline_run=pipeline_run - ) - while ( cache_candidates := find_cacheable_invocation_candidates( deployment=deployment, @@ -311,7 +299,9 @@ def create_cached_step_runs( # Make sure the request factory has the most up to date pipeline # run to avoid hydration calls - request_factory.pipeline_run = pipeline_run + request_factory.pipeline_run = Client().get_pipeline_run( + pipeline_run.id + ) try: step_run_request = request_factory.create_request( invocation_id @@ -336,15 +326,10 @@ def create_cached_step_runs( step_run = Client().zen_store.create_run_step(step_run_request) - # Refresh the pipeline run here to make sure we have the latest - # state - pipeline_run = Client().get_pipeline_run(pipeline_run.id) - - step_model_version, step_run = prepare_step_run_model_version( - step_run=step_run, pipeline_run=pipeline_run - ) - - if model_version := step_model_version or pipeline_model_version: + if ( + model_version := step_run.model_version + or pipeline_run.model_version + ): link_output_artifacts_to_model_version( artifacts=step_run.outputs, model_version=model_version, @@ -356,169 +341,6 @@ def create_cached_step_runs( return cached_invocations -def get_or_create_model_version_for_pipeline_run( - model: "Model", - pipeline_run: PipelineRunResponse, - substitutions: Dict[str, str], -) -> Tuple[ModelVersionResponse, bool]: - """Get or create a model version as part of a pipeline run. - - Args: - model: The model to get or create. - pipeline_run: The pipeline run for which the model should be created. - substitutions: Substitutions to apply to the model version name. - - Returns: - The model version and a boolean indicating whether it was newly created - or not. - """ - # Copy the model before modifying it so we don't accidently modify - # configurations in which the model object is potentially referenced - model = model.model_copy() - - if model.model_version_id: - return model._get_model_version(), False - elif model.version: - if isinstance(model.version, str): - model.version = string_utils.format_name_template( - model.version, - substitutions=substitutions, - ) - model.name = string_utils.format_name_template( - model.name, - substitutions=substitutions, - ) - - return ( - model._get_or_create_model_version(), - model._created_model_version, - ) - - # The model version should be created as part of this run - # -> We first check if it was already created as part of this run, and if - # not we do create it. If this is running in two parallel steps, we might - # run into issues that this will create two versions. Ideally, all model - # versions required for a pipeline run and its steps could be created - # server-side at run creation time before the first step starts. - if model_version := get_model_version_created_by_pipeline_run( - model_name=model.name, pipeline_run=pipeline_run - ): - return model_version, False - else: - return model._get_or_create_model_version(), True - - -def get_model_version_created_by_pipeline_run( - model_name: str, pipeline_run: PipelineRunResponse -) -> Optional[ModelVersionResponse]: - """Get a model version that was created by a specific pipeline run. - - This function does not refresh the pipeline run, so it will only try to - fetch the model version from existing steps if they're already part of the - response. - - Args: - model_name: The model name for which to get the version. - pipeline_run: The pipeline run for which to get the version. - - Returns: - A model version with the given name created by the run, or None if such - a model version does not exist. - """ - if pipeline_run.config.model and pipeline_run.model_version: - if ( - pipeline_run.config.model.name == model_name - and pipeline_run.config.model.version is None - ): - return pipeline_run.model_version - - # We fetch a list of hydrated step runs here in order to avoid hydration - # calls for each step separately. - candidate_step_runs = pagination_utils.depaginate( - Client().list_run_steps, - pipeline_run_id=pipeline_run.id, - model=model_name, - hydrate=True, - ) - for step_run in candidate_step_runs: - if step_run.config.model and step_run.model_version: - if ( - step_run.config.model.name == model_name - and step_run.config.model.version is None - ): - return step_run.model_version - - return None - - -def prepare_pipeline_run_model_version( - pipeline_run: PipelineRunResponse, -) -> Tuple[Optional[ModelVersionResponse], PipelineRunResponse]: - """Prepare the model version for a pipeline run. - - Args: - pipeline_run: The pipeline run for which to prepare the model version. - - Returns: - The prepared model version and the updated pipeline run. - """ - model_version = None - - if pipeline_run.model_version: - model_version = pipeline_run.model_version - elif config_model := pipeline_run.config.model: - model_version, _ = get_or_create_model_version_for_pipeline_run( - model=config_model, - pipeline_run=pipeline_run, - substitutions=pipeline_run.config.substitutions, - ) - pipeline_run = Client().zen_store.update_run( - run_id=pipeline_run.id, - run_update=PipelineRunUpdate(model_version_id=model_version.id), - ) - link_pipeline_run_to_model_version( - pipeline_run=pipeline_run, model_version=model_version - ) - log_model_version_dashboard_url(model_version) - - return model_version, pipeline_run - - -def prepare_step_run_model_version( - step_run: StepRunResponse, pipeline_run: PipelineRunResponse -) -> Tuple[Optional[ModelVersionResponse], StepRunResponse]: - """Prepare the model version for a step run. - - Args: - step_run: The step run for which to prepare the model version. - pipeline_run: The pipeline run of the step. - - Returns: - The prepared model version and the updated step run. - """ - model_version = None - - if step_run.model_version: - model_version = step_run.model_version - elif config_model := step_run.config.model: - model_version, created = get_or_create_model_version_for_pipeline_run( - model=config_model, - pipeline_run=pipeline_run, - substitutions=step_run.config.substitutions, - ) - step_run = Client().zen_store.update_run_step( - step_run_id=step_run.id, - step_run_update=StepRunUpdate(model_version_id=model_version.id), - ) - link_pipeline_run_to_model_version( - pipeline_run=pipeline_run, model_version=model_version - ) - if created: - log_model_version_dashboard_url(model_version) - - return model_version, step_run - - def log_model_version_dashboard_url( model_version: ModelVersionResponse, ) -> None: @@ -546,24 +368,6 @@ def log_model_version_dashboard_url( ) -def link_pipeline_run_to_model_version( - pipeline_run: PipelineRunResponse, model_version: ModelVersionResponse -) -> None: - """Link a pipeline run to a model version. - - Args: - pipeline_run: The pipeline run to link. - model_version: The model version to link. - """ - client = Client() - client.zen_store.create_model_version_pipeline_run_link( - ModelVersionPipelineRunRequest( - pipeline_run=pipeline_run.id, - model_version=model_version.id, - ) - ) - - def link_output_artifacts_to_model_version( artifacts: Dict[str, List[ArtifactVersionResponse]], model_version: ModelVersionResponse, diff --git a/src/zenml/zen_server/rbac/rbac_sql_zen_store.py b/src/zenml/zen_server/rbac/rbac_sql_zen_store.py new file mode 100644 index 00000000000..1d6082a9e75 --- /dev/null +++ b/src/zenml/zen_server/rbac/rbac_sql_zen_store.py @@ -0,0 +1,173 @@ +# Copyright (c) ZenML GmbH 2024. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at: +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express +# or implied. See the License for the specific language governing +# permissions and limitations under the License. +"""RBAC SQL Zen Store implementation.""" + +from typing import ( + Optional, + Tuple, +) +from uuid import UUID + +from zenml.logger import get_logger +from zenml.models import ( + ModelRequest, + ModelResponse, + ModelVersionRequest, + ModelVersionResponse, +) +from zenml.zen_server.feature_gate.endpoint_utils import ( + check_entitlement, + report_usage, +) +from zenml.zen_server.rbac.models import Action, ResourceType +from zenml.zen_server.rbac.utils import ( + verify_permission, + verify_permission_for_model, +) +from zenml.zen_stores.sql_zen_store import SqlZenStore + +logger = get_logger(__name__) + + +class RBACSqlZenStore(SqlZenStore): + """Wrapper around the SQLZenStore that implements RBAC functionality.""" + + def _get_or_create_model( + self, model_request: ModelRequest + ) -> Tuple[bool, ModelResponse]: + """Get or create a model. + + Args: + model_request: The model request. + + # noqa: DAR401 + Raises: + Exception: If the user is not allowed to create a model. + + Returns: + A boolean whether the model was created or not, and the model. + """ + allow_model_creation = True + error = None + + try: + verify_permission( + resource_type=ResourceType.MODEL, action=Action.CREATE + ) + check_entitlement(resource_type=ResourceType.MODEL) + except Exception as e: + allow_model_creation = False + error = e + + if allow_model_creation: + created, model_response = super()._get_or_create_model( + model_request + ) + else: + try: + model_response = self.get_model(model_request.name) + created = False + except KeyError: + # The model does not exist. We now raise the error that + # explains why the model could not be created, instead of just + # the KeyError that it doesn't exist + assert error + raise error from None + + if created: + report_usage( + resource_type=ResourceType.MODEL, resource_id=model_response.id + ) + else: + verify_permission_for_model(model_response, action=Action.READ) + + return created, model_response + + def _get_model_version( + self, + model_id: UUID, + version_name: Optional[str] = None, + producer_run_id: Optional[UUID] = None, + ) -> ModelVersionResponse: + """Get a model version. + + Args: + model_id: The ID of the model. + version_name: The name of the model version. + producer_run_id: The ID of the producer pipeline run. If this is + set, only numeric versions created as part of the pipeline run + will be returned. + + Returns: + The model version. + """ + model_version = super()._get_model_version( + model_id=model_id, + version_name=version_name, + producer_run_id=producer_run_id, + ) + verify_permission_for_model(model_version, action=Action.READ) + return model_version + + def _get_or_create_model_version( + self, + model_version_request: ModelVersionRequest, + producer_run_id: Optional[UUID] = None, + ) -> Tuple[bool, ModelVersionResponse]: + """Get or create a model version. + + Args: + model_version_request: The model version request. + producer_run_id: ID of the producer pipeline run. + + # noqa: DAR401 + Raises: + Exception: If the authenticated user is not allowed to + create a model version. + + Returns: + A boolean whether the model version was created or not, and the + model version. + """ + allow_creation = True + error = None + + try: + verify_permission( + resource_type=ResourceType.MODEL_VERSION, action=Action.CREATE + ) + except Exception as e: + allow_creation = False + error = e + + if allow_creation: + created, model_version_response = ( + super()._get_or_create_model_version(model_version_request, producer_run_id=producer_run_id) + ) + else: + try: + model_version_response = self._get_model_version( + model_id=model_version_request.model, + version_name=model_version_request.name, + producer_run_id=producer_run_id, + ) + created = False + except KeyError: + # The model version does not exist. We now raise the error that + # explains why the version could not be created, instead of just + # the KeyError that it doesn't exist + assert error + raise error from None + + return created, model_version_response diff --git a/src/zenml/zen_stores/base_zen_store.py b/src/zenml/zen_stores/base_zen_store.py index 210f6b8b1ed..11467c44814 100644 --- a/src/zenml/zen_stores/base_zen_store.py +++ b/src/zenml/zen_stores/base_zen_store.py @@ -36,6 +36,7 @@ DEFAULT_STACK_AND_COMPONENT_NAME, DEFAULT_WORKSPACE_NAME, ENV_ZENML_DEFAULT_WORKSPACE_NAME, + ENV_ZENML_SERVER, IS_DEBUG_ENV, ) from zenml.enums import ( @@ -155,9 +156,16 @@ def get_store_class(store_type: StoreType) -> Type["BaseZenStore"]: TypeError: If the store type is unsupported. """ if store_type == StoreType.SQL: - from zenml.zen_stores.sql_zen_store import SqlZenStore + if os.environ.get(ENV_ZENML_SERVER): + from zenml.zen_server.rbac.rbac_sql_zen_store import ( + RBACSqlZenStore, + ) + + return RBACSqlZenStore + else: + from zenml.zen_stores.sql_zen_store import SqlZenStore - return SqlZenStore + return SqlZenStore elif store_type == StoreType.REST: from zenml.zen_stores.rest_zen_store import RestZenStore diff --git a/src/zenml/zen_stores/migrations/versions/a1237ba94fd8_add_model_version_producer_run_unique_.py b/src/zenml/zen_stores/migrations/versions/a1237ba94fd8_add_model_version_producer_run_unique_.py new file mode 100644 index 00000000000..007b5ddbb8a --- /dev/null +++ b/src/zenml/zen_stores/migrations/versions/a1237ba94fd8_add_model_version_producer_run_unique_.py @@ -0,0 +1,68 @@ +"""Add model version producer run unique constraint [a1237ba94fd8]. + +Revision ID: a1237ba94fd8 +Revises: 26351d482b9e +Create Date: 2024-12-13 10:28:55.432414 + +""" + +import sqlalchemy as sa +import sqlmodel +from alembic import op + +# revision identifiers, used by Alembic. +revision = "a1237ba94fd8" +down_revision = "26351d482b9e" +branch_labels = None +depends_on = None + + +def upgrade() -> None: + """Upgrade database schema and/or data, creating a new revision.""" + # ### commands auto generated by Alembic - please adjust! ### + with op.batch_alter_table("model_version", schema=None) as batch_op: + batch_op.add_column( + sa.Column( + "producer_run_id_if_numeric", + sqlmodel.sql.sqltypes.GUID(), + nullable=True, + ) + ) + + # Set the producer_run_id_if_numeric column to the model version ID for + # existing rows + connection = op.get_bind() + metadata = sa.MetaData() + metadata.reflect(only=("model_version",), bind=connection) + model_version_table = sa.Table("model_version", metadata) + + connection.execute( + model_version_table.update().values( + producer_run_id_if_numeric=model_version_table.c.id + ) + ) + + with op.batch_alter_table("model_version", schema=None) as batch_op: + batch_op.alter_column( + "producer_run_id_if_numeric", + existing_type=sqlmodel.sql.sqltypes.GUID(), + nullable=False, + ) + batch_op.create_unique_constraint( + "unique_numeric_version_for_pipeline_run", + ["model_id", "producer_run_id_if_numeric"], + ) + + # ### end Alembic commands ### + + +def downgrade() -> None: + """Downgrade database schema and/or data back to the previous revision.""" + # ### commands auto generated by Alembic - please adjust! ### + with op.batch_alter_table("model_version", schema=None) as batch_op: + batch_op.drop_constraint( + "unique_numeric_version_for_pipeline_run", type_="unique" + ) + batch_op.drop_column("producer_run_id_if_numeric") + + # ### end Alembic commands ### diff --git a/src/zenml/zen_stores/schemas/model_schemas.py b/src/zenml/zen_stores/schemas/model_schemas.py index feb4a93dc80..41c186c75ca 100644 --- a/src/zenml/zen_stores/schemas/model_schemas.py +++ b/src/zenml/zen_stores/schemas/model_schemas.py @@ -15,10 +15,16 @@ from datetime import datetime from typing import TYPE_CHECKING, Any, Dict, List, Optional, cast -from uuid import UUID +from uuid import UUID, uuid4 from pydantic import ConfigDict -from sqlalchemy import BOOLEAN, INTEGER, TEXT, Column, UniqueConstraint +from sqlalchemy import ( + BOOLEAN, + INTEGER, + TEXT, + Column, + UniqueConstraint, +) from sqlmodel import Field, Relationship from zenml.enums import ( @@ -228,11 +234,13 @@ class ModelVersionSchema(NamedSchema, RunMetadataInterface, table=True): __tablename__ = MODEL_VERSION_TABLENAME __table_args__ = ( - # We need two unique constraints here: + # We need three unique constraints here: # - The first to ensure that each model version for a # model has a unique version number # - The second one to ensure that explicit names given by # users are unique + # - The third one to ensure that a pipeline run only produces a single + # auto-incremented version per model UniqueConstraint( "number", "model_id", @@ -243,6 +251,11 @@ class ModelVersionSchema(NamedSchema, RunMetadataInterface, table=True): "model_id", name="unique_version_for_model_id", ), + UniqueConstraint( + "model_id", + "producer_run_id_if_numeric", + name="unique_numeric_version_for_pipeline_run", + ), ) workspace_id: UUID = build_foreign_key_field( @@ -312,12 +325,23 @@ class ModelVersionSchema(NamedSchema, RunMetadataInterface, table=True): ), ) pipeline_runs: List["PipelineRunSchema"] = Relationship( - back_populates="model_version" + back_populates="model_version", ) step_runs: List["StepRunSchema"] = Relationship( back_populates="model_version" ) + # We want to make sure each pipeline run only creates a single numeric + # version for each model. To solve this, we need to add a unique constraint. + # If a value of a unique constraint is NULL it is ignored and the + # remaining values in the unique constraint have to be unique. In + # our case however, we only want the unique constraint applied in + # case there is a producer run and only for numeric versions. To solve this, + # we fall back to the model version ID (which is the primary key and + # therefore unique) in case there is no producer run or the version is not + # numeric. + producer_run_id_if_numeric: UUID + # TODO: In Pydantic v2, the `model_` is a protected namespaces for all # fields defined under base models. If not handled, this raises a warning. # It is possible to suppress this warning message with the following @@ -328,24 +352,36 @@ class ModelVersionSchema(NamedSchema, RunMetadataInterface, table=True): @classmethod def from_request( - cls, model_version_request: ModelVersionRequest + cls, + model_version_request: ModelVersionRequest, + model_version_number: int, + producer_run_id: Optional[UUID] = None, ) -> "ModelVersionSchema": """Convert an `ModelVersionRequest` to an `ModelVersionSchema`. Args: model_version_request: The request model version to convert. + model_version_number: The model version number. + producer_run_id: The ID of the producer run. Returns: The converted schema. """ + id_ = uuid4() + is_numeric = str(model_version_number) == model_version_request.name + return cls( + id=id_, workspace_id=model_version_request.workspace, user_id=model_version_request.user, model_id=model_version_request.model, name=model_version_request.name, - number=model_version_request.number, + number=model_version_number, description=model_version_request.description, stage=model_version_request.stage, + producer_run_id_if_numeric=producer_run_id + if (producer_run_id and is_numeric) + else id_, ) def to_model( diff --git a/src/zenml/zen_stores/schemas/pipeline_deployment_schemas.py b/src/zenml/zen_stores/schemas/pipeline_deployment_schemas.py index ae2fe609bcf..409bd2eebfc 100644 --- a/src/zenml/zen_stores/schemas/pipeline_deployment_schemas.py +++ b/src/zenml/zen_stores/schemas/pipeline_deployment_schemas.py @@ -228,13 +228,6 @@ def to_model( Returns: The created `PipelineDeploymentResponse`. """ - pipeline_configuration = PipelineConfiguration.model_validate_json( - self.pipeline_configuration - ) - step_configurations = json.loads(self.step_configurations) - for s, c in step_configurations.items(): - step_configurations[s] = Step.model_validate(c) - body = PipelineDeploymentResponseBody( user=self.user.to_model() if self.user else None, created=self.created, @@ -242,6 +235,13 @@ def to_model( ) metadata = None if include_metadata: + pipeline_configuration = PipelineConfiguration.model_validate_json( + self.pipeline_configuration + ) + step_configurations = json.loads(self.step_configurations) + for s, c in step_configurations.items(): + step_configurations[s] = Step.model_validate(c) + metadata = PipelineDeploymentResponseMetadata( workspace=self.workspace.to_model(), run_name_template=self.run_name_template, diff --git a/src/zenml/zen_stores/sql_zen_store.py b/src/zenml/zen_stores/sql_zen_store.py index ce20d6687f6..19bdda8b28f 100644 --- a/src/zenml/zen_stores/sql_zen_store.py +++ b/src/zenml/zen_stores/sql_zen_store.py @@ -71,6 +71,7 @@ col, create_engine, delete, + desc, or_, select, ) @@ -296,7 +297,11 @@ replace_localhost_with_internal_hostname, ) from zenml.utils.pydantic_utils import before_validator_handler -from zenml.utils.string_utils import random_str, validate_name +from zenml.utils.string_utils import ( + format_name_template, + random_str, + validate_name, +) from zenml.zen_stores import template_utils from zenml.zen_stores.base_zen_store import ( BaseZenStore, @@ -5154,6 +5159,20 @@ def create_run( "already exists." ) + if model_version_id := self._get_or_create_model_version_for_run( + new_run + ): + new_run.model_version_id = model_version_id + session.add(new_run) + session.commit() + + self.create_model_version_pipeline_run_link( + ModelVersionPipelineRunRequest( + model_version=model_version_id, pipeline_run=new_run.id + ) + ) + session.refresh(new_run) + return new_run.to_model( include_metadata=True, include_resources=True ) @@ -8216,6 +8235,21 @@ def create_run_step(self, step_run: StepRunRequest) -> StepRunResponse: session.commit() session.refresh(step_schema) + if model_version_id := self._get_or_create_model_version_for_run( + step_schema + ): + step_schema.model_version_id = model_version_id + session.add(step_schema) + session.commit() + + self.create_model_version_pipeline_run_link( + ModelVersionPipelineRunRequest( + model_version=model_version_id, + pipeline_run=step_schema.pipeline_run_id, + ) + ) + session.refresh(step_schema) + return step_schema.to_model( include_metadata=True, include_resources=True ) @@ -10218,6 +10252,22 @@ def update_model( # ----------------------------- Model Versions ----------------------------- + def _get_or_create_model( + self, model_request: ModelRequest + ) -> Tuple[bool, ModelResponse]: + """Get or create a model. + + Args: + model_request: The model request. + + Returns: + A boolean whether the model was created or not, and the model. + """ + try: + return True, self.create_model(model_request) + except EntityExistsError: + return False, self.get_model(model_request.name) + def _get_next_numeric_version_for_model( self, session: Session, model_id: UUID ) -> int: @@ -10242,55 +10292,276 @@ def _get_next_numeric_version_for_model( else: return int(current_max_version) + 1 - def _model_version_exists(self, model_id: UUID, version: str) -> bool: + def _model_version_exists( + self, + model_id: UUID, + version: Optional[str] = None, + producer_run_id: Optional[UUID] = None, + ) -> bool: """Check if a model version with a certain version exists. Args: model_id: The model ID of the version. version: The version name. + producer_run_id: The producer run ID. If given, checks if a numeric + version for the producer run exists. Returns: - If a model version with the given version name exists. + If a model version for the given arguments exists. """ + query = select(ModelVersionSchema.id).where( + ModelVersionSchema.model_id == model_id + ) + + if version: + query = query.where(ModelVersionSchema.name == version) + + if producer_run_id: + query = query.where( + ModelVersionSchema.producer_run_id_if_numeric + == producer_run_id, + ) + with Session(self.engine) as session: - return ( - session.exec( - select(ModelVersionSchema.id) - .where(ModelVersionSchema.model_id == model_id) - .where(ModelVersionSchema.name == version) - ).first() - is not None + return session.exec(query).first() is not None + + def _get_model_version( + self, + model_id: UUID, + version_name: Optional[str] = None, + producer_run_id: Optional[UUID] = None, + ) -> ModelVersionResponse: + """Get a model version. + + Args: + model_id: The ID of the model. + version_name: The name of the model version. + producer_run_id: The ID of the producer pipeline run. If this is + set, only numeric versions created as part of the pipeline run + will be returned. + + Raises: + ValueError: If no version name or producer run ID was provided. + KeyError: If no model version was found. + + Returns: + The model version. + """ + query = select(ModelVersionSchema).where( + ModelVersionSchema.model_id == model_id + ) + + if version_name: + if version_name.isnumeric(): + query = query.where( + ModelVersionSchema.number == int(version_name) + ) + error_text = ( + f"No version with number {version_name} found " + f"for model {model_id}." + ) + elif version_name in ModelStages.values(): + if version_name == ModelStages.LATEST: + query = query.order_by( + desc(col(ModelVersionSchema.number)) + ).limit(1) + else: + query = query.where( + ModelVersionSchema.stage == version_name + ) + error_text = ( + f"No {version_name} stage version found for " + f"model {model_id}." + ) + else: + query = query.where(ModelVersionSchema.name == version_name) + error_text = ( + f"No {version_name} version found for model {model_id}." + ) + + elif producer_run_id: + query = query.where( + ModelVersionSchema.producer_run_id_if_numeric + == producer_run_id, + ) + error_text = ( + f"No numeric model version found for model {model_id} " + f"and producer run {producer_run_id}." + ) + else: + raise ValueError( + "Version name or producer run id need to be specified." ) - @track_decorator(AnalyticsEvent.CREATED_MODEL_VERSION) - def create_model_version( - self, model_version: ModelVersionRequest + with Session(self.engine) as session: + schema = session.exec(query).one_or_none() + + if not schema: + raise KeyError(error_text) + + return schema.to_model( + include_metadata=True, include_resources=True + ) + + def _get_or_create_model_version( + self, + model_version_request: ModelVersionRequest, + producer_run_id: Optional[UUID] = None, + ) -> Tuple[bool, ModelVersionResponse]: + """Get or create a model version. + + Args: + model_version_request: The model version request. + producer_run_id: ID of the producer pipeline run. + + Raises: + EntityCreationError: If the model version creation failed. + + Returns: + A boolean whether the model version was created or not, and the + model version. + """ + try: + model_version = self._create_model_version( + model_version=model_version_request, + producer_run_id=producer_run_id, + ) + track(event=AnalyticsEvent.CREATED_MODEL_VERSION) + return True, model_version + except EntityCreationError: + # Need to explicitly re-raise this here as otherwise the catching + # of the RuntimeError would include this + raise + except RuntimeError: + return False, self._get_model_version( + model_id=model_version_request.model, + producer_run_id=producer_run_id, + ) + except EntityExistsError: + return False, self._get_model_version( + model_id=model_version_request.model, + version_name=model_version_request.name, + ) + + def _get_or_create_model_version_for_run( + self, pipeline_or_step_run: Union[PipelineRunSchema, StepRunSchema] + ) -> Optional[UUID]: + """Get or create a model version for a pipeline or step run. + + Args: + pipeline_or_step_run: The pipeline or step run for which to create + the model version. + + Returns: + The model version. + """ + if isinstance(pipeline_or_step_run, PipelineRunSchema): + producer_run_id = pipeline_or_step_run.id + pipeline_run = pipeline_or_step_run.to_model(include_metadata=True) + configured_model = pipeline_run.config.model + substitutions = pipeline_run.config.substitutions + else: + producer_run_id = pipeline_or_step_run.pipeline_run_id + step_run = pipeline_or_step_run.to_model(include_metadata=True) + configured_model = step_run.config.model + substitutions = step_run.config.substitutions + + if not configured_model: + return None + + model_request = ModelRequest( + name=format_name_template( + configured_model.name, substitutions=substitutions + ), + license=configured_model.license, + description=configured_model.description, + audience=configured_model.audience, + use_cases=configured_model.use_cases, + limitations=configured_model.limitations, + trade_offs=configured_model.trade_offs, + ethics=configured_model.ethics, + save_models_to_registry=configured_model.save_models_to_registry, + user=pipeline_or_step_run.user_id, + workspace=pipeline_or_step_run.workspace_id, + ) + + _, model_response = self._get_or_create_model( + model_request=model_request + ) + + version_name = None + if configured_model.version is not None: + version_name = format_name_template( + str(configured_model.version), substitutions=substitutions + ) + + # If the model version was specified to be a numeric version or + # stage we don't try to create it (which will fail because it is not + # allowed) but try to fetch it immediately + if ( + version_name.isnumeric() + or version_name in ModelStages.values() + ): + return self._get_model_version( + model_id=model_response.id, version_name=version_name + ).id + + model_version_request = ModelVersionRequest( + model=model_response.id, + name=version_name, + description=configured_model.description, + tags=configured_model.tags, + user=pipeline_or_step_run.user_id, + workspace=pipeline_or_step_run.workspace_id, + ) + + _, model_version_response = self._get_or_create_model_version( + model_version_request=model_version_request, + producer_run_id=producer_run_id, + ) + return model_version_response.id + + def _create_model_version( + self, + model_version: ModelVersionRequest, + producer_run_id: Optional[UUID] = None, ) -> ModelVersionResponse: """Creates a new model version. Args: model_version: the Model Version to be created. + producer_run_id: ID of the pipeline run that produced this model + version. Returns: The newly created model version. Raises: - ValueError: If `number` is not None during model version creation. + ValueError: If the requested version name is invalid. EntityExistsError: If a model version with the given name already exists. EntityCreationError: If the model version creation failed. + RuntimeError: If an auto-incremented model version already exists + for the producer run. """ - if model_version.number is not None: - raise ValueError( - "`number` field must be None during model version creation." - ) + has_custom_name = False + if model_version.name: + has_custom_name = True + validate_name(model_version) - model = self.get_model(model_version.model) + if model_version.name.isnumeric(): + raise ValueError( + "Can't create model version with custom numeric model " + "version name." + ) - has_custom_name = model_version.name is not None - if has_custom_name: - validate_name(model_version) + if str(model_version.name).lower() in ModelStages.values(): + raise ValueError( + "Can't create model version with a name that is used as a " + f"model version stage ({ModelStages.values()})." + ) + model = self.get_model(model_version.model) model_version_id = None remaining_tries = MAX_RETRIES_FOR_VERSIONED_ENTITY_CREATION @@ -10298,17 +10569,19 @@ def create_model_version( remaining_tries -= 1 try: with Session(self.engine) as session: - model_version.number = ( + model_version_number = ( self._get_next_numeric_version_for_model( session=session, model_id=model.id, ) ) if not has_custom_name: - model_version.name = str(model_version.number) + model_version.name = str(model_version_number) model_version_schema = ModelVersionSchema.from_request( - model_version + model_version, + model_version_number=model_version_number, + producer_run_id=producer_run_id, ) session.add(model_version_schema) session.commit() @@ -10329,6 +10602,13 @@ def create_model_version( f"{model_version.name}): A model with the " "same name and version already exists." ) + elif producer_run_id and self._model_version_exists( + model_id=model.id, producer_run_id=producer_run_id + ): + raise RuntimeError( + "Auto-incremented model version already exists for " + f"producer run {producer_run_id}." + ) elif remaining_tries == 0: raise EntityCreationError( f"Failed to create version for model " @@ -10347,10 +10627,9 @@ def create_model_version( ) logger.debug( "Failed to create model version %s " - "(version %s) due to an integrity error. " + "due to an integrity error. " "Retrying in %f seconds.", model.name, - model_version.number, sleep_duration, ) time.sleep(sleep_duration) @@ -10365,6 +10644,20 @@ def create_model_version( return self.get_model_version(model_version_id) + @track_decorator(AnalyticsEvent.CREATED_MODEL_VERSION) + def create_model_version( + self, model_version: ModelVersionRequest + ) -> ModelVersionResponse: + """Creates a new model version. + + Args: + model_version: the Model Version to be created. + + Returns: + The newly created model version. + """ + return self._create_model_version(model_version=model_version) + def get_model_version( self, model_version_id: UUID, hydrate: bool = True ) -> ModelVersionResponse: diff --git a/tests/integration/functional/model/test_model_version.py b/tests/integration/functional/model/test_model_version.py index d16b9dc31bd..9d91e1eb76e 100644 --- a/tests/integration/functional/model/test_model_version.py +++ b/tests/integration/functional/model/test_model_version.py @@ -268,13 +268,11 @@ def test_model_fetch_model_and_version_latest(self): def test_init_stage_logic(self): """Test that if version is set to string contained in ModelStages user is informed about it.""" with ModelContext(create_model=False) as (mdl_name, _, _): - with mock.patch("zenml.model.model.logger.info") as logger: - mv = Model( - name=mdl_name, - version=ModelStages.PRODUCTION.value, - ) - logger.assert_called_once() - assert mv.version == ModelStages.PRODUCTION.value + mv = Model( + name=mdl_name, + version=ModelStages.PRODUCTION.value, + ) + assert mv.version == ModelStages.PRODUCTION.value mv = Model(name=mdl_name, version=ModelStages.PRODUCTION) assert mv.version == ModelStages.PRODUCTION diff --git a/tests/integration/functional/pipelines/test_pipeline_context.py b/tests/integration/functional/pipelines/test_pipeline_context.py index f070ca0272f..e3049d6cb59 100644 --- a/tests/integration/functional/pipelines/test_pipeline_context.py +++ b/tests/integration/functional/pipelines/test_pipeline_context.py @@ -93,7 +93,7 @@ def test_that_argument_as_get_artifact_of_model_in_pipeline_context_fails_if_not clean_client: "Client", ): producer_pipe(False) - with pytest.raises(RuntimeError): + with pytest.raises(KeyError): consumer_pipe() diff --git a/tests/integration/functional/steps/test_model_version.py b/tests/integration/functional/steps/test_model_version.py index 2100890bd8a..3990494a7f2 100644 --- a/tests/integration/functional/steps/test_model_version.py +++ b/tests/integration/functional/steps/test_model_version.py @@ -22,7 +22,7 @@ from zenml import get_pipeline_context, get_step_context, pipeline, step from zenml.artifacts.artifact_config import ArtifactConfig from zenml.client import Client -from zenml.enums import ModelStages +from zenml.enums import ExecutionStatus, ModelStages from zenml.model.model import Model @@ -571,7 +571,7 @@ def _inner_pipeline(): # this will run all steps, including one requesting new version run_1 = f"run_{uuid4()}" # model is configured with latest stage, so a warm-up needed - with pytest.raises(RuntimeError): + with pytest.raises(KeyError): _inner_pipeline.with_options(run_name=run_1)() run_2 = f"run_{uuid4()}" Model(name="step")._get_or_create_model_version() @@ -812,3 +812,147 @@ def _inner_pipeline(): assert "{time}" not in versions[1].version assert len(versions[1]._get_model_version().data_artifact_ids["data"]) == 2 assert versions[1].version != first_version_name + + +@step +def noop() -> None: + pass + + +def test_model_version_creation(clean_client: "Client"): + """Tests that model versions get created correctly for a pipeline run.""" + shared_model_name = random_resource_name() + custom_model_name = random_resource_name() + + @pipeline(model=Model(name=shared_model_name), enable_cache=False) + def _inner_pipeline(): + noop.with_options(model=Model(name=shared_model_name))(id="shared") + noop.with_options( + model=Model(name=shared_model_name, version="custom") + )(id="custom_version") + noop.with_options(model=Model(name=custom_model_name))( + id="custom_model" + ) + + run_1 = _inner_pipeline() + shared_versions = clean_client.list_model_versions(shared_model_name) + assert len(shared_versions) == 2 + implicit_version = shared_versions[-2] + explicit_version = shared_versions[-1] + + custom_versions = clean_client.list_model_versions(custom_model_name) + assert len(custom_versions) == 1 + custom_version = custom_versions[-1] + + assert run_1.model_version_id == implicit_version.id + for name, step_ in run_1.steps.items(): + if name == "shared": + assert step_.model_version_id == implicit_version.id + elif name == "custom_version": + assert step_.model_version_id == explicit_version.id + else: + assert step_.model_version_id == custom_version.id + links = clean_client.list_model_version_pipeline_run_links( + pipeline_run_id=run_1.id + ) + assert len(links) == 3 + + run_2 = _inner_pipeline() + shared_versions = clean_client.list_model_versions(shared_model_name) + assert len(shared_versions) == 3 + implicit_version = shared_versions[-1] + explicit_version = shared_versions[-2] + + custom_versions = clean_client.list_model_versions(custom_model_name) + assert len(custom_versions) == 2 + custom_version = custom_versions[-1] + + assert run_2.model_version_id == implicit_version.id + for name, step_ in run_2.steps.items(): + if name == "shared": + assert step_.model_version_id == implicit_version.id + elif name == "custom_version": + assert step_.model_version_id == explicit_version.id + else: + assert step_.model_version_id == custom_version.id + links = clean_client.list_model_version_pipeline_run_links( + pipeline_run_id=run_2.id + ) + assert len(links) == 3 + + # Run with caching enabled to see if everything still works + run_3 = _inner_pipeline.with_options(enable_cache=True)() + shared_versions = clean_client.list_model_versions(shared_model_name) + assert len(shared_versions) == 4 + implicit_version = shared_versions[-1] + explicit_version = shared_versions[-3] + + custom_versions = clean_client.list_model_versions(custom_model_name) + assert len(custom_versions) == 3 + custom_version = custom_versions[-1] + + assert run_3.model_version_id == implicit_version.id + for name, step_ in run_3.steps.items(): + assert step_.status == ExecutionStatus.CACHED + + if name == "shared": + assert step_.model_version_id == implicit_version.id + elif name == "custom_version": + assert step_.model_version_id == explicit_version.id + else: + assert step_.model_version_id == custom_version.id + links = clean_client.list_model_version_pipeline_run_links( + pipeline_run_id=run_3.id + ) + assert len(links) == 3 + + +def test_model_version_fetching_by_stage(clean_client: "Client"): + """Tests that model versions can be fetched by number or stage.""" + model_name = random_resource_name() + + @pipeline(model=Model(name=model_name), enable_cache=False) + def _creator_pipeline(): + noop() + + @pipeline(model=Model(name=model_name, version=1), enable_cache=False) + def _fetch_by_version_number_pipeline(): + noop() + + @pipeline( + model=Model(name=model_name, version="latest"), enable_cache=False + ) + def _fetch_latest_version_pipeline(): + noop() + + @pipeline( + model=Model(name=model_name, version="production"), enable_cache=False + ) + def _fetch_prod_version_pipeline(): + noop() + + with pytest.raises(KeyError): + _fetch_by_version_number_pipeline() + + with pytest.raises(KeyError): + _fetch_latest_version_pipeline() + + with pytest.raises(KeyError): + _fetch_prod_version_pipeline() + + _creator_pipeline() + _creator_pipeline() + + versions = clean_client.list_model_versions(model_name) + assert len(versions) == 2 + mv_1, mv_2 = versions + mv_1.set_stage("production") + + run = _fetch_by_version_number_pipeline() + assert run.model_version_id == mv_1.id + + run = _fetch_latest_version_pipeline() + assert run.model_version_id == mv_2.id + + run = _fetch_prod_version_pipeline() + assert run.model_version_id == mv_1.id diff --git a/tests/integration/integrations/gcp/orchestrators/test_vertex_orchestrator.py b/tests/integration/integrations/gcp/orchestrators/test_vertex_orchestrator.py index 6ffde7bddac..8e5c41f0d32 100644 --- a/tests/integration/integrations/gcp/orchestrators/test_vertex_orchestrator.py +++ b/tests/integration/integrations/gcp/orchestrators/test_vertex_orchestrator.py @@ -141,9 +141,13 @@ def test_vertex_orchestrator_stack_validation( {"cpu_limit": "4", "gpu_limit": 4, "memory_limit": "1G"}, { "accelerator": { + "count": "1", + "type": "NVIDIA_TESLA_K80", "resourceCount": "1", "resourceType": "NVIDIA_TESLA_K80", }, + "cpuLimit": 1.0, + "memoryLimit": 1.0, "resourceCpuLimit": "1.0", "resourceMemoryLimit": "1G", }, @@ -154,9 +158,13 @@ def test_vertex_orchestrator_stack_validation( {"cpu_limit": "1.0", "gpu_limit": 1, "memory_limit": "1G"}, { "accelerator": { + "count": "1", + "type": "NVIDIA_TESLA_K80", "resourceCount": "1", "resourceType": "NVIDIA_TESLA_K80", }, + "cpuLimit": 1.0, + "memoryLimit": 1.0, "resourceCpuLimit": "1.0", "resourceMemoryLimit": "1G", }, @@ -166,6 +174,8 @@ def test_vertex_orchestrator_stack_validation( ResourceSettings(cpu_count=1, gpu_count=None, memory="1GB"), {"cpu_limit": None, "gpu_limit": None, "memory_limit": None}, { + "cpuLimit": 1.0, + "memoryLimit": 1.0, "resourceCpuLimit": "1.0", "resourceMemoryLimit": "1G", }, @@ -174,7 +184,12 @@ def test_vertex_orchestrator_stack_validation( ( ResourceSettings(cpu_count=1, gpu_count=0, memory="1GB"), {"cpu_limit": None, "gpu_limit": None, "memory_limit": None}, - {"resourceCpuLimit": "1.0", "resourceMemoryLimit": "1G"}, + { + "cpuLimit": 1.0, + "memoryLimit": 1.0, + "resourceCpuLimit": "1.0", + "resourceMemoryLimit": "1G", + }, ), ], ) @@ -233,13 +248,16 @@ def _build_kfp_pipeline() -> None: job_spec = pipeline_json["deploymentSpec"]["executors"][ f"exec-{step_name}" ]["container"] + if "accelerator" in job_spec["resources"]: - if "count" in job_spec["resources"]["accelerator"]: - expected_resources["accelerator"]["count"] = expected_resources[ - "accelerator" - ]["resourceCount"] - if "type" in job_spec["resources"]["accelerator"]: - expected_resources["accelerator"]["type"] = expected_resources[ - "accelerator" - ]["resourceType"] + if "resourceCount" not in job_spec["resources"]["accelerator"]: + expected_resources["accelerator"].pop("resourceCount", None) + if "resourceType" not in job_spec["resources"]["accelerator"]: + expected_resources["accelerator"].pop("resourceType", None) + + if "resourceCpuLimit" not in job_spec["resources"]: + expected_resources.pop("resourceCpuLimit", None) + if "resourceMemoryLimit" not in job_spec["resources"]: + expected_resources.pop("resourceMemoryLimit", None) + assert job_spec["resources"] == expected_resources diff --git a/tests/unit/model/test_model_version_init.py b/tests/unit/model/test_model_version_init.py deleted file mode 100644 index 21009b96d5c..00000000000 --- a/tests/unit/model/test_model_version_init.py +++ /dev/null @@ -1,27 +0,0 @@ -from unittest.mock import patch - -import pytest - -from zenml.model.model import Model - - -@pytest.mark.parametrize( - "version_name,logger", - [ - ["staging", "info"], - ["1", "info"], - [1, "info"], - ], - ids=[ - "Pick model by text stage", - "Pick model by text version number", - "Pick model by integer version number", - ], -) -def test_init_warns(version_name, logger): - with patch(f"zenml.model.model.logger.{logger}") as logger: - Model( - name="foo", - version=version_name, - ) - logger.assert_called_once()