Skip to content

Commit

Permalink
cuda 12 ??
Browse files Browse the repository at this point in the history
  • Loading branch information
syrkis committed Jun 15, 2024
1 parent e549b95 commit a9183a4
Showing 1 changed file with 2 additions and 3 deletions.
5 changes: 2 additions & 3 deletions Dockerfile
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
# Description: Dockerfile for JAX with CUDA support
FROM nvidia/cuda:11.7.1-cudnn8-devel-ubuntu20.04
FROM nvidia/cuda:12.0.0-cudnn8-devel-ubuntu20.04

# Set the working directory
WORKDIR /workspace
Expand Down Expand Up @@ -29,8 +29,7 @@ RUN python3.11 -m pip install -r requirements.txt

# Install JAX with CUDA support. HPC is on CUDA 11, and JAX 0.2.25 is the latest version for that
RUN python3.11 -m pip install --upgrade \
"jax[cuda11_pip]==0.4.26" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html \
"jaxlib==0.4.26" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html \
pip install -U "jax[cuda12]" \
optax

# Set the environment variables
Expand Down

0 comments on commit a9183a4

Please sign in to comment.