diff --git a/README.md b/README.md index e821a29768..55f927d62b 100644 --- a/README.md +++ b/README.md @@ -19,7 +19,7 @@ For more information, check the [documentation](https://deepmd.readthedocs.io/). ### Highlighted features -- **interfaced with multiple backends**, including TensorFlow and PyTorch, the most popular deep learning frameworks, making the training process highly automatic and efficient. +- **interfaced with multiple backends**, including TensorFlow, PyTorch, and JAX, the most popular deep learning frameworks, making the training process highly automatic and efficient. - **interfaced with high-performance classical MD and quantum (path-integral) MD packages**, including LAMMPS, i-PI, AMBER, CP2K, GROMACS, OpenMM, and ABUCUS. - **implements the Deep Potential series models**, which have been successfully applied to finite and extended systems, including organic molecules, metals, semiconductors, insulators, etc. - **implements MPI and GPU supports**, making it highly efficient for high-performance parallel and distributed computing. @@ -72,7 +72,7 @@ See [our latest paper](https://doi.org/10.1063/5.0155600) for details of all fea #### v3 -- Multiple backends supported. Add a PyTorch backend. +- Multiple backends supported. Add PyTorch and JAX backends. - The DPA-2 model. ## Install and use DeePMD-kit diff --git a/doc/_static/jax.svg b/doc/_static/jax.svg new file mode 100644 index 0000000000..360a6624d4 --- /dev/null +++ b/doc/_static/jax.svg @@ -0,0 +1 @@ + diff --git a/doc/backend.md b/doc/backend.md index f6eaf0e45b..cf99eea9cb 100644 --- a/doc/backend.md +++ b/doc/backend.md @@ -23,6 +23,15 @@ DeePMD-kit does not use the TensorFlow v2 API but uses the TensorFlow v1 API (`t [PyTorch](https://pytorch.org/) 2.0 or above is required. While `.pth` and `.pt` are the same in the PyTorch package, they have different meanings in the DeePMD-kit to distinguish the model and the checkpoint. +### JAX {{ jax_icon }} + +- Model filename extension: `.xlo` +- Checkpoint filename extension: `.jax` + +[JAX](https://jax.readthedocs.io/) 0.4.33 (which requires Python 3.10 or above) or above is required. +Both `.xlo` and `.jax` are customized format extensions defined in DeePMD-kit, since JAX has no convention for file extensions. +Currently, this backend is developed actively, and has no support for training and the C++ interface. + ### DP {{ dpmodel_icon }} :::{note} diff --git a/doc/conf.py b/doc/conf.py index c72e05bf8a..eca7665712 100644 --- a/doc/conf.py +++ b/doc/conf.py @@ -168,6 +168,7 @@ myst_substitutions = { "tensorflow_icon": """![TensorFlow](/_static/tensorflow.svg){class=platform-icon}""", "pytorch_icon": """![PyTorch](/_static/pytorch.svg){class=platform-icon}""", + "jax_icon": """![JAX](/_static/jax.svg){class=platform-icon}""", "dpmodel_icon": """![DP](/_static/logo_icon.svg){class=platform-icon}""", } diff --git a/doc/env.md b/doc/env.md index 65a50ff163..3cf42b724a 100644 --- a/doc/env.md +++ b/doc/env.md @@ -31,6 +31,7 @@ See [How to control the parallelism of a job](./troubleshooting/howtoset_num_nod - If ROCm is used, [ROCm environment variables](https://rocm.docs.amd.com/en/latest/conceptual/gpu-isolation.html#environment-variables) can be used to control ROCm devices. - {{ tensorflow_icon }} If TensorFlow is used, TensorFlow environment variables can be used. - {{ pytorch_icon }} If PyTorch is used, [PyTorch environment variables](https://pytorch.org/docs/stable/torch_environment_variables.html) can be used. +- {{ jax_icon }} [`JAX_PLATFORMS`](https://jax.readthedocs.io/en/latest/faq.html#controlling-data-and-computation-placement-on-devices) and [`XLA_FLAGS`](https://jax.readthedocs.io/en/latest/gpu_performance_tips.html#xla-performance-flags) are commonly used. ## Python interface only diff --git a/doc/install/easy-install-dev.md b/doc/install/easy-install-dev.md index bb68272ace..54309a8582 100644 --- a/doc/install/easy-install-dev.md +++ b/doc/install/easy-install-dev.md @@ -16,14 +16,12 @@ For CUDA 11.8 support, use the `devel_cu11` tag. ## Install with pip -Below is an one-line shell command to download the [artifact](https://nightly.link/deepmodeling/deepmd-kit/workflows/build_wheel/devel/artifact.zip) containing wheels and install it with `pip`: +Follow [the documentation for the stable version](easy-install.md#install-python-interface-with-pip), but add `--pre` and `--extra-index-url` options like below: ```sh pip install -U --pre deepmd-kit[gpu,cu12,lmp,torch] --extra-index-url https://deepmodeling.github.io/deepmd-kit/simple ``` -`cu12` and `lmp` are optional, which is the same as the stable version. - ## Download pre-compiled C Library {{ tensorflow_icon }} :::{note} diff --git a/doc/install/easy-install.md b/doc/install/easy-install.md index 99962d08b8..c2260b58b6 100644 --- a/doc/install/easy-install.md +++ b/doc/install/easy-install.md @@ -104,44 +104,114 @@ docker pull ghcr.io/deepmodeling/deepmd-kit:2.2.8_cuda12.0_gpu ## Install Python interface with pip -If you have no existing TensorFlow installed, you can use `pip` to install the pre-built package of the Python interface with CUDA 12 supported: +[Create a new environment](https://docs.deepmodeling.com/faq/conda.html#how-to-create-a-new-conda-pip-environment), and then execute the following command: + +:::::::{tab-set} + +::::::{tab-item} TensorFlow {{ tensorflow_icon }} + +:::::{tab-set} + +::::{tab-item} CUDA 12 ```bash -pip install deepmd-kit[gpu,cu12,torch] +pip install deepmd-kit[gpu,cu12] ``` `cu12` is required only when CUDA Toolkit and cuDNN were not installed. -To install the package built against CUDA 11.8, use +:::: + +::::{tab-item} CUDA 11 ```bash -pip install torch --index-url https://download.pytorch.org/whl/cu118 pip install deepmd-kit-cu11[gpu,cu11] ``` -Or install the CPU version without CUDA supported: +:::: + +::::{tab-item} CPU ```bash -pip install torch --index-url https://download.pytorch.org/whl/cpu pip install deepmd-kit[cpu] ``` +:::: + +::::: + [The LAMMPS module](../third-party/lammps-command.md) and [the i-PI driver](../third-party/ipi.md) are only provided on Linux and macOS for the TensorFlow backend. To install LAMMPS and/or i-PI, add `lmp` and/or `ipi` to extras: ```bash -pip install deepmd-kit[gpu,cu12,torch,lmp,ipi] +pip install deepmd-kit[gpu,cu12,lmp,ipi] ``` MPICH is required for parallel running. -:::{Warning} -When installing from pip, only the TensorFlow {{ tensorflow_icon }} backend is supported with LAMMPS and i-PI. -::: +:::::: + +::::::{tab-item} PyTorch {{ pytorch_icon }} + +:::::{tab-set} + +::::{tab-item} CUDA 12 + +```bash +pip install deepmd-kit[torch] +``` + +:::: + +::::{tab-item} CUDA 11.8 + +```bash +pip install torch --index-url https://download.pytorch.org/whl/cu118 +pip install deepmd-kit-cu11 +``` + +:::: + +::::{tab-item} CPU + +```bash +pip install torch --index-url https://download.pytorch.org/whl/cpu +pip install deepmd-kit +``` + +:::: + +::::: + +:::::: + +::::::{tab-item} JAX {{ jax_icon }} + +:::::{tab-set} + +::::{tab-item} CUDA 12 + +```bash +pip install deepmd-kit[jax] jax[cuda12] +``` + +:::: + +::::{tab-item} CPU + +```bash +pip install deepmd-kit[jax] +``` + +:::: + +::::: + +:::::: + +::::::: -It is suggested to install the package into an isolated environment. The supported platform includes Linux x86-64 and aarch64 with GNU C Library 2.28 or above, macOS x86-64 and arm64, and Windows x86-64. -A specific version of TensorFlow and PyTorch which is compatible with DeePMD-kit will be also installed. :::{Warning} -If your platform is not supported, or you want to build against the installed TensorFlow, or you want to enable ROCM support, please [build from source](install-from-source.md). +If your platform is not supported, or you want to build against the installed backends, or you want to enable ROCM support, please [build from source](install-from-source.md). ::: diff --git a/doc/install/install-from-source.md b/doc/install/install-from-source.md index 07239cd3b7..4a0a104b7e 100644 --- a/doc/install/install-from-source.md +++ b/doc/install/install-from-source.md @@ -78,6 +78,21 @@ One can also [use conda](https://docs.deepmodeling.org/faq/conda.html) to instal ::: +:::{tab-item} JAX {{ jax_icon }} + +To install [JAX AI Stack](https://github.com/jax-ml/jax-ai-stack), run + +```sh +pip install jax-ai-stack +``` + +One can also install packages in JAX AI Stack manually. +Follow [JAX documentation](https://jax.readthedocs.io/en/latest/installation.html) to install JAX built against different CUDA versions or without CUDA. + +One can also [use conda](https://docs.deepmodeling.org/faq/conda.html) to install JAX from [conda-forge](https://conda-forge.org). + +::: + :::: It is important that every time a new shell is started and one wants to use `DeePMD-kit`, the virtual environment should be activated by diff --git a/doc/model/sel.md b/doc/model/sel.md index 4908954618..babea1d463 100644 --- a/doc/model/sel.md +++ b/doc/model/sel.md @@ -24,6 +24,14 @@ dp --pt neighbor-stat -s data -r 6.0 -t O H ::: +:::{tab-item} JAX {{ jax_icon }} + +```sh +dp --jax neighbor-stat -s data -r 6.0 -t O H +``` + +::: + :::: where `data` is the directory of data, `6.0` is the cutoff radius, and `O` and `H` is the type map. The program will give the `max_nbor_size`. For example, `max_nbor_size` of the water example is `[38, 72]`, meaning an atom may have 38 O neighbors and 72 H neighbors in the training data. diff --git a/doc/model/train-energy.md b/doc/model/train-energy.md index 75d31d4670..484564b14f 100644 --- a/doc/model/train-energy.md +++ b/doc/model/train-energy.md @@ -1,7 +1,7 @@ -# Fit energy {{ tensorflow_icon }} {{ pytorch_icon }} {{ dpmodel_icon }} +# Fit energy {{ tensorflow_icon }} {{ pytorch_icon }} {{ jax_icon }} {{ dpmodel_icon }} :::{note} -**Supported backends**: TensorFlow {{ tensorflow_icon }}, PyTorch {{ pytorch_icon }}, DP {{ dpmodel_icon }} +**Supported backends**: TensorFlow {{ tensorflow_icon }}, PyTorch {{ pytorch_icon }}, JAX {{ jax_icon }}, DP {{ dpmodel_icon }} ::: In this section, we will take `$deepmd_source_dir/examples/water/se_e2_a/input.json` as an example of the input file. diff --git a/doc/model/train-fitting-dos.md b/doc/model/train-fitting-dos.md index d04dbc669c..fb4a3677e5 100644 --- a/doc/model/train-fitting-dos.md +++ b/doc/model/train-fitting-dos.md @@ -1,7 +1,7 @@ -# Fit electronic density of states (DOS) {{ tensorflow_icon }} {{ pytorch_icon }} {{ dpmodel_icon }} +# Fit electronic density of states (DOS) {{ tensorflow_icon }} {{ pytorch_icon }} {{ jax_icon }} {{ dpmodel_icon }} :::{note} -**Supported backends**: TensorFlow {{ tensorflow_icon }}, PyTorch {{ pytorch_icon }}, DP {{ dpmodel_icon }} +**Supported backends**: TensorFlow {{ tensorflow_icon }}, PyTorch {{ pytorch_icon }}, JAX {{ jax_icon }}, DP {{ dpmodel_icon }} ::: Here we present an API to DeepDOS model, which can be used to fit electronic density of state (DOS) (which is a vector). diff --git a/doc/model/train-se-atten.md b/doc/model/train-se-atten.md index bebce78365..3e88a4e950 100644 --- a/doc/model/train-se-atten.md +++ b/doc/model/train-se-atten.md @@ -1,7 +1,7 @@ -# Descriptor `"se_atten"` {{ tensorflow_icon }} {{ pytorch_icon }} {{ dpmodel_icon }} +# Descriptor `"se_atten"` {{ tensorflow_icon }} {{ pytorch_icon }} {{ jax_icon }} {{ dpmodel_icon }} :::{note} -**Supported backends**: TensorFlow {{ tensorflow_icon }}, PyTorch {{ pytorch_icon }}, DP {{ dpmodel_icon }} +**Supported backends**: TensorFlow {{ tensorflow_icon }}, PyTorch {{ pytorch_icon }}, JAX {{ jax_icon }}, DP {{ dpmodel_icon }} ::: ## DPA-1: Pretraining of Attention-based Deep Potential Model for Molecular Simulation diff --git a/doc/model/train-se-e2-a.md b/doc/model/train-se-e2-a.md index 81b95399e0..d4a4510a31 100644 --- a/doc/model/train-se-e2-a.md +++ b/doc/model/train-se-e2-a.md @@ -1,7 +1,7 @@ -# Descriptor `"se_e2_a"` {{ tensorflow_icon }} {{ pytorch_icon }} {{ dpmodel_icon }} +# Descriptor `"se_e2_a"` {{ tensorflow_icon }} {{ pytorch_icon }} {{ jax_icon }} {{ dpmodel_icon }} :::{note} -**Supported backends**: TensorFlow {{ tensorflow_icon }}, PyTorch {{ pytorch_icon }}, DP {{ dpmodel_icon }} +**Supported backends**: TensorFlow {{ tensorflow_icon }}, PyTorch {{ pytorch_icon }}, JAX {{ jax_icon }}, DP {{ dpmodel_icon }} ::: The notation of `se_e2_a` is short for the Deep Potential Smooth Edition (DeepPot-SE) constructed from all information (both angular and radial) of atomic configurations. The `e2` stands for the embedding with two-atoms information. This descriptor was described in detail in [the DeepPot-SE paper](https://arxiv.org/abs/1805.09003). diff --git a/doc/model/train-se-e2-r.md b/doc/model/train-se-e2-r.md index 316bde43b4..baff6d6331 100644 --- a/doc/model/train-se-e2-r.md +++ b/doc/model/train-se-e2-r.md @@ -1,7 +1,7 @@ -# Descriptor `"se_e2_r"` {{ tensorflow_icon }} {{ pytorch_icon }} {{ dpmodel_icon }} +# Descriptor `"se_e2_r"` {{ tensorflow_icon }} {{ pytorch_icon }} {{ jax_icon }} {{ dpmodel_icon }} :::{note} -**Supported backends**: TensorFlow {{ tensorflow_icon }}, PyTorch {{ pytorch_icon }}, DP {{ dpmodel_icon }} +**Supported backends**: TensorFlow {{ tensorflow_icon }}, PyTorch {{ pytorch_icon }}, JAX {{ jax_icon }}, DP {{ dpmodel_icon }} ::: The notation of `se_e2_r` is short for the Deep Potential Smooth Edition (DeepPot-SE) constructed from the radial information of atomic configurations. The `e2` stands for the embedding with two-atom information.