Skip to content

Commit

Permalink
docker gpu attempt
Browse files Browse the repository at this point in the history
  • Loading branch information
syrkis committed Jun 9, 2024
1 parent 9a73564 commit 24c3a12
Show file tree
Hide file tree
Showing 5 changed files with 75 additions and 6 deletions.
6 changes: 3 additions & 3 deletions Dockerfile
Original file line number Diff line number Diff line change
Expand Up @@ -20,9 +20,9 @@ COPY requirements.txt .
RUN python3.11 -m pip install -r requirements.txt

RUN python3.11 -m pip install \
jaxlib==0.3.25+cuda11.cudnn82 -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html \
jax==0.3.25 \
optax==0.1.5
jaxlib==0.4.25+cuda11.cudnn86 -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html \
jax==0.4.25 \
optax==0.2.2

ENV PYGLFW_PREVIEW=1

Expand Down
69 changes: 68 additions & 1 deletion poetry.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,8 @@ numpy-hilbert-curve = "^1.0.1"
wandb = "^0.17.0"
seaborn = "^0.13.2"
jax = "^0.4.28"
jaxlib = "^0.4.28"
optax = "^0.2.2"
scikit-learn = "^1.5.0"


[build-system]
Expand Down
3 changes: 3 additions & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ idna==3.7
jax==0.4.28
jaxlib==0.4.28
Jinja2==3.1.4
joblib==1.4.2
jsonschema==4.22.0
jsonschema-specifications==2023.12.1
kiwisolver==1.4.5
Expand Down Expand Up @@ -50,6 +51,7 @@ referencing==0.35.1
requests==2.32.2
rich==13.7.1
rpds-py==0.18.1
scikit-learn==1.5.0
scipy==1.13.1
seaborn==0.13.2
sentry-sdk==2.4.0
Expand All @@ -59,6 +61,7 @@ smmap==5.0.1
streamlit==1.35.0
sympy==1.12
tenacity==8.3.0
threadpoolctl==3.5.0
toml==0.10.2
toolz==0.12.1
tornado==6.4
Expand Down
1 change: 0 additions & 1 deletion src/plots.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@
# by: Noah Syrkis

# imports
import plotly as py
import pandas as pd
import jax.numpy as jnp
import numpy as np
Expand Down

0 comments on commit 24c3a12

Please sign in to comment.