-
Notifications
You must be signed in to change notification settings - Fork 19
/
Copy pathsetup.sh
113 lines (101 loc) · 4.27 KB
/
setup.sh
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
#!/bin/bash
# Copyright 2024 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# Description:
# bash setup.sh MODE={stable,nightly} DEVICE={tpu,gpu}
# You need to specify a MODE, default value stable.
# For MODE=stable you may additionally specify JAX_VERSION, e.g. JAX_VERSION=0.4.33
# Enable "exit immediately if any command fails" option
set -e
export DEBIAN_FRONTEND=noninteractive
(sudo bash || bash) <<'EOF'
apt update && \
apt install -y numactl lsb-release gnupg curl net-tools iproute2 procps lsof git ethtool && \
export GCSFUSE_REPO=gcsfuse-`lsb_release -c -s`
echo "deb https://packages.cloud.google.com/apt $GCSFUSE_REPO main" | tee /etc/apt/sources.list.d/gcsfuse.list
curl https://packages.cloud.google.com/apt/doc/apt-key.gpg | apt-key add -
apt update -y && apt -y install gcsfuse
rm -rf /var/lib/apt/lists/*
EOF
# Set environment variables from command line arguments
for ARGUMENT in "$@"; do
IFS='=' read -r KEY VALUE <<< "$ARGUMENT"
export "$KEY"="$VALUE"
done
# Default device is TPU
if [[ -z "$DEVICE" ]]; then
export DEVICE="tpu"
fi
# Unset JAX_VERSION if set to "NONE"
if [[ $JAX_VERSION == NONE ]]; then
unset JAX_VERSION
fi
# Validate JAX_VERSION is only used with stable mode
if [[ -n $JAX_VERSION && ! ($MODE == "stable" || -z $MODE) ]]; then
echo -e "\n\nError: You can only specify a JAX_VERSION with stable mode.\n\n"
exit 1
fi
# Install JAX and JAXlib based on the specified mode
if [[ "$MODE" == "stable" || ! -v MODE ]]; then
# Stable mode
if [[ $DEVICE == "tpu" ]]; then
echo "Installing stable jax, jaxlib for tpu"
if [[ -n "$JAX_VERSION" ]]; then
echo "Installing stable jax, jaxlib, libtpu version ${JAX_VERSION}"
pip3 install "jax[tpu]==${JAX_VERSION}" -f https://storage.googleapis.com/jax-releases/libtpu_releases.html
else
echo "Installing stable jax, jaxlib, libtpu
for tpu"
pip3 install 'jax[tpu]>0.4' -f https://storage.googleapis.com/jax-releases/libtpu_releases.html
fi
elif [[ $DEVICE == "gpu" ]]; then
echo "Installing stable jax, jaxlib for NVIDIA gpu"
if [[ -n "$JAX_VERSION" ]]; then
echo "Installing stable jax, jaxlib ${JAX_VERSION}"
pip3 install -U "jax[cuda12]==${JAX_VERSION}" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html
else
echo "Installing stable jax, jaxlib, libtpu for NVIDIA gpu"
pip3 install "jax[cuda12]" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html
fi
export NVTE_FRAMEWORK=jax
pip3 install git+https://github.com/NVIDIA/TransformerEngine.git@stable
fi
elif [[ $MODE == "nightly" ]]; then
# Nightly mode
if [[ $DEVICE == "gpu" ]]; then
echo "Installing jax-nightly, jaxlib-nightly"
# Install jax-nightly
pip install -U --pre jax jaxlib jax-cuda12-plugin[with_cuda] jax-cuda12-pjrt -f https://storage.googleapis.com/jax-releases/jax_nightly_releases.html
# Install Transformer Engine
export NVTE_FRAMEWORK=jax
pip3 install git+https://github.com/NVIDIA/TransformerEngine.git@stable
elif [[ $DEVICE == "tpu" ]]; then
echo "Installing jax-nightly,jaxlib-nightly"
# Install jax-nightly
pip3 install --pre -U jax -f https://storage.googleapis.com/jax-releases/jax_nightly_releases.html
# Install jaxlib-nightly
pip3 install --pre -U jaxlib -f https://storage.googleapis.com/jax-releases/jaxlib_nightly_releases.html
# Install libtpu-nightly
pip3 install --pre -U libtpu-nightly -f https://storage.googleapis.com/jax-releases/libtpu_releases.html
fi
echo "Installing nightly tensorboard plugin profile"
pip3 install tbp-nightly --upgrade
else
echo -e "\n\nError: You can only set MODE to [stable,nightly].\n\n"
exit 1
fi
# Install dependencies from requirements.txt
pip3 install -U -r requirements.txt
# Install maxdiffusion
pip3 install -U .