Skip to content

Commit

Permalink
support different prec. support env mat
Browse files Browse the repository at this point in the history
  • Loading branch information
Han Wang committed Jan 8, 2024
1 parent f181a30 commit c7840de
Show file tree
Hide file tree
Showing 5 changed files with 171 additions and 14 deletions.
8 changes: 8 additions & 0 deletions deepmd_utils/model_format/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,7 @@
# SPDX-License-Identifier: LGPL-3.0-or-later
from .common import (
PRECISION_DICT,
)
from .network import (
EmbeddingNet,
NativeLayer,
Expand All @@ -7,6 +10,10 @@
save_dp_model,
traverse_model_dict,
)
from .env_mat import(
EnvMat,
)


__all__ = [
"EmbeddingNet",
Expand All @@ -15,4 +22,5 @@
"load_dp_model",
"save_dp_model",
"traverse_model_dict",
"PRECISION_DICT",
]
16 changes: 16 additions & 0 deletions deepmd_utils/model_format/common.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
from abc import ABC
import numpy as np

PRECISION_DICT = {
"float16": np.float16,
"float32": np.float32,
"float64": np.float64,
"default": np.float64,
}

class NativeOP(ABC):
"""The unit operation of a native model."""

def call(self, *args, **kwargs):
"""Forward pass in NumPy implementation."""
raise NotImplementedError

Check warning on line 16 in deepmd_utils/model_format/common.py

View check run for this annotation

Codecov / codecov/patch

deepmd_utils/model_format/common.py#L16

Added line #L16 was not covered by tests
85 changes: 85 additions & 0 deletions deepmd_utils/model_format/env_mat.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,85 @@
import numpy as np

from .common import (
NativeOP,
)

def compute_smooth_weight(
distance: np.ndarray,
rmin: float,
rmax: float,
):
"""Compute smooth weight for descriptor elements."""
min_mask = distance <= rmin
max_mask = distance >= rmax
mid_mask = np.logical_not(np.logical_or(min_mask, max_mask))
uu = (distance - rmin) / (rmax - rmin)
vv = uu * uu * uu * (-6. * uu * uu + 15. * uu - 10.) + 1.
return vv * mid_mask + min_mask

def _make_env_mat(
nlist,
coord,
rcut: float,
ruct_smth: float,
):
"""Make smooth environment matrix."""
nf, nloc, nnei = nlist.shape
# nf x nall x 3
coord = coord.reshape(nf, -1, 3)
mask = (nlist >= 0)
nlist = (nlist * mask)
# nf x (nloc x nnei) x 3
index = np.tile(nlist.reshape(nf, -1, 1), (1, 1, 3))
coord_r = np.take_along_axis(coord, index, 1)
# nf x nloc x nnei x 3
coord_r = coord_r.reshape(nf, nloc, nnei, 3)
# nf x nloc x 1 x 3
coord_l = coord[:,:nloc].reshape(nf, -1, 1, 3)
# nf x nloc x nnei x 3
diff = coord_r - coord_l
# nf x nloc x nnei
length = np.linalg.norm(diff, axis=-1, keepdims=True)
# for index 0 nloc atom
length = length + ~np.expand_dims(mask, -1)
t0 = 1 / length
t1 = diff / length ** 2
weight = compute_smooth_weight(length, ruct_smth, rcut)
env_mat_se_a = (
np.concatenate([t0, t1], axis=-1) * weight * np.expand_dims(mask, -1)
)
return env_mat_se_a, diff * np.expand_dims(mask,-1), weight


class EnvMat(NativeOP):
def __init__(
self,
rcut,
rcut_smth,
):
self.rcut = rcut
self.rcut_smth = rcut_smth

def call(
self,
nlist,
coord_ext,
):
em, diff, ww = _make_env_mat(
nlist, coord_ext, self.rcut, self.rcut_smth)
return em, ww

def serialize(
self,
)->dict:
return {
"rcut": self.rcut,
"rcut_smth": self.rcut_smth,
}

@classmethod
def deserialize(
cls,
data: dict,
)->"EnvMat":
return cls(**data)
31 changes: 19 additions & 12 deletions deepmd_utils/model_format/network.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,6 @@
See issue #2982 for more information.
"""
import json
from abc import (
ABC,
)
from typing import (
List,
Optional,
Expand All @@ -20,6 +17,11 @@
except ImportError:
__version__ = "unknown"

from .common import (
NativeOP,
PRECISION_DICT,
)


def traverse_model_dict(model_obj, callback: callable, is_variable: bool = False):
"""Traverse a model dict and call callback on each variable.
Expand Down Expand Up @@ -124,12 +126,6 @@ def load_dp_model(filename: str) -> dict:
return model_dict


class NativeOP(ABC):
"""The unit operation of a native model."""

def call(self, *args, **kwargs):
"""Forward pass in NumPy implementation."""
raise NotImplementedError


