Skip to content

Commit

Permalink
fix: test_mlflow_registry
Browse files Browse the repository at this point in the history
  • Loading branch information
yleilawang authored and yleilawang committed Sep 18, 2024
1 parent 630322f commit 5a43c33
Show file tree
Hide file tree
Showing 2 changed files with 62 additions and 29 deletions.
1 change: 1 addition & 0 deletions numalogic/registry/mlflow_registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@

import mlflow.pyfunc
import mlflow.pytorch
import mlflow.sklearn
from mlflow.entities.model_registry import ModelVersion
from mlflow.exceptions import RestException
from mlflow.protos.databricks_pb2 import ErrorCode, RESOURCE_DOES_NOT_EXIST
Expand Down
90 changes: 61 additions & 29 deletions tests/registry/test_mlflow_registry.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from collections import OrderedDict
import unittest
from contextlib import contextmanager
from unittest.mock import patch, Mock
Expand All @@ -9,10 +10,12 @@
from mlflow.store.entities import PagedList
from sklearn.preprocessing import StandardScaler

from mlflow.models.model import ModelInfo

Check failure on line 13 in tests/registry/test_mlflow_registry.py

View workflow job for this annotation

GitHub Actions / ruff

Ruff (F401)

tests/registry/test_mlflow_registry.py:13:33: F401 `mlflow.models.model.ModelInfo` imported but unused
from numalogic.models.autoencoder.variants import VanillaAE
from numalogic.registry import MLflowRegistry, ArtifactData, LocalLRUCache


from numalogic.registry.mlflow_registry import ModelStage
from numalogic.tools.exceptions import ModelVersionError
from tests.registry._mlflow_utils import (
model_sklearn,
create_model,
Expand All @@ -26,7 +29,7 @@
mock_list_of_model_version,
mock_list_of_model_version2,
return_sklearn_rundata,
mock_get_model_version_obj,
mock_get_model_version_obj
)

TRACKING_URI = "http://0.0.0.0:5009"
Expand All @@ -53,6 +56,7 @@ def test_construct_key(self):
key = MLflowRegistry.construct_key(skeys, dkeys)
self.assertEqual("model_:nnet::error1", key)


@patch("mlflow.pytorch.log_model", mock_log_model_pytorch)
@patch("mlflow.log_param", mock_log_state_dict)
@patch("mlflow.start_run", Mock(return_value=ActiveRun(return_pytorch_rundata_dict())))
Expand All @@ -70,52 +74,70 @@ def test_save_model(self):
mock_status = "READY"
self.assertEqual(mock_status, status.status)


@patch("mlflow.sklearn.log_model", mock_log_model_sklearn)
@patch("mlflow.log_param", mock_log_state_dict)
@patch("mlflow.start_run", Mock(return_value=ActiveRun(return_sklearn_rundata())))
@patch("mlflow.active_run", Mock(return_value=return_sklearn_rundata()))
@patch("mlflow.tracking.MlflowClient.transition_model_version_stage", mock_transition_stage)
@patch("mlflow.tracking.MlflowClient.get_latest_versions", mock_get_model_version)
@patch("mlflow.tracking.MlflowClient.search_model_versions", mock_list_of_model_version2)
@patch("mlflow.tracking.MlflowClient.search_model_versions", mock_list_of_model_version)
def test_save_model_sklearn(self):
model = self.model_sklearn
ml = MLflowRegistry(TRACKING_URI)
skeys = self.skeys
dkeys = self.dkeys
status = ml.save(skeys=skeys, dkeys=dkeys, artifact=model, artifact_type="sklearn")
status = ml.save(skeys=skeys,

Check failure on line 90 in tests/registry/test_mlflow_registry.py

View workflow job for this annotation

GitHub Actions / ruff

Ruff (W291)

tests/registry/test_mlflow_registry.py:90:38: W291 Trailing whitespace
dkeys=dkeys,

Check failure on line 91 in tests/registry/test_mlflow_registry.py

View workflow job for this annotation

GitHub Actions / ruff

Ruff (W291)

tests/registry/test_mlflow_registry.py:91:38: W291 Trailing whitespace
artifact=model,

Check failure on line 92 in tests/registry/test_mlflow_registry.py

View workflow job for this annotation

GitHub Actions / ruff

Ruff (W291)

tests/registry/test_mlflow_registry.py:92:41: W291 Trailing whitespace
artifact_type="sklearn")

Check failure on line 94 in tests/registry/test_mlflow_registry.py

View workflow job for this annotation

