Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(pt): support fparam/aparam in DeepEval #3356

Merged
merged 4 commits into from
Feb 28, 2024
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 27 additions & 0 deletions deepmd/infer/deep_eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -472,6 +472,33 @@
aparam = np.array(aparam)
natoms, nframes = self._get_natoms_and_nframes(coords, atom_types, mixed_type)
atom_types = self._expande_atype(atom_types, nframes, mixed_type)
coords = coords.reshape(nframes, natoms, 3)
if cells is not None:
cells = cells.reshape(nframes, 3, 3)
if fparam is not None:
fdim = self.get_dim_fparam()
if fparam.size == nframes * fdim:
fparam = np.reshape(fparam, [nframes, fdim])
elif fparam.size == fdim:
fparam = np.tile(fparam.reshape([-1]), [nframes, 1])

Check warning on line 483 in deepmd/infer/deep_eval.py

View check run for this annotation

Codecov / codecov/patch

deepmd/infer/deep_eval.py#L475-L483

Added lines #L475 - L483 were not covered by tests
else:
raise RuntimeError(

Check warning on line 485 in deepmd/infer/deep_eval.py

View check run for this annotation

Codecov / codecov/patch

deepmd/infer/deep_eval.py#L485

Added line #L485 was not covered by tests
"got wrong size of frame param, should be either %d x %d or %d"
% (nframes, fdim, fdim)
)
if aparam is not None:
fdim = self.get_dim_aparam()
if aparam.size == nframes * natoms * fdim:
aparam = np.reshape(aparam, [nframes, natoms * fdim])
elif aparam.size == natoms * fdim:
aparam = np.tile(aparam.reshape([-1]), [nframes, 1])
elif aparam.size == fdim:
aparam = np.tile(aparam.reshape([-1]), [nframes, natoms])

Check warning on line 496 in deepmd/infer/deep_eval.py

View check run for this annotation

Codecov / codecov/patch

deepmd/infer/deep_eval.py#L489-L496

Added lines #L489 - L496 were not covered by tests
else:
raise RuntimeError(

Check warning on line 498 in deepmd/infer/deep_eval.py

View check run for this annotation

Codecov / codecov/patch

deepmd/infer/deep_eval.py#L498

Added line #L498 was not covered by tests
"got wrong size of frame param, should be either %d x %d x %d or %d x %d or %d"
% (nframes, natoms, fdim, natoms, fdim, fdim)
)
return coords, cells, atom_types, fparam, aparam, nframes, natoms

def get_sel_type(self) -> List[int]:
Expand Down
32 changes: 27 additions & 5 deletions deepmd/pt/infer/deep_eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,9 @@
DEVICE,
GLOBAL_PT_FLOAT_PRECISION,
)
from deepmd.pt.utils.utils import (

Check warning on line 57 in deepmd/pt/infer/deep_eval.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/infer/deep_eval.py#L57

Added line #L57 was not covered by tests
to_torch_tensor,
)

