diff --git a/numalogic/registry/mlflow_registry.py b/numalogic/registry/mlflow_registry.py index 62948928..75c4bd7a 100644 --- a/numalogic/registry/mlflow_registry.py +++ b/numalogic/registry/mlflow_registry.py @@ -193,7 +193,7 @@ def load_multiple( self, skeys: KEYS, dkeys_list: list[list[str]], - ) -> Optional[dict[str, ArtifactData]]: + ) -> Optional[ArtifactData]: """ Load multiple artifacts from the registry for pyfunc models. Args: @@ -203,24 +203,23 @@ def load_multiple( Returns ------- - Optional[dict[str, ArtifactData]]: A dictionary mapping joined dynamic keys - to the loaded artifacts, or None if no artifacts were found. + Optional[ArtifactData]: The loaded ArtifactData object if available otherwise None. + ArtifactData should contain a dictionary of artifacts. """ dkeys = self.__get_sorted_unique_dkeys(dkeys_list) loaded_model = self.load(skeys=skeys, dkeys=dkeys, artifact_type="pyfunc") if loaded_model is not None: - metadata = loaded_model.artifact.unwrap_python_model().metadata - dict_artifacts = loaded_model.artifact.unwrap_python_model().dict_artifacts - artifacts_dict = {} - for artifact in dict_artifacts.values(): - artifact_data = ArtifactData( - artifact=artifact.artifact, metadata=metadata, extras=None - ) - dynamic_key = ":".join(artifact.dkeys) - artifacts_dict[dynamic_key] = artifact_data - else: - artifacts_dict = None - return artifacts_dict + try: + unwrapped_composite_model = loaded_model.artifact.unwrap_python_model() + except Exception: + _LOGGER.exception("Error occurred while unwrapping python model") + return None + + dict_artifacts = unwrapped_composite_model.dict_artifacts + metadata = loaded_model.metadata + version_info = loaded_model.extras + return ArtifactData(artifact=dict_artifacts, metadata=metadata, extras=version_info) + return None @staticmethod def __log_mlflow_err(mlflow_err: RestException, model_key: str) -> None: @@ -297,14 +296,14 @@ def save_multiple( mlflow ModelVersion instance """ multiple_artifacts = CompositeModels(skeys=skeys, dict_artifacts=dict_artifacts, **metadata) - dkeys_list = multiple_artifacts.get_dkeys_list() + dkeys_list = multiple_artifacts._get_dkeys_list() sorted_dkeys = self.__get_sorted_unique_dkeys(dkeys_list) return self.save( - skeys=multiple_artifacts.skeys, + skeys=skeys, dkeys=sorted_dkeys, artifact=multiple_artifacts, artifact_type="pyfunc", - metadata=multiple_artifacts.metadata, + **metadata, ) @staticmethod @@ -449,12 +448,14 @@ class CompositeModels(mlflow.pyfunc.PythonModel): metadata (META_VT): Additional metadata associated with the artifacts. """ + __slots__ = ("skeys", "dict_artifacts", "metadata") + def __init__(self, skeys: KEYS, dict_artifacts: dict[str, KeyedArtifact], **metadata: META_VT): self.skeys = skeys self.dict_artifacts = dict_artifacts self.metadata = metadata - def get_dkeys_list(self): + def _get_dkeys_list(self): """ Returns a list of all dynamic keys in the stored artifacts. diff --git a/pyproject.toml b/pyproject.toml index 222b544e..5fb0804b 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [tool.poetry] name = "numalogic" -version = "0.13.2" +version = "0.13.3" description = "Collection of operational Machine Learning models and tools." authors = ["Numalogic Developers"] packages = [{ include = "numalogic" }] diff --git a/tests/registry/test_mlflow_registry.py b/tests/registry/test_mlflow_registry.py index de2cbd44..f755e0a1 100644 --- a/tests/registry/test_mlflow_registry.py +++ b/tests/registry/test_mlflow_registry.py @@ -457,6 +457,21 @@ def test_cache_loading(self): key = MLflowRegistry.construct_key(self.skeys, self.dkeys) self.assertIsNotNone(ml._load_from_cache(key)) + @patch("mlflow.start_run", Mock(return_value=ActiveRun(return_pyfunc_rundata()))) + @patch("mlflow.active_run", Mock(return_value=return_pyfunc_rundata())) + @patch("mlflow.tracking.MlflowClient.transition_model_version_stage", mock_transition_stage) + @patch("mlflow.tracking.MlflowClient.get_latest_versions", mock_get_model_version) + @patch("mlflow.tracking.MlflowClient.get_run", Mock(return_value=return_pyfunc_rundata())) + @patch("mlflow.pyfunc.load_model", mock_load_model_pyfunc) + def test_cache_loading_pyfunc(self): + cache_registry = LocalLRUCache(ttl=50000) + ml = MLflowRegistry(TRACKING_URI, cache_registry=cache_registry) + dkeys_list = [["AE", "infer"], ["scaler", "infer"]] + ml.load_multiple(skeys=self.skeys, dkeys_list=dkeys_list) + unique_sorted_dkeys = ["AE", "infer", "scaler"] + key = MLflowRegistry.construct_key(self.skeys, unique_sorted_dkeys) + self.assertIsNotNone(ml._load_from_cache(key)) + if __name__ == "__main__": unittest.main()