Skip to content

Commit

Permalink
Merge branch 'devel' into enerhess
Browse files Browse the repository at this point in the history
  • Loading branch information
1azyking authored Nov 13, 2024
2 parents c472c9c + 698b08d commit 23107fa
Show file tree
Hide file tree
Showing 38 changed files with 12,927 additions and 56 deletions.
14 changes: 7 additions & 7 deletions .github/workflows/test_cc.yml
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,13 @@ jobs:
mpi: mpich
- uses: lukka/get-cmake@latest
- run: python -m pip install uv
- run: source/install/uv_with_retry.sh pip install --system tensorflow
- name: Install Python dependencies
run: |
source/install/uv_with_retry.sh pip install --system tensorflow-cpu
export TENSORFLOW_ROOT=$(python -c 'import importlib,pathlib;print(pathlib.Path(importlib.util.find_spec("tensorflow").origin).parent)')
source/install/uv_with_retry.sh pip install --system -e .[cpu,test,lmp,jax] mpi4py
- name: Convert models
run: source/tests/infer/convert-models.sh
- name: Download libtorch
run: |
wget https://download.pytorch.org/libtorch/cpu/libtorch-cxx11-abi-shared-with-deps-2.1.2%2Bcpu.zip -O libtorch.zip
Expand All @@ -47,12 +53,6 @@ jobs:
CMAKE_GENERATOR: Ninja
CXXFLAGS: ${{ matrix.check_memleak && '-fsanitize=leak' || '' }}
# test lammps
- run: |
export TENSORFLOW_ROOT=$(python -c 'import importlib,pathlib;print(pathlib.Path(importlib.util.find_spec("tensorflow").origin).parent)')
source/install/uv_with_retry.sh pip install --system -e .[cpu,test,lmp] mpi4py
env:
DP_BUILD_TESTING: 1
if: ${{ !matrix.check_memleak }}
- run: pytest --cov=deepmd source/lmp/tests
env:
OMP_NUM_THREADS: 1
Expand Down
7 changes: 5 additions & 2 deletions .github/workflows/test_cuda.yml
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ jobs:
runs-on: nvidia
# https://github.com/deepmodeling/deepmd-kit/pull/2884#issuecomment-1744216845
container:
image: nvidia/cuda:12.3.1-devel-ubuntu22.04
image: nvidia/cuda:12.6.2-cudnn-devel-ubuntu22.04
options: --gpus all
if: github.repository_owner == 'deepmodeling' && (github.event_name == 'pull_request' && github.event.label && github.event.label.name == 'Test CUDA' || github.event_name == 'workflow_dispatch' || github.event_name == 'merge_group')
steps:
Expand Down Expand Up @@ -63,12 +63,15 @@ jobs:
CUDA_VISIBLE_DEVICES: 0
# See https://jax.readthedocs.io/en/latest/gpu_memory_allocation.html
XLA_PYTHON_CLIENT_PREALLOCATE: false
- name: Convert models
run: source/tests/infer/convert-models.sh
- name: Download libtorch
run: |
wget https://download.pytorch.org/libtorch/cu124/libtorch-cxx11-abi-shared-with-deps-2.5.0%2Bcu124.zip -O libtorch.zip
unzip libtorch.zip
- run: |
export CMAKE_PREFIX_PATH=$GITHUB_WORKSPACE/libtorch
export LD_LIBRARY_PATH=$CUDA_PATH/lib64:/usr/lib/x86_64-linux-gnu/:$LD_LIBRARY_PATH
source/install/test_cc_local.sh
env:
OMP_NUM_THREADS: 1
Expand All @@ -79,7 +82,7 @@ jobs:
DP_VARIANT: cuda
DP_USE_MPICH2: 1
- run: |
export LD_LIBRARY_PATH=$GITHUB_WORKSPACE/dp_test/lib:$GITHUB_WORKSPACE/libtorch/lib:$CUDA_PATH/lib64:$LD_LIBRARY_PATH
export LD_LIBRARY_PATH=$CUDA_PATH/lib64:/usr/lib/x86_64-linux-gnu/:$GITHUB_WORKSPACE/dp_test/lib:$GITHUB_WORKSPACE/libtorch/lib:$LD_LIBRARY_PATH
export PATH=$GITHUB_WORKSPACE/dp_test/bin:$PATH
python -m pytest -s source/lmp/tests || (cat log.lammps && exit 1)
python -m pytest source/ipi/tests
Expand Down
2 changes: 1 addition & 1 deletion .github/workflows/test_python.yml
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ jobs:
restore-keys: |
test2-durations-combined-${{ matrix.python }}-${{ github.sha }}
test2-durations-combined-${{ matrix.python }}
- run: pytest --cov=deepmd source/tests --durations=0 --splits 6 --group ${{ matrix.group }} --store-durations --durations-path=.test_durations --splitting-algorithm least_duration
- run: pytest --cov=deepmd source/tests --durations=0 --splits 6 --group ${{ matrix.group }} --store-durations --clean-durations --durations-path=.test_durations --splitting-algorithm least_duration
env:
NUM_WORKERS: 0
- name: Test TF2 eager mode
Expand Down
2 changes: 1 addition & 1 deletion .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ repos:
exclude: ^source/3rdparty
- repo: https://github.com/astral-sh/ruff-pre-commit
# Ruff version.
rev: v0.7.2
rev: v0.7.3
hooks:
- id: ruff
args: ["--fix"]
Expand Down
3 changes: 2 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ For more information, check the [documentation](https://deepmd.readthedocs.io/).
### Highlighted features

- **interfaced with multiple backends**, including TensorFlow, PyTorch, and JAX, the most popular deep learning frameworks, making the training process highly automatic and efficient.
- **interfaced with high-performance classical MD and quantum (path-integral) MD packages**, including LAMMPS, i-PI, AMBER, CP2K, GROMACS, OpenMM, and ABUCUS.
- **interfaced with high-performance classical MD and quantum (path-integral) MD packages**, including LAMMPS, i-PI, AMBER, CP2K, GROMACS, OpenMM, and ABACUS.
- **implements the Deep Potential series models**, which have been successfully applied to finite and extended systems, including organic molecules, metals, semiconductors, insulators, etc.
- **implements MPI and GPU supports**, making it highly efficient for high-performance parallel and distributed computing.
- **highly modularized**, easy to adapt to different descriptors for deep learning-based potential energy models.
Expand Down Expand Up @@ -74,6 +74,7 @@ See [our latest paper](https://doi.org/10.1063/5.0155600) for details of all fea

- Multiple backends supported. Add PyTorch and JAX backends.
- The DPA-2 model.
- Plugin mechanisms for external models.

## Install and use DeePMD-kit

Expand Down
4 changes: 2 additions & 2 deletions deepmd/dpmodel/model/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,8 +80,8 @@ def get_zbl_model(data: dict) -> DPZBLModel:
filepath = data["use_srtab"]
pt_model = PairTabAtomicModel(
filepath,
data["descriptor"]["rcut"],
data["descriptor"]["sel"],
descriptor.get_rcut(),
descriptor.get_sel(),
type_map=data["type_map"],
)

Expand Down
71 changes: 71 additions & 0 deletions deepmd/jax/jax2tf/format_nlist.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,71 @@
# SPDX-License-Identifier: LGPL-3.0-or-later
import tensorflow as tf
import tensorflow.experimental.numpy as tnp


@tf.function(autograph=True)
def format_nlist(
extended_coord: tnp.ndarray,
nlist: tnp.ndarray,
nsel: int,
rcut: float,
):
"""Format neighbor list.
If nnei == nsel, do nothing;
If nnei < nsel, pad -1;
If nnei > nsel, sort by distance and truncate.
Parameters
----------
extended_coord
The extended coordinates of the atoms.
shape: nf x nall x 3
nlist
The neighbor list.
shape: nf x nloc x nnei
nsel
The number of selected neighbors.
rcut
The cutoff radius.
Returns
-------
nlist
The formatted neighbor list.
shape: nf x nloc x nsel
"""
nlist_shape = tf.shape(nlist)
n_nf, n_nloc, n_nsel = nlist_shape[0], nlist_shape[1], nlist_shape[2]
extended_coord = extended_coord.reshape([n_nf, -1, 3])

if n_nsel < nsel:
# make a copy before revise
ret = tnp.concatenate(
[
nlist,
tnp.full([n_nf, n_nloc, nsel - n_nsel], -1, dtype=nlist.dtype),
],
axis=-1,
)

elif n_nsel > nsel:
# make a copy before revise
m_real_nei = nlist >= 0
ret = tnp.where(m_real_nei, nlist, 0)
coord0 = extended_coord[:, :n_nloc, :]
index = ret.reshape(n_nf, n_nloc * n_nsel, 1)
index = tnp.repeat(index, 3, axis=2)
coord1 = tnp.take_along_axis(extended_coord, index, axis=1)
coord1 = coord1.reshape(n_nf, n_nloc, n_nsel, 3)
rr2 = tnp.sum(tnp.square(coord0[:, :, None, :] - coord1), axis=-1)
rr2 = tnp.where(m_real_nei, rr2, float("inf"))
rr2, ret_mapping = tnp.sort(rr2, axis=-1), tnp.argsort(rr2, axis=-1)
ret = tnp.take_along_axis(ret, ret_mapping, axis=2)
ret = tnp.where(rr2 > rcut * rcut, -1, ret)
ret = ret[..., :nsel]
else: # n_nsel == nsel:
ret = nlist
# do a reshape any way; this will tell the xla the shape without any dynamic shape
ret = tnp.reshape(ret, [n_nf, n_nloc, nsel])
return ret
9 changes: 7 additions & 2 deletions deepmd/jax/jax2tf/serialization.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,9 @@
jax2tf,
)

from deepmd.jax.jax2tf.format_nlist import (
format_nlist,
)
from deepmd.jax.jax2tf.make_model import (
model_call_from_call_lower,
)
Expand Down Expand Up @@ -76,7 +79,7 @@ def call_lower_with_fixed_do_atomic_virial(
input_signature=[
tf.TensorSpec([None, None, 3], tf.float64),
tf.TensorSpec([None, None], tf.int32),
tf.TensorSpec([None, None, model.get_nnei()], tf.int64),
tf.TensorSpec([None, None, None], tf.int64),
tf.TensorSpec([None, None], tf.int64),
tf.TensorSpec([None, model.get_dim_fparam()], tf.float64),
tf.TensorSpec([None, None, model.get_dim_aparam()], tf.float64),
Expand All @@ -85,6 +88,7 @@ def call_lower_with_fixed_do_atomic_virial(
def call_lower_without_atomic_virial(
coord, atype, nlist, mapping, fparam, aparam
):
nlist = format_nlist(coord, nlist, model.get_nnei(), model.get_rcut())
return tf.cond(
tf.shape(coord)[1] == tf.shape(nlist)[1],
lambda: exported_whether_do_atomic_virial(
Expand All @@ -102,13 +106,14 @@ def call_lower_without_atomic_virial(
input_signature=[
tf.TensorSpec([None, None, 3], tf.float64),
tf.TensorSpec([None, None], tf.int32),
tf.TensorSpec([None, None, model.get_nnei()], tf.int64),
tf.TensorSpec([None, None, None], tf.int64),
tf.TensorSpec([None, None], tf.int64),
tf.TensorSpec([None, model.get_dim_fparam()], tf.float64),
tf.TensorSpec([None, None, model.get_dim_aparam()], tf.float64),
],
)
def call_lower_with_atomic_virial(coord, atype, nlist, mapping, fparam, aparam):
nlist = format_nlist(coord, nlist, model.get_nnei(), model.get_rcut())
return tf.cond(
tf.shape(coord)[1] == tf.shape(nlist)[1],
lambda: exported_whether_do_atomic_virial(
Expand Down
10 changes: 5 additions & 5 deletions deepmd/pt/model/descriptor/se_a.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,11 +73,11 @@
if not hasattr(torch.ops.deepmd, "tabulate_fusion_se_a"):

def tabulate_fusion_se_a(
argument0,
argument1,
argument2,
argument3,
argument4,
argument0: torch.Tensor,
argument1: torch.Tensor,
argument2: torch.Tensor,
argument3: torch.Tensor,
argument4: int,
) -> list[torch.Tensor]:
raise NotImplementedError(
"tabulate_fusion_se_a is not available since customized PyTorch OP library is not built when freezing the model. "
Expand Down
14 changes: 7 additions & 7 deletions deepmd/pt/model/descriptor/se_atten.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,13 +52,13 @@
if not hasattr(torch.ops.deepmd, "tabulate_fusion_se_atten"):

def tabulate_fusion_se_atten(
argument0,
argument1,
argument2,
argument3,
argument4,
argument5,
argument6,
argument0: torch.Tensor,
argument1: torch.Tensor,
argument2: torch.Tensor,
argument3: torch.Tensor,
argument4: torch.Tensor,
argument5: int,
argument6: bool,
) -> list[torch.Tensor]:
raise NotImplementedError(
"tabulate_fusion_se_atten is not available since customized PyTorch OP library is not built when freezing the model. "
Expand Down
8 changes: 4 additions & 4 deletions deepmd/pt/model/descriptor/se_r.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,10 +62,10 @@
if not hasattr(torch.ops.deepmd, "tabulate_fusion_se_r"):

def tabulate_fusion_se_r(
argument0,
argument1,
argument2,
argument3,
argument0: torch.Tensor,
argument1: torch.Tensor,
argument2: torch.Tensor,
argument3: int,
) -> list[torch.Tensor]:
raise NotImplementedError(
"tabulate_fusion_se_r is not available since customized PyTorch OP library is not built when freezing the model. "
Expand Down
10 changes: 5 additions & 5 deletions deepmd/pt/model/descriptor/se_t.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,11 +73,11 @@
if not hasattr(torch.ops.deepmd, "tabulate_fusion_se_t"):

def tabulate_fusion_se_t(
argument0,
argument1,
argument2,
argument3,
argument4,
argument0: torch.Tensor,
argument1: torch.Tensor,
argument2: torch.Tensor,
argument3: torch.Tensor,
argument4: int,
) -> list[torch.Tensor]:
raise NotImplementedError(
"tabulate_fusion_se_t is not available since customized PyTorch OP library is not built when freezing the model. "
Expand Down
4 changes: 3 additions & 1 deletion doc/backend.md
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,9 @@ While `.pth` and `.pt` are the same in the PyTorch package, they have different
[JAX](https://jax.readthedocs.io/) 0.4.33 (which requires Python 3.10 or above) or above is required.
Both `.xlo` and `.jax` are customized format extensions defined in DeePMD-kit, since JAX has no convention for file extensions.
`.savedmodel` is the TensorFlow [SavedModel format](https://www.tensorflow.org/guide/saved_model) generated by [JAX2TF](https://www.tensorflow.org/guide/jax2tf), which needs the installation of TensorFlow.
Currently, this backend is developed actively, and has no support for training and the C++ interface.
Only the `.savedmodel` format supports C++ inference, which needs the TensorFlow C++ interface.
The model is device-specific, so that the model generated on the GPU device cannot be run on the CPUs.
Currently, this backend is developed actively, and has no support for training.

### DP {{ dpmodel_icon }}

Expand Down
1 change: 1 addition & 0 deletions doc/development/create-a-model-pt.md
Original file line number Diff line number Diff line change
Expand Up @@ -180,6 +180,7 @@ The arguments here should be consistent with the class arguments of your new com
## Package new codes

You may package new codes into a new Python package if you don't want to contribute it to the main DeePMD-kit repository.
A good example is [DeePMD-GNN](https://github.com/njzjz/deepmd-gnn).
It's crucial to add your new component to `project.entry-points."deepmd.pt"` in `pyproject.toml`:

```toml
Expand Down
6 changes: 6 additions & 0 deletions doc/freeze/compress.md
Original file line number Diff line number Diff line change
Expand Up @@ -124,3 +124,9 @@ Notice: Model compression for the `se_atten_v2` descriptor is exclusively design
- relu6
- softplus
- sigmoid

## Requirements of installation {{ pytorch_icon }}

When compressing models in the PyTorch backend, the customized OP library for the Python interface must be installed when [freezing the model](../freeze/freeze.md).

The customized OP library for the Python interface can be installed by setting environment variable {envvar}`DP_ENABLE_PYTORCH` to `1` during [installation](../install/install-from-source.md).
10 changes: 6 additions & 4 deletions doc/install/install-from-source.md
Original file line number Diff line number Diff line change
Expand Up @@ -297,7 +297,9 @@ If one does not need to use DeePMD-kit with LAMMPS or i-PI, then the python inte

::::{tab-set}

:::{tab-item} TensorFlow {{ tensorflow_icon }}
:::{tab-item} TensorFlow {{ tensorflow_icon }} / JAX {{ jax_icon }}

The C++ interfaces of both TensorFlow and JAX backends are based on the TensorFlow C++ library.

Since TensorFlow 2.12, TensorFlow C++ library (`libtensorflow_cc`) is packaged inside the Python library. Thus, you can skip building TensorFlow C++ library manually. If that does not work for you, you can still build it manually.

Expand Down Expand Up @@ -338,7 +340,7 @@ We recommend using [conda packages](https://docs.deepmodeling.org/faq/conda.html

::::{tab-set}

:::{tab-item} TensorFlow {{ tensorflow_icon }}
:::{tab-item} TensorFlow {{ tensorflow_icon }} / JAX {{ jax_icon }}

I assume you have activated the TensorFlow Python environment and want to install DeePMD-kit into path `$deepmd_root`, then execute CMake

Expand Down Expand Up @@ -375,7 +377,7 @@ One may add the following CMake variables to `cmake` using the [`-D <var>=<value

**Type**: `BOOL` (`ON`/`OFF`), Default: `OFF`

{{ tensorflow_icon }} Whether building the TensorFlow backend.
{{ tensorflow_icon }} {{ jax_icon }} Whether building the TensorFlow backend and the JAX backend.

:::

Expand All @@ -391,7 +393,7 @@ One may add the following CMake variables to `cmake` using the [`-D <var>=<value

**Type**: `PATH`

{{ tensorflow_icon }} The Path to TensorFlow's C++ interface.
{{ tensorflow_icon }} {{ jax_icon }} The Path to TensorFlow's C++ interface.

:::

Expand Down
10 changes: 10 additions & 0 deletions doc/model/dpa2.md
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,16 @@ If one runs LAMMPS with MPI, the customized OP library for the C++ interface sho
If one runs LAMMPS with MPI and CUDA devices, it is recommended to compile the customized OP library for the C++ interface with a [CUDA-Aware MPI](https://developer.nvidia.com/mpi-solutions-gpus) library and CUDA,
otherwise the communication between GPU cards falls back to the slower CPU implementation.

## Limiations of the JAX backend with LAMMPS {{ jax_icon }}

When using the JAX backend, 2 or more MPI ranks are not supported. One must set `map` to `yes` using the [`atom_modify`](https://docs.lammps.org/atom_modify.html) command.

```lammps
atom_modify map yes
```

See the example `examples/water/lmp/jax_dpa2.lammps`.

## Data format

DPA-2 supports both the [standard data format](../data/system.md) and the [mixed type data format](../data/system.md#mixed-type).
Loading

0 comments on commit 23107fa

Please sign in to comment.