Skip to content

Commit

Permalink
Create model versions server-side to avoid race conditions (#3254)
Browse files Browse the repository at this point in the history
* POC

* Log model version if created

* Fetch model version if it already exists

* Test

* Missing docstring

* Refactor fetching model version into separate method

* Remove FK cycle

* Add DB migration

* Cleanup

* Move to computed property to solve alembic issue

* mypy

* Some cleanup

* More tests and fixes

* Typo

* Move logs so they don't happen when creating the object

* Always log model version

* Mysql fix

* Linting

* More linting

* Make columns non-nullable

* Formatting

* Fix some tests

* Refactor

* Fix comparison for mysql

* Implement RBAC checks inside SQL ZenStore

* Docstrings

* Try to fix vertex test

* Fix alembic order

* Ignore generated columns in DB migration

* Linting

* Rework entire DB stuff to work with mariadb

* Revert change to DB backup

* Add missing null check

* Try to fix vertex again
  • Loading branch information
schustmi authored Dec 13, 2024
1 parent 941dc81 commit d6fae4e
Show file tree
Hide file tree
Showing 15 changed files with 828 additions and 325 deletions.
28 changes: 12 additions & 16 deletions src/zenml/model/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -509,22 +509,6 @@ def _root_validator(cls, data: Dict[str, Any]) -> Dict[str, Any]:
raise ValueError(
"`model_version_id` field is for internal use only"
)

version = data.get("version", None)

if (
version in [stage.value for stage in ModelStages]
and not suppress_class_validation_warnings
):
logger.info(
f"Version `{version}` matches one of the possible "
"`ModelStages` and will be fetched using stage."
)
if str(version).isnumeric() and not suppress_class_validation_warnings:
logger.info(
f"`version` `{version}` is numeric and will be fetched "
"using version number."
)
data["suppress_class_validation_warnings"] = True
return data

Expand Down Expand Up @@ -603,6 +587,18 @@ def _get_model_version(
hydrate=hydrate,
)
else:
if self.version in ModelStages.values():
logger.info(
f"Version `{self.version}` for model {self.name} matches "
"one of the possible `ModelStages` and will be fetched "
"using stage."
)
if str(self.version).isnumeric():
logger.info(
f"Version `{self.version}` for model {self.name} is "
"numeric and will be fetched using version number."
)

mv = zenml_client.get_model_version(
model_name_or_id=self.name,
model_version_name_or_number_or_id=self.version,
Expand Down
4 changes: 0 additions & 4 deletions src/zenml/models/v2/core/model_version.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,10 +77,6 @@ class ModelVersionRequest(WorkspaceScopedRequest):
default=None,
)

number: Optional[int] = Field(
description="The number of the model version",
default=None,
)
model: UUID = Field(
description="The ID of the model containing version",
)
Expand Down
22 changes: 9 additions & 13 deletions src/zenml/orchestrators/step_launcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -179,12 +179,10 @@ def launch(self) -> None:
pipeline_run_id=pipeline_run.id,
pipeline_run_metadata=pipeline_run_metadata,
)

pipeline_model_version, pipeline_run = (
step_run_utils.prepare_pipeline_run_model_version(
pipeline_run
)
)
if model_version := pipeline_run.model_version:
step_run_utils.log_model_version_dashboard_url(
model_version=model_version
)

