Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

docs: document JAX backend #4259

Merged
merged 4 commits into from
Oct 29, 2024
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions doc/_static/jax.svg
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
9 changes: 9 additions & 0 deletions doc/backend.md
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,15 @@ DeePMD-kit does not use the TensorFlow v2 API but uses the TensorFlow v1 API (`t
[PyTorch](https://pytorch.org/) 2.0 or above is required.
While `.pth` and `.pt` are the same in the PyTorch package, they have different meanings in the DeePMD-kit to distinguish the model and the checkpoint.

### JAX {{ jax_icon }}

- Model filename extension: `.xlo`
- Checkpoint filename extension: `.jax`

[JAX](https://jax.readthedocs.io/) 0.4.33 (which requires Python 3.10 or above) or above is required.
Both `.xlo` and `.jax` are customized format extensions defined in DeePMD-kit, since JAX has no convention for file extensions.
Currently, this backend is developed actively, and has no support for training and the C++ interface.

### DP {{ dpmodel_icon }}

:::{note}
Expand Down
1 change: 1 addition & 0 deletions doc/conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -167,6 +167,7 @@
myst_substitutions = {
"tensorflow_icon": """![TensorFlow](/_static/tensorflow.svg){class=platform-icon}""",
"pytorch_icon": """![PyTorch](/_static/pytorch.svg){class=platform-icon}""",
"jax_icon": """![JAX](/_static/jax.svg){class=platform-icon}""",
"dpmodel_icon": """![DP](/_static/logo_icon.svg){class=platform-icon}""",
}

Expand Down
1 change: 1 addition & 0 deletions doc/env.md
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ See [How to control the parallelism of a job](./troubleshooting/howtoset_num_nod
- If ROCm is used, [ROCm environment variables](https://rocm.docs.amd.com/en/latest/conceptual/gpu-isolation.html#environment-variables) can be used to control ROCm devices.
- {{ tensorflow_icon }} If TensorFlow is used, TensorFlow environment variables can be used.
- {{ pytorch_icon }} If PyTorch is used, [PyTorch environment variables](https://pytorch.org/docs/stable/torch_environment_variables.html) can be used.
- {{ jax_icon }} [`JAX_PLATFORMS`](https://jax.readthedocs.io/en/latest/faq.html#controlling-data-and-computation-placement-on-devices) and [`XLA_FLAGS`](https://jax.readthedocs.io/en/latest/gpu_performance_tips.html#xla-performance-flags) are commonly used.

## Python interface only

Expand Down
15 changes: 15 additions & 0 deletions doc/install/install-from-source.md
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,21 @@ One can also [use conda](https://docs.deepmodeling.org/faq/conda.html) to instal

:::

:::{tab-item} JAX {{ jax_icon }}

To install [JAX AI Stack](https://github.com/jax-ml/jax-ai-stack), run

```sh
pip install jax-ai-stack
```

One can also install packages in JAX AI Stack manually.
Follow [JAX documentation](https://jax.readthedocs.io/en/latest/installation.html) to install JAX built against different CUDA versions or without CUDA.

One can also [use conda](https://docs.deepmodeling.org/faq/conda.html) to install JAX from [conda-forge](https://conda-forge.org).

:::

::::

It is important that every time a new shell is started and one wants to use `DeePMD-kit`, the virtual environment should be activated by
Expand Down
8 changes: 8 additions & 0 deletions doc/model/sel.md
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,14 @@ dp --pt neighbor-stat -s data -r 6.0 -t O H

:::

:::{tab-item} JAX {{ jax_icon }}

```sh
dp --jax neighbor-stat -s data -r 6.0 -t O H
```

:::

::::

where `data` is the directory of data, `6.0` is the cutoff radius, and `O` and `H` is the type map. The program will give the `max_nbor_size`. For example, `max_nbor_size` of the water example is `[38, 72]`, meaning an atom may have 38 O neighbors and 72 H neighbors in the training data.
Expand Down
4 changes: 2 additions & 2 deletions doc/model/train-energy.md
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
# Fit energy {{ tensorflow_icon }} {{ pytorch_icon }} {{ dpmodel_icon }}
# Fit energy {{ tensorflow_icon }} {{ pytorch_icon }} {{ jax_icon }} {{ dpmodel_icon }}

:::{note}
**Supported backends**: TensorFlow {{ tensorflow_icon }}, PyTorch {{ pytorch_icon }}, DP {{ dpmodel_icon }}
**Supported backends**: TensorFlow {{ tensorflow_icon }}, PyTorch {{ pytorch_icon }}, JAX {{ jax_icon }}, DP {{ dpmodel_icon }}
:::

In this section, we will take `$deepmd_source_dir/examples/water/se_e2_a/input.json` as an example of the input file.
Expand Down
4 changes: 2 additions & 2 deletions doc/model/train-fitting-dos.md
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
# Fit electronic density of states (DOS) {{ tensorflow_icon }} {{ pytorch_icon }} {{ dpmodel_icon }}
# Fit electronic density of states (DOS) {{ tensorflow_icon }} {{ pytorch_icon }} {{ jax_icon }} {{ dpmodel_icon }}

:::{note}
**Supported backends**: TensorFlow {{ tensorflow_icon }}, PyTorch {{ pytorch_icon }}, DP {{ dpmodel_icon }}
**Supported backends**: TensorFlow {{ tensorflow_icon }}, PyTorch {{ pytorch_icon }}, JAX {{ jax_icon }}, DP {{ dpmodel_icon }}
:::

Here we present an API to DeepDOS model, which can be used to fit electronic density of state (DOS) (which is a vector).
Expand Down
4 changes: 2 additions & 2 deletions doc/model/train-se-atten.md
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
# Descriptor `"se_atten"` {{ tensorflow_icon }} {{ pytorch_icon }} {{ dpmodel_icon }}
# Descriptor `"se_atten"` {{ tensorflow_icon }} {{ pytorch_icon }} {{ jax_icon }} {{ dpmodel_icon }}

:::{note}
**Supported backends**: TensorFlow {{ tensorflow_icon }}, PyTorch {{ pytorch_icon }}, DP {{ dpmodel_icon }}
**Supported backends**: TensorFlow {{ tensorflow_icon }}, PyTorch {{ pytorch_icon }}, JAX {{ jax_icon }}, DP {{ dpmodel_icon }}
:::

## DPA-1: Pretraining of Attention-based Deep Potential Model for Molecular Simulation
Expand Down
4 changes: 2 additions & 2 deletions doc/model/train-se-e2-a.md
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
# Descriptor `"se_e2_a"` {{ tensorflow_icon }} {{ pytorch_icon }} {{ dpmodel_icon }}
# Descriptor `"se_e2_a"` {{ tensorflow_icon }} {{ pytorch_icon }} {{ jax_icon }} {{ dpmodel_icon }}

:::{note}
**Supported backends**: TensorFlow {{ tensorflow_icon }}, PyTorch {{ pytorch_icon }}, DP {{ dpmodel_icon }}
**Supported backends**: TensorFlow {{ tensorflow_icon }}, PyTorch {{ pytorch_icon }}, JAX {{ jax_icon }}, DP {{ dpmodel_icon }}
:::

The notation of `se_e2_a` is short for the Deep Potential Smooth Edition (DeepPot-SE) constructed from all information (both angular and radial) of atomic configurations. The `e2` stands for the embedding with two-atoms information. This descriptor was described in detail in [the DeepPot-SE paper](https://arxiv.org/abs/1805.09003).
Expand Down
4 changes: 2 additions & 2 deletions doc/model/train-se-e2-r.md
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
# Descriptor `"se_e2_r"` {{ tensorflow_icon }} {{ pytorch_icon }} {{ dpmodel_icon }}
# Descriptor `"se_e2_r"` {{ tensorflow_icon }} {{ pytorch_icon }} {{ jax_icon }} {{ dpmodel_icon }}

:::{note}
**Supported backends**: TensorFlow {{ tensorflow_icon }}, PyTorch {{ pytorch_icon }}, DP {{ dpmodel_icon }}
**Supported backends**: TensorFlow {{ tensorflow_icon }}, PyTorch {{ pytorch_icon }}, JAX {{ jax_icon }}, DP {{ dpmodel_icon }}
:::

The notation of `se_e2_r` is short for the Deep Potential Smooth Edition (DeepPot-SE) constructed from the radial information of atomic configurations. The `e2` stands for the embedding with two-atom information.
Expand Down
Loading