Reusable Tutorials Workflow #6
Workflow file for this run
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
name: Reusable Tutorials Workflow | |
on: | |
workflow_dispatch: | |
inputs: | |
smoke_test: | |
required: true | |
type: boolean | |
use_stable_pytorch_gpytorch: | |
required: true | |
type: boolean | |
use_stable_ax: | |
required: true | |
type: boolean | |
workflow_call: | |
inputs: | |
smoke_test: | |
required: false | |
type: boolean | |
default: true | |
use_stable_pytorch_gpytorch: | |
required: false | |
type: boolean | |
default: false | |
use_stable_ax: | |
required: false | |
type: boolean | |
default: false | |
jobs: | |
tutorials: | |
name: Run tutorials | |
runs-on: ubuntu-latest | |
steps: | |
- uses: actions/checkout@v4 | |
- name: Set up Python | |
uses: actions/setup-python@v5 | |
with: | |
python-version: "3.10" | |
- name: Fetch all history for all tags and branches | |
# We need to do this so setuptools_scm knows how to set the BoTorch version. | |
run: git fetch --prune --unshallow | |
- if: ${{ !inputs.use_stable_pytorch_gpytorch }} | |
name: Install latest PyTorch & GPyTorch | |
run: | | |
pip install --pre torch torchvision --index-url https://download.pytorch.org/whl/nightly/cpu | |
pip install git+https://github.com/cornellius-gp/linear_operator.git | |
pip install git+https://github.com/cornellius-gp/gpytorch.git | |
- if: ${{ inputs.use_stable_pytorch_gpytorch }} | |
name: Install min required PyTorch, GPyTorch, and linear_operator | |
run: | | |
python setup.py egg_info | |
req_txt="botorch.egg-info/requires.txt" | |
min_torch_version=$(grep '\btorch[>=]=' ${req_txt} | sed 's/[^0-9.]//g') | |
min_gpytorch_version=$(grep '\bgpytorch[>=]=' ${req_txt} | sed 's/[^0-9.]//g') | |
min_linear_operator_version=$(grep '\blinear_operator[>=]=' ${req_txt} | sed 's/[^0-9.]//g') | |
pip install "numpy<2" # Numpy >2.0 is not compatible with PyTorch <2.2. | |
pip install "torch==${min_torch_version}" "gpytorch==${min_gpytorch_version}" "linear_operator==${min_linear_operator_version}" torchvision | |
- name: Install BoTorch with tutorials dependencies | |
env: | |
ALLOW_LATEST_GPYTORCH_LINOP: true | |
run: | | |
pip install .[tutorials] | |
- if: ${{ !inputs.use_stable_ax }} | |
name: Install latest Ax | |
env: | |
# This is so Ax's setup doesn't install a pinned BoTorch version. | |
ALLOW_BOTORCH_LATEST: true | |
run: | | |
pip install git+https://github.com/facebook/Ax.git | |
- if: ${{ inputs.use_stable_ax }} | |
name: Install stable Ax | |
env: | |
ALLOW_BOTORCH_LATEST: true | |
run: | | |
pip install ax-platform --no-binary ax-platform | |
- if: ${{ inputs.smoke_test }} | |
name: Run tutorials with smoke test | |
run: | | |
python scripts/run_tutorials.py -p "$(pwd)" -s | |
- if: ${{ !inputs.smoke_test }} | |
name: Run tutorials without smoke test | |
run: | | |
python scripts/run_tutorials.py -p "$(pwd)" |