Skip to content

Commit

Permalink
torch support for the dp model format
Browse files Browse the repository at this point in the history
  • Loading branch information
Han Wang committed Jan 31, 2024
1 parent fc3cc8e commit 515a724
Show file tree
Hide file tree
Showing 4 changed files with 239 additions and 3 deletions.
9 changes: 7 additions & 2 deletions deepmd/pt/model/model/make_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,10 @@ def forward_common(
"""
nframes, nloc = atype.shape[:2]
if box is not None:
coord_normalized = normalize_coord(coord, box.reshape(-1, 3, 3))
coord_normalized = normalize_coord(
coord.view(nframes, nloc, 3),
box.reshape(nframes, 3, 3),
)
else:
coord_normalized = coord.clone()
extended_coord, extended_atype, mapping = extend_coord_with_ghosts(
Expand All @@ -74,7 +77,7 @@ def forward_common(
self.get_sel(),
distinguish_types=self.distinguish_types(),
)
extended_coord = extended_coord.reshape(nframes, -1, 3)
extended_coord = extended_coord.view(nframes, -1, 3)
model_predict_lower = self.forward_common_lower(
extended_coord,
extended_atype,
Expand Down Expand Up @@ -119,6 +122,8 @@ def forward_common_lower(
the result dict, defined by the fitting net output def.
"""
nframes, nall = extended_atype.shape[:2]
extended_coord = extended_coord.view(nframes, -1, 3)
atomic_ret = self.forward_atomic(
extended_coord,
extended_atype,
Expand Down
2 changes: 1 addition & 1 deletion source/tests/pt/test_dp_atomic_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@
dtype = env.GLOBAL_PT_FLOAT_PRECISION


class TestInvarFitting(unittest.TestCase, TestCaseSingleFrameWithNlist):
class TestDPAtomicModel(unittest.TestCase, TestCaseSingleFrameWithNlist):
def setUp(self):
TestCaseSingleFrameWithNlist.setUp(self)

Expand Down
210 changes: 210 additions & 0 deletions source/tests/pt/test_dp_mode.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,210 @@
# SPDX-License-Identifier: LGPL-3.0-or-later
import unittest

import numpy as np
import torch

from deepmd.model_format import DescrptSeA as DPDescrptSeA
from deepmd.model_format import DPModel as DPDPModel
from deepmd.model_format import InvarFitting as DPInvarFitting
from deepmd.pt.model.descriptor.se_a import (
DescrptSeA,
)
from deepmd.pt.model.model.ener import (
DPModel,
)
from deepmd.pt.model.task.ener import (
InvarFitting,
)
from deepmd.pt.utils import (
env,
)
from deepmd.pt.utils.utils import (
to_numpy_array,
to_torch_tensor,
)

from .test_env_mat import (
TestCaseSingleFrameWithNlist,
TestCaseSingleFrameWithoutNlist,
)

dtype = env.GLOBAL_PT_FLOAT_PRECISION


class TestDPModel(unittest.TestCase, TestCaseSingleFrameWithoutNlist):
def setUp(self):
TestCaseSingleFrameWithoutNlist.setUp(self)

def test_self_consistency(self):
nf, nloc = self.atype.shape

Check notice

Code scanning / CodeQL

Unused local variable Note test

Variable nf is not used.

Check notice

Code scanning / CodeQL

Unused local variable Note test

Variable nloc is not used.
ds = DescrptSeA(
self.rcut,
self.rcut_smth,
self.sel,
).to(env.DEVICE)
ft = InvarFitting(
"energy",
self.nt,
ds.get_dim_out(),
1,
distinguish_types=ds.distinguish_types(),
).to(env.DEVICE)
type_map = ["foo", "bar"]
# TODO: dirty hack to avoid data stat!!!
md0 = DPModel(ds, ft, type_map=type_map, resuming=True).to(env.DEVICE)
md1 = DPModel.deserialize(md0.serialize()).to(env.DEVICE)
args = [to_torch_tensor(ii) for ii in [self.coord, self.atype, self.cell]]
ret0 = md0.forward_common(*args)
ret1 = md1.forward_common(*args)
np.testing.assert_allclose(
to_numpy_array(ret0["energy"]),
to_numpy_array(ret1["energy"]),
)
np.testing.assert_allclose(
to_numpy_array(ret0["energy_redu"]),
to_numpy_array(ret1["energy_redu"]),
)
np.testing.assert_allclose(
to_numpy_array(ret0["energy_derv_r"]),
to_numpy_array(ret1["energy_derv_r"]),
)
np.testing.assert_allclose(
to_numpy_array(ret0["energy_derv_c_redu"]),
to_numpy_array(ret1["energy_derv_c_redu"]),
)
ret0 = md0.forward_common(*args, do_atomic_virial=True)
ret1 = md1.forward_common(*args, do_atomic_virial=True)
np.testing.assert_allclose(
to_numpy_array(ret0["energy_derv_c"]),
to_numpy_array(ret1["energy_derv_c"]),
)

def test_dp_consistency(self):
nf, nloc = self.atype.shape

Check notice

Code scanning / CodeQL

Unused local variable Note test

Variable nf is not used.

Check notice

Code scanning / CodeQL

Unused local variable Note test

Variable nloc is not used.
ds = DPDescrptSeA(
self.rcut,
self.rcut_smth,
self.sel,
)
ft = DPInvarFitting(
"energy",
self.nt,
ds.get_dim_out(),
1,
distinguish_types=ds.distinguish_types(),
)
type_map = ["foo", "bar"]
md0 = DPDPModel(ds, ft, type_map=type_map)
md1 = DPModel.deserialize(md0.serialize()).to(env.DEVICE)
args0 = [self.coord, self.atype, self.cell]
args1 = [to_torch_tensor(ii) for ii in [self.coord, self.atype, self.cell]]
ret0 = md0.call(*args0)
ret1 = md1.forward_common(*args1)
np.testing.assert_allclose(
ret0["energy"],
to_numpy_array(ret1["energy"]),
)
np.testing.assert_allclose(
ret0["energy_redu"],
to_numpy_array(ret1["energy_redu"]),
)


class TestDPModelLower(unittest.TestCase, TestCaseSingleFrameWithNlist):
def setUp(self):
TestCaseSingleFrameWithNlist.setUp(self)

def test_self_consistency(self):
nf, nloc, nnei = self.nlist.shape

Check notice

Code scanning / CodeQL

Unused local variable Note test

Variable nf is not used.

Check notice

Code scanning / CodeQL

Unused local variable Note test

Variable nloc is not used.

Check notice

Code scanning / CodeQL

Unused local variable Note test

Variable nnei is not used.
ds = DescrptSeA(
self.rcut,
self.rcut_smth,
self.sel,
).to(env.DEVICE)
ft = InvarFitting(
"energy",
self.nt,
ds.get_dim_out(),
1,
distinguish_types=ds.distinguish_types(),
).to(env.DEVICE)
type_map = ["foo", "bar"]
# TODO: dirty hack to avoid data stat!!!
md0 = DPModel(ds, ft, type_map=type_map, resuming=True).to(env.DEVICE)
md1 = DPModel.deserialize(md0.serialize()).to(env.DEVICE)
args = [
to_torch_tensor(ii) for ii in [self.coord_ext, self.atype_ext, self.nlist]
]
ret0 = md0.forward_common_lower(*args)
ret1 = md1.forward_common_lower(*args)
np.testing.assert_allclose(
to_numpy_array(ret0["energy"]),
to_numpy_array(ret1["energy"]),
)
np.testing.assert_allclose(
to_numpy_array(ret0["energy_redu"]),
to_numpy_array(ret1["energy_redu"]),
)
np.testing.assert_allclose(
to_numpy_array(ret0["energy_derv_r"]),
to_numpy_array(ret1["energy_derv_r"]),
)
ret0 = md0.forward_common_lower(*args, do_atomic_virial=True)
ret1 = md1.forward_common_lower(*args, do_atomic_virial=True)
np.testing.assert_allclose(
to_numpy_array(ret0["energy_derv_c"]),
to_numpy_array(ret1["energy_derv_c"]),
)

def test_dp_consistency(self):
rng = np.random.default_rng()

Check notice

Code scanning / CodeQL

Unused local variable Note test

Variable rng is not used.
nf, nloc, nnei = self.nlist.shape

Check notice

Code scanning / CodeQL

Unused local variable Note test

Variable nf is not used.

Check notice

Code scanning / CodeQL

Unused local variable Note test

Variable nloc is not used.

Check notice

Code scanning / CodeQL

Unused local variable Note test

Variable nnei is not used.
ds = DPDescrptSeA(
self.rcut,
self.rcut_smth,
self.sel,
)
ft = DPInvarFitting(
"energy",
self.nt,
ds.get_dim_out(),
1,
distinguish_types=ds.distinguish_types(),
)
type_map = ["foo", "bar"]
md0 = DPDPModel(ds, ft, type_map=type_map)
md1 = DPModel.deserialize(md0.serialize()).to(env.DEVICE)
args0 = [self.coord_ext, self.atype_ext, self.nlist]
args1 = [
to_torch_tensor(ii) for ii in [self.coord_ext, self.atype_ext, self.nlist]
]
ret0 = md0.call_lower(*args0)
ret1 = md1.forward_common_lower(*args1)
np.testing.assert_allclose(
ret0["energy"],
to_numpy_array(ret1["energy"]),
)
np.testing.assert_allclose(
ret0["energy_redu"],
to_numpy_array(ret1["energy_redu"]),
)

def test_jit(self):
nf, nloc, nnei = self.nlist.shape

Check notice

Code scanning / CodeQL

Unused local variable Note test

Variable nf is not used.

Check notice

Code scanning / CodeQL

Unused local variable Note test

Variable nloc is not used.

Check notice

Code scanning / CodeQL

Unused local variable Note test

Variable nnei is not used.
ds = DescrptSeA(
self.rcut,
self.rcut_smth,
self.sel,
).to(env.DEVICE)
ft = InvarFitting(
"energy",
self.nt,
ds.get_dim_out(),
1,
distinguish_types=ds.distinguish_types(),
).to(env.DEVICE)
type_map = ["foo", "bar"]
# TODO: dirty hack to avoid data stat!!!
md0 = DPModel(ds, ft, type_map=type_map, resuming=True).to(env.DEVICE)
torch.jit.script(md0)
21 changes: 21 additions & 0 deletions source/tests/pt/test_env_mat.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,27 @@ def setUp(self):
self.rcut_smth = 2.2


class TestCaseSingleFrameWithoutNlist:
def setUp(self):
# nloc == 3, nall == 4
self.nloc = 3
self.nf, self.nt = 1, 2
self.coord = np.array(
[
[0, 0, 0],
[0, 1, 0],
[0, 0, 1],
],
dtype=np.float64,
).reshape([1, self.nloc * 3])
self.atype = np.array([0, 0, 1], dtype=int).reshape([1, self.nloc])
self.cell = 2.0 * np.eye(3).reshape([1, 9])
# sel = [5, 2]
self.sel = [5, 2]
self.rcut = 0.4
self.rcut_smth = 2.2


# to be merged with the tf test case
@unittest.skipIf(not support_env_mat, "EnvMat not supported")
class TestEnvMat(unittest.TestCase, TestCaseSingleFrameWithNlist):
Expand Down

0 comments on commit 515a724

Please sign in to comment.