Skip to content

Commit

Permalink
Refactor
Browse files Browse the repository at this point in the history
  • Loading branch information
schustmi committed Dec 10, 2024
1 parent 92e76e9 commit 2762327
Show file tree
Hide file tree
Showing 3 changed files with 33 additions and 37 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ def upgrade() -> None:
)
batch_op.add_column(
sa.Column(
"producer_run_id_with_fallback",
"producer_run_id_if_numeric",
sa.CHAR(32),
sa.Computed(
"CASE WHEN producer_run_id IS NOT NULL AND is_numeric = TRUE THEN producer_run_id ELSE id END",
Expand All @@ -48,7 +48,7 @@ def upgrade() -> None:
)
batch_op.create_index(
"unique_numeric_version_for_pipeline_run",
["model_id", "is_numeric", "producer_run_id_with_fallback"],
["model_id", "producer_run_id_if_numeric"],
unique=True,
)

Expand Down
11 changes: 5 additions & 6 deletions src/zenml/zen_stores/schemas/model_schemas.py
Original file line number Diff line number Diff line change
Expand Up @@ -257,8 +257,7 @@ class ModelVersionSchema(NamedSchema, RunMetadataInterface, table=True):
Index(
"unique_numeric_version_for_pipeline_run",
"model_id",
"is_numeric",
"producer_run_id_with_fallback",
"producer_run_id_if_numeric",
unique=True,
),
)
Expand Down Expand Up @@ -348,10 +347,10 @@ class ModelVersionSchema(NamedSchema, RunMetadataInterface, table=True):
# If a value of a unique constraint is NULL it is ignored and the
# remaining values in the unique constraint have to be unique. In
# our case however, we only want the unique constraint applied in
# case there is a producer run. To solve this, we fallback to the
# model version ID (which is the primary key and therefore unique)
# in case there is no producer run.
producer_run_id_with_fallback: str = Field(
# case there is a producer run and only for numeric versions. To solve this,
# we fall back to the model version ID (which is the primary key and
# therefore unique) in case there is no producer run.
producer_run_id_if_numeric: str = Field(
sa_column=Column(
sa.CHAR(32),
Computed(
Expand Down
55 changes: 26 additions & 29 deletions src/zenml/zen_stores/sql_zen_store.py
Original file line number Diff line number Diff line change
Expand Up @@ -10462,40 +10462,37 @@ def _get_or_create_model_version_for_run(
except EntityExistsError:
model_response = self.get_model(model_request.name)

# If the model version was specified to be a numeric version or stage
# we don't create it but try to fetch it instead
if isinstance(configured_model.version, int) or (
isinstance(configured_model.version, str)
and configured_model.version.isnumeric()
):
model_version_id = self._get_model_version_id(
model_id=model_response.id,
version_number=int(configured_model.version),
)
if not model_version_id:
raise KeyError(
f"No version with number {configured_model.version} found "
f"for model {model_response.name}."
)
return model_version_id
elif configured_model.version in ModelStages.values():
model_version_id = self._get_model_version_id(
model_id=model_response.id,
version_stage=ModelStages(configured_model.version),
)
if not model_version_id:
raise KeyError(
f"No {configured_model.version} stage version found for "
f"model {model_response.name}."
)
return model_version_id

version_name = None
if isinstance(configured_model.version, str):
if configured_model.version is not None:
version_name = format_name_template(
str(configured_model.version), substitutions=substitutions
)

# If the model version was specified to be a numeric version or
# stage we don't create it but try to fetch it instead
if version_name.isnumeric():
model_version_id = self._get_model_version_id(
model_id=model_response.id,
version_number=int(version_name),
)
if not model_version_id:
raise KeyError(
f"No version with number {version_name} found "
f"for model {model_response.name}."
)
return model_version_id
elif version_name in ModelStages.values():
model_version_id = self._get_model_version_id(
model_id=model_response.id,
version_stage=ModelStages(version_name),
)
if not model_version_id:
raise KeyError(
f"No {version_name} stage version found for "
f"model {model_response.name}."
)
return model_version_id

model_version_request = ModelVersionRequest(
model=model_response.id,
name=version_name,
Expand Down

0 comments on commit 2762327

Please sign in to comment.