Introduce GPU testing to JAX Triton #230
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
name: ci | |
on: | |
push: | |
branches: | |
- main | |
pull_request: | |
branches: | |
- main | |
permissions: | |
contents: read # to fetch code | |
concurrency: | |
group: ${{ github.workflow }}-${{ github.head_ref || github.ref }} | |
cancel-in-progress: true | |
jobs: | |
lint: | |
runs-on: ubuntu-latest | |
steps: | |
- uses: actions/checkout@692973e3d937129bcbf40652eb9f2f61becf3332 # ratchet:actions/checkout@v4 | |
- name: Set up Python 3.10 | |
uses: actions/setup-python@f677139bbe7f9c59b41e40162b753c062f5d49a3 # ratchet:actions/setup-python@v5 | |
with: | |
python-version: '3.10' | |
- uses: pre-commit/action@2c7b3805fd2a0fd8c1884dcaebf91fc102a13ecd # ratchet:pre-commit/[email protected] | |
test: | |
runs-on: linux-x86-g2-48-l4-4gpu | |
container: | |
image: index.docker.io/library/ubuntu@sha256:0e5e4a57c2499249aafc3b40fcd541e9a456aab7296681a3994d631587203f97 # ratchet:ubuntu:22.04 | |
steps: | |
- uses: actions/checkout@692973e3d937129bcbf40652eb9f2f61becf3332 # ratchet:actions/checkout@v4 | |
- name: Set up Python 3.10 | |
uses: actions/setup-python@f677139bbe7f9c59b41e40162b753c062f5d49a3 # ratchet:actions/setup-python@v5 | |
with: | |
python-version: '3.10' | |
- name: Setup Released JAX | |
run: | | |
pip install torch | |
pip install -U "jax[cuda12]" | |
pip install pytest | |
- name: Test JAX Triton | |
run: | | |
echo "Running JAX Triton GPU Tests" | |
nvidia-smi | |
pip install . | |
# Need newer ml-dtypes because we install newer numpy | |
pip install --upgrade ml-dtypes | |
pytest -v --tb=short tests/ | |