From 5a2031599995c27a3799047299c4065398f314a4 Mon Sep 17 00:00:00 2001 From: v-chen_data Date: Tue, 19 Nov 2024 16:16:39 -0500 Subject: [PATCH] setup mlflow tracking --- tests/callbacks/test_loggers_across_callbacks.py | 9 +++++++++ tests/conftest.py | 8 -------- tests/loggers/test_mlflow_logger.py | 9 +++++++++ 3 files changed, 18 insertions(+), 8 deletions(-) diff --git a/tests/callbacks/test_loggers_across_callbacks.py b/tests/callbacks/test_loggers_across_callbacks.py index 17886874b1..828d4429f2 100644 --- a/tests/callbacks/test_loggers_across_callbacks.py +++ b/tests/callbacks/test_loggers_across_callbacks.py @@ -17,6 +17,15 @@ ) +@pytest.fixture(autouse=True) +def setup_mlflow_tracking(monkeypatch, tmp_path): + mlflow = pytest.importorskip('mlflow') + # Use a temporary directory instead of 'databricks' + tracking_uri = str(tmp_path / 'mlruns') + monkeypatch.setenv(mlflow.environment_variables.MLFLOW_TRACKING_URI.name, tracking_uri) + os.makedirs(tracking_uri, exist_ok=True) + + @pytest.mark.parametrize('logger_cls', get_cbs_and_marks(loggers=True)) @pytest.mark.parametrize('callback_cls', get_cbs_and_marks(callbacks=True)) @pytest.mark.filterwarnings('ignore::UserWarning') diff --git a/tests/conftest.py b/tests/conftest.py index 0e0937c476..cd670e2151 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -150,11 +150,3 @@ def pytest_sessionfinish(session: pytest.Session, exitstatus: int): if exitstatus == 5: session.exitstatus = 0 # Ignore no-test-ran errors - -@pytest.fixture(autouse=True) -def setup_mlflow_tracking(monkeypatch, tmp_path): - mlflow = pytest.importorskip('mlflow') - # Use a temporary directory instead of 'databricks' - tracking_uri = str(tmp_path / 'mlruns') - monkeypatch.setenv(mlflow.environment_variables.MLFLOW_TRACKING_URI.name, tracking_uri) - os.makedirs(tracking_uri, exist_ok=True) diff --git a/tests/loggers/test_mlflow_logger.py b/tests/loggers/test_mlflow_logger.py index 9d84baa06f..65a7f53ca7 100644 --- a/tests/loggers/test_mlflow_logger.py +++ b/tests/loggers/test_mlflow_logger.py @@ -25,6 +25,15 @@ ) +@pytest.fixture(autouse=True) +def setup_mlflow_tracking(monkeypatch, tmp_path): + mlflow = pytest.importorskip('mlflow') + # Use a temporary directory instead of 'databricks' + tracking_uri = str(tmp_path / 'mlruns') + monkeypatch.setenv(mlflow.environment_variables.MLFLOW_TRACKING_URI.name, tracking_uri) + os.makedirs(tracking_uri, exist_ok=True) + + def _get_latest_mlflow_run(experiment_name, tracking_uri=None): pytest.importorskip('mlflow') from mlflow import MlflowClient