request_factory = step_run_utils.StepRunRequestFactory(
deployment=self._deployment,
Expand All @@ -209,12 +207,10 @@ def launch(self) -> None:
step_run = Client().zen_store.create_run_step(
step_run_request
)

step_model_version, step_run = (
step_run_utils.prepare_step_run_model_version(
step_run=step_run, pipeline_run=pipeline_run
if model_version := step_run.model_version:
step_run_utils.log_model_version_dashboard_url(
model_version=model_version
)
)

if not step_run.status.is_finished:
logger.info(f"Step `{self._step_name}` has started.")
Expand Down Expand Up @@ -289,8 +285,8 @@ def _bypass() -> None:
f"Using cached version of step `{self._step_name}`."
)
if (
model_version := step_model_version
or pipeline_model_version
model_version := step_run.model_version
or pipeline_run.model_version
):
step_run_utils.link_output_artifacts_to_model_version(
artifacts=step_run.outputs,
Expand Down
212 changes: 8 additions & 204 deletions src/zenml/orchestrators/step_run_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
"""Utilities for creating step runs."""

from datetime import datetime
from typing import TYPE_CHECKING, Dict, List, Optional, Set, Tuple
from typing import Dict, List, Optional, Set, Tuple

from zenml.client import Client
from zenml.config.step_configurations import Step
Expand All @@ -24,21 +24,13 @@
from zenml.model.utils import link_artifact_version_to_model_version
from zenml.models import (
ArtifactVersionResponse,
ModelVersionPipelineRunRequest,
ModelVersionResponse,
PipelineDeploymentResponse,
PipelineRunResponse,
PipelineRunUpdate,
StepRunRequest,
StepRunResponse,
StepRunUpdate,
)
from zenml.orchestrators import cache_utils, input_utils, utils
from zenml.stack import Stack
from zenml.utils import pagination_utils, string_utils

if TYPE_CHECKING:
from zenml.model.model import Model

logger = get_logger(__name__)

Expand Down Expand Up @@ -293,10 +285,6 @@ def create_cached_step_runs(
deployment=deployment, pipeline_run=pipeline_run, stack=stack
)

pipeline_model_version, pipeline_run = prepare_pipeline_run_model_version(
pipeline_run=pipeline_run
)

while (
cache_candidates := find_cacheable_invocation_candidates(
deployment=deployment,
Expand All @@ -311,7 +299,9 @@ def create_cached_step_runs(

# Make sure the request factory has the most up to date pipeline
# run to avoid hydration calls
request_factory.pipeline_run = pipeline_run
request_factory.pipeline_run = Client().get_pipeline_run(
pipeline_run.id
)
try:
step_run_request = request_factory.create_request(
invocation_id
Expand All @@ -336,15 +326,10 @@ def create_cached_step_runs(

step_run = Client().zen_store.create_run_step(step_run_request)

# Refresh the pipeline run here to make sure we have the latest
# state
pipeline_run = Client().get_pipeline_run(pipeline_run.id)

step_model_version, step_run = prepare_step_run_model_version(
step_run=step_run, pipeline_run=pipeline_run
)

if model_version := step_model_version or pipeline_model_version:
if (
model_version := step_run.model_version
or pipeline_run.model_version
):
link_output_artifacts_to_model_version(
artifacts=step_run.outputs,
model_version=model_version,
Expand All @@ -356,169 +341,6 @@ def create_cached_step_runs(
return cached_invocations


def get_or_create_model_version_for_pipeline_run(
model: "Model",
pipeline_run: PipelineRunResponse,
substitutions: Dict[str, str],
) -> Tuple[ModelVersionResponse, bool]:
"""Get or create a model version as part of a pipeline run.
Args:
model: The model to get or create.
pipeline_run: The pipeline run for which the model should be created.
substitutions: Substitutions to apply to the model version name.
Returns:
The model version and a boolean indicating whether it was newly created
or not.
"""
# Copy the model before modifying it so we don't accidently modify
# configurations in which the model object is potentially referenced
model = model.model_copy()

if model.model_version_id:
return model._get_model_version(), False
elif model.version:
if isinstance(model.version, str):
model.version = string_utils.format_name_template(
model.version,
substitutions=substitutions,
)
model.name = string_utils.format_name_template(
model.name,
substitutions=substitutions,
)

return (
model._get_or_create_model_version(),
model._created_model_version,
)

# The model version should be created as part of this run
# -> We first check if it was already created as part of this run, and if
# not we do create it. If this is running in two parallel steps, we might
# run into issues that this will create two versions. Ideally, all model
# versions required for a pipeline run and its steps could be created
# server-side at run creation time before the first step starts.
if model_version := get_model_version_created_by_pipeline_run(
model_name=model.name, pipeline_run=pipeline_run
):
return model_version, False
else:
return model._get_or_create_model_version(), True


def get_model_version_created_by_pipeline_run(
model_name: str, pipeline_run: PipelineRunResponse
) -> Optional[ModelVersionResponse]:
"""Get a model version that was created by a specific pipeline run.
This function does not refresh the pipeline run, so it will only try to
fetch the model version from existing steps if they're already part of the
response.
Args:
model_name: The model name for which to get the version.
pipeline_run: The pipeline run for which to get the version.
Returns:
A model version with the given name created by the run, or None if such
a model version does not exist.
"""
if pipeline_run.config.model and pipeline_run.model_version:
if (
pipeline_run.config.model.name == model_name
and pipeline_run.config.model.version is None
):
return pipeline_run.model_version

# We fetch a list of hydrated step runs here in order to avoid hydration
# calls for each step separately.
candidate_step_runs = pagination_utils.depaginate(
Client().list_run_steps,
pipeline_run_id=pipeline_run.id,
model=model_name,
hydrate=True,
)
for step_run in candidate_step_runs:
if step_run.config.model and step_run.model_version:
if (
step_run.config.model.name == model_name
and step_run.config.model.version is None
):
return step_run.model_version

return None


def prepare_pipeline_run_model_version(
pipeline_run: PipelineRunResponse,
) -> Tuple[Optional[ModelVersionResponse], PipelineRunResponse]:
"""Prepare the model version for a pipeline run.
Args:
pipeline_run: The pipeline run for which to prepare the model version.
Returns:
The prepared model version and the updated pipeline run.
"""
model_version = None

if pipeline_run.model_version:
model_version = pipeline_run.model_version
elif config_model := pipeline_run.config.model:
model_version, _ = get_or_create_model_version_for_pipeline_run(
model=config_model,
pipeline_run=pipeline_run,
substitutions=pipeline_run.config.substitutions,
)
pipeline_run = Client().zen_store.update_run(
run_id=pipeline_run.id,
run_update=PipelineRunUpdate(model_version_id=model_version.id),
)
link_pipeline_run_to_model_version(
pipeline_run=pipeline_run, model_version=model_version
)
log_model_version_dashboard_url(model_version)

return model_version, pipeline_run


def prepare_step_run_model_version(
step_run: StepRunResponse, pipeline_run: PipelineRunResponse
) -> Tuple[Optional[ModelVersionResponse], StepRunResponse]:
"""Prepare the model version for a step run.
Args:
step_run: The step run for which to prepare the model version.
pipeline_run: The pipeline run of the step.
Returns:
The prepared model version and the updated step run.
"""
model_version = None

if step_run.model_version:
model_version = step_run.model_version
elif config_model := step_run.config.model:
model_version, created = get_or_create_model_version_for_pipeline_run(
model=config_model,
pipeline_run=pipeline_run,
substitutions=step_run.config.substitutions,
)
step_run = Client().zen_store.update_run_step(
step_run_id=step_run.id,
step_run_update=StepRunUpdate(model_version_id=model_version.id),
)
link_pipeline_run_to_model_version(
pipeline_run=pipeline_run, model_version=model_version
)
if created:
log_model_version_dashboard_url(model_version)

return model_version, step_run


def log_model_version_dashboard_url(
model_version: ModelVersionResponse,
) -> None:
Expand Down Expand Up @@ -546,24 +368,6 @@ def log_model_version_dashboard_url(
)


def link_pipeline_run_to_model_version(
pipeline_run: PipelineRunResponse, model_version: ModelVersionResponse
) -> None:
"""Link a pipeline run to a model version.
Args:
pipeline_run: The pipeline run to link.
model_version: The model version to link.
"""
client = Client()
client.zen_store.create_model_version_pipeline_run_link(
ModelVersionPipelineRunRequest(
pipeline_run=pipeline_run.id,
model_version=model_version.id,
)
)


def link_output_artifacts_to_model_version(
artifacts: Dict[str, List[ArtifactVersionResponse]],
model_version: ModelVersionResponse,
Expand Down
Loading

0 comments on commit d6fae4e

Please sign in to comment.