diff --git a/src/zenml/zen_stores/sql_zen_store.py b/src/zenml/zen_stores/sql_zen_store.py index 5fb0deb89c..cae03a05c3 100644 --- a/src/zenml/zen_stores/sql_zen_store.py +++ b/src/zenml/zen_stores/sql_zen_store.py @@ -10360,47 +10360,51 @@ def _model_version_exists(self, model_id: UUID, version: str) -> bool: is not None ) - def _get_model_version_id_for_producer_run( - self, model_id: UUID, producer_run_id: UUID + def _get_model_version_id( + self, + model_id: UUID, + version_name: Optional[str] = None, + version_number: Optional[int] = None, + version_stage: Optional[ModelStages] = None, + producer_run_id: Optional[UUID] = None, ) -> Optional[UUID]: - """Get the ID of the numeric model version created for a pipeline run. + """Get the model version ID. Args: - model_id: The model ID. - producer_run_id: The ID of the producer pipeline run. + model_id: The ID of the model. + version_name: The name of the model version. + version_number: The number of the model version. + version_stage: The stage of the model version. + producer_run_id: The ID of the producer pipeline run. If this is + set, only numeric versions created as part of the pipeline run + will be returned. Returns: - ID of the numeric model version that was created for the pipeline - run, or None if no numeric model version was created yet. + The ID of the model version. """ - with Session(self.engine) as session: - query = select(ModelVersionSchema.id).where( - ModelVersionSchema.producer_run_id == producer_run_id, - ModelVersionSchema.is_numeric.is_(True), - ModelVersionSchema.model_id == model_id, - ) + query = select(ModelVersionSchema.id).where( + ModelVersionSchema.model_id == model_id + ) - return session.exec(query).one_or_none() + if version_name: + query = query.where(ModelVersionSchema.name == version_name) - def _get_model_version_id_for_name( - self, model_id: UUID, version_name: str - ) -> UUID: - """Get the model version ID by name + if version_number: + query = query.where(ModelVersionSchema.number == version_number) - Args: - model_id: The ID of the model. - version_name: The name of the model version. + if version_stage: + query = query.where( + ModelVersionSchema.stage == version_stage.value + ) - Returns: - The ID of the model version. - """ - with Session(self.engine) as session: - query = select(ModelVersionSchema.id).where( - ModelVersionSchema.model_id == model_id, - ModelVersionSchema.name == version_name, + if producer_run_id: + query = query.where( + ModelVersionSchema.producer_run_id == producer_run_id, + ModelVersionSchema.is_numeric.is_(True), ) - return session.exec(query).one() + with Session(self.engine) as session: + return session.exec(query).one_or_none() def _get_or_create_model_version_for_run( self, pipeline_or_step_run: Union[PipelineRunSchema, StepRunSchema] @@ -10411,6 +10415,10 @@ def _get_or_create_model_version_for_run( pipeline_or_step_run: The pipeline or step run for which to create the model version. + Raises: + KeyError: If the configured model for the run specifies a version + number or stage for which no model version exists. + Returns: The model version. """ @@ -10449,8 +10457,31 @@ def _get_or_create_model_version_for_run( except EntityExistsError: model_response = self.get_model(model_request.name) + # If the model version was specified to be a numeric version or stage + # we don't create it but try to fetch it instead + if isinstance(configured_model.version, int): + model_version_id = self._get_model_version_id( + model_id=model_response.id, + version_number=configured_model.version, + ) + if not model_version_id: + raise KeyError( + f"No version with number {configured_model.version} found " + f"for model {model_response.name}." + ) + elif isinstance(configured_model.version, ModelStages): + model_version_id = self._get_model_version_id( + model_id=model_response.id, + version_stage=configured_model.version, + ) + if not model_version_id: + raise KeyError( + f"No {configured_model.version} stage version found for " + f"model {model_response.name}." + ) + version_name = None - if configured_model.version is not None: + if isinstance(configured_model.version, str): version_name = format_name_template( str(configured_model.version), substitutions=substitutions ) @@ -10465,17 +10496,22 @@ def _get_or_create_model_version_for_run( ) try: - return self._create_model_version_or_get_from_producer_run( + model_version_id = self._create_model_version( model_version=model_version_request, producer_run_id=producer_run_id, ).id + track(event=AnalyticsEvent.CREATED_MODEL_VERSION) + return model_version_id + except RuntimeError: + return self._get_model_version_id( + model_id=model_response.id, producer_run_id=producer_run_id + ) except EntityExistsError: - return self._get_model_version_id_for_name( + return self._get_model_version_id( model_id=model_response.id, version_name=version_name ) - # TODO: figure out analytics - def _create_model_version_or_get_from_producer_run( + def _create_model_version( self, model_version: ModelVersionRequest, producer_run_id: Optional[UUID] = None, @@ -10484,6 +10520,8 @@ def _create_model_version_or_get_from_producer_run( Args: model_version: the Model Version to be created. + producer_run_id: ID of the pipeline run that produced this model + version. Returns: The newly created model version. @@ -10493,23 +10531,32 @@ def _create_model_version_or_get_from_producer_run( EntityExistsError: If a model version with the given name already exists. EntityCreationError: If the model version creation failed. + RuntimeError: If a auto-incremented model version already exists + for the producer run. """ if model_version.number is not None: raise ValueError( "The model version number can't be explicitly specified." ) - if model_version.name and model_version.name.isnumeric(): - raise ValueError( - "Custom numeric model version names are now allowed." - ) + has_custom_name = False + if model_version.name: + has_custom_name = True + validate_name(model_version) - model = self.get_model(model_version.model) + if model_version.name.isnumeric(): + raise ValueError( + "Can't create model version with custom numeric model " + "version name." + ) - has_custom_name = model_version.name is not None - if has_custom_name: - validate_name(model_version) + if str(model_version.name).lower() in ModelStages.values(): + raise ValueError( + "Can't create model version with a name that is used as a " + f"model version stage ({ModelStages.values()})." + ) + model = self.get_model(model_version.model) model_version_id = None remaining_tries = MAX_RETRIES_FOR_VERSIONED_ENTITY_CREATION @@ -10549,15 +10596,16 @@ def _create_model_version_or_get_from_producer_run( f"{model_version.name}): A model with the " "same name and version already exists." ) - elif producer_run_id and ( - model_version_id - := self._get_model_version_id_for_producer_run( + elif ( + producer_run_id + and self._get_model_version_id_for_producer_run( model_id=model.id, producer_run_id=producer_run_id ) ): - # There is a numeric model version for the same model and - # producer run already -> We return that one - break + raise RuntimeError( + "Auto-incremented model version already exists for " + f"producer run {producer_run_id}." + ) elif remaining_tries == 0: raise EntityCreationError( f"Failed to create version for model " @@ -10612,90 +10660,7 @@ def create_model_version( exists. EntityCreationError: If the model version creation failed. """ - if model_version.number is not None: - raise ValueError( - "`number` field must be None during model version creation." - ) - - model = self.get_model(model_version.model) - - has_custom_name = model_version.name is not None - if has_custom_name: - validate_name(model_version) - - model_version_id = None - - remaining_tries = MAX_RETRIES_FOR_VERSIONED_ENTITY_CREATION - while remaining_tries > 0: - remaining_tries -= 1 - try: - with Session(self.engine) as session: - model_version.number = ( - self._get_next_numeric_version_for_model( - session=session, - model_id=model.id, - ) - ) - if not has_custom_name: - model_version.name = str(model_version.number) - - model_version_schema = ModelVersionSchema.from_request( - model_version - ) - session.add(model_version_schema) - session.commit() - - model_version_id = model_version_schema.id - break - except IntegrityError: - if has_custom_name and self._model_version_exists( - model_id=model.id, version=cast(str, model_version.name) - ): - # We failed not because of a version number conflict, - # but because the user requested a version name that - # is already taken -> We don't retry anymore but fail - # immediately. - raise EntityExistsError( - f"Unable to create model version " - f"{model.name} (version " - f"{model_version.name}): A model with the " - "same name and version already exists." - ) - elif remaining_tries == 0: - raise EntityCreationError( - f"Failed to create version for model " - f"{model.name}. This is most likely " - "caused by multiple parallel requests that try " - "to create versions for this model in the " - "database." - ) - else: - attempt = ( - MAX_RETRIES_FOR_VERSIONED_ENTITY_CREATION - - remaining_tries - ) - sleep_duration = exponential_backoff_with_jitter( - attempt=attempt - ) - logger.debug( - "Failed to create model version %s " - "(version %s) due to an integrity error. " - "Retrying in %f seconds.", - model.name, - model_version.number, - sleep_duration, - ) - time.sleep(sleep_duration) - - assert model_version_id - if model_version.tags: - self._attach_tags_to_resource( - tag_names=model_version.tags, - resource_id=model_version_id, - resource_type=TaggableResourceTypes.MODEL_VERSION, - ) - - return self.get_model_version(model_version_id) + return self._create_model_version(model_version=model_version) def get_model_version( self, model_version_id: UUID, hydrate: bool = True