diff --git a/README.md b/README.md index 6ce200a..38eb6d3 100644 --- a/README.md +++ b/README.md @@ -79,7 +79,6 @@ The non-Python dependencies that you'll need are: - [FFTW](https://www.fftw.org), - [OpenMP](https://www.openmp.org) (for CPU, optional), - CUDA (for GPU, >= 11.8), and -- cuDNN (for GPU). Older versions of CUDA may work, but they are untested. @@ -116,47 +115,68 @@ Other ways of installing JAX are given on the JAX website; the ["local CUDA" install methods](https://jax.readthedocs.io/en/latest/installation.html#pip-installation-gpu-cuda-installed-locally-harder) are preferred for jax-finufft as this ensures the CUDA extensions are compiled -with the same Toolkit version as the CUDA runtime. +with the same Toolkit version as the CUDA runtime. However, this is not required +as long as both JAX and jax-finufft use CUDA with the same major version.
Install GPU dependencies using Flatiron module system ```bash -ml modules/2.2 -ml gcc -ml python/3.11 -ml fftw -ml cuda/11 -ml cudnn -ml nccl - -export LD_LIBRARY_PATH=$CUDA_HOME/extras/CUPTI/lib64:$LD_LIBRARY_PATH +ml modules/2.3 \ + gcc \ + python/3.11 \ + fftw \ + cuda/12 + export CMAKE_ARGS="$CMAKE_ARGS -DCMAKE_CUDA_ARCHITECTURES=60;70;80;90 -DJAX_FINUFFT_USE_CUDA=ON" ```
+ +#### Configuring the build +There are several important CMake variables that control aspects of the jax-finufft and (cu)finufft builds. These include: + +- **`JAX_FINUFFT_USE_CUDA`** [disabled by default]: build with GPU support +- **`CMAKE_CUDA_ARCHITECTURES`** [default `native`]: the target GPU architecture. `native` means the GPU arch of the build system. +- **`FINUFFT_ARCH_FLAGS`** [default `-march=native`]: the target CPU architecture. `native` means the CPU arch of the build system. + +Each of these can be set as `-Ccmake.define.NAME=VALUE` arguments to `pip install`. For example, +to build with GPU support from the repo root, run: + +```bash +pip install -Ccmake.define.JAX_FINUFFT_USE_CUDA=ON . +``` + +Use multiple `-C` arguments to set multiple variables. The `-C` argument will work with any of the source installation methods (e.g. PyPI source dist, GitHub, etc). + +Build options can also be set with the `CMAKE_ARGS` environment variable. For example: + +```bash +export CMAKE_ARGS="$CMAKE_ARGS -DJAX_FINUFFT_USE_CUDA=ON" +``` + #### GPU build configuration +Building with GPU support requires passing `JAX_FINUFFT_USE_CUDA=ON` to CMake. See [Configuring the build](#configuring-the-build). -You'll need to configure your build to select the appropriate CUDA -architecture(s) using the environment variable `CMAKE_ARGS`. To query your GPU's -CUDA architecture (compute capability), you can run: +By default, jax-finufft will build for the GPU of the build machine. If you need to target +a different compute capability, such as 8.0 for Ampere, set `CMAKE_CUDA_ARCHITECTURES` as a CMake define: ```bash -$ nvidia-smi --query-gpu=compute_cap --format=csv,noheader -7.0 +pip install -Ccmake.define.JAX_FINUFFT_USE_CUDA=ON -Ccmake.define.CMAKE_CUDA_ARCHITECTURES=80 . ``` -This corresponds to `CMAKE_CUDA_ARCHITECTURES=70`, i.e.: +`CMAKE_CUDA_ARCHITECTURES` also takes a semicolon-separated list. +To detect the arch for a specific GPU, one can run: ```bash -export CMAKE_ARGS="$CMAKE_ARGS -DCMAKE_CUDA_ARCHITECTURES=70 -DJAX_FINUFFT_USE_CUDA=ON" +$ nvidia-smi --query-gpu=compute_cap --format=csv,noheader +8.0 ``` -Note that the pip installation below uses CMake, so `CMAKE_ARGS` has to be set -before then, but is not needed at runtime. +The values are also listed on the [NVIDIA website](https://developer.nvidia.com/cuda-gpus). -At runtime, you may also need: +In some cases, you may also need the following at runtime: ```bash export LD_LIBRARY_PATH="$CUDA_PATH/extras/CUPTI/lib64:$LD_LIBRARY_PATH" @@ -197,6 +217,13 @@ python -m pip install -e . where the `-e` flag optionally runs an "editable" install. +As yet another alternative, the latest development version from GitHub can be +installed directly (i.e. without cloning first) with + +```bash +python -m pip install git+https://github.com/flatironinstitute/jax-finufft.git +``` + ## Usage This library provides two high-level functions (and these should be all that you