Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

docs: document JAX backend #4259

Merged
merged 4 commits into from
Oct 29, 2024
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions doc/_static/jax.svg
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
9 changes: 9 additions & 0 deletions doc/backend.md
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,15 @@ DeePMD-kit does not use the TensorFlow v2 API but uses the TensorFlow v1 API (`t
[PyTorch](https://pytorch.org/) 2.0 or above is required.
While `.pth` and `.pt` are the same in the PyTorch package, they have different meanings in the DeePMD-kit to distinguish the model and the checkpoint.

### JAX {{ jax_icon }}

- Model filename extension: `.xlo`
- Checkpoint filename extension: `.jax`

[JAX](https://jax.readthedocs.io/) 0.4.33 (which requires Python 3.10 or above) or above is required.
Both `.xlo` and `.jax` are customized format extensions defined in DeePMD-kit, since JAX has no convention for file extensions.
Currently, this backend is developed actively, and has no support for training and the C++ interface.

### DP {{ dpmodel_icon }}

:::{note}
Expand Down
1 change: 1 addition & 0 deletions doc/conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -167,6 +167,7 @@
myst_substitutions = {
"tensorflow_icon": """![TensorFlow](/_static/tensorflow.svg){class=platform-icon}""",
"pytorch_icon": """![PyTorch](/_static/pytorch.svg){class=platform-icon}""",
"jax_icon": """![JAX](/_static/jax.svg){class=platform-icon}""",
"dpmodel_icon": """![DP](/_static/logo_icon.svg){class=platform-icon}""",
}

Expand Down
1 change: 1 addition & 0 deletions doc/env.md
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ See [How to control the parallelism of a job](./troubleshooting/howtoset_num_nod
- If ROCm is used, [ROCm environment variables](https://rocm.docs.amd.com/en/latest/conceptual/gpu-isolation.html#environment-variables) can be used to control ROCm devices.
- {{ tensorflow_icon }} If TensorFlow is used, TensorFlow environment variables can be used.
- {{ pytorch_icon }} If PyTorch is used, [PyTorch environment variables](https://pytorch.org/docs/stable/torch_environment_variables.html) can be used.
- {{ jax_icon }} [`JAX_PLATFORMS`](https://jax.readthedocs.io/en/latest/faq.html#controlling-data-and-computation-placement-on-devices) and [`XLA_FLAGS`](https://jax.readthedocs.io/en/latest/gpu_performance_tips.html#xla-performance-flags) are commonly used.

## Python interface only

Expand Down
4 changes: 1 addition & 3 deletions doc/install/easy-install-dev.md
Original file line number Diff line number Diff line change
Expand Up @@ -16,14 +16,12 @@ For CUDA 11.8 support, use the `devel_cu11` tag.

## Install with pip

Below is an one-line shell command to download the [artifact](https://nightly.link/deepmodeling/deepmd-kit/workflows/build_wheel/devel/artifact.zip) containing wheels and install it with `pip`:
Follow [the documentation for the stable version](easy-install.md#install-python-interface-with-pip), but add `--pre` and `--extra-index-url` options like below:

```sh
pip install -U --pre deepmd-kit[gpu,cu12,lmp,torch] --extra-index-url https://deepmodeling.github.io/deepmd-kit/simple
```
njzjz marked this conversation as resolved.
Show resolved Hide resolved

`cu12` and `lmp` are optional, which is the same as the stable version.

## Download pre-compiled C Library {{ tensorflow_icon }}

:::{note}
Expand Down
96 changes: 83 additions & 13 deletions doc/install/easy-install.md
Original file line number Diff line number Diff line change
Expand Up @@ -104,44 +104,114 @@ docker pull ghcr.io/deepmodeling/deepmd-kit:2.2.8_cuda12.0_gpu

## Install Python interface with pip

If you have no existing TensorFlow installed, you can use `pip` to install the pre-built package of the Python interface with CUDA 12 supported:
[Create a new environment](https://docs.deepmodeling.com/faq/conda.html#how-to-create-a-new-conda-pip-environment), and then execute the following command:

:::::::{tab-set}

::::::{tab-item} TensorFlow {{ tensorflow_icon }}

:::::{tab-set}

::::{tab-item} CUDA 12

```bash
pip install deepmd-kit[gpu,cu12,torch]
pip install deepmd-kit[gpu,cu12]
```

`cu12` is required only when CUDA Toolkit and cuDNN were not installed.

To install the package built against CUDA 11.8, use
::::

::::{tab-item} CUDA 11

```bash
pip install torch --index-url https://download.pytorch.org/whl/cu118
pip install deepmd-kit-cu11[gpu,cu11]
```

Or install the CPU version without CUDA supported:
::::

::::{tab-item} CPU

```bash
pip install torch --index-url https://download.pytorch.org/whl/cpu
pip install deepmd-kit[cpu]
```

::::

:::::

[The LAMMPS module](../third-party/lammps-command.md) and [the i-PI driver](../third-party/ipi.md) are only provided on Linux and macOS for the TensorFlow backend. To install LAMMPS and/or i-PI, add `lmp` and/or `ipi` to extras:

```bash
pip install deepmd-kit[gpu,cu12,torch,lmp,ipi]
pip install deepmd-kit[gpu,cu12,lmp,ipi]
```

MPICH is required for parallel running.

:::{Warning}
When installing from pip, only the TensorFlow {{ tensorflow_icon }} backend is supported with LAMMPS and i-PI.
:::
::::::

::::::{tab-item} PyTorch {{ pytorch_icon }}

:::::{tab-set}

::::{tab-item} CUDA 12

```bash
pip install deepmd-kit[torch]
```

::::

::::{tab-item} CUDA 11.8

```bash
pip install torch --index-url https://download.pytorch.org/whl/cu118
pip install deepmd-kit-cu11
```

::::

::::{tab-item} CPU

```bash
pip install torch --index-url https://download.pytorch.org/whl/cpu
pip install deepmd-kit
```

::::

:::::

::::::

::::::{tab-item} JAX {{ jax_icon }}

:::::{tab-set}

::::{tab-item} CUDA 12

```bash
pip install deepmd-kit[jax] jax[cuda12]
```

::::

::::{tab-item} CPU

```bash
pip install deepmd-kit[jax]
```

::::

:::::

::::::

:::::::
njzjz marked this conversation as resolved.
Show resolved Hide resolved

It is suggested to install the package into an isolated environment.
The supported platform includes Linux x86-64 and aarch64 with GNU C Library 2.28 or above, macOS x86-64 and arm64, and Windows x86-64.
A specific version of TensorFlow and PyTorch which is compatible with DeePMD-kit will be also installed.

:::{Warning}
If your platform is not supported, or you want to build against the installed TensorFlow, or you want to enable ROCM support, please [build from source](install-from-source.md).
If your platform is not supported, or you want to build against the installed backends, or you want to enable ROCM support, please [build from source](install-from-source.md).
:::
15 changes: 15 additions & 0 deletions doc/install/install-from-source.md
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,21 @@ One can also [use conda](https://docs.deepmodeling.org/faq/conda.html) to instal

:::

:::{tab-item} JAX {{ jax_icon }}

To install [JAX AI Stack](https://github.com/jax-ml/jax-ai-stack), run

```sh
pip install jax-ai-stack
```

One can also install packages in JAX AI Stack manually.
Follow [JAX documentation](https://jax.readthedocs.io/en/latest/installation.html) to install JAX built against different CUDA versions or without CUDA.

One can also [use conda](https://docs.deepmodeling.org/faq/conda.html) to install JAX from [conda-forge](https://conda-forge.org).

:::

::::

It is important that every time a new shell is started and one wants to use `DeePMD-kit`, the virtual environment should be activated by
Expand Down
8 changes: 8 additions & 0 deletions doc/model/sel.md
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,14 @@ dp --pt neighbor-stat -s data -r 6.0 -t O H

:::

:::{tab-item} JAX {{ jax_icon }}

```sh
dp --jax neighbor-stat -s data -r 6.0 -t O H
```

:::

::::

where `data` is the directory of data, `6.0` is the cutoff radius, and `O` and `H` is the type map. The program will give the `max_nbor_size`. For example, `max_nbor_size` of the water example is `[38, 72]`, meaning an atom may have 38 O neighbors and 72 H neighbors in the training data.
Expand Down
4 changes: 2 additions & 2 deletions doc/model/train-energy.md
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
# Fit energy {{ tensorflow_icon }} {{ pytorch_icon }} {{ dpmodel_icon }}
# Fit energy {{ tensorflow_icon }} {{ pytorch_icon }} {{ jax_icon }} {{ dpmodel_icon }}

:::{note}
**Supported backends**: TensorFlow {{ tensorflow_icon }}, PyTorch {{ pytorch_icon }}, DP {{ dpmodel_icon }}
**Supported backends**: TensorFlow {{ tensorflow_icon }}, PyTorch {{ pytorch_icon }}, JAX {{ jax_icon }}, DP {{ dpmodel_icon }}
:::

In this section, we will take `$deepmd_source_dir/examples/water/se_e2_a/input.json` as an example of the input file.
Expand Down
4 changes: 2 additions & 2 deletions doc/model/train-fitting-dos.md
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
# Fit electronic density of states (DOS) {{ tensorflow_icon }} {{ pytorch_icon }} {{ dpmodel_icon }}
# Fit electronic density of states (DOS) {{ tensorflow_icon }} {{ pytorch_icon }} {{ jax_icon }} {{ dpmodel_icon }}

:::{note}
**Supported backends**: TensorFlow {{ tensorflow_icon }}, PyTorch {{ pytorch_icon }}, DP {{ dpmodel_icon }}
**Supported backends**: TensorFlow {{ tensorflow_icon }}, PyTorch {{ pytorch_icon }}, JAX {{ jax_icon }}, DP {{ dpmodel_icon }}
:::

Here we present an API to DeepDOS model, which can be used to fit electronic density of state (DOS) (which is a vector).
Expand Down
4 changes: 2 additions & 2 deletions doc/model/train-se-atten.md
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
# Descriptor `"se_atten"` {{ tensorflow_icon }} {{ pytorch_icon }} {{ dpmodel_icon }}
# Descriptor `"se_atten"` {{ tensorflow_icon }} {{ pytorch_icon }} {{ jax_icon }} {{ dpmodel_icon }}

:::{note}
**Supported backends**: TensorFlow {{ tensorflow_icon }}, PyTorch {{ pytorch_icon }}, DP {{ dpmodel_icon }}
**Supported backends**: TensorFlow {{ tensorflow_icon }}, PyTorch {{ pytorch_icon }}, JAX {{ jax_icon }}, DP {{ dpmodel_icon }}
:::

## DPA-1: Pretraining of Attention-based Deep Potential Model for Molecular Simulation
Expand Down
4 changes: 2 additions & 2 deletions doc/model/train-se-e2-a.md
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
# Descriptor `"se_e2_a"` {{ tensorflow_icon }} {{ pytorch_icon }} {{ dpmodel_icon }}
# Descriptor `"se_e2_a"` {{ tensorflow_icon }} {{ pytorch_icon }} {{ jax_icon }} {{ dpmodel_icon }}

:::{note}
**Supported backends**: TensorFlow {{ tensorflow_icon }}, PyTorch {{ pytorch_icon }}, DP {{ dpmodel_icon }}
**Supported backends**: TensorFlow {{ tensorflow_icon }}, PyTorch {{ pytorch_icon }}, JAX {{ jax_icon }}, DP {{ dpmodel_icon }}
:::

The notation of `se_e2_a` is short for the Deep Potential Smooth Edition (DeepPot-SE) constructed from all information (both angular and radial) of atomic configurations. The `e2` stands for the embedding with two-atoms information. This descriptor was described in detail in [the DeepPot-SE paper](https://arxiv.org/abs/1805.09003).
Expand Down
4 changes: 2 additions & 2 deletions doc/model/train-se-e2-r.md
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
# Descriptor `"se_e2_r"` {{ tensorflow_icon }} {{ pytorch_icon }} {{ dpmodel_icon }}
# Descriptor `"se_e2_r"` {{ tensorflow_icon }} {{ pytorch_icon }} {{ jax_icon }} {{ dpmodel_icon }}

:::{note}
**Supported backends**: TensorFlow {{ tensorflow_icon }}, PyTorch {{ pytorch_icon }}, DP {{ dpmodel_icon }}
**Supported backends**: TensorFlow {{ tensorflow_icon }}, PyTorch {{ pytorch_icon }}, JAX {{ jax_icon }}, DP {{ dpmodel_icon }}
:::

The notation of `se_e2_r` is short for the Deep Potential Smooth Edition (DeepPot-SE) constructed from the radial information of atomic configurations. The `e2` stands for the embedding with two-atom information.
Expand Down