Skip to content

Commit

Permalink
fix: load_multiple
Browse files Browse the repository at this point in the history
Signed-off-by: Leila Wang <[email protected]>
  • Loading branch information
yleilawang committed Sep 23, 2024
1 parent e8b190a commit 8db67b1
Show file tree
Hide file tree
Showing 3 changed files with 36 additions and 20 deletions.
39 changes: 20 additions & 19 deletions numalogic/registry/mlflow_registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -193,7 +193,7 @@ def load_multiple(
self,
skeys: KEYS,
dkeys_list: list[list[str]],
) -> Optional[dict[str, ArtifactData]]:
) -> Optional[ArtifactData]:
"""
Load multiple artifacts from the registry for pyfunc models.
Args:
Expand All @@ -203,24 +203,23 @@ def load_multiple(
Returns
-------
Optional[dict[str, ArtifactData]]: A dictionary mapping joined dynamic keys
to the loaded artifacts, or None if no artifacts were found.
Optional[ArtifactData]: The loaded ArtifactData object if available otherwise None.
ArtifactData should contain a dictionary of artifacts.
"""
dkeys = self.__get_sorted_unique_dkeys(dkeys_list)
loaded_model = self.load(skeys=skeys, dkeys=dkeys, artifact_type="pyfunc")
if loaded_model is not None:
metadata = loaded_model.artifact.unwrap_python_model().metadata
dict_artifacts = loaded_model.artifact.unwrap_python_model().dict_artifacts
artifacts_dict = {}
for artifact in dict_artifacts.values():
artifact_data = ArtifactData(
artifact=artifact.artifact, metadata=metadata, extras=None
)
dynamic_key = ":".join(artifact.dkeys)
artifacts_dict[dynamic_key] = artifact_data
else:
artifacts_dict = None
return artifacts_dict
try:
unwrapped_composite_model = loaded_model.artifact.unwrap_python_model()
except Exception:
_LOGGER.exception("Error occurred while unwrapping python model")
return None

dict_artifacts = unwrapped_composite_model.dict_artifacts
metadata = loaded_model.metadata
version_info = loaded_model.extras
return ArtifactData(artifact=dict_artifacts, metadata=metadata, extras=version_info)
return None

@staticmethod
def __log_mlflow_err(mlflow_err: RestException, model_key: str) -> None:
Expand Down Expand Up @@ -297,14 +296,14 @@ def save_multiple(
mlflow ModelVersion instance
"""
multiple_artifacts = CompositeModels(skeys=skeys, dict_artifacts=dict_artifacts, **metadata)
dkeys_list = multiple_artifacts.get_dkeys_list()
dkeys_list = multiple_artifacts._get_dkeys_list()
sorted_dkeys = self.__get_sorted_unique_dkeys(dkeys_list)
return self.save(
skeys=multiple_artifacts.skeys,
skeys=skeys,
dkeys=sorted_dkeys,
artifact=multiple_artifacts,
artifact_type="pyfunc",
metadata=multiple_artifacts.metadata,
**metadata,
)

@staticmethod
Expand Down Expand Up @@ -449,12 +448,14 @@ class CompositeModels(mlflow.pyfunc.PythonModel):
metadata (META_VT): Additional metadata associated with the artifacts.
"""

__slots__ = ("skeys", "dict_artifacts", "metadata")

def __init__(self, skeys: KEYS, dict_artifacts: dict[str, KeyedArtifact], **metadata: META_VT):
self.skeys = skeys
self.dict_artifacts = dict_artifacts
self.metadata = metadata

def get_dkeys_list(self):
def _get_dkeys_list(self):
"""
Returns a list of all dynamic keys in the stored artifacts.
Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[tool.poetry]
name = "numalogic"
version = "0.13.2"
version = "0.13.3"
description = "Collection of operational Machine Learning models and tools."
authors = ["Numalogic Developers"]
packages = [{ include = "numalogic" }]
Expand Down
15 changes: 15 additions & 0 deletions tests/registry/test_mlflow_registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -457,6 +457,21 @@ def test_cache_loading(self):
key = MLflowRegistry.construct_key(self.skeys, self.dkeys)
self.assertIsNotNone(ml._load_from_cache(key))

@patch("mlflow.start_run", Mock(return_value=ActiveRun(return_pyfunc_rundata())))
@patch("mlflow.active_run", Mock(return_value=return_pyfunc_rundata()))
@patch("mlflow.tracking.MlflowClient.transition_model_version_stage", mock_transition_stage)
@patch("mlflow.tracking.MlflowClient.get_latest_versions", mock_get_model_version)
@patch("mlflow.tracking.MlflowClient.get_run", Mock(return_value=return_pyfunc_rundata()))
@patch("mlflow.pyfunc.load_model", mock_load_model_pyfunc)
def test_cache_loading_pyfunc(self):
cache_registry = LocalLRUCache(ttl=50000)
ml = MLflowRegistry(TRACKING_URI, cache_registry=cache_registry)
dkeys_list = [["AE", "infer"], ["scaler", "infer"]]
ml.load_multiple(skeys=self.skeys, dkeys_list=dkeys_list)
unique_sorted_dkeys = ["AE", "infer", "scaler"]
key = MLflowRegistry.construct_key(self.skeys, unique_sorted_dkeys)
self.assertIsNotNone(ml._load_from_cache(key))


if __name__ == "__main__":
unittest.main()

0 comments on commit 8db67b1

Please sign in to comment.