Skip to content

Commit

Permalink
setup mlflow tracking
Browse files Browse the repository at this point in the history
  • Loading branch information
v-chen_data committed Nov 19, 2024
1 parent 0f8ae76 commit 5a20315
Show file tree
Hide file tree
Showing 3 changed files with 18 additions and 8 deletions.
9 changes: 9 additions & 0 deletions tests/callbacks/test_loggers_across_callbacks.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,15 @@
)


@pytest.fixture(autouse=True)
def setup_mlflow_tracking(monkeypatch, tmp_path):
mlflow = pytest.importorskip('mlflow')
# Use a temporary directory instead of 'databricks'
tracking_uri = str(tmp_path / 'mlruns')
monkeypatch.setenv(mlflow.environment_variables.MLFLOW_TRACKING_URI.name, tracking_uri)
os.makedirs(tracking_uri, exist_ok=True)


@pytest.mark.parametrize('logger_cls', get_cbs_and_marks(loggers=True))
@pytest.mark.parametrize('callback_cls', get_cbs_and_marks(callbacks=True))
@pytest.mark.filterwarnings('ignore::UserWarning')
Expand Down
8 changes: 0 additions & 8 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -150,11 +150,3 @@ def pytest_sessionfinish(session: pytest.Session, exitstatus: int):
if exitstatus == 5:
session.exitstatus = 0 # Ignore no-test-ran errors


@pytest.fixture(autouse=True)
def setup_mlflow_tracking(monkeypatch, tmp_path):
mlflow = pytest.importorskip('mlflow')
# Use a temporary directory instead of 'databricks'
tracking_uri = str(tmp_path / 'mlruns')
monkeypatch.setenv(mlflow.environment_variables.MLFLOW_TRACKING_URI.name, tracking_uri)
os.makedirs(tracking_uri, exist_ok=True)
9 changes: 9 additions & 0 deletions tests/loggers/test_mlflow_logger.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,15 @@
)


@pytest.fixture(autouse=True)
def setup_mlflow_tracking(monkeypatch, tmp_path):
mlflow = pytest.importorskip('mlflow')
# Use a temporary directory instead of 'databricks'
tracking_uri = str(tmp_path / 'mlruns')
monkeypatch.setenv(mlflow.environment_variables.MLFLOW_TRACKING_URI.name, tracking_uri)
os.makedirs(tracking_uri, exist_ok=True)


def _get_latest_mlflow_run(experiment_name, tracking_uri=None):
pytest.importorskip('mlflow')
from mlflow import MlflowClient
Expand Down

0 comments on commit 5a20315

Please sign in to comment.