Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(jax): SavedModel C++ interface (including DPA-2 supports) #4307

Merged
merged 81 commits into from
Nov 13, 2024
Merged
Show file tree
Hide file tree
Changes from 80 commits
Commits
Show all changes
81 commits
Select commit Hold shift + click to select a range
147400e
feat: saved model C++ interface
njzjz Nov 4, 2024
8c6d522
model
njzjz Nov 4, 2024
140f3e1
update test data
njzjz Nov 4, 2024
a0b8074
need CPU model
njzjz Nov 4, 2024
e6bf59f
skip memory check
njzjz Nov 4, 2024
6c10e8e
fix
njzjz Nov 4, 2024
2aa6deb
Apply suggestions from code review
njzjz Nov 4, 2024
f16dd92
Update source/api_cc/src/DeepPotJAX.cc
njzjz Nov 4, 2024
297ae26
debug memory leak
njzjz Nov 4, 2024
e64e06a
add LAMMPS test
njzjz Nov 4, 2024
8fefce8
fix memory leak in add_input
njzjz Nov 4, 2024
261c7bd
pass reference
njzjz Nov 4, 2024
4d5ccc5
delete function and retvals
njzjz Nov 4, 2024
d365bbc
Merge branch 'savedmodel-cxx-debug-mem' into savedmodel-cxx
njzjz Nov 4, 2024
21fc045
no need to skip the test
njzjz Nov 4, 2024
660171e
Merge branch 'devel' into savedmodel-cxx
njzjz Nov 4, 2024
d552821
Merge remote-tracking branch 'origin/devel' into savedmodel-cxx
njzjz Nov 4, 2024
0461248
add limitation
njzjz Nov 4, 2024
f26f3fe
fix tf string parse
njzjz Nov 4, 2024
713d065
Update source/api_cc/tests/test_deeppot_jax.cc
njzjz Nov 4, 2024
ccb182d
cast void*
njzjz Nov 4, 2024
8ccead6
handle zero atom
njzjz Nov 5, 2024
904042d
Merge branch 'devel' into jax-cxx-dpa2
njzjz Nov 5, 2024
0f9d5c5
feat(jax): DPA-2 for LAMMPS
njzjz Nov 5, 2024
bad564b
use the cpu model
njzjz Nov 5, 2024
2b165d7
fix function name
njzjz Nov 5, 2024
e717ba3
fix typos
njzjz Nov 5, 2024
f075075
nloc_real -> nall_real
njzjz Nov 6, 2024
58dcf2b
document limation
njzjz Nov 6, 2024
d93d13a
Merge branch 'devel' into jax-cxx-dpa1
Nov 9, 2024
232f7cd
fix(tf): fix normalize when compressing a model converted from other …
Nov 10, 2024
ce9ee61
apply padding method
Nov 10, 2024
6b10eb7
update model
njzjz Nov 11, 2024
afc71cb
Merge commit 'ce9ee61e71b83d2c682522706f98955dfecea98a' into jax-cxx-…
njzjz Nov 11, 2024
e1a2b55
Merge remote-tracking branch 'origin/devel' into reformat-jax-cxx
njzjz Nov 11, 2024
649f98e
update base class
njzjz Nov 11, 2024
1cad0b2
perhaps PADDING_FACTOR doesn't need so much
njzjz Nov 11, 2024
239d186
use max size
njzjz Nov 11, 2024
37c8739
bump API version
njzjz Nov 11, 2024
b863c79
update model
njzjz Nov 11, 2024
95ad9d0
update model
njzjz Nov 11, 2024
b6d039f
Revert "use max size"
njzjz Nov 11, 2024
5e2ea67
test
njzjz Nov 11, 2024
72a23d2
debug
njzjz Nov 11, 2024
edc4445
add all functions
njzjz Nov 11, 2024
b0808f1
Reapply "use max size"
njzjz Nov 11, 2024
458be34
Revert "debug"
njzjz Nov 11, 2024
87908c3
Revert "test"
njzjz Nov 11, 2024
3a0ca2d
Revert "update model"
njzjz Nov 11, 2024
1863b27
Revert "update model"
njzjz Nov 11, 2024
eb549e5
cast type
njzjz Nov 11, 2024
8a154bd
update model
njzjz Nov 11, 2024
4dab4fb
bugfix
njzjz Nov 11, 2024
c4f08c8
fix OOM issue
njzjz Nov 11, 2024
ef70135
no nlist interface
njzjz Nov 11, 2024
be02814
fix skip
njzjz Nov 11, 2024
3c46f37
try to reduce memory
njzjz Nov 12, 2024
49f57bc
fix skip tests
njzjz Nov 12, 2024
e8a99f4
also skip lammps dpa-2 tests for CUDA
njzjz Nov 12, 2024
8f83a28
should be fw
njzjz Nov 12, 2024
8c05d54
Revert "should be fw"
njzjz Nov 12, 2024
88be054
Revert "try to reduce memory"
njzjz Nov 12, 2024
5cfc83c
Revert "fix OOM issue"
njzjz Nov 12, 2024
9af5267
set --clean-durations
njzjz Nov 12, 2024
01567d6
Merge branch 'devel' into savedmodel-cxx
njzjz Nov 12, 2024
dc4a9d7
Merge remote-tracking branch 'origin/devel' into savedmodel-cxx
njzjz Nov 12, 2024
1234489
add example
njzjz Nov 12, 2024
86d1b7a
convert models at runtime
njzjz Nov 12, 2024
93cc440
add script path
njzjz Nov 12, 2024
546f7dc
revert strict=False
njzjz Nov 12, 2024
0d51bcc
revert .gitignore
njzjz Nov 12, 2024
fc1f90d
prefer cuda's cudnn
njzjz Nov 13, 2024
9447603
bump cuda version
njzjz Nov 13, 2024
6d5b45a
debug
njzjz Nov 13, 2024
e569ed9
fix docker name
njzjz Nov 13, 2024
1b3fd5e
set allow_growth to True
njzjz Nov 13, 2024
9d95778
fix compile error
njzjz Nov 13, 2024
09efdd3
fix typo
njzjz Nov 13, 2024
39f357c
call TFE_ContextOptionsSetConfig
njzjz Nov 13, 2024
cfff834
fix config
njzjz Nov 13, 2024
ca02625
Revert "debug"
njzjz Nov 13, 2024
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 7 additions & 7 deletions .github/workflows/test_cc.yml
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,13 @@ jobs:
mpi: mpich
- uses: lukka/get-cmake@latest
- run: python -m pip install uv
- run: source/install/uv_with_retry.sh pip install --system tensorflow
- name: Install Python dependencies
run: |
source/install/uv_with_retry.sh pip install --system tensorflow-cpu
export TENSORFLOW_ROOT=$(python -c 'import importlib,pathlib;print(pathlib.Path(importlib.util.find_spec("tensorflow").origin).parent)')
source/install/uv_with_retry.sh pip install --system -e .[cpu,test,lmp,jax] mpi4py
- name: Convert models
run: source/tests/infer/convert-models.sh
- name: Download libtorch
run: |
wget https://download.pytorch.org/libtorch/cpu/libtorch-cxx11-abi-shared-with-deps-2.1.2%2Bcpu.zip -O libtorch.zip
Expand All @@ -47,12 +53,6 @@ jobs:
CMAKE_GENERATOR: Ninja
CXXFLAGS: ${{ matrix.check_memleak && '-fsanitize=leak' || '' }}
# test lammps
- run: |
export TENSORFLOW_ROOT=$(python -c 'import importlib,pathlib;print(pathlib.Path(importlib.util.find_spec("tensorflow").origin).parent)')
source/install/uv_with_retry.sh pip install --system -e .[cpu,test,lmp] mpi4py
env:
DP_BUILD_TESTING: 1
if: ${{ !matrix.check_memleak }}
- run: pytest --cov=deepmd source/lmp/tests
env:
OMP_NUM_THREADS: 1
Expand Down
8 changes: 6 additions & 2 deletions .github/workflows/test_cuda.yml
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ jobs:
runs-on: nvidia
# https://github.com/deepmodeling/deepmd-kit/pull/2884#issuecomment-1744216845
container:
image: nvidia/cuda:12.3.1-devel-ubuntu22.04
image: nvidia/cuda:12.6.2-cudnn-devel-ubuntu22.04
options: --gpus all
if: github.repository_owner == 'deepmodeling' && (github.event_name == 'pull_request' && github.event.label && github.event.label.name == 'Test CUDA' || github.event_name == 'workflow_dispatch' || github.event_name == 'merge_group')
steps:
Expand Down Expand Up @@ -63,12 +63,16 @@ jobs:
CUDA_VISIBLE_DEVICES: 0
# See https://jax.readthedocs.io/en/latest/gpu_memory_allocation.html
XLA_PYTHON_CLIENT_PREALLOCATE: false
if: false # debug
- name: Convert models
run: source/tests/infer/convert-models.sh
njzjz marked this conversation as resolved.
Show resolved Hide resolved
- name: Download libtorch
run: |
wget https://download.pytorch.org/libtorch/cu124/libtorch-cxx11-abi-shared-with-deps-2.5.0%2Bcu124.zip -O libtorch.zip
unzip libtorch.zip
- run: |
export CMAKE_PREFIX_PATH=$GITHUB_WORKSPACE/libtorch
export LD_LIBRARY_PATH=$CUDA_PATH/lib64:/usr/lib/x86_64-linux-gnu/:$LD_LIBRARY_PATH
source/install/test_cc_local.sh
env:
OMP_NUM_THREADS: 1
Expand All @@ -79,7 +83,7 @@ jobs:
DP_VARIANT: cuda
DP_USE_MPICH2: 1
- run: |
export LD_LIBRARY_PATH=$GITHUB_WORKSPACE/dp_test/lib:$GITHUB_WORKSPACE/libtorch/lib:$CUDA_PATH/lib64:$LD_LIBRARY_PATH
export LD_LIBRARY_PATH=$CUDA_PATH/lib64:/usr/lib/x86_64-linux-gnu/:$GITHUB_WORKSPACE/dp_test/lib:$GITHUB_WORKSPACE/libtorch/lib:$LD_LIBRARY_PATH
export PATH=$GITHUB_WORKSPACE/dp_test/bin:$PATH
python -m pytest -s source/lmp/tests || (cat log.lammps && exit 1)
python -m pytest source/ipi/tests
Expand Down
4 changes: 3 additions & 1 deletion doc/backend.md
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,9 @@ While `.pth` and `.pt` are the same in the PyTorch package, they have different
[JAX](https://jax.readthedocs.io/) 0.4.33 (which requires Python 3.10 or above) or above is required.
Both `.xlo` and `.jax` are customized format extensions defined in DeePMD-kit, since JAX has no convention for file extensions.
`.savedmodel` is the TensorFlow [SavedModel format](https://www.tensorflow.org/guide/saved_model) generated by [JAX2TF](https://www.tensorflow.org/guide/jax2tf), which needs the installation of TensorFlow.
Currently, this backend is developed actively, and has no support for training and the C++ interface.
Only the `.savedmodel` format supports C++ inference, which needs the TensorFlow C++ interface.
The model is device-specific, so that the model generated on the GPU device cannot be run on the CPUs.
Currently, this backend is developed actively, and has no support for training.
njzjz marked this conversation as resolved.
Show resolved Hide resolved
njzjz marked this conversation as resolved.
Show resolved Hide resolved

### DP {{ dpmodel_icon }}

Expand Down
10 changes: 6 additions & 4 deletions doc/install/install-from-source.md
Original file line number Diff line number Diff line change
Expand Up @@ -297,7 +297,9 @@ If one does not need to use DeePMD-kit with LAMMPS or i-PI, then the python inte

::::{tab-set}

:::{tab-item} TensorFlow {{ tensorflow_icon }}
:::{tab-item} TensorFlow {{ tensorflow_icon }} / JAX {{ jax_icon }}

The C++ interfaces of both TensorFlow and JAX backends are based on the TensorFlow C++ library.
njzjz marked this conversation as resolved.
Show resolved Hide resolved

Since TensorFlow 2.12, TensorFlow C++ library (`libtensorflow_cc`) is packaged inside the Python library. Thus, you can skip building TensorFlow C++ library manually. If that does not work for you, you can still build it manually.

Expand Down Expand Up @@ -338,7 +340,7 @@ We recommend using [conda packages](https://docs.deepmodeling.org/faq/conda.html

::::{tab-set}

:::{tab-item} TensorFlow {{ tensorflow_icon }}
:::{tab-item} TensorFlow {{ tensorflow_icon }} / JAX {{ jax_icon }}

I assume you have activated the TensorFlow Python environment and want to install DeePMD-kit into path `$deepmd_root`, then execute CMake

Expand Down Expand Up @@ -375,7 +377,7 @@ One may add the following CMake variables to `cmake` using the [`-D <var>=<value

**Type**: `BOOL` (`ON`/`OFF`), Default: `OFF`

{{ tensorflow_icon }} Whether building the TensorFlow backend.
{{ tensorflow_icon }} {{ jax_icon }} Whether building the TensorFlow backend and the JAX backend.
njzjz marked this conversation as resolved.
Show resolved Hide resolved

:::

Expand All @@ -391,7 +393,7 @@ One may add the following CMake variables to `cmake` using the [`-D <var>=<value

**Type**: `PATH`

{{ tensorflow_icon }} The Path to TensorFlow's C++ interface.
{{ tensorflow_icon }} {{ jax_icon }} The Path to TensorFlow's C++ interface.

:::

Expand Down
10 changes: 10 additions & 0 deletions doc/model/dpa2.md
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,16 @@ If one runs LAMMPS with MPI, the customized OP library for the C++ interface sho
If one runs LAMMPS with MPI and CUDA devices, it is recommended to compile the customized OP library for the C++ interface with a [CUDA-Aware MPI](https://developer.nvidia.com/mpi-solutions-gpus) library and CUDA,
otherwise the communication between GPU cards falls back to the slower CPU implementation.

## Limiations of the JAX backend with LAMMPS {{ jax_icon }}

When using the JAX backend, 2 or more MPI ranks are not supported. One must set `map` to `yes` using the [`atom_modify`](https://docs.lammps.org/atom_modify.html) command.
njzjz marked this conversation as resolved.
Show resolved Hide resolved

```lammps
atom_modify map yes
```

See the example `examples/water/lmp/jax_dpa2.lammps`.

## Data format

DPA-2 supports both the [standard data format](../data/system.md) and the [mixed type data format](../data/system.md#mixed-type).
Binary file added examples/water/dpa2/frozen_model.pth
Binary file not shown.
31 changes: 31 additions & 0 deletions examples/water/lmp/jax_dpa2.lammps
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@


# bulk water

units metal
boundary p p p
atom_style atomic
# Below line is required when using DPA-2 with the JAX backend
atom_modify map yes

neighbor 2.0 bin
neigh_modify every 10 delay 0 check no

read_data water.lmp
mass 1 16
mass 2 2
njzjz marked this conversation as resolved.
Show resolved Hide resolved

# See https://deepmd.rtfd.io/lammps/ for usage
pair_style deepmd frozen_model.savedmodel
# If atom names (O H in this example) are not set in the pair_coeff command, the type_map defined by the training parameter will be used by default.
pair_coeff * * O H

velocity all create 330.0 23456789

fix 1 all nvt temp 330.0 330.0 0.5
timestep 0.0005
thermo_style custom step pe ke etotal temp press vol
thermo 100
dump 1 all custom 100 water.dump id type x y z

run 1000
16 changes: 13 additions & 3 deletions source/api_c/include/c_api.h
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ extern "C" {
/** C API version. Bumped whenever the API is changed.
* @since API version 22
*/
#define DP_C_API_VERSION 24
#define DP_C_API_VERSION 25

/**
* @brief Neighbor list.
Expand All @@ -31,7 +31,7 @@ extern DP_Nlist* DP_NewNlist(int inum_,
int* ilist_,
int* numneigh_,
int** firstneigh_);
/*
/**
* @brief Create a new neighbor list with communication capabilities.
* @details This function extends DP_NewNlist by adding support for parallel
* communication, allowing the neighbor list to be used in distributed
Expand Down Expand Up @@ -68,7 +68,7 @@ extern DP_Nlist* DP_NewNlist_comm(int inum_,
int* recvproc,
void* world);

/*
/**
* @brief Set mask for a neighbor list.
*
* @param nl Neighbor list.
Expand All @@ -78,6 +78,16 @@ extern DP_Nlist* DP_NewNlist_comm(int inum_,
**/
extern void DP_NlistSetMask(DP_Nlist* nl, int mask);

/**
* @brief Set mapping for a neighbor list.
*
* @param nl Neighbor list.
* @param mapping mapping from all atoms to real atoms, in size nall.
* @since API version 25
*
**/
extern void DP_NlistSetMapping(DP_Nlist* nl, int* mapping);

/**
* @brief Delete a neighbor list.
*
Expand Down
5 changes: 5 additions & 0 deletions source/api_c/include/deepmd.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -863,6 +863,11 @@ struct InputNlist {
* @brief Set mask for this neighbor list.
*/
void set_mask(int mask) { DP_NlistSetMask(nl, mask); };
/**
* @brief Set mapping for this neighbor list.
* @param mapping mapping from all atoms to real atoms, in size nall.
*/
void set_mapping(int *mapping) { DP_NlistSetMapping(nl, mapping); };
};

/**
Expand Down
3 changes: 3 additions & 0 deletions source/api_c/src/c_api.cc
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,9 @@ DP_Nlist* DP_NewNlist_comm(int inum_,
return new_nl;
}
void DP_NlistSetMask(DP_Nlist* nl, int mask) { nl->nl.set_mask(mask); }
void DP_NlistSetMapping(DP_Nlist* nl, int* mapping) {
nl->nl.set_mapping(mapping);
}
njzjz marked this conversation as resolved.
Show resolved Hide resolved
void DP_DeleteNlist(DP_Nlist* nl) { delete nl; }

// DP Base Model
Expand Down
Loading
Loading