Skip to content

Commit

Permalink
docs: document JAX backend (#4259)
Browse files Browse the repository at this point in the history
<!-- This is an auto-generated comment: release notes by coderabbit.ai
-->
## Summary by CodeRabbit

## Release Notes

- **New Features**
- Introduced support for the JAX backend, expanding user options for
model training and execution.
- Added installation instructions for JAX within the source installation
documentation.
- Included new environment variables related to JAX to enhance
configuration options.

- **Documentation Updates**
- Updated various documentation files to reflect the addition of JAX,
including sections on model commands, supported backends, and
environment variables.
- Enhanced documentation with a visual representation for JAX through an
icon.
- Improved clarity and organization of installation instructions for
DeePMD-kit.
- Updated the README to highlight JAX as a supported backend and reflect
changes in version history.
<!-- end of auto-generated comment: release notes by coderabbit.ai -->

---------

Signed-off-by: Jinzhe Zeng <[email protected]>
  • Loading branch information
njzjz authored Oct 29, 2024
1 parent 82aaa0d commit dd36e6c
Show file tree
Hide file tree
Showing 14 changed files with 131 additions and 28 deletions.
4 changes: 2 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ For more information, check the [documentation](https://deepmd.readthedocs.io/).

### Highlighted features

- **interfaced with multiple backends**, including TensorFlow and PyTorch, the most popular deep learning frameworks, making the training process highly automatic and efficient.
- **interfaced with multiple backends**, including TensorFlow, PyTorch, and JAX, the most popular deep learning frameworks, making the training process highly automatic and efficient.
- **interfaced with high-performance classical MD and quantum (path-integral) MD packages**, including LAMMPS, i-PI, AMBER, CP2K, GROMACS, OpenMM, and ABUCUS.
- **implements the Deep Potential series models**, which have been successfully applied to finite and extended systems, including organic molecules, metals, semiconductors, insulators, etc.
- **implements MPI and GPU supports**, making it highly efficient for high-performance parallel and distributed computing.
Expand Down Expand Up @@ -72,7 +72,7 @@ See [our latest paper](https://doi.org/10.1063/5.0155600) for details of all fea

#### v3

- Multiple backends supported. Add a PyTorch backend.
- Multiple backends supported. Add PyTorch and JAX backends.
- The DPA-2 model.

## Install and use DeePMD-kit
Expand Down
1 change: 1 addition & 0 deletions doc/_static/jax.svg
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
9 changes: 9 additions & 0 deletions doc/backend.md
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,15 @@ DeePMD-kit does not use the TensorFlow v2 API but uses the TensorFlow v1 API (`t
[PyTorch](https://pytorch.org/) 2.0 or above is required.
While `.pth` and `.pt` are the same in the PyTorch package, they have different meanings in the DeePMD-kit to distinguish the model and the checkpoint.

### JAX {{ jax_icon }}

- Model filename extension: `.xlo`
- Checkpoint filename extension: `.jax`

[JAX](https://jax.readthedocs.io/) 0.4.33 (which requires Python 3.10 or above) or above is required.
Both `.xlo` and `.jax` are customized format extensions defined in DeePMD-kit, since JAX has no convention for file extensions.
Currently, this backend is developed actively, and has no support for training and the C++ interface.

### DP {{ dpmodel_icon }}

:::{note}
Expand Down
1 change: 1 addition & 0 deletions doc/conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -168,6 +168,7 @@
myst_substitutions = {
"tensorflow_icon": """![TensorFlow](/_static/tensorflow.svg){class=platform-icon}""",
"pytorch_icon": """![PyTorch](/_static/pytorch.svg){class=platform-icon}""",
"jax_icon": """![JAX](/_static/jax.svg){class=platform-icon}""",
"dpmodel_icon": """![DP](/_static/logo_icon.svg){class=platform-icon}""",
}

Expand Down
1 change: 1 addition & 0 deletions doc/env.md
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ See [How to control the parallelism of a job](./troubleshooting/howtoset_num_nod
- If ROCm is used, [ROCm environment variables](https://rocm.docs.amd.com/en/latest/conceptual/gpu-isolation.html#environment-variables) can be used to control ROCm devices.
- {{ tensorflow_icon }} If TensorFlow is used, TensorFlow environment variables can be used.
- {{ pytorch_icon }} If PyTorch is used, [PyTorch environment variables](https://pytorch.org/docs/stable/torch_environment_variables.html) can be used.
- {{ jax_icon }} [`JAX_PLATFORMS`](https://jax.readthedocs.io/en/latest/faq.html#controlling-data-and-computation-placement-on-devices) and [`XLA_FLAGS`](https://jax.readthedocs.io/en/latest/gpu_performance_tips.html#xla-performance-flags) are commonly used.

## Python interface only

Expand Down
4 changes: 1 addition & 3 deletions doc/install/easy-install-dev.md
Original file line number Diff line number Diff line change
Expand Up @@ -16,14 +16,12 @@ For CUDA 11.8 support, use the `devel_cu11` tag.

## Install with pip

Below is an one-line shell command to download the [artifact](https://nightly.link/deepmodeling/deepmd-kit/workflows/build_wheel/devel/artifact.zip) containing wheels and install it with `pip`:
Follow [the documentation for the stable version](easy-install.md#install-python-interface-with-pip), but add `--pre` and `--extra-index-url` options like below:

```sh
pip install -U --pre deepmd-kit[gpu,cu12,lmp,torch] --extra-index-url https://deepmodeling.github.io/deepmd-kit/simple
```

`cu12` and `lmp` are optional, which is the same as the stable version.

## Download pre-compiled C Library {{ tensorflow_icon }}

:::{note}
Expand Down
96 changes: 83 additions & 13 deletions doc/install/easy-install.md
Original file line number Diff line number Diff line change
Expand Up @@ -104,44 +104,114 @@ docker pull ghcr.io/deepmodeling/deepmd-kit:2.2.8_cuda12.0_gpu

## Install Python interface with pip

If you have no existing TensorFlow installed, you can use `pip` to install the pre-built package of the Python interface with CUDA 12 supported:
[Create a new environment](https://docs.deepmodeling.com/faq/conda.html#how-to-create-a-new-conda-pip-environment), and then execute the following command:

:::::::{tab-set}

::::::{tab-item} TensorFlow {{ tensorflow_icon }}

:::::{tab-set}

::::{tab-item} CUDA 12

```bash
pip install deepmd-kit[gpu,cu12,torch]
pip install deepmd-kit[gpu,cu12]
```

`cu12` is required only when CUDA Toolkit and cuDNN were not installed.

To install the package built against CUDA 11.8, use
::::

::::{tab-item} CUDA 11

```bash
pip install torch --index-url https://download.pytorch.org/whl/cu118
pip install deepmd-kit-cu11[gpu,cu11]
```

Or install the CPU version without CUDA supported:
::::

::::{tab-item} CPU

```bash
pip install torch --index-url https://download.pytorch.org/whl/cpu
pip install deepmd-kit[cpu]
```

::::

:::::

[The LAMMPS module](../third-party/lammps-command.md) and [the i-PI driver](../third-party/ipi.md) are only provided on Linux and macOS for the TensorFlow backend. To install LAMMPS and/or i-PI, add `lmp` and/or `ipi` to extras:

```bash
pip install deepmd-kit[gpu,cu12,torch,lmp,ipi]
pip install deepmd-kit[gpu,cu12,lmp,ipi]
```

MPICH is required for parallel running.

:::{Warning}
When installing from pip, only the TensorFlow {{ tensorflow_icon }} backend is supported with LAMMPS and i-PI.
:::
::::::

::::::{tab-item} PyTorch {{ pytorch_icon }}

:::::{tab-set}

::::{tab-item} CUDA 12

```bash
pip install deepmd-kit[torch]
```

::::

::::{tab-item} CUDA 11.8

```bash
pip install torch --index-url https://download.pytorch.org/whl/cu118
pip install deepmd-kit-cu11
```

::::

::::{tab-item} CPU

```bash
pip install torch --index-url https://download.pytorch.org/whl/cpu
pip install deepmd-kit
```

::::

:::::

::::::

::::::{tab-item} JAX {{ jax_icon }}

:::::{tab-set}

::::{tab-item} CUDA 12

```bash
pip install deepmd-kit[jax] jax[cuda12]
```

::::

::::{tab-item} CPU

```bash
pip install deepmd-kit[jax]
```

::::

:::::

::::::

:::::::

It is suggested to install the package into an isolated environment.
The supported platform includes Linux x86-64 and aarch64 with GNU C Library 2.28 or above, macOS x86-64 and arm64, and Windows x86-64.
A specific version of TensorFlow and PyTorch which is compatible with DeePMD-kit will be also installed.

:::{Warning}
If your platform is not supported, or you want to build against the installed TensorFlow, or you want to enable ROCM support, please [build from source](install-from-source.md).
If your platform is not supported, or you want to build against the installed backends, or you want to enable ROCM support, please [build from source](install-from-source.md).
:::
15 changes: 15 additions & 0 deletions doc/install/install-from-source.md
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,21 @@ One can also [use conda](https://docs.deepmodeling.org/faq/conda.html) to instal

:::

:::{tab-item} JAX {{ jax_icon }}

To install [JAX AI Stack](https://github.com/jax-ml/jax-ai-stack), run

```sh
pip install jax-ai-stack
```

One can also install packages in JAX AI Stack manually.
Follow [JAX documentation](https://jax.readthedocs.io/en/latest/installation.html) to install JAX built against different CUDA versions or without CUDA.

One can also [use conda](https://docs.deepmodeling.org/faq/conda.html) to install JAX from [conda-forge](https://conda-forge.org).

:::

::::

It is important that every time a new shell is started and one wants to use `DeePMD-kit`, the virtual environment should be activated by
Expand Down
8 changes: 8 additions & 0 deletions doc/model/sel.md
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,14 @@ dp --pt neighbor-stat -s data -r 6.0 -t O H

:::

:::{tab-item} JAX {{ jax_icon }}

```sh
dp --jax neighbor-stat -s data -r 6.0 -t O H
```

:::

::::

where `data` is the directory of data, `6.0` is the cutoff radius, and `O` and `H` is the type map. The program will give the `max_nbor_size`. For example, `max_nbor_size` of the water example is `[38, 72]`, meaning an atom may have 38 O neighbors and 72 H neighbors in the training data.
Expand Down
4 changes: 2 additions & 2 deletions doc/model/train-energy.md
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
# Fit energy {{ tensorflow_icon }} {{ pytorch_icon }} {{ dpmodel_icon }}
# Fit energy {{ tensorflow_icon }} {{ pytorch_icon }} {{ jax_icon }} {{ dpmodel_icon }}

:::{note}
**Supported backends**: TensorFlow {{ tensorflow_icon }}, PyTorch {{ pytorch_icon }}, DP {{ dpmodel_icon }}
**Supported backends**: TensorFlow {{ tensorflow_icon }}, PyTorch {{ pytorch_icon }}, JAX {{ jax_icon }}, DP {{ dpmodel_icon }}
:::

In this section, we will take `$deepmd_source_dir/examples/water/se_e2_a/input.json` as an example of the input file.
Expand Down
4 changes: 2 additions & 2 deletions doc/model/train-fitting-dos.md
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
# Fit electronic density of states (DOS) {{ tensorflow_icon }} {{ pytorch_icon }} {{ dpmodel_icon }}
# Fit electronic density of states (DOS) {{ tensorflow_icon }} {{ pytorch_icon }} {{ jax_icon }} {{ dpmodel_icon }}

:::{note}
**Supported backends**: TensorFlow {{ tensorflow_icon }}, PyTorch {{ pytorch_icon }}, DP {{ dpmodel_icon }}
**Supported backends**: TensorFlow {{ tensorflow_icon }}, PyTorch {{ pytorch_icon }}, JAX {{ jax_icon }}, DP {{ dpmodel_icon }}
:::

Here we present an API to DeepDOS model, which can be used to fit electronic density of state (DOS) (which is a vector).
Expand Down
4 changes: 2 additions & 2 deletions doc/model/train-se-atten.md
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
# Descriptor `"se_atten"` {{ tensorflow_icon }} {{ pytorch_icon }} {{ dpmodel_icon }}
# Descriptor `"se_atten"` {{ tensorflow_icon }} {{ pytorch_icon }} {{ jax_icon }} {{ dpmodel_icon }}

:::{note}
**Supported backends**: TensorFlow {{ tensorflow_icon }}, PyTorch {{ pytorch_icon }}, DP {{ dpmodel_icon }}
**Supported backends**: TensorFlow {{ tensorflow_icon }}, PyTorch {{ pytorch_icon }}, JAX {{ jax_icon }}, DP {{ dpmodel_icon }}
:::

## DPA-1: Pretraining of Attention-based Deep Potential Model for Molecular Simulation
Expand Down
4 changes: 2 additions & 2 deletions doc/model/train-se-e2-a.md
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
# Descriptor `"se_e2_a"` {{ tensorflow_icon }} {{ pytorch_icon }} {{ dpmodel_icon }}
# Descriptor `"se_e2_a"` {{ tensorflow_icon }} {{ pytorch_icon }} {{ jax_icon }} {{ dpmodel_icon }}

:::{note}
**Supported backends**: TensorFlow {{ tensorflow_icon }}, PyTorch {{ pytorch_icon }}, DP {{ dpmodel_icon }}
**Supported backends**: TensorFlow {{ tensorflow_icon }}, PyTorch {{ pytorch_icon }}, JAX {{ jax_icon }}, DP {{ dpmodel_icon }}
:::

The notation of `se_e2_a` is short for the Deep Potential Smooth Edition (DeepPot-SE) constructed from all information (both angular and radial) of atomic configurations. The `e2` stands for the embedding with two-atoms information. This descriptor was described in detail in [the DeepPot-SE paper](https://arxiv.org/abs/1805.09003).
Expand Down
4 changes: 2 additions & 2 deletions doc/model/train-se-e2-r.md
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
# Descriptor `"se_e2_r"` {{ tensorflow_icon }} {{ pytorch_icon }} {{ dpmodel_icon }}
# Descriptor `"se_e2_r"` {{ tensorflow_icon }} {{ pytorch_icon }} {{ jax_icon }} {{ dpmodel_icon }}

:::{note}
**Supported backends**: TensorFlow {{ tensorflow_icon }}, PyTorch {{ pytorch_icon }}, DP {{ dpmodel_icon }}
**Supported backends**: TensorFlow {{ tensorflow_icon }}, PyTorch {{ pytorch_icon }}, JAX {{ jax_icon }}, DP {{ dpmodel_icon }}
:::

The notation of `se_e2_r` is short for the Deep Potential Smooth Edition (DeepPot-SE) constructed from the radial information of atomic configurations. The `e2` stands for the embedding with two-atom information.
Expand Down

0 comments on commit dd36e6c

Please sign in to comment.