Skip to content

Commit

Permalink
[Feature] Support custom artifact_location in MLflowVisBackend (#1505)
Browse files Browse the repository at this point in the history
  • Loading branch information
daavoo authored Feb 26, 2024
1 parent c423d0c commit 2fe0ece
Show file tree
Hide file tree
Showing 2 changed files with 17 additions and 2 deletions.
11 changes: 9 additions & 2 deletions mmengine/visualization/vis_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -669,6 +669,10 @@ class MLflowVisBackend(BaseVisBackend):
will be added to the experiment. If it is None, which means all
the config will be added. Defaults to None.
`New in version 0.7.4.`
artifact_location (str, optional): The location to store run artifacts.
If None, the server picks an appropriate default.
Defaults to None.
`New in version 0.10.4.`
"""

def __init__(self,
Expand All @@ -680,7 +684,8 @@ def __init__(self,
tracking_uri: Optional[str] = None,
artifact_suffix: SUFFIX_TYPE = ('.json', '.log', '.py',
'yaml'),
tracked_config_keys: Optional[dict] = None):
tracked_config_keys: Optional[dict] = None,
artifact_location: Optional[str] = None):
super().__init__(save_dir)
self._exp_name = exp_name
self._run_name = run_name
Expand All @@ -689,6 +694,7 @@ def __init__(self,
self._tracking_uri = tracking_uri
self._artifact_suffix = artifact_suffix
self._tracked_config_keys = tracked_config_keys
self._artifact_location = artifact_location

def _init_env(self):
"""Setup env for MLflow."""
Expand Down Expand Up @@ -726,7 +732,8 @@ def _init_env(self):
self._exp_name = self._exp_name or 'Default'

if self._mlflow.get_experiment_by_name(self._exp_name) is None:
self._mlflow.create_experiment(self._exp_name)
self._mlflow.create_experiment(
self._exp_name, artifact_location=self._artifact_location)

self._mlflow.set_experiment(self._exp_name)

Expand Down
8 changes: 8 additions & 0 deletions tests/test_visualizer/test_vis_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -282,6 +282,14 @@ def test_experiment(self):
mlflow_vis_backend = MLflowVisBackend('temp_dir')
assert mlflow_vis_backend.experiment == mlflow_vis_backend._mlflow

def test_create_experiment(self):
with patch('mlflow.create_experiment') as mock_create_experiment:
MLflowVisBackend(
'temp_dir', exp_name='test',
artifact_location='foo')._init_env()
mock_create_experiment.assert_any_call(
'test', artifact_location='foo')

def test_add_config(self):
cfg = Config(dict(a=1, b=dict(b1=[0, 1])))
mlflow_vis_backend = MLflowVisBackend('temp_dir')
Expand Down

0 comments on commit 2fe0ece

Please sign in to comment.