GitHub Actions / ruff

Ruff (W293)

tests/registry/test_mlflow_registry.py:94:1: W293 Blank line contains whitespace
mock_status = "READY"
self.assertEqual(mock_status, status.status)

@patch("mlflow.pytorch.log_model", mock_log_model_pytorch())

@patch("mlflow.pytorch.log_model", mock_log_model_pytorch)
@patch("mlflow.start_run", Mock(return_value=ActiveRun(return_pytorch_rundata_dict())))
@patch("mlflow.active_run", Mock(return_value=return_pytorch_rundata_dict()))
@patch("mlflow.log_params", {"lr": 0.01})
@patch("mlflow.log_params", Mock(return_value=OrderedDict([("learning_rate",0.01)])))
@patch("mlflow.tracking.MlflowClient.transition_model_version_stage", mock_transition_stage)
@patch("mlflow.tracking.MlflowClient.get_latest_versions", mock_get_model_version)
@patch("mlflow.tracking.MlflowClient.search_model_versions", mock_list_of_model_version2)
@patch("mlflow.pytorch.load_model", Mock(return_value=VanillaAE(10)))
@patch("mlflow.tracking.MlflowClient.get_run", Mock(return_value=return_pytorch_rundata_dict()))
def test_load_model_when_pytorch_model_exist1(self):
model = self.model
ml = MLflowRegistry(TRACKING_URI)
skeys = self.skeys
dkeys = self.dkeys
ml.save(skeys=skeys, dkeys=dkeys, artifact=model, **{"lr": 0.01}, artifact_type="pytorch")
ml.save(skeys=skeys,

Check failure on line 113 in tests/registry/test_mlflow_registry.py

View workflow job for this annotation

GitHub Actions / ruff

Ruff (W291)

tests/registry/test_mlflow_registry.py:113:29: W291 Trailing whitespace
dkeys=dkeys,

Check failure on line 114 in tests/registry/test_mlflow_registry.py

View workflow job for this annotation

GitHub Actions / ruff

Ruff (W291)

tests/registry/test_mlflow_registry.py:114:29: W291 Trailing whitespace
artifact=model,

Check failure on line 115 in tests/registry/test_mlflow_registry.py

View workflow job for this annotation

GitHub Actions / ruff

Ruff (W291)

tests/registry/test_mlflow_registry.py:115:32: W291 Trailing whitespace
**{"lr": 0.01},

Check failure on line 116 in tests/registry/test_mlflow_registry.py

View workflow job for this annotation

GitHub Actions / ruff

Ruff (W291)

tests/registry/test_mlflow_registry.py:116:32: W291 Trailing whitespace
artifact_type="pytorch")
data = ml.load(skeys=skeys, dkeys=dkeys, artifact_type="pytorch")
self.assertIsNotNone(data.metadata)
self.assertIsInstance(data.artifact, VanillaAE)

Check failure on line 121 in tests/registry/test_mlflow_registry.py

View workflow job for this annotation

GitHub Actions / ruff

Ruff (W293)

tests/registry/test_mlflow_registry.py:121:1: W293 Blank line contains whitespace


@patch("mlflow.pytorch.log_model", mock_log_model_pytorch())
@patch("mlflow.pytorch.log_model", mock_log_model_pytorch)
@patch("mlflow.start_run", Mock(return_value=ActiveRun(return_pytorch_rundata_dict())))
@patch("mlflow.active_run", Mock(return_value=return_pytorch_rundata_dict()))
@patch("mlflow.tracking.MlflowClient.transition_model_version_stage", mock_transition_stage)
@patch("mlflow.tracking.MlflowClient.get_latest_versions", mock_get_model_version)
@patch("mlflow.tracking.MlflowClient.search_model_versions", mock_list_of_model_version2)
@patch("mlflow.pytorch.load_model", Mock(return_value=VanillaAE(10)))
@patch("mlflow.tracking.MlflowClient.get_run", Mock(return_value=return_empty_rundata()))
def test_load_model_when_pytorch_model_exist2(self):
model = self.model
ml = MLflowRegistry(TRACKING_URI, models_to_retain=2)
skeys = self.skeys
dkeys = self.dkeys
ml.save(skeys=skeys, dkeys=dkeys, artifact=model, artifact_type="pytorch")
ml.save(skeys=skeys,
dkeys=dkeys,
artifact=model,
artifact_type="pytorch")
data = ml.load(skeys=skeys, dkeys=dkeys, artifact_type="pytorch")
self.assertEqual(data.metadata, {})
self.assertIsInstance(data.artifact, VanillaAE)
Expand Down Expand Up @@ -147,12 +169,15 @@ def test_load_model_when_sklearn_model_exist(self):
self.assertIsInstance(data.artifact, StandardScaler)
self.assertEqual(data.metadata, {})

