From 515a72468f34795d6a7db56a469105ac24b919b8 Mon Sep 17 00:00:00 2001 From: Han Wang Date: Wed, 31 Jan 2024 14:49:36 +0800 Subject: [PATCH] torch support for the dp model format --- deepmd/pt/model/model/make_model.py | 9 +- source/tests/pt/test_dp_atomic_model.py | 2 +- source/tests/pt/test_dp_mode.py | 210 ++++++++++++++++++++++++ source/tests/pt/test_env_mat.py | 21 +++ 4 files changed, 239 insertions(+), 3 deletions(-) create mode 100644 source/tests/pt/test_dp_mode.py diff --git a/deepmd/pt/model/model/make_model.py b/deepmd/pt/model/model/make_model.py index 3ddd21fbb8..9e2613368b 100644 --- a/deepmd/pt/model/model/make_model.py +++ b/deepmd/pt/model/model/make_model.py @@ -60,7 +60,10 @@ def forward_common( """ nframes, nloc = atype.shape[:2] if box is not None: - coord_normalized = normalize_coord(coord, box.reshape(-1, 3, 3)) + coord_normalized = normalize_coord( + coord.view(nframes, nloc, 3), + box.reshape(nframes, 3, 3), + ) else: coord_normalized = coord.clone() extended_coord, extended_atype, mapping = extend_coord_with_ghosts( @@ -74,7 +77,7 @@ def forward_common( self.get_sel(), distinguish_types=self.distinguish_types(), ) - extended_coord = extended_coord.reshape(nframes, -1, 3) + extended_coord = extended_coord.view(nframes, -1, 3) model_predict_lower = self.forward_common_lower( extended_coord, extended_atype, @@ -119,6 +122,8 @@ def forward_common_lower( the result dict, defined by the fitting net output def. """ + nframes, nall = extended_atype.shape[:2] + extended_coord = extended_coord.view(nframes, -1, 3) atomic_ret = self.forward_atomic( extended_coord, extended_atype, diff --git a/source/tests/pt/test_dp_atomic_model.py b/source/tests/pt/test_dp_atomic_model.py index 46aab471c4..a8b20f829d 100644 --- a/source/tests/pt/test_dp_atomic_model.py +++ b/source/tests/pt/test_dp_atomic_model.py @@ -31,7 +31,7 @@ dtype = env.GLOBAL_PT_FLOAT_PRECISION -class TestInvarFitting(unittest.TestCase, TestCaseSingleFrameWithNlist): +class TestDPAtomicModel(unittest.TestCase, TestCaseSingleFrameWithNlist): def setUp(self): TestCaseSingleFrameWithNlist.setUp(self) diff --git a/source/tests/pt/test_dp_mode.py b/source/tests/pt/test_dp_mode.py new file mode 100644 index 0000000000..519d6ece77 --- /dev/null +++ b/source/tests/pt/test_dp_mode.py @@ -0,0 +1,210 @@ +# SPDX-License-Identifier: LGPL-3.0-or-later +import unittest + +import numpy as np +import torch + +from deepmd.model_format import DescrptSeA as DPDescrptSeA +from deepmd.model_format import DPModel as DPDPModel +from deepmd.model_format import InvarFitting as DPInvarFitting +from deepmd.pt.model.descriptor.se_a import ( + DescrptSeA, +) +from deepmd.pt.model.model.ener import ( + DPModel, +) +from deepmd.pt.model.task.ener import ( + InvarFitting, +) +from deepmd.pt.utils import ( + env, +) +from deepmd.pt.utils.utils import ( + to_numpy_array, + to_torch_tensor, +) + +from .test_env_mat import ( + TestCaseSingleFrameWithNlist, + TestCaseSingleFrameWithoutNlist, +) + +dtype = env.GLOBAL_PT_FLOAT_PRECISION + + +class TestDPModel(unittest.TestCase, TestCaseSingleFrameWithoutNlist): + def setUp(self): + TestCaseSingleFrameWithoutNlist.setUp(self) + + def test_self_consistency(self): + nf, nloc = self.atype.shape + ds = DescrptSeA( + self.rcut, + self.rcut_smth, + self.sel, + ).to(env.DEVICE) + ft = InvarFitting( + "energy", + self.nt, + ds.get_dim_out(), + 1, + distinguish_types=ds.distinguish_types(), + ).to(env.DEVICE) + type_map = ["foo", "bar"] + # TODO: dirty hack to avoid data stat!!! + md0 = DPModel(ds, ft, type_map=type_map, resuming=True).to(env.DEVICE) + md1 = DPModel.deserialize(md0.serialize()).to(env.DEVICE) + args = [to_torch_tensor(ii) for ii in [self.coord, self.atype, self.cell]] + ret0 = md0.forward_common(*args) + ret1 = md1.forward_common(*args) + np.testing.assert_allclose( + to_numpy_array(ret0["energy"]), + to_numpy_array(ret1["energy"]), + ) + np.testing.assert_allclose( + to_numpy_array(ret0["energy_redu"]), + to_numpy_array(ret1["energy_redu"]), + ) + np.testing.assert_allclose( + to_numpy_array(ret0["energy_derv_r"]), + to_numpy_array(ret1["energy_derv_r"]), + ) + np.testing.assert_allclose( + to_numpy_array(ret0["energy_derv_c_redu"]), + to_numpy_array(ret1["energy_derv_c_redu"]), + ) + ret0 = md0.forward_common(*args, do_atomic_virial=True) + ret1 = md1.forward_common(*args, do_atomic_virial=True) + np.testing.assert_allclose( + to_numpy_array(ret0["energy_derv_c"]), + to_numpy_array(ret1["energy_derv_c"]), + ) + + def test_dp_consistency(self): + nf, nloc = self.atype.shape + ds = DPDescrptSeA( + self.rcut, + self.rcut_smth, + self.sel, + ) + ft = DPInvarFitting( + "energy", + self.nt, + ds.get_dim_out(), + 1, + distinguish_types=ds.distinguish_types(), + ) + type_map = ["foo", "bar"] + md0 = DPDPModel(ds, ft, type_map=type_map) + md1 = DPModel.deserialize(md0.serialize()).to(env.DEVICE) + args0 = [self.coord, self.atype, self.cell] + args1 = [to_torch_tensor(ii) for ii in [self.coord, self.atype, self.cell]] + ret0 = md0.call(*args0) + ret1 = md1.forward_common(*args1) + np.testing.assert_allclose( + ret0["energy"], + to_numpy_array(ret1["energy"]), + ) + np.testing.assert_allclose( + ret0["energy_redu"], + to_numpy_array(ret1["energy_redu"]), + ) + + +class TestDPModelLower(unittest.TestCase, TestCaseSingleFrameWithNlist): + def setUp(self): + TestCaseSingleFrameWithNlist.setUp(self) + + def test_self_consistency(self): + nf, nloc, nnei = self.nlist.shape + ds = DescrptSeA( + self.rcut, + self.rcut_smth, + self.sel, + ).to(env.DEVICE) + ft = InvarFitting( + "energy", + self.nt, + ds.get_dim_out(), + 1, + distinguish_types=ds.distinguish_types(), + ).to(env.DEVICE) + type_map = ["foo", "bar"] + # TODO: dirty hack to avoid data stat!!! + md0 = DPModel(ds, ft, type_map=type_map, resuming=True).to(env.DEVICE) + md1 = DPModel.deserialize(md0.serialize()).to(env.DEVICE) + args = [ + to_torch_tensor(ii) for ii in [self.coord_ext, self.atype_ext, self.nlist] + ] + ret0 = md0.forward_common_lower(*args) + ret1 = md1.forward_common_lower(*args) + np.testing.assert_allclose( + to_numpy_array(ret0["energy"]), + to_numpy_array(ret1["energy"]), + ) + np.testing.assert_allclose( + to_numpy_array(ret0["energy_redu"]), + to_numpy_array(ret1["energy_redu"]), + ) + np.testing.assert_allclose( + to_numpy_array(ret0["energy_derv_r"]), + to_numpy_array(ret1["energy_derv_r"]), + ) + ret0 = md0.forward_common_lower(*args, do_atomic_virial=True) + ret1 = md1.forward_common_lower(*args, do_atomic_virial=True) + np.testing.assert_allclose( + to_numpy_array(ret0["energy_derv_c"]), + to_numpy_array(ret1["energy_derv_c"]), + ) + + def test_dp_consistency(self): + rng = np.random.default_rng() + nf, nloc, nnei = self.nlist.shape + ds = DPDescrptSeA( + self.rcut, + self.rcut_smth, + self.sel, + ) + ft = DPInvarFitting( + "energy", + self.nt, + ds.get_dim_out(), + 1, + distinguish_types=ds.distinguish_types(), + ) + type_map = ["foo", "bar"] + md0 = DPDPModel(ds, ft, type_map=type_map) + md1 = DPModel.deserialize(md0.serialize()).to(env.DEVICE) + args0 = [self.coord_ext, self.atype_ext, self.nlist] + args1 = [ + to_torch_tensor(ii) for ii in [self.coord_ext, self.atype_ext, self.nlist] + ] + ret0 = md0.call_lower(*args0) + ret1 = md1.forward_common_lower(*args1) + np.testing.assert_allclose( + ret0["energy"], + to_numpy_array(ret1["energy"]), + ) + np.testing.assert_allclose( + ret0["energy_redu"], + to_numpy_array(ret1["energy_redu"]), + ) + + def test_jit(self): + nf, nloc, nnei = self.nlist.shape + ds = DescrptSeA( + self.rcut, + self.rcut_smth, + self.sel, + ).to(env.DEVICE) + ft = InvarFitting( + "energy", + self.nt, + ds.get_dim_out(), + 1, + distinguish_types=ds.distinguish_types(), + ).to(env.DEVICE) + type_map = ["foo", "bar"] + # TODO: dirty hack to avoid data stat!!! + md0 = DPModel(ds, ft, type_map=type_map, resuming=True).to(env.DEVICE) + torch.jit.script(md0) diff --git a/source/tests/pt/test_env_mat.py b/source/tests/pt/test_env_mat.py index f4931e9ecc..73707a3099 100644 --- a/source/tests/pt/test_env_mat.py +++ b/source/tests/pt/test_env_mat.py @@ -55,6 +55,27 @@ def setUp(self): self.rcut_smth = 2.2 +class TestCaseSingleFrameWithoutNlist: + def setUp(self): + # nloc == 3, nall == 4 + self.nloc = 3 + self.nf, self.nt = 1, 2 + self.coord = np.array( + [ + [0, 0, 0], + [0, 1, 0], + [0, 0, 1], + ], + dtype=np.float64, + ).reshape([1, self.nloc * 3]) + self.atype = np.array([0, 0, 1], dtype=int).reshape([1, self.nloc]) + self.cell = 2.0 * np.eye(3).reshape([1, 9]) + # sel = [5, 2] + self.sel = [5, 2] + self.rcut = 0.4 + self.rcut_smth = 2.2 + + # to be merged with the tf test case @unittest.skipIf(not support_env_mat, "EnvMat not supported") class TestEnvMat(unittest.TestCase, TestCaseSingleFrameWithNlist):