diff --git a/tests/callbacks/test_loggers_across_callbacks.py b/tests/callbacks/test_loggers_across_callbacks.py index 17886874b1..828d4429f2 100644 --- a/tests/callbacks/test_loggers_across_callbacks.py +++ b/tests/callbacks/test_loggers_across_callbacks.py @@ -17,6 +17,15 @@ ) +@pytest.fixture(autouse=True) +def setup_mlflow_tracking(monkeypatch, tmp_path): + mlflow = pytest.importorskip('mlflow') + # Use a temporary directory instead of 'databricks' + tracking_uri = str(tmp_path / 'mlruns') + monkeypatch.setenv(mlflow.environment_variables.MLFLOW_TRACKING_URI.name, tracking_uri) + os.makedirs(tracking_uri, exist_ok=True) + + @pytest.mark.parametrize('logger_cls', get_cbs_and_marks(loggers=True)) @pytest.mark.parametrize('callback_cls', get_cbs_and_marks(callbacks=True)) @pytest.mark.filterwarnings('ignore::UserWarning') diff --git a/tests/conftest.py b/tests/conftest.py index 0e0937c476..cd670e2151 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -150,11 +150,3 @@ def pytest_sessionfinish(session: pytest.Session, exitstatus: int): if exitstatus == 5: session.exitstatus = 0 # Ignore no-test-ran errors - -@pytest.fixture(autouse=True) -def setup_mlflow_tracking(monkeypatch, tmp_path): - mlflow = pytest.importorskip('mlflow') - # Use a temporary directory instead of 'databricks' - tracking_uri = str(tmp_path / 'mlruns') - monkeypatch.setenv(mlflow.environment_variables.MLFLOW_TRACKING_URI.name, tracking_uri) - os.makedirs(tracking_uri, exist_ok=True) diff --git a/tests/loggers/test_mlflow_logger.py b/tests/loggers/test_mlflow_logger.py index 9d84baa06f..65a7f53ca7 100644 --- a/tests/loggers/test_mlflow_logger.py +++ b/tests/loggers/test_mlflow_logger.py @@ -25,6 +25,15 @@ ) +@pytest.fixture(autouse=True) +def setup_mlflow_tracking(monkeypatch, tmp_path): + mlflow = pytest.importorskip('mlflow') + # Use a temporary directory instead of 'databricks' + tracking_uri = str(tmp_path / 'mlruns') + monkeypatch.setenv(mlflow.environment_variables.MLFLOW_TRACKING_URI.name, tracking_uri) + os.makedirs(tracking_uri, exist_ok=True) + + def _get_latest_mlflow_run(experiment_name, tracking_uri=None): pytest.importorskip('mlflow') from mlflow import MlflowClient