Skip to content

Commit

Permalink
feat(pt): support fparam/aparam in DeepEval (#3356)
Browse files Browse the repository at this point in the history
Signed-off-by: Jinzhe Zeng <[email protected]>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
  • Loading branch information
njzjz and pre-commit-ci[bot] authored Feb 28, 2024
1 parent 3ad57da commit 2a1508d
Show file tree
Hide file tree
Showing 7 changed files with 121 additions and 48 deletions.
27 changes: 27 additions & 0 deletions deepmd/infer/deep_eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -472,6 +472,33 @@ def _standard_input(self, coords, cells, atom_types, fparam, aparam, mixed_type)
aparam = np.array(aparam)
natoms, nframes = self._get_natoms_and_nframes(coords, atom_types, mixed_type)
atom_types = self._expande_atype(atom_types, nframes, mixed_type)
coords = coords.reshape(nframes, natoms, 3)
if cells is not None:
cells = cells.reshape(nframes, 3, 3)
if fparam is not None:
fdim = self.get_dim_fparam()
if fparam.size == nframes * fdim:
fparam = np.reshape(fparam, [nframes, fdim])
elif fparam.size == fdim:
fparam = np.tile(fparam.reshape([-1]), [nframes, 1])
else:
raise RuntimeError(
"got wrong size of frame param, should be either %d x %d or %d"
% (nframes, fdim, fdim)
)
if aparam is not None:
fdim = self.get_dim_aparam()
if aparam.size == nframes * natoms * fdim:
aparam = np.reshape(aparam, [nframes, natoms * fdim])
elif aparam.size == natoms * fdim:
aparam = np.tile(aparam.reshape([-1]), [nframes, 1])
elif aparam.size == fdim:
aparam = np.tile(aparam.reshape([-1]), [nframes, natoms])
else:
raise RuntimeError(
"got wrong size of frame param, should be either %d x %d x %d or %d x %d or %d"
% (nframes, natoms, fdim, natoms, fdim, fdim)
)
return coords, cells, atom_types, fparam, aparam, nframes, natoms

def get_sel_type(self) -> List[int]:
Expand Down
32 changes: 27 additions & 5 deletions deepmd/pt/infer/deep_eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,9 @@
DEVICE,
GLOBAL_PT_FLOAT_PRECISION,
)
from deepmd.pt.utils.utils import (
to_torch_tensor,
)

