From 276232710eca2ae57d9e0649dcb167beecd94c6b Mon Sep 17 00:00:00 2001 From: Michael Schuster Date: Tue, 10 Dec 2024 10:09:57 +0100 Subject: [PATCH] Refactor --- ...c3_add_model_version_producer_run_index.py | 4 +- src/zenml/zen_stores/schemas/model_schemas.py | 11 ++-- src/zenml/zen_stores/sql_zen_store.py | 55 +++++++++---------- 3 files changed, 33 insertions(+), 37 deletions(-) diff --git a/src/zenml/zen_stores/migrations/versions/cd7b17d5a3c3_add_model_version_producer_run_index.py b/src/zenml/zen_stores/migrations/versions/cd7b17d5a3c3_add_model_version_producer_run_index.py index c810a3ac95..c97392fe22 100644 --- a/src/zenml/zen_stores/migrations/versions/cd7b17d5a3c3_add_model_version_producer_run_index.py +++ b/src/zenml/zen_stores/migrations/versions/cd7b17d5a3c3_add_model_version_producer_run_index.py @@ -38,7 +38,7 @@ def upgrade() -> None: ) batch_op.add_column( sa.Column( - "producer_run_id_with_fallback", + "producer_run_id_if_numeric", sa.CHAR(32), sa.Computed( "CASE WHEN producer_run_id IS NOT NULL AND is_numeric = TRUE THEN producer_run_id ELSE id END", @@ -48,7 +48,7 @@ def upgrade() -> None: ) batch_op.create_index( "unique_numeric_version_for_pipeline_run", - ["model_id", "is_numeric", "producer_run_id_with_fallback"], + ["model_id", "producer_run_id_if_numeric"], unique=True, ) diff --git a/src/zenml/zen_stores/schemas/model_schemas.py b/src/zenml/zen_stores/schemas/model_schemas.py index 4621de44fa..318ea4b74f 100644 --- a/src/zenml/zen_stores/schemas/model_schemas.py +++ b/src/zenml/zen_stores/schemas/model_schemas.py @@ -257,8 +257,7 @@ class ModelVersionSchema(NamedSchema, RunMetadataInterface, table=True): Index( "unique_numeric_version_for_pipeline_run", "model_id", - "is_numeric", - "producer_run_id_with_fallback", + "producer_run_id_if_numeric", unique=True, ), ) @@ -348,10 +347,10 @@ class ModelVersionSchema(NamedSchema, RunMetadataInterface, table=True): # If a value of a unique constraint is NULL it is ignored and the # remaining values in the unique constraint have to be unique. In # our case however, we only want the unique constraint applied in - # case there is a producer run. To solve this, we fallback to the - # model version ID (which is the primary key and therefore unique) - # in case there is no producer run. - producer_run_id_with_fallback: str = Field( + # case there is a producer run and only for numeric versions. To solve this, + # we fall back to the model version ID (which is the primary key and + # therefore unique) in case there is no producer run. + producer_run_id_if_numeric: str = Field( sa_column=Column( sa.CHAR(32), Computed( diff --git a/src/zenml/zen_stores/sql_zen_store.py b/src/zenml/zen_stores/sql_zen_store.py index d83d419352..a8ad7c01f2 100644 --- a/src/zenml/zen_stores/sql_zen_store.py +++ b/src/zenml/zen_stores/sql_zen_store.py @@ -10462,40 +10462,37 @@ def _get_or_create_model_version_for_run( except EntityExistsError: model_response = self.get_model(model_request.name) - # If the model version was specified to be a numeric version or stage - # we don't create it but try to fetch it instead - if isinstance(configured_model.version, int) or ( - isinstance(configured_model.version, str) - and configured_model.version.isnumeric() - ): - model_version_id = self._get_model_version_id( - model_id=model_response.id, - version_number=int(configured_model.version), - ) - if not model_version_id: - raise KeyError( - f"No version with number {configured_model.version} found " - f"for model {model_response.name}." - ) - return model_version_id - elif configured_model.version in ModelStages.values(): - model_version_id = self._get_model_version_id( - model_id=model_response.id, - version_stage=ModelStages(configured_model.version), - ) - if not model_version_id: - raise KeyError( - f"No {configured_model.version} stage version found for " - f"model {model_response.name}." - ) - return model_version_id - version_name = None - if isinstance(configured_model.version, str): + if configured_model.version is not None: version_name = format_name_template( str(configured_model.version), substitutions=substitutions ) + # If the model version was specified to be a numeric version or + # stage we don't create it but try to fetch it instead + if version_name.isnumeric(): + model_version_id = self._get_model_version_id( + model_id=model_response.id, + version_number=int(version_name), + ) + if not model_version_id: + raise KeyError( + f"No version with number {version_name} found " + f"for model {model_response.name}." + ) + return model_version_id + elif version_name in ModelStages.values(): + model_version_id = self._get_model_version_id( + model_id=model_response.id, + version_stage=ModelStages(version_name), + ) + if not model_version_id: + raise KeyError( + f"No {version_name} stage version found for " + f"model {model_response.name}." + ) + return model_version_id + model_version_request = ModelVersionRequest( model=model_response.id, name=version_name,