if TYPE_CHECKING:
import ase.neighborlist
Expand Down Expand Up @@ -228,8 +231,6 @@
The output of the evaluation. The keys are the names of the output
variables, and the values are the corresponding output arrays.
"""
if fparam is not None or aparam is not None:
raise NotImplementedError
# convert all of the input to numpy array
atom_types = np.array(atom_types, dtype=np.int32)
coords = np.array(coords)
Expand All @@ -240,7 +241,12 @@
)
request_defs = self._get_request_defs(atomic)
out = self._eval_func(self._eval_model, numb_test, natoms)(
coords, cells, atom_types, request_defs
coords,
cells,
atom_types,
fparam,
aparam,
request_defs,
)
return dict(
zip(
Expand Down Expand Up @@ -330,6 +336,8 @@
coords: np.ndarray,
cells: Optional[np.ndarray],
atom_types: np.ndarray,
fparam: Optional[np.ndarray],
aparam: Optional[np.ndarray],
request_defs: List[OutputVariableDef],
):
model = self.dp.to(DEVICE)
Expand All @@ -355,12 +363,26 @@
)
else:
box_input = None

if fparam is not None:
fparam_input = to_torch_tensor(fparam.reshape(-1, self.get_dim_fparam()))

Check warning on line 367 in deepmd/pt/infer/deep_eval.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/infer/deep_eval.py#L366-L367

Added lines #L366 - L367 were not covered by tests
else:
fparam_input = None
if aparam is not None:
aparam_input = to_torch_tensor(

Check warning on line 371 in deepmd/pt/infer/deep_eval.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/infer/deep_eval.py#L369-L371

Added lines #L369 - L371 were not covered by tests
aparam.reshape(-1, natoms, self.get_dim_aparam())
)
else:
aparam_input = None

Check warning on line 375 in deepmd/pt/infer/deep_eval.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/infer/deep_eval.py#L375

Added line #L375 was not covered by tests
do_atomic_virial = any(
x.category == OutputVariableCategory.DERV_C_REDU for x in request_defs
)
batch_output = model(
coord_input, type_input, box=box_input, do_atomic_virial=do_atomic_virial
coord_input,
type_input,
box=box_input,
do_atomic_virial=do_atomic_virial,
fparam=fparam_input,
aparam=aparam_input,
)
if isinstance(batch_output, tuple):
batch_output = batch_output[0]
Expand Down
9 changes: 8 additions & 1 deletion deepmd/pt/train/wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -164,6 +164,8 @@ def forward(
task_key: Optional[torch.Tensor] = None,
inference_only=False,
do_atomic_virial=False,
fparam: Optional[torch.Tensor] = None,
aparam: Optional[torch.Tensor] = None,
):
if not self.multi_task:
task_key = "Default"
Expand All @@ -172,7 +174,12 @@ def forward(
task_key is not None
), f"Multitask model must specify the inference task! Supported tasks are {list(self.model.keys())}."
model_pred = self.model[task_key](
coord, atype, box=box, do_atomic_virial=do_atomic_virial
coord,
atype,
box=box,
do_atomic_virial=do_atomic_virial,
fparam=fparam,
aparam=aparam,
)
natoms = atype.shape[-1]
if not self.inference_only and not inference_only:
Expand Down
2 changes: 1 addition & 1 deletion source/tests/infer/fparam_aparam.pbtxt
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ node {
dtype: DT_STRING
tensor_shape {
}
string_val: "{\"model\":{\"data_stat_nbatch\":1,\"descriptor\":{\"type\":\"se_e2_a\",\"sel\":[60],\"rcut_smth\":1.8,\"rcut\":6.0,\"neuron\":[5,10,20],\"resnet_dt\":false,\"axis_neuron\":8,\"seed\":1,\"activation_function\":\"tanh\",\"type_one_side\":false,\"precision\":\"default\",\"trainable\":true,\"exclude_types\":[],\"set_davg_zero\":false},\"fitting_net\":{\"neuron\":[5,5,5],\"resnet_dt\":true,\"numb_fparam\":1,\"numb_aparam\":1,\"seed\":1,\"type\":\"ener\",\"activation_function\":\"tanh\",\"precision\":\"default\",\"trainable\":true,\"rcond\":0.001,\"atom_ener\":[],\"use_aparam_as_mask\":false},\"data_stat_protect\":0.01,\"data_bias_nsample\":10},\"loss\":{\"start_pref_e\":0.02,\"limit_pref_e\":1,\"start_pref_f\":1000,\"limit_pref_f\":1,\"start_pref_v\":0,\"limit_pref_v\":0,\"type\":\"ener\",\"start_pref_ae\":0.0,\"limit_pref_ae\":0.0,\"start_pref_pf\":0.0,\"limit_pref_pf\":0.0,\"enable_atom_ener_coeff\":false},\"learning_rate\":{\"start_lr\":0.001,\"stop_lr\":3e-08,\"decay_steps\":5000,\"scale_by_worker\":\"linear\",\"type\":\"exp\"},\"training\":{\"training_data\":{\"systems\":[\"../data/e3000_i2000/\",\"../data/e8000_i2000/\"],\"set_prefix\":\"set\",\"batch_size\":1,\"auto_prob\":\"prob_sys_size\",\"sys_probs\":null},\"seed\":1,\"disp_file\":\"lcurve.out\",\"disp_freq\":100,\"save_freq\":1000,\"save_ckpt\":\"model.ckpt\",\"disp_training\":true,\"time_training\":true,\"profiling\":false,\"profiling_file\":\"timeline.json\",\"numb_steps\":1000,\"validation_data\":null,\"enable_profiler\":false,\"tensorboard\":false,\"tensorboard_log_dir\":\"log\",\"tensorboard_freq\":1}}"
string_val: "{\"model\":{\"data_stat_nbatch\":1,\"type_map\":[\"O\"],\"descriptor\":{\"type\":\"se_e2_a\",\"sel\":[60],\"rcut_smth\":1.8,\"rcut\":6.0,\"neuron\":[5,10,20],\"resnet_dt\":false,\"axis_neuron\":8,\"seed\":1,\"activation_function\":\"tanh\",\"type_one_side\":false,\"precision\":\"default\",\"trainable\":true,\"exclude_types\":[],\"set_davg_zero\":false},\"fitting_net\":{\"neuron\":[5,5,5],\"resnet_dt\":true,\"numb_fparam\":1,\"numb_aparam\":1,\"seed\":1,\"type\":\"ener\",\"activation_function\":\"tanh\",\"precision\":\"default\",\"trainable\":true,\"rcond\":0.001,\"atom_ener\":[],\"use_aparam_as_mask\":false},\"data_stat_protect\":0.01,\"data_bias_nsample\":10},\"loss\":{\"start_pref_e\":0.02,\"limit_pref_e\":1,\"start_pref_f\":1000,\"limit_pref_f\":1,\"start_pref_v\":0,\"limit_pref_v\":0,\"type\":\"ener\",\"start_pref_ae\":0.0,\"limit_pref_ae\":0.0,\"start_pref_pf\":0.0,\"limit_pref_pf\":0.0,\"enable_atom_ener_coeff\":false},\"learning_rate\":{\"start_lr\":0.001,\"stop_lr\":3e-08,\"decay_steps\":5000,\"scale_by_worker\":\"linear\",\"type\":\"exp\"},\"training\":{\"training_data\":{\"systems\":[\"../data/e3000_i2000/\",\"../data/e8000_i2000/\"],\"set_prefix\":\"set\",\"batch_size\":1,\"auto_prob\":\"prob_sys_size\",\"sys_probs\":null},\"seed\":1,\"disp_file\":\"lcurve.out\",\"disp_freq\":100,\"save_freq\":1000,\"save_ckpt\":\"model.ckpt\",\"disp_training\":true,\"time_training\":true,\"profiling\":false,\"profiling_file\":\"timeline.json\",\"numb_steps\":1000,\"validation_data\":null,\"enable_profiler\":false,\"tensorboard\":false,\"tensorboard_log_dir\":\"log\",\"tensorboard_freq\":1}}"
}
}
}
Expand Down
Binary file added source/tests/infer/fparam_aparam.pth
Binary file not shown.
20 changes: 20 additions & 0 deletions source/tests/pt/model/test_deeppot.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,8 @@
DeepPot,
)

from ...tf.test_deeppot_a import TestFparamAparam as TestFparamAparamTF


class TestDeepPot(unittest.TestCase):
def setUp(self):
Expand Down Expand Up @@ -123,3 +125,21 @@ def setUp(self):
@unittest.mock.patch("deepmd.pt.infer.deep_eval.DEVICE", torch.device("cpu"))
def test_dp_test_cpu(self):
self.test_dp_test()


class TestFparamAparamPT(TestFparamAparamTF):
@classmethod
def setUpClass(cls):
cls.dp = DeepPot(
str(Path(__file__).parent.parent.parent / "infer/fparam_aparam.pth")
)

def setUp(self):
super().setUp()
# For unclear reason, the precision is only 1e-7
# not sure if it is expected...
self.places = 1e-7
wanghan-iapcm marked this conversation as resolved.
Show resolved Hide resolved

@classmethod
def tearDownClass(cls):
pass
47 changes: 20 additions & 27 deletions source/tests/tf/test_deeppot_a.py
Original file line number Diff line number Diff line change
Expand Up @@ -1022,6 +1022,7 @@ def setUp(self):
2.875323131744185121e-02,
]
)
self.places = default_places

@classmethod
def tearDownClass(cls):
Expand All @@ -1030,7 +1031,7 @@ def tearDownClass(cls):

def test_attrs(self):
self.assertEqual(self.dp.get_ntypes(), 1)
self.assertAlmostEqual(self.dp.get_rcut(), 6.0, places=default_places)
self.assertAlmostEqual(self.dp.get_rcut(), 6.0, places=self.places)
self.assertEqual(self.dp.get_dim_fparam(), 1)
self.assertEqual(self.dp.get_dim_aparam(), 1)

Expand All @@ -1050,13 +1051,11 @@ def test_1frame(self):
self.assertEqual(ff.shape, (nframes, natoms, 3))
self.assertEqual(vv.shape, (nframes, 9))
# check values
np.testing.assert_almost_equal(
ff.ravel(), self.expected_f.ravel(), default_places
)
np.testing.assert_almost_equal(ff.ravel(), self.expected_f.ravel(), self.places)
expected_se = np.sum(self.expected_e.reshape([nframes, -1]), axis=1)
np.testing.assert_almost_equal(ee.ravel(), expected_se.ravel(), default_places)
np.testing.assert_almost_equal(ee.ravel(), expected_se.ravel(), self.places)
expected_sv = np.sum(self.expected_v.reshape([nframes, -1, 9]), axis=1)
np.testing.assert_almost_equal(vv.ravel(), expected_sv.ravel(), default_places)
np.testing.assert_almost_equal(vv.ravel(), expected_sv.ravel(), self.places)

def test_1frame_atm(self):
ee, ff, vv, ae, av = self.dp.eval(
Expand All @@ -1076,19 +1075,13 @@ def test_1frame_atm(self):
self.assertEqual(ae.shape, (nframes, natoms, 1))
self.assertEqual(av.shape, (nframes, natoms, 9))
# check values
np.testing.assert_almost_equal(
ff.ravel(), self.expected_f.ravel(), default_places
)
np.testing.assert_almost_equal(
ae.ravel(), self.expected_e.ravel(), default_places
)
np.testing.assert_almost_equal(
av.ravel(), self.expected_v.ravel(), default_places
)
np.testing.assert_almost_equal(ff.ravel(), self.expected_f.ravel(), self.places)
np.testing.assert_almost_equal(ae.ravel(), self.expected_e.ravel(), self.places)
np.testing.assert_almost_equal(av.ravel(), self.expected_v.ravel(), self.places)
expected_se = np.sum(self.expected_e.reshape([nframes, -1]), axis=1)
np.testing.assert_almost_equal(ee.ravel(), expected_se.ravel(), default_places)
np.testing.assert_almost_equal(ee.ravel(), expected_se.ravel(), self.places)
expected_sv = np.sum(self.expected_v.reshape([nframes, -1, 9]), axis=1)
np.testing.assert_almost_equal(vv.ravel(), expected_sv.ravel(), default_places)
np.testing.assert_almost_equal(vv.ravel(), expected_sv.ravel(), self.places)

def test_2frame_atm_single_param(self):
coords2 = np.concatenate((self.coords, self.coords))
Expand All @@ -1113,13 +1106,13 @@ def test_2frame_atm_single_param(self):
expected_f = np.concatenate((self.expected_f, self.expected_f), axis=0)
expected_e = np.concatenate((self.expected_e, self.expected_e), axis=0)
expected_v = np.concatenate((self.expected_v, self.expected_v), axis=0)
np.testing.assert_almost_equal(ff.ravel(), expected_f.ravel(), default_places)
np.testing.assert_almost_equal(ae.ravel(), expected_e.ravel(), default_places)
np.testing.assert_almost_equal(av.ravel(), expected_v.ravel(), default_places)
np.testing.assert_almost_equal(ff.ravel(), expected_f.ravel(), self.places)
np.testing.assert_almost_equal(ae.ravel(), expected_e.ravel(), self.places)
np.testing.assert_almost_equal(av.ravel(), expected_v.ravel(), self.places)
expected_se = np.sum(expected_e.reshape([nframes, -1]), axis=1)
np.testing.assert_almost_equal(ee.ravel(), expected_se.ravel(), default_places)
np.testing.assert_almost_equal(ee.ravel(), expected_se.ravel(), self.places)
expected_sv = np.sum(expected_v.reshape([nframes, -1, 9]), axis=1)
np.testing.assert_almost_equal(vv.ravel(), expected_sv.ravel(), default_places)
np.testing.assert_almost_equal(vv.ravel(), expected_sv.ravel(), self.places)

def test_2frame_atm_all_param(self):
coords2 = np.concatenate((self.coords, self.coords))
Expand All @@ -1144,13 +1137,13 @@ def test_2frame_atm_all_param(self):
expected_f = np.concatenate((self.expected_f, self.expected_f), axis=0)
expected_e = np.concatenate((self.expected_e, self.expected_e), axis=0)
expected_v = np.concatenate((self.expected_v, self.expected_v), axis=0)
np.testing.assert_almost_equal(ff.ravel(), expected_f.ravel(), default_places)
np.testing.assert_almost_equal(ae.ravel(), expected_e.ravel(), default_places)
np.testing.assert_almost_equal(av.ravel(), expected_v.ravel(), default_places)
np.testing.assert_almost_equal(ff.ravel(), expected_f.ravel(), self.places)
np.testing.assert_almost_equal(ae.ravel(), expected_e.ravel(), self.places)
np.testing.assert_almost_equal(av.ravel(), expected_v.ravel(), self.places)
expected_se = np.sum(expected_e.reshape([nframes, -1]), axis=1)
np.testing.assert_almost_equal(ee.ravel(), expected_se.ravel(), default_places)
np.testing.assert_almost_equal(ee.ravel(), expected_se.ravel(), self.places)
expected_sv = np.sum(expected_v.reshape([nframes, -1, 9]), axis=1)
np.testing.assert_almost_equal(vv.ravel(), expected_sv.ravel(), default_places)
np.testing.assert_almost_equal(vv.ravel(), expected_sv.ravel(), self.places)


class TestDeepPotAPBCNeighborList(TestDeepPotAPBC):
Expand Down
Loading