diff --git a/Dockerfile b/Dockerfile index 7710100..ce2b7b3 100644 --- a/Dockerfile +++ b/Dockerfile @@ -1,5 +1,5 @@ # Description: Dockerfile for JAX with CUDA support -FROM nvidia/cuda:11.7.1-cudnn8-devel-ubuntu20.04 +FROM nvidia/cuda:12.0.0-cudnn8-devel-ubuntu20.04 # Set the working directory WORKDIR /workspace @@ -29,8 +29,7 @@ RUN python3.11 -m pip install -r requirements.txt # Install JAX with CUDA support. HPC is on CUDA 11, and JAX 0.2.25 is the latest version for that RUN python3.11 -m pip install --upgrade \ - "jax[cuda11_pip]==0.4.26" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html \ - "jaxlib==0.4.26" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html \ + pip install -U "jax[cuda12]" \ optax # Set the environment variables