Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Updating the installation documentation #85

Merged
merged 7 commits into from
Apr 22, 2024
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
169 changes: 126 additions & 43 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -25,50 +25,119 @@ forward, reverse, and higher-order differentiation, as well as batching using

## Installation

_For now, only a source build is supported._
The easiest ways to install jax-finufft is to install a pre-compiled binary from
PyPI or conda-forge, but if you need GPU support or want to get tuned
performance, you'll want to follow the instructions to install from source as
described below.

For building, you should only need a recent version of Python (>3.6) and
[FFTW](https://www.fftw.org/). GPU-enabled builds also require a working CUDA
compiler (i.e. the CUDA Toolkit), CUDA >= 11.8, and a compatible cuDNN (older versions of CUDA may work but
are untested). At runtime, you'll need `numpy` and `jax`.
### Install binary from PyPI

First, clone the repo and `cd` into the repo root (don't forget the `--recursive` flag because FINUFFT is included as a submodule):
> [!NOTE]
> Only the CPU-enabled build of jax-finufft is available as a binary wheel on
> PyPI. For a GPU-enabled build, you'll need to build from source as described
> below.

To install a binary wheel from [PyPI](https://pypi.org/project/jax-finufft/)
using pip, run the following commands:

```bash
git clone --recursive https://github.com/flatironinstitute/jax-finufft
cd jax-finufft
python -m pip install "jax[cpu]"
python -m pip install jax-finufft
```

Then, you can use `conda` to set up a build environment (but you're welcome to
use whatever workflow works for you!). For example, for a CPU build, you can use:
If this fails, you may need to use a conda-forge binary, or install from source.

### Install binary from conda-forge

> [!NOTE]
> Only the CPU-enabled build of jax-finufft is available as a binary from
> conda-forge. For a GPU-enabled build, you'll need to build from source as
> described below.

To install using [mamba](https://github.com/mamba-org/mamba) (or
[conda](https://docs.conda.io)), run:

```bash
conda create -n jax-finufft -c conda-forge python=3.10 numpy scipy fftw cxx-compiler
conda activate jax-finufft
export CPATH=$CONDA_PREFIX/include:$CPATH
python -m pip install "jax[cpu]"
python -m pip install .
mamba install -c conda-forge jax-finufft
```

The `CPATH` export is needed so that the build can find the headers for libraries like FFTW installed through conda.
### Install from source

#### Dependencies

For a GPU build, while the CUDA libraries and compiler are nominally available through conda,
our experience trying to install them this way suggests that the "traditional"
way of obtaining the [CUDA Toolkit](https://developer.nvidia.com/cuda-downloads) directly
from NVIDIA may work best (see [related advice for Horovod](https://horovod.readthedocs.io/en/stable/conda_include.html)). After installing the CUDA Toolkit, one can set up the rest of the dependencies with:
Unsurprisingly, a key dependency is JAX, which can be installed following the
directions in [the JAX
documentation](https://jax.readthedocs.io/en/latest/installation.html). If
you're going to want to run on a GPU, make sure that you install the appropriate
JAX build.

The non-Python dependencies that you'll need are:

- [FFTW](https://www.fftw.org),
- [OpenMP](https://www.openmp.org) (for CPU, optional),
- CUDA (for GPU, >= 11.8), and
- cuDNN (for GPU).

Older versions of CUDA may work, but they are untested.

Below we provide some example workflows for installing the required dependencies:

<details>
<summary>Install CPU dependencies with mamba or conda</summary>

```bash
conda create -n gpu-jax-finufft -c conda-forge python=3.10 numpy scipy fftw 'gxx<12'
conda activate gpu-jax-finufft
export CPATH=$CONDA_PREFIX/include:$CPATH
export CMAKE_ARGS="-DCMAKE_CUDA_ARCHITECTURES=70 -DJAX_FINUFFT_USE_CUDA=ON"
mamba create -n jax-finufft -c conda-forge python jax fftw cxx-compiler
mamba activate jax-finufft
```
</details>

<details>
<summary>Install GPU dependencies with mamba or conda</summary>

For a GPU build, while the CUDA libraries and compiler are nominally available
through conda, our experience trying to install them this way suggests that the
"traditional" way of obtaining the [CUDA
Toolkit](https://developer.nvidia.com/cuda-downloads) directly from NVIDIA may
work best (see [related advice for
Horovod](https://horovod.readthedocs.io/en/stable/conda_include.html)). After
installing the CUDA Toolkit, one can set up the rest of the dependencies with:

```bash
mamba create -n gpu-jax-finufft -c conda-forge python numpy scipy fftw 'gxx<12'
mamba activate gpu-jax-finufft
export CMAKE_PREFIX_PATH=$CONDA_PREFIX:$CMAKE_PREFIX_PATH
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Did we decide to remove this from the default instructions?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't know whether or not we need it for this one. The package that seems to set it automatically is cxx-compiler, which isn't installed here, but I haven't tested this one, so I just kept it more or less the same as what you had before.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh, good call. The reason we're using gxx<12 is because CUDA 11 only supports GCC <= 11. CUDA 12 does support GCC 12.1, but cxx-compilers might install a newer version... so let's leave this as-is, I agree!

python -m pip install "jax[cuda11_local]" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html
python -m pip install .
```

Other ways of installing JAX are given on the JAX website; the ["local CUDA" install methods](https://jax.readthedocs.io/en/latest/installation.html#pip-installation-gpu-cuda-installed-locally-harder) are preferred for jax-finufft as this ensures the CUDA extensions are compiled with the same Toolkit version as the CUDA runtime.
Other ways of installing JAX are given on the JAX website; the ["local CUDA"
install
methods](https://jax.readthedocs.io/en/latest/installation.html#pip-installation-gpu-cuda-installed-locally-harder)
are preferred for jax-finufft as this ensures the CUDA extensions are compiled
with the same Toolkit version as the CUDA runtime.
</details>

<details>
<summary>Install GPU dependencies using Flatiron module system</summary>

```bash
ml modules/2.2
ml gcc
ml python/3.11
ml fftw
ml cuda/11
ml cudnn
ml nccl

export LD_LIBRARY_PATH=$CUDA_HOME/extras/CUPTI/lib64:$LD_LIBRARY_PATH
export CMAKE_ARGS="$CMAKE_ARGS -DCMAKE_CUDA_ARCHITECTURES=60;70;80;90 -DJAX_FINUFFT_USE_CUDA=ON"
```
</details>

#### GPU build configuration

In the above `CMAKE_ARGS` line, you'll need to select the CUDA architecture(s) you wish to compile for. To query your GPU's CUDA architecture (compute capability), you can run:
You'll need to configure your build to select the appropriate CUDA
architecture(s) using the environment variable `CMAKE_ARGS`. To query your GPU's
CUDA architecture (compute capability), you can run:

```bash
$ nvidia-smi --query-gpu=compute_cap --format=csv,noheader
Expand All @@ -78,38 +147,52 @@ $ nvidia-smi --query-gpu=compute_cap --format=csv,noheader
This corresponds to `CMAKE_CUDA_ARCHITECTURES=70`, i.e.:

```bash
export CMAKE_ARGS="-DCMAKE_CUDA_ARCHITECTURES=70 -DJAX_FINUFFT_USE_CUDA=ON"
export CMAKE_ARGS="$CMAKE_ARGS -DCMAKE_CUDA_ARCHITECTURES=70 -DJAX_FINUFFT_USE_CUDA=ON"
```

Note that the pip installation is running CMake, so `CMAKE_ARGS` has to be set before then, but is not needed at runtime.
Note that the pip installation below uses CMake, so `CMAKE_ARGS` has to be set
before then, but is not needed at runtime.

At runtime, you may also need:

```bash
export LD_LIBRARY_PATH="$CUDA_PATH/extras/CUPTI/lib64:$LD_LIBRARY_PATH"
```

If `CUDA_PATH` isn't set, you'll need to replace it with the path to your CUDA installation in the above line, often something like `/usr/local/cuda`.
If `CUDA_PATH` isn't set, you'll need to replace it with the path to your CUDA
installation in the above line, often something like `/usr/local/cuda`.

For Flatiron users, the following environment setup script can be used instead of conda:
#### Install source from PyPI

<details>
<summary>Environment script</summary>
The source code for all released versions of jax-finufft are available on PyPI,
and this can be installed using:

```bash
ml modules/2.2
ml gcc
ml python/3.11
ml fftw
ml cuda/11
ml cudnn
ml nccl
python -m pip install --no-binary jax-finufft
```

export LD_LIBRARY_PATH=$CUDA_HOME/extras/CUPTI/lib64:$LD_LIBRARY_PATH
export CMAKE_ARGS="-DCMAKE_CUDA_ARCHITECTURES=60;70;80;90 -DJAX_FINUFFT_USE_CUDA=ON"
#### Install source from GitHub

Alternatively, you can check out the source repository from GitHub:

```bash
git clone --recurse-submodules https://github.com/flatironinstitute/jax-finufft
cd jax-finufft
```

</details>
> [!NOTE]
> Don't forget the `--recurse-submodules` argument when cloning the repo because
> the upstream FINUFFT library is included as a git submodule. If you do forget,
> you can run `git submodule update --init --recursive` in your local copy to
> checkout the submodule after the initial clone.

After cloning the repository, you can install the local copy using:

```bash
python -m pip install -e .
```

where the `-e` flag optionally runs an "editable" install.

## Usage

Expand Down