Skip to content

Introduce GPU testing to JAX Triton #225

Introduce GPU testing to JAX Triton

Introduce GPU testing to JAX Triton #225

Workflow file for this run

name: ci
on:
push:
branches:
- main
pull_request:
branches:
- main
permissions:
contents: read # to fetch code
concurrency:
group: ${{ github.workflow }}-${{ github.head_ref || github.ref }}
cancel-in-progress: true
jobs:
lint:
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@692973e3d937129bcbf40652eb9f2f61becf3332 # ratchet:actions/checkout@v4
- name: Set up Python 3.10
uses: actions/setup-python@f677139bbe7f9c59b41e40162b753c062f5d49a3 # ratchet:actions/setup-python@v5
with:
python-version: '3.10'
- uses: pre-commit/action@2c7b3805fd2a0fd8c1884dcaebf91fc102a13ecd # ratchet:pre-commit/[email protected]
test:
runs-on: linux-x86-g2-48-l4-4gpu
container:
# TODO: change image based on what is needed for these tests
image: us-central1-docker.pkg.dev/tensorflow-sigs/tensorflow/ml-build:latest@sha256:772fe35a3bfd0112ac274771e2183f9f332900f01a655b51e2ca28463b098842 # ratchet:us-central1-docker.pkg.dev/tensorflow-sigs/tensorflow/ml-build:latest
steps:
- uses: actions/checkout@692973e3d937129bcbf40652eb9f2f61becf3332 # ratchet:actions/checkout@v4
# - name: Setup Compat Driver
# run: |
# # This container should already have the CUDA apt repos setup
# apt-get update
# apt-get install -y --no-install-recommends cuda-compat-12-6
- name: Wait For Connection
uses: google-ml-infra/actions/ci_connection@059880faa02ce79cc9969629e652005ab8b5f332 # ratchet:google-ml-infra/actions/ci_connection@main
with:
halt-dispatch-input: "1"
- name: Setup Released JAX
run: |
pip install -U "jax[cuda12]"
pip install pytest
- name: Test JAX Triton
run: |
echo "Running JAX Triton GPU Tests"
nvidia-smi
pip install .
# Need newer ml-dtypes because we install newer numpy
pip install --upgrade ml-dtypes
pytest -v --tb=short tests/