Skip to content

Commit

Permalink
docs: document JAX backend
Browse files Browse the repository at this point in the history
Signed-off-by: Jinzhe Zeng <[email protected]>
  • Loading branch information
njzjz committed Oct 26, 2024
1 parent fa61d69 commit 6236cb6
Show file tree
Hide file tree
Showing 11 changed files with 45 additions and 10 deletions.
1 change: 1 addition & 0 deletions doc/_static/jax.svg
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
9 changes: 9 additions & 0 deletions doc/backend.md
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,15 @@ DeePMD-kit does not use the TensorFlow v2 API but uses the TensorFlow v1 API (`t
[PyTorch](https://pytorch.org/) 2.0 or above is required.
While `.pth` and `.pt` are the same in the PyTorch package, they have different meanings in the DeePMD-kit to distinguish the model and the checkpoint.

### JAX {{ jax_icon }}

- Model filename extension: `.xlo`
- Checkpoint filename extension: `.jax`

[JAX](https://jax.readthedocs.io/) 0.4.33 (which requires Python 3.10 or above) or above is required.
Both `.xlo` and `.jax` are customized format extensions defined in DeePMD-kit, since JAX has no convention for file extensions.
Currently, this backend is developed actively, and has no support for training and the C++ interface.

### DP {{ dpmodel_icon }}

:::{note}
Expand Down
1 change: 1 addition & 0 deletions doc/conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -167,6 +167,7 @@
myst_substitutions = {
"tensorflow_icon": """![TensorFlow](/_static/tensorflow.svg){class=platform-icon}""",
"pytorch_icon": """![PyTorch](/_static/pytorch.svg){class=platform-icon}""",
"jax_icon": """![JAX](/_static/jax.svg){class=platform-icon}""",
"dpmodel_icon": """![DP](/_static/logo_icon.svg){class=platform-icon}""",
}

Expand Down
1 change: 1 addition & 0 deletions doc/env.md
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ See [How to control the parallelism of a job](./troubleshooting/howtoset_num_nod
- If ROCm is used, [ROCm environment variables](https://rocm.docs.amd.com/en/latest/conceptual/gpu-isolation.html#environment-variables) can be used to control ROCm devices.
- {{ tensorflow_icon }} If TensorFlow is used, TensorFlow environment variables can be used.
- {{ pytorch_icon }} If PyTorch is used, [PyTorch environment variables](https://pytorch.org/docs/stable/torch_environment_variables.html) can be used.
- {{ jax_icon }} [`JAX_PLATFORMS`](https://jax.readthedocs.io/en/latest/faq.html#controlling-data-and-computation-placement-on-devices) and [`XLA_FLAGS`](https://jax.readthedocs.io/en/latest/gpu_performance_tips.html#xla-performance-flags) are commonly used.

## Python interface only

Expand Down
15 changes: 15 additions & 0 deletions doc/install/install-from-source.md
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,21 @@ One can also [use conda](https://docs.deepmodeling.org/faq/conda.html) to instal

:::

:::{tab-item} JAX {{ jax_icon }}

To install [JAX AI Stack](https://github.com/jax-ml/jax-ai-stack), run

```sh
pip install jax-ai-stack
```

One can also install packages in JAX AI Stack manually.
Follow [JAX documentation](https://jax.readthedocs.io/en/latest/installation.html) to install JAX built against different CUDA versions or without CUDA.

One can also [use conda](https://docs.deepmodeling.org/faq/conda.html) to install JAX from [conda-forge](https://conda-forge.org).

:::

::::

It is important that every time a new shell is started and one wants to use `DeePMD-kit`, the virtual environment should be activated by
Expand Down
8 changes: 8 additions & 0 deletions doc/model/sel.md
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,14 @@ dp --pt neighbor-stat -s data -r 6.0 -t O H

:::

:::{tab-item} JAX {{ jax_icon }}

```sh
dp --jax neighbor-stat -s data -r 6.0 -t O H
```

:::

::::

where `data` is the directory of data, `6.0` is the cutoff radius, and `O` and `H` is the type map. The program will give the `max_nbor_size`. For example, `max_nbor_size` of the water example is `[38, 72]`, meaning an atom may have 38 O neighbors and 72 H neighbors in the training data.
Expand Down
4 changes: 2 additions & 2 deletions doc/model/train-energy.md
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
# Fit energy {{ tensorflow_icon }} {{ pytorch_icon }} {{ dpmodel_icon }}
# Fit energy {{ tensorflow_icon }} {{ pytorch_icon }} {{ jax_icon }} {{ dpmodel_icon }}

:::{note}
**Supported backends**: TensorFlow {{ tensorflow_icon }}, PyTorch {{ pytorch_icon }}, DP {{ dpmodel_icon }}
**Supported backends**: TensorFlow {{ tensorflow_icon }}, PyTorch {{ pytorch_icon }}, JAX {{ jax_icon }}, DP {{ dpmodel_icon }}
:::

In this section, we will take `$deepmd_source_dir/examples/water/se_e2_a/input.json` as an example of the input file.
Expand Down
4 changes: 2 additions & 2 deletions doc/model/train-fitting-dos.md
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
# Fit electronic density of states (DOS) {{ tensorflow_icon }} {{ pytorch_icon }} {{ dpmodel_icon }}
# Fit electronic density of states (DOS) {{ tensorflow_icon }} {{ pytorch_icon }} {{ jax_icon }} {{ dpmodel_icon }}

:::{note}
**Supported backends**: TensorFlow {{ tensorflow_icon }}, PyTorch {{ pytorch_icon }}, DP {{ dpmodel_icon }}
**Supported backends**: TensorFlow {{ tensorflow_icon }}, PyTorch {{ pytorch_icon }}, JAX {{ jax_icon }}, DP {{ dpmodel_icon }}
:::

Here we present an API to DeepDOS model, which can be used to fit electronic density of state (DOS) (which is a vector).
Expand Down
4 changes: 2 additions & 2 deletions doc/model/train-se-atten.md
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
# Descriptor `"se_atten"` {{ tensorflow_icon }} {{ pytorch_icon }} {{ dpmodel_icon }}
# Descriptor `"se_atten"` {{ tensorflow_icon }} {{ pytorch_icon }} {{ jax_icon }} {{ dpmodel_icon }}

:::{note}
**Supported backends**: TensorFlow {{ tensorflow_icon }}, PyTorch {{ pytorch_icon }}, DP {{ dpmodel_icon }}
**Supported backends**: TensorFlow {{ tensorflow_icon }}, PyTorch {{ pytorch_icon }}, JAX {{ jax_icon }}, DP {{ dpmodel_icon }}
:::

## DPA-1: Pretraining of Attention-based Deep Potential Model for Molecular Simulation
Expand Down
4 changes: 2 additions & 2 deletions doc/model/train-se-e2-a.md
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
# Descriptor `"se_e2_a"` {{ tensorflow_icon }} {{ pytorch_icon }} {{ dpmodel_icon }}
# Descriptor `"se_e2_a"` {{ tensorflow_icon }} {{ pytorch_icon }} {{ jax_icon }} {{ dpmodel_icon }}

:::{note}
**Supported backends**: TensorFlow {{ tensorflow_icon }}, PyTorch {{ pytorch_icon }}, DP {{ dpmodel_icon }}
**Supported backends**: TensorFlow {{ tensorflow_icon }}, PyTorch {{ pytorch_icon }}, JAX {{ jax_icon }}, DP {{ dpmodel_icon }}
:::

The notation of `se_e2_a` is short for the Deep Potential Smooth Edition (DeepPot-SE) constructed from all information (both angular and radial) of atomic configurations. The `e2` stands for the embedding with two-atoms information. This descriptor was described in detail in [the DeepPot-SE paper](https://arxiv.org/abs/1805.09003).
Expand Down
4 changes: 2 additions & 2 deletions doc/model/train-se-e2-r.md
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
# Descriptor `"se_e2_r"` {{ tensorflow_icon }} {{ pytorch_icon }} {{ dpmodel_icon }}
# Descriptor `"se_e2_r"` {{ tensorflow_icon }} {{ pytorch_icon }} {{ jax_icon }} {{ dpmodel_icon }}

:::{note}
**Supported backends**: TensorFlow {{ tensorflow_icon }}, PyTorch {{ pytorch_icon }}, DP {{ dpmodel_icon }}
**Supported backends**: TensorFlow {{ tensorflow_icon }}, PyTorch {{ pytorch_icon }}, JAX {{ jax_icon }}, DP {{ dpmodel_icon }}
:::

The notation of `se_e2_r` is short for the Deep Potential Smooth Edition (DeepPot-SE) constructed from the radial information of atomic configurations. The `e2` stands for the embedding with two-atom information.
Expand Down

0 comments on commit 6236cb6

Please sign in to comment.