Skip to content

Commit

Permalink
Merge branch 'devel' into jax-zbl
Browse files Browse the repository at this point in the history
Signed-off-by: Jinzhe Zeng <[email protected]>
  • Loading branch information
njzjz committed Nov 3, 2024
2 parents 84803b8 + bfbe2ed commit 70e9eae
Show file tree
Hide file tree
Showing 58 changed files with 3,931 additions and 592 deletions.
4 changes: 3 additions & 1 deletion .github/workflows/test_cuda.yml
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ jobs:
&& sudo apt-get -y install cuda-12-3 libcudnn8=8.9.5.*-1+cuda12.3
if: false # skip as we use nvidia image
- run: python -m pip install -U uv
- run: source/install/uv_with_retry.sh pip install --system "tensorflow~=2.18.0rc2" "torch~=2.5.0"
- run: source/install/uv_with_retry.sh pip install --system "tensorflow~=2.18.0rc2" "torch~=2.5.0" "jax[cuda12]"
- run: |
export PYTORCH_ROOT=$(python -c 'import torch;print(torch.__path__[0])')
export TENSORFLOW_ROOT=$(python -c 'import importlib,pathlib;print(pathlib.Path(importlib.util.find_spec("tensorflow").origin).parent)')
Expand All @@ -61,6 +61,8 @@ jobs:
env:
NUM_WORKERS: 0
CUDA_VISIBLE_DEVICES: 0
# See https://jax.readthedocs.io/en/latest/gpu_memory_allocation.html
XLA_PYTHON_CLIENT_PREALLOCATE: false
- name: Download libtorch
run: |
wget https://download.pytorch.org/libtorch/cu124/libtorch-cxx11-abi-shared-with-deps-2.5.0%2Bcu124.zip -O libtorch.zip
Expand Down
31 changes: 31 additions & 0 deletions deepmd/dpmodel/atomic_model/dp_atomic_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,37 @@ def need_sorted_nlist_for_lower(self) -> bool:
"""Returns whether the atomic model needs sorted nlist when using `forward_lower`."""
return self.descriptor.need_sorted_nlist_for_lower()

def enable_compression(
self,
min_nbor_dist: float,
table_extrapolate: float = 5,
table_stride_1: float = 0.01,
table_stride_2: float = 0.1,
check_frequency: int = -1,
) -> None:
"""Call descriptor enable_compression().
Parameters
----------
min_nbor_dist
The nearest distance between atoms
table_extrapolate
The scale of model extrapolation
table_stride_1
The uniform stride of the first table
table_stride_2
The uniform stride of the second table
check_frequency
The overflow check frequency
"""
self.descriptor.enable_compression(
min_nbor_dist,
table_extrapolate,
table_stride_1,
table_stride_2,
check_frequency,
)

def forward_atomic(
self,
extended_coord: np.ndarray,
Expand Down
32 changes: 32 additions & 0 deletions deepmd/dpmodel/atomic_model/linear_atomic_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -151,6 +151,38 @@ def _sort_rcuts_sels(self) -> tuple[list[float], list[int]]:
)
return [p[0] for p in zipped], [p[1] for p in zipped]

def enable_compression(
self,
min_nbor_dist: float,
table_extrapolate: float = 5,
table_stride_1: float = 0.01,
table_stride_2: float = 0.1,
check_frequency: int = -1,
) -> None:
"""Compress model.
Parameters
----------
min_nbor_dist
The nearest distance between atoms
table_extrapolate
The scale of model extrapolation
table_stride_1
The uniform stride of the first table
table_stride_2
The uniform stride of the second table
check_frequency
The overflow check frequency
"""
for model in self.models:
model.enable_compression(
min_nbor_dist,
table_extrapolate,
table_stride_1,
table_stride_2,
check_frequency,
)

