Skip to content

Commit

Permalink
More tests and fixes
Browse files Browse the repository at this point in the history
  • Loading branch information
schustmi committed Dec 9, 2024
1 parent 08b00ca commit a77bc53
Show file tree
Hide file tree
Showing 2 changed files with 63 additions and 10 deletions.
22 changes: 12 additions & 10 deletions src/zenml/zen_stores/sql_zen_store.py
Original file line number Diff line number Diff line change
Expand Up @@ -10393,9 +10393,14 @@ def _get_model_version_id(
query = query.where(ModelVersionSchema.number == version_number)

if version_stage:
query = query.where(
ModelVersionSchema.stage == version_stage.value
)
if version_stage == ModelStages.LATEST:
query = query.order_by(desc(ModelVersionSchema.number)).limit(
1
)
else:
query = query.where(
ModelVersionSchema.stage == version_stage.value
)

if producer_run_id:
query = query.where(
Expand Down Expand Up @@ -10469,16 +10474,18 @@ def _get_or_create_model_version_for_run(
f"No version with number {configured_model.version} found "
f"for model {model_response.name}."
)
elif isinstance(configured_model.version, ModelStages):
return model_version_id
elif configured_model.version in ModelStages.values():
model_version_id = self._get_model_version_id(
model_id=model_response.id,
version_stage=configured_model.version,
version_stage=ModelStages(configured_model.version),
)
if not model_version_id:
raise KeyError(
f"No {configured_model.version} stage version found for "
f"model {model_response.name}."
)
return model_version_id

version_name = None
if isinstance(configured_model.version, str):
Expand Down Expand Up @@ -10534,11 +10541,6 @@ def _create_model_version(
RuntimeError: If an auto-incremented model version already exists
for the producer run.
"""
if model_version.number is not None:
raise ValueError(
"The model version number can't be explicitly specified."
)

has_custom_name = False
if model_version.name:
has_custom_name = True
Expand Down
51 changes: 51 additions & 0 deletions tests/integration/functional/steps/test_model_version.py
Original file line number Diff line number Diff line change
Expand Up @@ -905,3 +905,54 @@ def _inner_pipeline():
pipeline_run_id=run_3.id
)
assert len(links) == 3


def test_model_version_fetching_by_stage(clean_client: "Client"):
"""Tests that model versions get created correctly for a pipeline run."""
model_name = random_resource_name()

@pipeline(model=Model(name=model_name), enable_cache=False)
def _creator_pipeline():
noop()

@pipeline(model=Model(name=model_name, version=1), enable_cache=False)
def _fetch_by_version_number_pipeline():
noop()

@pipeline(
model=Model(name=model_name, version="latest"), enable_cache=False
)
def _fetch_latest_version_pipeline():
noop()

@pipeline(
model=Model(name=model_name, version="production"), enable_cache=False
)
def _fetch_prod_version_pipeline():
noop()

with pytest.raises(KeyError):
_fetch_by_version_number_pipeline()

with pytest.raises(KeyError):
_fetch_latest_version_pipeline()

with pytest.raises(KeyError):
_fetch_prod_version_pipeline()

_creator_pipeline()
_creator_pipeline()

versions = clean_client.list_model_versions(model_name)
assert len(versions) == 2
mv_1, mv_2 = versions
mv_1.set_stage("production")

run = _fetch_by_version_number_pipeline()
assert run.model_version_id == mv_1.id

run = _fetch_latest_version_pipeline()
assert run.model_version_id == mv_2.id

run = _fetch_prod_version_pipeline()
assert run.model_version_id == mv_1.id

0 comments on commit a77bc53

Please sign in to comment.