if TYPE_CHECKING:
import ase.neighborlist
Expand Down Expand Up @@ -228,8 +231,6 @@ def eval(
The output of the evaluation. The keys are the names of the output
variables, and the values are the corresponding output arrays.
"""
if fparam is not None or aparam is not None:
raise NotImplementedError
# convert all of the input to numpy array
atom_types = np.array(atom_types, dtype=np.int32)
coords = np.array(coords)
Expand All @@ -240,7 +241,12 @@ def eval(
)
request_defs = self._get_request_defs(atomic)
out = self._eval_func(self._eval_model, numb_test, natoms)(
coords, cells, atom_types, request_defs
coords,
cells,
atom_types,
fparam,
aparam,
request_defs,
)
return dict(
zip(
Expand Down Expand Up @@ -330,6 +336,8 @@ def _eval_model(
coords: np.ndarray,
cells: Optional[np.ndarray],
atom_types: np.ndarray,
fparam: Optional[np.ndarray],
aparam: Optional[np.ndarray],
request_defs: List[OutputVariableDef],
):
model = self.dp.to(DEVICE)
Expand All @@ -355,12 +363,26 @@ def _eval_model(
)
else:
box_input = None

if fparam is not None:
fparam_input = to_torch_tensor(fparam.reshape(-1, self.get_dim_fparam()))
else:
fparam_input = None
if aparam is not None:
aparam_input = to_torch_tensor(
aparam.reshape(-1, natoms, self.get_dim_aparam())
)
else:
aparam_input = None
do_atomic_virial = any(
x.category == OutputVariableCategory.DERV_C_REDU for x in request_defs
)
batch_output = model(
coord_input, type_input, box=box_input, do_atomic_virial=do_atomic_virial
coord_input,
type_input,
box=box_input,
do_atomic_virial=do_atomic_virial,
fparam=fparam_input,
aparam=aparam_input,
)
if isinstance(batch_output, tuple):
batch_output = batch_output[0]
Expand Down
9 changes: 8 additions & 1 deletion deepmd/pt/train/wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -164,6 +164,8 @@ def forward(
task_key: Optional[torch.Tensor] = None,
inference_only=False,
do_atomic_virial=False,
fparam: Optional[torch.Tensor] = None,
aparam: Optional[torch.Tensor] = None,
):
if not self.multi_task:
task_key = "Default"
Expand All @@ -172,7 +174,12 @@ def forward(
task_key is not None
), f"Multitask model must specify the inference task! Supported tasks are {list(self.model.keys())}."
model_pred = self.model[task_key](
coord, atype, box=box, do_atomic_virial=do_atomic_virial
coord,
atype,
box=box,
do_atomic_virial=do_atomic_virial,
fparam=fparam,
aparam=aparam,
)
natoms = atype.shape[-1]
if not self.inference_only and not inference_only:
Expand Down
2 changes: 1 addition & 1 deletion source/tests/infer/fparam_aparam.pbtxt
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ node {
dtype: DT_STRING
tensor_shape {
}
string_val: "{\"model\":{\"data_stat_nbatch\":1,\"descriptor\":{\"type\":\"se_e2_a\",\"sel\":[60],\"rcut_smth\":1.8,\"rcut\":6.0,\"neuron\":[5,10,20],\"resnet_dt\":false,\"axis_neuron\":8,\"seed\":1,\"activation_function\":\"tanh\",\"type_one_side\":false,\"precision\":\"default\",\"trainable\":true,\"exclude_types\":[],\"set_davg_zero\":false},\"fitting_net\":{\"neuron\":[5,5,5],\"resnet_dt\":true,\"numb_fparam\":1,\"numb_aparam\":1,\"seed\":1,\"type\":\"ener\",\"activation_function\":\"tanh\",\"precision\":\"default\",\"trainable\":true,\"rcond\":0.001,\"atom_ener\":[],\"use_aparam_as_mask\":false},\"data_stat_protect\":0.01,\"data_bias_nsample\":10},\"loss\":{\"start_pref_e\":0.02,\"limit_pref_e\":1,\"start_pref_f\":1000,\"limit_pref_f\":1,\"start_pref_v\":0,\"limit_pref_v\":0,\"type\":\"ener\",\"start_pref_ae\":0.0,\"limit_pref_ae\":0.0,\"start_pref_pf\":0.0,\"limit_pref_pf\":0.0,\"enable_atom_ener_coeff\":false},\"learning_rate\":{\"start_lr\":0.001,\"stop_lr\":3e-08,\"decay_steps\":5000,\"scale_by_worker\":\"linear\",\"type\":\"exp\"},\"training\":{\"training_data\":{\"systems\":[\"../data/e3000_i2000/\",\"../data/e8000_i2000/\"],\"set_prefix\":\"set\",\"batch_size\":1,\"auto_prob\":\"prob_sys_size\",\"sys_probs\":null},\"seed\":1,\"disp_file\":\"lcurve.out\",\"disp_freq\":100,\"save_freq\":1000,\"save_ckpt\":\"model.ckpt\",\"disp_training\":true,\"time_training\":true,\"profiling\":false,\"profiling_file\":\"timeline.json\",\"numb_steps\":1000,\"validation_data\":null,\"enable_profiler\":false,\"tensorboard\":false,\"tensorboard_log_dir\":\"log\",\"tensorboard_freq\":1}}"
string_val: "{\"model\":{\"data_stat_nbatch\":1,\"type_map\":[\"O\"],\"descriptor\":{\"type\":\"se_e2_a\",\"sel\":[60],\"rcut_smth\":1.8,\"rcut\":6.0,\"neuron\":[5,10,20],\"resnet_dt\":false,\"axis_neuron\":8,\"seed\":1,\"activation_function\":\"tanh\",\"type_one_side\":false,\"precision\":\"default\",\"trainable\":true,\"exclude_types\":[],\"set_davg_zero\":false},\"fitting_net\":{\"neuron\":[5,5,5],\"resnet_dt\":true,\"numb_fparam\":1,\"numb_aparam\":1,\"seed\":1,\"type\":\"ener\",\"activation_function\":\"tanh\",\"precision\":\"default\",\"trainable\":true,\"rcond\":0.001,\"atom_ener\":[],\"use_aparam_as_mask\":false},\"data_stat_protect\":0.01,\"data_bias_nsample\":10},\"loss\":{\"start_pref_e\":0.02,\"limit_pref_e\":1,\"start_pref_f\":1000,\"limit_pref_f\":1,\"start_pref_v\":0,\"limit_pref_v\":0,\"type\":\"ener\",\"start_pref_ae\":0.0,\"limit_pref_ae\":0.0,\"start_pref_pf\":0.0,\"limit_pref_pf\":0.0,\"enable_atom_ener_coeff\":false},\"learning_rate\":{\"start_lr\":0.001,\"stop_lr\":3e-08,\"decay_steps\":5000,\"scale_by_worker\":\"linear\",\"type\":\"exp\"},\"training\":{\"training_data\":{\"systems\":[\"../data/e3000_i2000/\",\"../data/e8000_i2000/\"],\"set_prefix\":\"set\",\"batch_size\":1,\"auto_prob\":\"prob_sys_size\",\"sys_probs\":null},\"seed\":1,\"disp_file\":\"lcurve.out\",\"disp_freq\":100,\"save_freq\":1000,\"save_ckpt\":\"model.ckpt\",\"disp_training\":true,\"time_training\":true,\"profiling\":false,\"profiling_file\":\"timeline.json\",\"numb_steps\":1000,\"validation_data\":null,\"enable_profiler\":false,\"tensorboard\":false,\"tensorboard_log_dir\":\"log\",\"tensorboard_freq\":1}}"
}
}
}
Expand Down
Binary file added source/tests/infer/fparam_aparam.pth
Binary file not shown.
22 changes: 22 additions & 0 deletions source/tests/pt/model/test_deeppot.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,10 @@
DeepPot,
)

from ...tf.test_deeppot_a import (
FparamAparamCommonTest,
)


class TestDeepPot(unittest.TestCase):
def setUp(self):
Expand Down Expand Up @@ -123,3 +127,21 @@ def setUp(self):
@unittest.mock.patch("deepmd.pt.infer.deep_eval.DEVICE", torch.device("cpu"))
def test_dp_test_cpu(self):
self.test_dp_test()


class TestFparamAparamPT(FparamAparamCommonTest, unittest.TestCase):
@classmethod
def setUpClass(cls):
cls.dp = DeepPot(
str(Path(__file__).parent.parent.parent / "infer/fparam_aparam.pth")
)

def setUp(self):
super().setUp()
# For unclear reason, the precision is only 1e-7
# not sure if it is expected...
self.places = 1e-7

@classmethod
def tearDownClass(cls):
pass
77 changes: 36 additions & 41 deletions source/tests/tf/test_deeppot_a.py
Original file line number Diff line number Diff line change
Expand Up @@ -894,17 +894,9 @@ def test_eval_typeebd(self):
np.testing.assert_almost_equal(eval_typeebd, expected_typeebd, default_places)


class TestFparamAparam(unittest.TestCase):
class FparamAparamCommonTest:
"""Test fparam and aparam."""

@classmethod
def setUpClass(cls):
convert_pbtxt_to_pb(
str(infer_path / os.path.join("fparam_aparam.pbtxt")),
"fparam_aparam.pb",
)
cls.dp = DeepPot("fparam_aparam.pb")

def setUp(self):
self.coords = np.array(
[
Expand Down Expand Up @@ -1022,15 +1014,11 @@ def setUp(self):
2.875323131744185121e-02,
]
)

@classmethod
def tearDownClass(cls):
os.remove("fparam_aparam.pb")
cls.dp = None
self.places = default_places

def test_attrs(self):
self.assertEqual(self.dp.get_ntypes(), 1)
self.assertAlmostEqual(self.dp.get_rcut(), 6.0, places=default_places)
self.assertAlmostEqual(self.dp.get_rcut(), 6.0, places=self.places)
self.assertEqual(self.dp.get_dim_fparam(), 1)
self.assertEqual(self.dp.get_dim_aparam(), 1)

Expand All @@ -1050,13 +1038,11 @@ def test_1frame(self):
self.assertEqual(ff.shape, (nframes, natoms, 3))
self.assertEqual(vv.shape, (nframes, 9))
# check values
np.testing.assert_almost_equal(
ff.ravel(), self.expected_f.ravel(), default_places
)
np.testing.assert_almost_equal(ff.ravel(), self.expected_f.ravel(), self.places)
expected_se = np.sum(self.expected_e.reshape([nframes, -1]), axis=1)
np.testing.assert_almost_equal(ee.ravel(), expected_se.ravel(), default_places)
np.testing.assert_almost_equal(ee.ravel(), expected_se.ravel(), self.places)
expected_sv = np.sum(self.expected_v.reshape([nframes, -1, 9]), axis=1)
np.testing.assert_almost_equal(vv.ravel(), expected_sv.ravel(), default_places)
np.testing.assert_almost_equal(vv.ravel(), expected_sv.ravel(), self.places)

def test_1frame_atm(self):
ee, ff, vv, ae, av = self.dp.eval(
Expand All @@ -1076,19 +1062,13 @@ def test_1frame_atm(self):
self.assertEqual(ae.shape, (nframes, natoms, 1))
self.assertEqual(av.shape, (nframes, natoms, 9))
# check values
np.testing.assert_almost_equal(
ff.ravel(), self.expected_f.ravel(), default_places
)
np.testing.assert_almost_equal(
ae.ravel(), self.expected_e.ravel(), default_places
)
np.testing.assert_almost_equal(
av.ravel(), self.expected_v.ravel(), default_places
)
np.testing.assert_almost_equal(ff.ravel(), self.expected_f.ravel(), self.places)
np.testing.assert_almost_equal(ae.ravel(), self.expected_e.ravel(), self.places)
np.testing.assert_almost_equal(av.ravel(), self.expected_v.ravel(), self.places)
expected_se = np.sum(self.expected_e.reshape([nframes, -1]), axis=1)
np.testing.assert_almost_equal(ee.ravel(), expected_se.ravel(), default_places)
np.testing.assert_almost_equal(ee.ravel(), expected_se.ravel(), self.places)
expected_sv = np.sum(self.expected_v.reshape([nframes, -1, 9]), axis=1)
np.testing.assert_almost_equal(vv.ravel(), expected_sv.ravel(), default_places)
np.testing.assert_almost_equal(vv.ravel(), expected_sv.ravel(), self.places)

def test_2frame_atm_single_param(self):
coords2 = np.concatenate((self.coords, self.coords))
Expand All @@ -1113,13 +1093,13 @@ def test_2frame_atm_single_param(self):
expected_f = np.concatenate((self.expected_f, self.expected_f), axis=0)
expected_e = np.concatenate((self.expected_e, self.expected_e), axis=0)
expected_v = np.concatenate((self.expected_v, self.expected_v), axis=0)
np.testing.assert_almost_equal(ff.ravel(), expected_f.ravel(), default_places)
np.testing.assert_almost_equal(ae.ravel(), expected_e.ravel(), default_places)
np.testing.assert_almost_equal(av.ravel(), expected_v.ravel(), default_places)
np.testing.assert_almost_equal(ff.ravel(), expected_f.ravel(), self.places)
np.testing.assert_almost_equal(ae.ravel(), expected_e.ravel(), self.places)
np.testing.assert_almost_equal(av.ravel(), expected_v.ravel(), self.places)
expected_se = np.sum(expected_e.reshape([nframes, -1]), axis=1)
np.testing.assert_almost_equal(ee.ravel(), expected_se.ravel(), default_places)
np.testing.assert_almost_equal(ee.ravel(), expected_se.ravel(), self.places)
expected_sv = np.sum(expected_v.reshape([nframes, -1, 9]), axis=1)
np.testing.assert_almost_equal(vv.ravel(), expected_sv.ravel(), default_places)
np.testing.assert_almost_equal(vv.ravel(), expected_sv.ravel(), self.places)

def test_2frame_atm_all_param(self):
coords2 = np.concatenate((self.coords, self.coords))
Expand All @@ -1144,13 +1124,28 @@ def test_2frame_atm_all_param(self):
expected_f = np.concatenate((self.expected_f, self.expected_f), axis=0)
expected_e = np.concatenate((self.expected_e, self.expected_e), axis=0)
expected_v = np.concatenate((self.expected_v, self.expected_v), axis=0)
np.testing.assert_almost_equal(ff.ravel(), expected_f.ravel(), default_places)
np.testing.assert_almost_equal(ae.ravel(), expected_e.ravel(), default_places)
np.testing.assert_almost_equal(av.ravel(), expected_v.ravel(), default_places)
np.testing.assert_almost_equal(ff.ravel(), expected_f.ravel(), self.places)
np.testing.assert_almost_equal(ae.ravel(), expected_e.ravel(), self.places)
np.testing.assert_almost_equal(av.ravel(), expected_v.ravel(), self.places)
expected_se = np.sum(expected_e.reshape([nframes, -1]), axis=1)
np.testing.assert_almost_equal(ee.ravel(), expected_se.ravel(), default_places)
np.testing.assert_almost_equal(ee.ravel(), expected_se.ravel(), self.places)
expected_sv = np.sum(expected_v.reshape([nframes, -1, 9]), axis=1)
np.testing.assert_almost_equal(vv.ravel(), expected_sv.ravel(), default_places)
np.testing.assert_almost_equal(vv.ravel(), expected_sv.ravel(), self.places)


class TestFparamAparam(FparamAparamCommonTest, unittest.TestCase):
@classmethod
def setUpClass(cls):
convert_pbtxt_to_pb(
str(infer_path / os.path.join("fparam_aparam.pbtxt")),
"fparam_aparam.pb",
)
cls.dp = DeepPot("fparam_aparam.pb")

@classmethod
def tearDownClass(cls):
os.remove("fparam_aparam.pb")
cls.dp = None


class TestDeepPotAPBCNeighborList(TestDeepPotAPBC):
Expand Down

0 comments on commit 2a1508d

Please sign in to comment.