def forward_atomic(
self,
extended_coord,
Expand Down
25 changes: 25 additions & 0 deletions deepmd/dpmodel/atomic_model/make_base_atomic_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -148,6 +148,31 @@ def change_type_map(
) -> None:
pass

def enable_compression(
self,
min_nbor_dist: float,
table_extrapolate: float = 5,
table_stride_1: float = 0.01,
table_stride_2: float = 0.1,
check_frequency: int = -1,
) -> None:
"""Call descriptor enable_compression().
Parameters
----------
min_nbor_dist
The nearest distance between atoms
table_extrapolate
The scale of model extrapolation
table_stride_1
The uniform stride of the first table
table_stride_2
The uniform stride of the second table
check_frequency
The overflow check frequency
"""
raise NotImplementedError("This atomi model doesn't support compression!")

def make_atom_mask(
self,
atype: t_tensor,
Expand Down
37 changes: 25 additions & 12 deletions deepmd/dpmodel/descriptor/dpa2.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,18 @@
Union,
)

import array_api_compat
import numpy as np

from deepmd.dpmodel import (
NativeOP,
)
from deepmd.dpmodel.array_api import (
xp_take_along_axis,
)
from deepmd.dpmodel.common import (
to_numpy_array,
)
from deepmd.dpmodel.utils import (
EnvMat,
NetworkCollection,
Expand Down Expand Up @@ -787,9 +794,10 @@ def call(
The smooth switch function. shape: nf x nloc x nnei
"""
xp = array_api_compat.array_namespace(coord_ext, atype_ext, nlist)
use_three_body = self.use_three_body
nframes, nloc, nnei = nlist.shape
nall = coord_ext.reshape(nframes, -1).shape[1] // 3
nall = xp.reshape(coord_ext, (nframes, -1)).shape[1] // 3
# nlists
nlist_dict = build_multiple_neighbor_list(
coord_ext,
Expand All @@ -798,7 +806,10 @@ def call(
self.nsel_list,
)
# repinit
g1_ext = self.type_embedding.call()[atype_ext]
g1_ext = xp.reshape(
xp.take(self.type_embedding.call(), xp.reshape(atype_ext, [-1]), axis=0),
(nframes, nall, self.tebd_dim),
)
g1_inp = g1_ext[:, :nloc, :]
g1, _, _, _, _ = self.repinit(
nlist_dict[
Expand All @@ -823,16 +834,18 @@ def call(
g1_ext,
mapping,
)
g1 = np.concatenate([g1, g1_three_body], axis=-1)
g1 = xp.concat([g1, g1_three_body], axis=-1)
# linear to change shape
g1 = self.g1_shape_tranform(g1)
if self.add_tebd_to_repinit_out:
assert self.tebd_transform is not None
g1 = g1 + self.tebd_transform(g1_inp)
# mapping g1
assert mapping is not None
mapping_ext = np.tile(mapping.reshape(nframes, nall, 1), (1, 1, g1.shape[-1]))
g1_ext = np.take_along_axis(g1, mapping_ext, axis=1)
mapping_ext = xp.tile(
xp.reshape(mapping, (nframes, nall, 1)), (1, 1, g1.shape[-1])
)
g1_ext = xp_take_along_axis(g1, mapping_ext, axis=1)
# repformer
g1, g2, h2, rot_mat, sw = self.repformers(
nlist_dict[
Expand All @@ -846,7 +859,7 @@ def call(
mapping,
)
if self.concat_output_tebd:
g1 = np.concatenate([g1, g1_inp], axis=-1)
g1 = xp.concat([g1, g1_inp], axis=-1)
return g1, rot_mat, g2, h2, sw

def serialize(self) -> dict:
Expand Down Expand Up @@ -883,8 +896,8 @@ def serialize(self) -> dict:
"embeddings": repinit.embeddings.serialize(),
"env_mat": EnvMat(repinit.rcut, repinit.rcut_smth).serialize(),
"@variables": {
"davg": repinit["davg"],
"dstd": repinit["dstd"],
"davg": to_numpy_array(repinit["davg"]),
"dstd": to_numpy_array(repinit["dstd"]),
},
}
if repinit.tebd_input_mode in ["strip"]:
Expand All @@ -896,8 +909,8 @@ def serialize(self) -> dict:
"repformer_layers": [layer.serialize() for layer in repformers.layers],
"env_mat": EnvMat(repformers.rcut, repformers.rcut_smth).serialize(),
"@variables": {
"davg": repformers["davg"],
"dstd": repformers["dstd"],
"davg": to_numpy_array(repformers["davg"]),
"dstd": to_numpy_array(repformers["dstd"]),
},
}
data.update(
Expand All @@ -913,8 +926,8 @@ def serialize(self) -> dict:
repinit_three_body.rcut, repinit_three_body.rcut_smth
).serialize(),
"@variables": {
"davg": repinit_three_body["davg"],
"dstd": repinit_three_body["dstd"],
"davg": to_numpy_array(repinit_three_body["davg"]),
"dstd": to_numpy_array(repinit_three_body["dstd"]),
},
}
if repinit_three_body.tebd_input_mode in ["strip"]:
Expand Down
32 changes: 32 additions & 0 deletions deepmd/dpmodel/descriptor/hybrid.py
Original file line number Diff line number Diff line change
Expand Up @@ -210,6 +210,38 @@ def get_stat_mean_and_stddev(
stddev_list.append(stddev_item)
return mean_list, stddev_list

def enable_compression(
self,
min_nbor_dist: float,
table_extrapolate: float = 5,
table_stride_1: float = 0.01,
table_stride_2: float = 0.1,
check_frequency: int = -1,
) -> None:
"""Receive the statisitcs (distance, max_nbor_size and env_mat_range) of the training data.
Parameters
----------
min_nbor_dist
The nearest distance between atoms
table_extrapolate
The scale of model extrapolation
table_stride_1
The uniform stride of the first table
table_stride_2
The uniform stride of the second table
check_frequency
The overflow check frequency
"""
for descrpt in self.descrpt_list:
descrpt.enable_compression(
min_nbor_dist,
table_extrapolate,
table_stride_1,
table_stride_2,
check_frequency,
)

def call(
self,
coord_ext,
Expand Down
25 changes: 25 additions & 0 deletions deepmd/dpmodel/descriptor/make_base_descriptor.py
Original file line number Diff line number Diff line change
Expand Up @@ -147,6 +147,31 @@ def compute_input_stats(
"""Update mean and stddev for descriptor elements."""
raise NotImplementedError

def enable_compression(
self,
min_nbor_dist: float,
table_extrapolate: float = 5,
table_stride_1: float = 0.01,
table_stride_2: float = 0.1,
check_frequency: int = -1,
) -> None:
"""Receive the statisitcs (distance, max_nbor_size and env_mat_range) of the training data.
Parameters
----------
min_nbor_dist
The nearest distance between atoms
table_extrapolate
The scale of model extrapolation
table_stride_1
The uniform stride of the first table
table_stride_2
The uniform stride of the second table
check_frequency
The overflow check frequency
"""
raise NotImplementedError("This descriptor doesn't support compression!")

@abstractmethod
def fwd(
self,
Expand Down
Loading

0 comments on commit 70e9eae

Please sign in to comment.