From 2d992168580b8fae862d4c46ccdf195616098477 Mon Sep 17 00:00:00 2001 From: daavoo Date: Fri, 23 Feb 2024 11:05:46 +0100 Subject: [PATCH] [Feature] Support custom `artifact_location` in MLflowVisBackend. https://mlflow.org/docs/latest/python_api/mlflow.html?highlight=create_experiment#mlflow.create_experiment --- mmengine/visualization/vis_backend.py | 7 +++++-- tests/test_visualizer/test_vis_backend.py | 8 ++++++++ 2 files changed, 13 insertions(+), 2 deletions(-) diff --git a/mmengine/visualization/vis_backend.py b/mmengine/visualization/vis_backend.py index d5fa1fc068..9f5d37dde8 100644 --- a/mmengine/visualization/vis_backend.py +++ b/mmengine/visualization/vis_backend.py @@ -680,7 +680,8 @@ def __init__(self, tracking_uri: Optional[str] = None, artifact_suffix: SUFFIX_TYPE = ('.json', '.log', '.py', 'yaml'), - tracked_config_keys: Optional[dict] = None): + tracked_config_keys: Optional[dict] = None, + artifact_location: Optional[str] = None): super().__init__(save_dir) self._exp_name = exp_name self._run_name = run_name @@ -689,6 +690,7 @@ def __init__(self, self._tracking_uri = tracking_uri self._artifact_suffix = artifact_suffix self._tracked_config_keys = tracked_config_keys + self._artifact_location = artifact_location def _init_env(self): """Setup env for MLflow.""" @@ -726,7 +728,8 @@ def _init_env(self): self._exp_name = self._exp_name or 'Default' if self._mlflow.get_experiment_by_name(self._exp_name) is None: - self._mlflow.create_experiment(self._exp_name) + self._mlflow.create_experiment( + self._exp_name, artifact_location=self._artifact_location) self._mlflow.set_experiment(self._exp_name) diff --git a/tests/test_visualizer/test_vis_backend.py b/tests/test_visualizer/test_vis_backend.py index 59d87480c7..c991462ef9 100644 --- a/tests/test_visualizer/test_vis_backend.py +++ b/tests/test_visualizer/test_vis_backend.py @@ -282,6 +282,14 @@ def test_experiment(self): mlflow_vis_backend = MLflowVisBackend('temp_dir') assert mlflow_vis_backend.experiment == mlflow_vis_backend._mlflow + def test_create_experiment(self): + with patch('mlflow.create_experiment') as mock_create_experiment: + MLflowVisBackend( + 'temp_dir', exp_name='test', + artifact_location='foo')._init_env() + mock_create_experiment.assert_any_call( + 'test', artifact_location='foo') + def test_add_config(self): cfg = Config(dict(a=1, b=dict(b1=[0, 1]))) mlflow_vis_backend = MLflowVisBackend('temp_dir')