From 82ea731ac25e876fbeef8123aa7cc63eaf6fae1c Mon Sep 17 00:00:00 2001 From: Leila Wang Date: Mon, 23 Sep 2024 16:51:01 -0400 Subject: [PATCH] fix: test cases Signed-off-by: Leila Wang --- tests/registry/_mlflow_utils.py | 4 +++- tests/registry/test_mlflow_registry.py | 10 +++++----- 2 files changed, 8 insertions(+), 6 deletions(-) diff --git a/tests/registry/_mlflow_utils.py b/tests/registry/_mlflow_utils.py index 1120c3b7..1ffc0830 100644 --- a/tests/registry/_mlflow_utils.py +++ b/tests/registry/_mlflow_utils.py @@ -385,7 +385,9 @@ def return_pyfunc_rundata(): status="RUNNING", user_id="lol", ), - run_data=RunData(metrics={}, tags={}, params={}), + run_data=RunData( + metrics={}, tags={}, params=[mlflow.entities.Param("learning_rate", "0.01")] + ), ) diff --git a/tests/registry/test_mlflow_registry.py b/tests/registry/test_mlflow_registry.py index f755e0a1..611d803a 100644 --- a/tests/registry/test_mlflow_registry.py +++ b/tests/registry/test_mlflow_registry.py @@ -116,11 +116,11 @@ def test_load_multiple_models_when_pyfunc_model_exist(self): skeys = self.skeys dkeys_list = [["AE", "infer"], ["scaler", "infer"]] data = ml.load_multiple(skeys=skeys, dkeys_list=dkeys_list) - self.assertIsNotNone(data["AE:infer"].metadata) - self.assertIsNotNone(data["scaler:infer"].metadata) - self.assertIsInstance(data, dict) - self.assertIsInstance(data["AE:infer"].artifact, VanillaAE) - self.assertIsInstance(data["scaler:infer"].artifact, StandardScaler) + self.assertIsNotNone(data.metadata) + self.assertIsInstance(data, ArtifactData) + self.assertIsInstance(data.artifact, dict) + self.assertIsInstance(data.artifact["AE"].artifact, VanillaAE) + self.assertIsInstance(data.artifact["scaler"].artifact, StandardScaler) @patch("mlflow.sklearn.log_model", mock_log_model_sklearn) @patch("mlflow.log_param", mock_log_state_dict)