@patch("mlflow.pytorch.log_model", mock_log_model_pytorch())


@patch("mlflow.pytorch.log_model", mock_log_model_pytorch)
@patch("mlflow.start_run", Mock(return_value=ActiveRun(return_empty_rundata())))
@patch("mlflow.active_run", Mock(return_value=return_empty_rundata()))
@patch("mlflow.tracking.MlflowClient.search_model_versions", mock_list_of_model_version2)
@patch("mlflow.tracking.MlflowClient.transition_model_version_stage", mock_transition_stage)
@patch("mlflow.tracking.MlflowClient.get_model_version", mock_get_model_version_obj)
@patch("mlflow.tracking.MlflowClient.get_latest_versions", mock_get_model_version)
@patch("mlflow.pytorch.load_model", Mock(return_value=VanillaAE(10)))
@patch("mlflow.tracking.MlflowClient.get_run", Mock(return_value=return_empty_rundata()))
def test_load_model_with_version(self):
Expand All @@ -165,6 +190,8 @@ def test_load_model_with_version(self):
self.assertIsInstance(data.artifact, VanillaAE)
self.assertEqual(data.metadata, {})



@patch("mlflow.tracking.MlflowClient.search_model_versions", mock_list_of_model_version2)
@patch("mlflow.tracking.MlflowClient.transition_model_version_stage", mock_transition_stage)
@patch(
Expand All @@ -177,12 +204,17 @@ def test_staging_model_load_error(self):
ml = MLflowRegistry(TRACKING_URI, model_stage=ModelStage.STAGE)
skeys = self.skeys
dkeys = self.dkeys
ml.load(skeys=skeys, dkeys=dkeys, artifact_type="pytorch")
self.assertRaises(ModelVersionError)
with self.assertLogs(level="ERROR") as log:
result = ml.load(skeys=skeys, dkeys=dkeys, artifact_type="pytorch")
self.assertIsNone(result) # Ensure the result is None
self.assertTrue(any("No Model found" in message for message in log.output)) # Check that the expected log was made




@patch("mlflow.tracking.MlflowClient.search_model_versions", mock_list_of_model_version2)
@patch("mlflow.tracking.MlflowClient.transition_model_version_stage", mock_transition_stage)
@patch("mlflow.tracking.MlflowClient.get_latest_versions", mock_get_model_version())
@patch("mlflow.tracking.MlflowClient.get_latest_versions", mock_get_model_version)
@patch("mlflow.pytorch.load_model", Mock(return_value=VanillaAE(10)))
@patch("mlflow.tracking.MlflowClient.get_run", Mock(return_value=return_empty_rundata()))
def test_both_version_latest_model_with_version(self):
Expand All @@ -191,6 +223,7 @@ def test_both_version_latest_model_with_version(self):
dkeys = self.dkeys
with self.assertRaises(ValueError):
ml.load(skeys=skeys, dkeys=dkeys, latest=False, artifact_type="pytorch")


@patch("mlflow.tracking.MlflowClient.search_model_versions", mock_list_of_model_version2)
@patch("mlflow.tracking.MlflowClient.transition_model_version_stage", mock_transition_stage)
Expand Down Expand Up @@ -247,14 +280,16 @@ def test_no_implementation(self):
ml.load(skeys=fake_skeys, dkeys=fake_dkeys)
self.assertTrue(log.output)



@patch("mlflow.start_run", Mock(return_value=ActiveRun(return_pytorch_rundata_dict())))
@patch("mlflow.active_run", Mock(return_value=return_pytorch_rundata_dict()))
@patch("mlflow.pytorch.log_model", mock_log_model_pytorch)
@patch("mlflow.log_params", mock_log_state_dict)
@patch("mlflow.tracking.MlflowClient.transition_model_version_stage", mock_transition_stage)
@patch("mlflow.tracking.MlflowClient.get_latest_versions", mock_get_model_version)
@patch("mlflow.tracking.MlflowClient.search_model_versions", mock_list_of_model_version2)
@patch("mlflow.tracking.MlflowClient.delete_model_version", None)
@patch("mlflow.tracking.MlflowClient.delete_model_version", Mock(return_value=None))
@patch("mlflow.pytorch.load_model", Mock(side_effect=RuntimeError))
def test_delete_model_when_model_exist(self):
model = self.model
Expand Down Expand Up @@ -321,12 +356,15 @@ def test_load_other_mlflow_err(self):
dkeys = self.dkeys
self.assertIsNone(ml.load(skeys=skeys, dkeys=dkeys, artifact_type="pytorch"))

@patch("mlflow.pytorch.log_model", mock_log_model_pytorch())


@patch("mlflow.pytorch.log_model", mock_log_model_pytorch)
@patch("mlflow.start_run", Mock(return_value=ActiveRun(return_pytorch_rundata_dict())))
@patch("mlflow.active_run", Mock(return_value=return_pytorch_rundata_dict()))
@patch("mlflow.log_params", {"lr": 0.01})
@patch("mlflow.log_params", Mock(return_value=OrderedDict([("learning_rate",0.01)])))
@patch("mlflow.tracking.MlflowClient.transition_model_version_stage", mock_transition_stage)
@patch("mlflow.tracking.MlflowClient.get_latest_versions", mock_get_model_version)
@patch("mlflow.tracking.MlflowClient.search_model_versions", mock_list_of_model_version2)
@patch("mlflow.pytorch.load_model", Mock(return_value=VanillaAE(10)))
@patch("mlflow.tracking.MlflowClient.get_run", Mock(return_value=return_pytorch_rundata_dict()))
def test_is_model_stale_true(self):
Expand All @@ -342,12 +380,14 @@ def test_is_model_stale_true(self):
data = ml.load(skeys=self.skeys, dkeys=self.dkeys, artifact_type="pytorch")
self.assertTrue(ml.is_artifact_stale(data, 12))

@patch("mlflow.pytorch.log_model", mock_log_model_pytorch())

@patch("mlflow.pytorch.log_model", mock_log_model_pytorch)
@patch("mlflow.start_run", Mock(return_value=ActiveRun(return_pytorch_rundata_dict())))
@patch("mlflow.active_run", Mock(return_value=return_pytorch_rundata_dict()))
@patch("mlflow.log_params", {"lr": 0.01})
@patch("mlflow.log_params", Mock(return_value=OrderedDict([("learning_rate",0.01)])))
@patch("mlflow.tracking.MlflowClient.transition_model_version_stage", mock_transition_stage)
@patch("mlflow.tracking.MlflowClient.get_latest_versions", mock_get_model_version)
@patch("mlflow.tracking.MlflowClient.search_model_versions", mock_list_of_model_version2)
@patch("mlflow.pytorch.load_model", Mock(return_value=VanillaAE(10)))
@patch("mlflow.tracking.MlflowClient.get_run", Mock(return_value=return_pytorch_rundata_dict()))
def test_is_model_stale_false(self):
Expand Down Expand Up @@ -381,29 +421,21 @@ def test_cache(self):
self.assertIsNotNone(registry._load_from_cache("key"))
self.assertIsNotNone(registry._clear_cache("key"))

@patch("mlflow.pytorch.log_model", mock_log_model_pytorch())


@patch("mlflow.start_run", Mock(return_value=ActiveRun(return_pytorch_rundata_dict())))
@patch("mlflow.active_run", Mock(return_value=return_pytorch_rundata_dict()))
@patch("mlflow.log_params", {"lr": 0.01})
@patch("mlflow.tracking.MlflowClient.transition_model_version_stage", mock_transition_stage)
@patch("mlflow.tracking.MlflowClient.get_latest_versions", mock_get_model_version)
@patch("mlflow.pytorch.load_model", Mock(return_value=VanillaAE(10)))
@patch("mlflow.tracking.MlflowClient.get_run", Mock(return_value=return_pytorch_rundata_dict()))
def test_cache_loading(self):
cache_registry = LocalLRUCache(ttl=50000)
ml = MLflowRegistry(TRACKING_URI, cache_registry=cache_registry)
ml.save(
skeys=self.skeys,
dkeys=self.dkeys,
artifact=self.model,
**{"lr": 0.01},
artifact_type="pytorch",
)
ml.load(skeys=self.skeys, dkeys=self.dkeys, artifact_type="pytorch")
key = MLflowRegistry.construct_key(self.skeys, self.dkeys)
self.assertIsNotNone(ml._load_from_cache(key))
data = ml.load(skeys=self.skeys, dkeys=self.dkeys, artifact_type="pytorch")
self.assertIsNotNone(data)



if __name__ == "__main__":
Expand Down

0 comments on commit 5a43c33

Please sign in to comment.