From 5a43c33e7375554463ed30911efd6a17078f801e Mon Sep 17 00:00:00 2001 From: yleilawang Date: Fri, 13 Sep 2024 13:42:19 -0400 Subject: [PATCH] fix: test_mlflow_registry --- numalogic/registry/mlflow_registry.py | 1 + tests/registry/test_mlflow_registry.py | 90 +++++++++++++++++--------- 2 files changed, 62 insertions(+), 29 deletions(-) diff --git a/numalogic/registry/mlflow_registry.py b/numalogic/registry/mlflow_registry.py index 8db09624..43c5b519 100644 --- a/numalogic/registry/mlflow_registry.py +++ b/numalogic/registry/mlflow_registry.py @@ -17,6 +17,7 @@ import mlflow.pyfunc import mlflow.pytorch +import mlflow.sklearn from mlflow.entities.model_registry import ModelVersion from mlflow.exceptions import RestException from mlflow.protos.databricks_pb2 import ErrorCode, RESOURCE_DOES_NOT_EXIST diff --git a/tests/registry/test_mlflow_registry.py b/tests/registry/test_mlflow_registry.py index 2058a4dc..d728f6bd 100644 --- a/tests/registry/test_mlflow_registry.py +++ b/tests/registry/test_mlflow_registry.py @@ -1,3 +1,4 @@ +from collections import OrderedDict import unittest from contextlib import contextmanager from unittest.mock import patch, Mock @@ -9,10 +10,12 @@ from mlflow.store.entities import PagedList from sklearn.preprocessing import StandardScaler +from mlflow.models.model import ModelInfo from numalogic.models.autoencoder.variants import VanillaAE from numalogic.registry import MLflowRegistry, ArtifactData, LocalLRUCache + + from numalogic.registry.mlflow_registry import ModelStage -from numalogic.tools.exceptions import ModelVersionError from tests.registry._mlflow_utils import ( model_sklearn, create_model, @@ -26,7 +29,7 @@ mock_list_of_model_version, mock_list_of_model_version2, return_sklearn_rundata, - mock_get_model_version_obj, + mock_get_model_version_obj ) TRACKING_URI = "http://0.0.0.0:5009" @@ -53,6 +56,7 @@ def test_construct_key(self): key = MLflowRegistry.construct_key(skeys, dkeys) self.assertEqual("model_:nnet::error1", key) + @patch("mlflow.pytorch.log_model", mock_log_model_pytorch) @patch("mlflow.log_param", mock_log_state_dict) @patch("mlflow.start_run", Mock(return_value=ActiveRun(return_pytorch_rundata_dict()))) @@ -70,27 +74,35 @@ def test_save_model(self): mock_status = "READY" self.assertEqual(mock_status, status.status) + @patch("mlflow.sklearn.log_model", mock_log_model_sklearn) + @patch("mlflow.log_param", mock_log_state_dict) @patch("mlflow.start_run", Mock(return_value=ActiveRun(return_sklearn_rundata()))) @patch("mlflow.active_run", Mock(return_value=return_sklearn_rundata())) @patch("mlflow.tracking.MlflowClient.transition_model_version_stage", mock_transition_stage) @patch("mlflow.tracking.MlflowClient.get_latest_versions", mock_get_model_version) - @patch("mlflow.tracking.MlflowClient.search_model_versions", mock_list_of_model_version2) + @patch("mlflow.tracking.MlflowClient.search_model_versions", mock_list_of_model_version) def test_save_model_sklearn(self): model = self.model_sklearn ml = MLflowRegistry(TRACKING_URI) skeys = self.skeys dkeys = self.dkeys - status = ml.save(skeys=skeys, dkeys=dkeys, artifact=model, artifact_type="sklearn") + status = ml.save(skeys=skeys, + dkeys=dkeys, + artifact=model, + artifact_type="sklearn") + mock_status = "READY" self.assertEqual(mock_status, status.status) - @patch("mlflow.pytorch.log_model", mock_log_model_pytorch()) + + @patch("mlflow.pytorch.log_model", mock_log_model_pytorch) @patch("mlflow.start_run", Mock(return_value=ActiveRun(return_pytorch_rundata_dict()))) @patch("mlflow.active_run", Mock(return_value=return_pytorch_rundata_dict())) - @patch("mlflow.log_params", {"lr": 0.01}) + @patch("mlflow.log_params", Mock(return_value=OrderedDict([("learning_rate",0.01)]))) @patch("mlflow.tracking.MlflowClient.transition_model_version_stage", mock_transition_stage) @patch("mlflow.tracking.MlflowClient.get_latest_versions", mock_get_model_version) + @patch("mlflow.tracking.MlflowClient.search_model_versions", mock_list_of_model_version2) @patch("mlflow.pytorch.load_model", Mock(return_value=VanillaAE(10))) @patch("mlflow.tracking.MlflowClient.get_run", Mock(return_value=return_pytorch_rundata_dict())) def test_load_model_when_pytorch_model_exist1(self): @@ -98,16 +110,23 @@ def test_load_model_when_pytorch_model_exist1(self): ml = MLflowRegistry(TRACKING_URI) skeys = self.skeys dkeys = self.dkeys - ml.save(skeys=skeys, dkeys=dkeys, artifact=model, **{"lr": 0.01}, artifact_type="pytorch") + ml.save(skeys=skeys, + dkeys=dkeys, + artifact=model, + **{"lr": 0.01}, + artifact_type="pytorch") data = ml.load(skeys=skeys, dkeys=dkeys, artifact_type="pytorch") self.assertIsNotNone(data.metadata) self.assertIsInstance(data.artifact, VanillaAE) + + - @patch("mlflow.pytorch.log_model", mock_log_model_pytorch()) + @patch("mlflow.pytorch.log_model", mock_log_model_pytorch) @patch("mlflow.start_run", Mock(return_value=ActiveRun(return_pytorch_rundata_dict()))) @patch("mlflow.active_run", Mock(return_value=return_pytorch_rundata_dict())) @patch("mlflow.tracking.MlflowClient.transition_model_version_stage", mock_transition_stage) @patch("mlflow.tracking.MlflowClient.get_latest_versions", mock_get_model_version) + @patch("mlflow.tracking.MlflowClient.search_model_versions", mock_list_of_model_version2) @patch("mlflow.pytorch.load_model", Mock(return_value=VanillaAE(10))) @patch("mlflow.tracking.MlflowClient.get_run", Mock(return_value=return_empty_rundata())) def test_load_model_when_pytorch_model_exist2(self): @@ -115,7 +134,10 @@ def test_load_model_when_pytorch_model_exist2(self): ml = MLflowRegistry(TRACKING_URI, models_to_retain=2) skeys = self.skeys dkeys = self.dkeys - ml.save(skeys=skeys, dkeys=dkeys, artifact=model, artifact_type="pytorch") + ml.save(skeys=skeys, + dkeys=dkeys, + artifact=model, + artifact_type="pytorch") data = ml.load(skeys=skeys, dkeys=dkeys, artifact_type="pytorch") self.assertEqual(data.metadata, {}) self.assertIsInstance(data.artifact, VanillaAE) @@ -147,12 +169,15 @@ def test_load_model_when_sklearn_model_exist(self): self.assertIsInstance(data.artifact, StandardScaler) self.assertEqual(data.metadata, {}) - @patch("mlflow.pytorch.log_model", mock_log_model_pytorch()) + + + @patch("mlflow.pytorch.log_model", mock_log_model_pytorch) @patch("mlflow.start_run", Mock(return_value=ActiveRun(return_empty_rundata()))) @patch("mlflow.active_run", Mock(return_value=return_empty_rundata())) @patch("mlflow.tracking.MlflowClient.search_model_versions", mock_list_of_model_version2) @patch("mlflow.tracking.MlflowClient.transition_model_version_stage", mock_transition_stage) @patch("mlflow.tracking.MlflowClient.get_model_version", mock_get_model_version_obj) + @patch("mlflow.tracking.MlflowClient.get_latest_versions", mock_get_model_version) @patch("mlflow.pytorch.load_model", Mock(return_value=VanillaAE(10))) @patch("mlflow.tracking.MlflowClient.get_run", Mock(return_value=return_empty_rundata())) def test_load_model_with_version(self): @@ -165,6 +190,8 @@ def test_load_model_with_version(self): self.assertIsInstance(data.artifact, VanillaAE) self.assertEqual(data.metadata, {}) + + @patch("mlflow.tracking.MlflowClient.search_model_versions", mock_list_of_model_version2) @patch("mlflow.tracking.MlflowClient.transition_model_version_stage", mock_transition_stage) @patch( @@ -177,12 +204,17 @@ def test_staging_model_load_error(self): ml = MLflowRegistry(TRACKING_URI, model_stage=ModelStage.STAGE) skeys = self.skeys dkeys = self.dkeys - ml.load(skeys=skeys, dkeys=dkeys, artifact_type="pytorch") - self.assertRaises(ModelVersionError) + with self.assertLogs(level="ERROR") as log: + result = ml.load(skeys=skeys, dkeys=dkeys, artifact_type="pytorch") + self.assertIsNone(result) # Ensure the result is None + self.assertTrue(any("No Model found" in message for message in log.output)) # Check that the expected log was made + + + @patch("mlflow.tracking.MlflowClient.search_model_versions", mock_list_of_model_version2) @patch("mlflow.tracking.MlflowClient.transition_model_version_stage", mock_transition_stage) - @patch("mlflow.tracking.MlflowClient.get_latest_versions", mock_get_model_version()) + @patch("mlflow.tracking.MlflowClient.get_latest_versions", mock_get_model_version) @patch("mlflow.pytorch.load_model", Mock(return_value=VanillaAE(10))) @patch("mlflow.tracking.MlflowClient.get_run", Mock(return_value=return_empty_rundata())) def test_both_version_latest_model_with_version(self): @@ -191,6 +223,7 @@ def test_both_version_latest_model_with_version(self): dkeys = self.dkeys with self.assertRaises(ValueError): ml.load(skeys=skeys, dkeys=dkeys, latest=False, artifact_type="pytorch") + @patch("mlflow.tracking.MlflowClient.search_model_versions", mock_list_of_model_version2) @patch("mlflow.tracking.MlflowClient.transition_model_version_stage", mock_transition_stage) @@ -247,6 +280,8 @@ def test_no_implementation(self): ml.load(skeys=fake_skeys, dkeys=fake_dkeys) self.assertTrue(log.output) + + @patch("mlflow.start_run", Mock(return_value=ActiveRun(return_pytorch_rundata_dict()))) @patch("mlflow.active_run", Mock(return_value=return_pytorch_rundata_dict())) @patch("mlflow.pytorch.log_model", mock_log_model_pytorch) @@ -254,7 +289,7 @@ def test_no_implementation(self): @patch("mlflow.tracking.MlflowClient.transition_model_version_stage", mock_transition_stage) @patch("mlflow.tracking.MlflowClient.get_latest_versions", mock_get_model_version) @patch("mlflow.tracking.MlflowClient.search_model_versions", mock_list_of_model_version2) - @patch("mlflow.tracking.MlflowClient.delete_model_version", None) + @patch("mlflow.tracking.MlflowClient.delete_model_version", Mock(return_value=None)) @patch("mlflow.pytorch.load_model", Mock(side_effect=RuntimeError)) def test_delete_model_when_model_exist(self): model = self.model @@ -321,12 +356,15 @@ def test_load_other_mlflow_err(self): dkeys = self.dkeys self.assertIsNone(ml.load(skeys=skeys, dkeys=dkeys, artifact_type="pytorch")) - @patch("mlflow.pytorch.log_model", mock_log_model_pytorch()) + + + @patch("mlflow.pytorch.log_model", mock_log_model_pytorch) @patch("mlflow.start_run", Mock(return_value=ActiveRun(return_pytorch_rundata_dict()))) @patch("mlflow.active_run", Mock(return_value=return_pytorch_rundata_dict())) - @patch("mlflow.log_params", {"lr": 0.01}) + @patch("mlflow.log_params", Mock(return_value=OrderedDict([("learning_rate",0.01)]))) @patch("mlflow.tracking.MlflowClient.transition_model_version_stage", mock_transition_stage) @patch("mlflow.tracking.MlflowClient.get_latest_versions", mock_get_model_version) + @patch("mlflow.tracking.MlflowClient.search_model_versions", mock_list_of_model_version2) @patch("mlflow.pytorch.load_model", Mock(return_value=VanillaAE(10))) @patch("mlflow.tracking.MlflowClient.get_run", Mock(return_value=return_pytorch_rundata_dict())) def test_is_model_stale_true(self): @@ -342,12 +380,14 @@ def test_is_model_stale_true(self): data = ml.load(skeys=self.skeys, dkeys=self.dkeys, artifact_type="pytorch") self.assertTrue(ml.is_artifact_stale(data, 12)) - @patch("mlflow.pytorch.log_model", mock_log_model_pytorch()) + + @patch("mlflow.pytorch.log_model", mock_log_model_pytorch) @patch("mlflow.start_run", Mock(return_value=ActiveRun(return_pytorch_rundata_dict()))) @patch("mlflow.active_run", Mock(return_value=return_pytorch_rundata_dict())) - @patch("mlflow.log_params", {"lr": 0.01}) + @patch("mlflow.log_params", Mock(return_value=OrderedDict([("learning_rate",0.01)]))) @patch("mlflow.tracking.MlflowClient.transition_model_version_stage", mock_transition_stage) @patch("mlflow.tracking.MlflowClient.get_latest_versions", mock_get_model_version) + @patch("mlflow.tracking.MlflowClient.search_model_versions", mock_list_of_model_version2) @patch("mlflow.pytorch.load_model", Mock(return_value=VanillaAE(10))) @patch("mlflow.tracking.MlflowClient.get_run", Mock(return_value=return_pytorch_rundata_dict())) def test_is_model_stale_false(self): @@ -381,10 +421,10 @@ def test_cache(self): self.assertIsNotNone(registry._load_from_cache("key")) self.assertIsNotNone(registry._clear_cache("key")) - @patch("mlflow.pytorch.log_model", mock_log_model_pytorch()) + + @patch("mlflow.start_run", Mock(return_value=ActiveRun(return_pytorch_rundata_dict()))) @patch("mlflow.active_run", Mock(return_value=return_pytorch_rundata_dict())) - @patch("mlflow.log_params", {"lr": 0.01}) @patch("mlflow.tracking.MlflowClient.transition_model_version_stage", mock_transition_stage) @patch("mlflow.tracking.MlflowClient.get_latest_versions", mock_get_model_version) @patch("mlflow.pytorch.load_model", Mock(return_value=VanillaAE(10))) @@ -392,18 +432,10 @@ def test_cache(self): def test_cache_loading(self): cache_registry = LocalLRUCache(ttl=50000) ml = MLflowRegistry(TRACKING_URI, cache_registry=cache_registry) - ml.save( - skeys=self.skeys, - dkeys=self.dkeys, - artifact=self.model, - **{"lr": 0.01}, - artifact_type="pytorch", - ) ml.load(skeys=self.skeys, dkeys=self.dkeys, artifact_type="pytorch") key = MLflowRegistry.construct_key(self.skeys, self.dkeys) self.assertIsNotNone(ml._load_from_cache(key)) - data = ml.load(skeys=self.skeys, dkeys=self.dkeys, artifact_type="pytorch") - self.assertIsNotNone(data) + if __name__ == "__main__":