Skip to content

Commit

Permalink
Some cleanup
Browse files Browse the repository at this point in the history
  • Loading branch information
schustmi committed Dec 9, 2024
1 parent c66a24d commit 08b00ca
Show file tree
Hide file tree
Showing 3 changed files with 15 additions and 18 deletions.
4 changes: 0 additions & 4 deletions src/zenml/models/v2/core/model_version.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,10 +77,6 @@ class ModelVersionRequest(WorkspaceScopedRequest):
default=None,
)

number: Optional[int] = Field(
description="The number of the model version",
default=None,
)
model: UUID = Field(
description="The ID of the model containing version",
)
Expand Down
10 changes: 8 additions & 2 deletions src/zenml/zen_stores/schemas/model_schemas.py
Original file line number Diff line number Diff line change
Expand Up @@ -370,12 +370,17 @@ class ModelVersionSchema(NamedSchema, RunMetadataInterface, table=True):

@classmethod
def from_request(
cls, model_version_request: ModelVersionRequest
cls,
model_version_request: ModelVersionRequest,
model_version_number: int,
producer_run_id: Optional[UUID] = None,
) -> "ModelVersionSchema":
"""Convert an `ModelVersionRequest` to an `ModelVersionSchema`.
Args:
model_version_request: The request model version to convert.
model_version_number: The model version number.
producer_run_id: The ID of the producer run.
Returns:
The converted schema.
Expand All @@ -385,7 +390,8 @@ def from_request(
user_id=model_version_request.user,
model_id=model_version_request.model,
name=model_version_request.name,
number=model_version_request.number,
number=model_version_number,
producer_run_id=producer_run_id,
description=model_version_request.description,
stage=model_version_request.stage,
)
Expand Down
19 changes: 7 additions & 12 deletions src/zenml/zen_stores/sql_zen_store.py
Original file line number Diff line number Diff line change
Expand Up @@ -10527,11 +10527,11 @@ def _create_model_version(
The newly created model version.
Raises:
ValueError: If `number` is not None during model version creation.
ValueError: If the requested version name is invalid.
EntityExistsError: If a model version with the given name already
exists.
EntityCreationError: If the model version creation failed.
RuntimeError: If a auto-incremented model version already exists
RuntimeError: If an auto-incremented model version already exists
for the producer run.
"""
if model_version.number is not None:
Expand Down Expand Up @@ -10564,19 +10564,20 @@ def _create_model_version(
remaining_tries -= 1
try:
with Session(self.engine) as session:
model_version.number = (
model_version_number = (
self._get_next_numeric_version_for_model(
session=session,
model_id=model.id,
)
)
if not has_custom_name:
model_version.name = str(model_version.number)
model_version.name = str(model_version_number)

model_version_schema = ModelVersionSchema.from_request(
model_version
model_version,
model_version_number=model_version_number,
producer_run_id=producer_run_id,
)
model_version_schema.producer_run_id = producer_run_id
session.add(model_version_schema)
session.commit()

Expand Down Expand Up @@ -10650,12 +10651,6 @@ def create_model_version(
Returns:
The newly created model version.
Raises:
ValueError: If `number` is not None during model version creation.
EntityExistsError: If a model version with the given name already
exists.
EntityCreationError: If the model version creation failed.
"""
return self._create_model_version(model_version=model_version)

Expand Down

0 comments on commit 08b00ca

Please sign in to comment.