class NativeLayer(NativeOP):
Expand All @@ -156,10 +152,13 @@ def __init__(
idt: Optional[np.ndarray] = None,
activation_function: Optional[str] = None,
resnet: bool = False,
precision: str = "default",
) -> None:
self.w = w
self.b = b
self.idt = idt
prec = PRECISION_DICT[precision.lower()]
self.precision = precision
self.w = w.astype(prec) if w is not None else None
self.b = b.astype(prec) if b is not None else None
self.idt = idt.astype(prec) if idt is not None else None
self.activation_function = activation_function
self.resnet = resnet

Expand All @@ -180,6 +179,7 @@ def serialize(self) -> dict:
return {
"activation_function": self.activation_function,
"resnet": self.resnet,
"precision": self.precision,
"@variables": data,
}

Expand All @@ -198,6 +198,7 @@ def deserialize(cls, data: dict) -> "NativeLayer":
idt=data["@variables"].get("idt", None),
activation_function=data["activation_function"],
resnet=data.get("resnet", False),
precision=data.get("precision", "default"),
)

def __setitem__(self, key, value):
Expand All @@ -211,6 +212,8 @@ def __setitem__(self, key, value):
self.activation_function = value
elif key == "resnet":
self.resnet = value
elif key == "precision":
self.precision = value

Check warning on line 216 in deepmd_utils/model_format/network.py

View check run for this annotation

Codecov / codecov/patch

deepmd_utils/model_format/network.py#L215-L216

Added lines #L215 - L216 were not covered by tests
else:
raise KeyError(key)

Expand All @@ -225,6 +228,8 @@ def __getitem__(self, key):
return self.activation_function
elif key == "resnet":
return self.resnet
elif key == "precision":
return self.precision

Check warning on line 232 in deepmd_utils/model_format/network.py

View check run for this annotation

Codecov / codecov/patch

deepmd_utils/model_format/network.py#L231-L232

Added lines #L231 - L232 were not covered by tests
else:
raise KeyError(key)

Expand Down Expand Up @@ -338,6 +343,7 @@ def __init__(
neuron: List[int] = [24, 48, 96],
activation_function: str = "tanh",
resnet_dt: bool = False,
precision: str = "default",
):
layers = []
i_in = in_dim
Expand All @@ -351,6 +357,7 @@ def __init__(
idt=rng.normal(size=(ii)) if resnet_dt else None,
activation_function=activation_function,
resnet=True,
precision=precision,
).serialize()
)
i_in = i_ot
Expand Down
45 changes: 43 additions & 2 deletions source/tests/test_model_format_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
import numpy as np

from deepmd_utils.model_format import (
EnvMat,
EmbeddingNet,
NativeLayer,
NativeNet,
Expand All @@ -19,18 +20,19 @@

class TestNativeLayer(unittest.TestCase):
def test_serialize_deserize(self):
for (ni, no), bias, ut, activation_function, resnet, ashp in itertools.product(
for (ni, no), bias, ut, activation_function, resnet, ashp, prec in itertools.product(
[(5, 5), (5, 10), (5, 9), (9, 5)],
[True, False],
[True, False],
["tanh", "none"],
[True, False],
[None, [4], [3, 2]],
["float32", "float64", "default"],
):
ww = np.full((ni, no), 3.0)
bb = np.full((no,), 4.0) if bias else None
idt = np.full((no,), 5.0) if ut else None
nl0 = NativeLayer(ww, bb, idt, activation_function, resnet)
nl0 = NativeLayer(ww, bb, idt, activation_function, resnet, prec)
nl1 = NativeLayer.deserialize(nl0.serialize())
inp_shap = [ww.shape[0]]
if ashp is not None:
Expand Down Expand Up @@ -134,3 +136,42 @@ def test_save_load_model(self):
def tearDown(self) -> None:
if os.path.exists(self.filename):
os.remove(self.filename)


class TestEnvMat(unittest.TestCase):
def setUp(self):
# nloc == 3, nall == 4
self.nloc = 3
self.nall = 4
self.coord_ext = np.array(
[ [0, 0, 0],
[0, 1, 0],
[0, 0, 1],
[0, -2, 0],
],
dtype = np.float64,
).reshape([1, self.nall*3])
self.atype_ext = np.array(
[0, 0, 1, 0], dtype=int
).reshape([1, self.nall])
# sel = [5, 2]
self.nlist = np.array(
[
[1, 3, -1, -1, -1, 2, -1],
[0,-1, -1, -1, -1, 2, -1],
[0, 1, -1, -1, -1, 0, -1],
],
dtype=int,
).reshape([1, self.nloc, 7])
self.rcut = .4
self.rcut_smth = 2.2

def test_self_consistency(
self,
):
em0 = EnvMat(self.rcut, self.rcut_smth)
em1 = EnvMat.deserialize(em0.serialize())
mm0, ww0 = em0.call(self.nlist, self.coord_ext)
mm1, ww1 = em1.call(self.nlist, self.coord_ext)
np.testing.assert_allclose(mm0, mm1)
np.testing.assert_allclose(ww0, ww1)

0 comments on commit c7840de

Please sign in to comment.