Skip to content

Jax-Metal CI

Jax-Metal CI #64

# JAX-Metal plugin CI
name: Jax-Metal CI
on:
schedule:
- cron: "0 12 * * *" # Daily at 12:00 UTC
workflow_dispatch: # allows triggering the workflow run manually
pull_request: # Automatically trigger on pull requests affecting this file
branches:
- main
paths:
- '**workflows/metal_plugin_ci.yml'
jobs:
jax-metal-plugin-test:
strategy:
fail-fast: false # don't cancel all jobs on failure
matrix:
jaxlib-version: ["pypi_latest", "nightly"]
name: "Jax-Metal plugin test (jaxlib=${{ matrix.jaxlib-version }})"
runs-on: [self-hosted, macOS, ARM64]
steps:
- name: Get repo
uses: actions/checkout@v4
with:
path: jax
- name: Setup build and test enviroment
run: |
rm -rf ${GITHUB_WORKSPACE}/jax-metal-venv
python3 -m venv ${GITHUB_WORKSPACE}/jax-metal-venv
source ${GITHUB_WORKSPACE}/jax-metal-venv/bin/activate
pip install -U pip numpy wheel
pip install absl-py pytest
if [[ "${{ matrix.jaxlib-version }}" == "nightly" ]]; then
pip install --pre jaxlib \
-f https://storage.googleapis.com/jax-releases/jax_nightly_releases.html
fi;
cd jax
pip install .
pip install jax-metal
- name: Run test
run: |
source ${GITHUB_WORKSPACE}/jax-metal-venv/bin/activate
export ENABLE_PJRT_COMPATIBILITY=1
cd jax
pytest tests/lax_metal_test.py