diff --git a/mmengine/visualization/vis_backend.py b/mmengine/visualization/vis_backend.py index d5fa1fc068..f74eab1fcd 100644 --- a/mmengine/visualization/vis_backend.py +++ b/mmengine/visualization/vis_backend.py @@ -669,6 +669,10 @@ class MLflowVisBackend(BaseVisBackend): will be added to the experiment. If it is None, which means all the config will be added. Defaults to None. `New in version 0.7.4.` + artifact_location (str, optional): The location to store run artifacts. + If None, the server picks an appropriate default. + Defaults to None. + `New in version 0.10.4.` """ def __init__(self, @@ -680,7 +684,8 @@ def __init__(self, tracking_uri: Optional[str] = None, artifact_suffix: SUFFIX_TYPE = ('.json', '.log', '.py', 'yaml'), - tracked_config_keys: Optional[dict] = None): + tracked_config_keys: Optional[dict] = None, + artifact_location: Optional[str] = None): super().__init__(save_dir) self._exp_name = exp_name self._run_name = run_name @@ -689,6 +694,7 @@ def __init__(self, self._tracking_uri = tracking_uri self._artifact_suffix = artifact_suffix self._tracked_config_keys = tracked_config_keys + self._artifact_location = artifact_location def _init_env(self): """Setup env for MLflow.""" @@ -726,7 +732,8 @@ def _init_env(self): self._exp_name = self._exp_name or 'Default' if self._mlflow.get_experiment_by_name(self._exp_name) is None: - self._mlflow.create_experiment(self._exp_name) + self._mlflow.create_experiment( + self._exp_name, artifact_location=self._artifact_location) self._mlflow.set_experiment(self._exp_name) diff --git a/tests/test_visualizer/test_vis_backend.py b/tests/test_visualizer/test_vis_backend.py index 59d87480c7..c991462ef9 100644 --- a/tests/test_visualizer/test_vis_backend.py +++ b/tests/test_visualizer/test_vis_backend.py @@ -282,6 +282,14 @@ def test_experiment(self): mlflow_vis_backend = MLflowVisBackend('temp_dir') assert mlflow_vis_backend.experiment == mlflow_vis_backend._mlflow + def test_create_experiment(self): + with patch('mlflow.create_experiment') as mock_create_experiment: + MLflowVisBackend( + 'temp_dir', exp_name='test', + artifact_location='foo')._init_env() + mock_create_experiment.assert_any_call( + 'test', artifact_location='foo') + def test_add_config(self): cfg = Config(dict(a=1, b=dict(b1=[0, 1]))) mlflow_vis_backend = MLflowVisBackend('temp_dir')