Skip to content

Commit

Permalink
add support for deep dipole
Browse files Browse the repository at this point in the history
Signed-off-by: Jinzhe Zeng <[email protected]>
  • Loading branch information
njzjz committed Dec 8, 2023
1 parent ab5b3eb commit e4a9e75
Show file tree
Hide file tree
Showing 3 changed files with 103 additions and 14 deletions.
4 changes: 4 additions & 0 deletions deepmd/infer/deep_dipole.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,8 @@ class DeepDipole(DeepTensor):
If uses the default tf graph, otherwise build a new tf graph for evaluation
input_map : dict, optional
The input map for tf.import_graph_def. Only work with default tf graph
neighbor_list : ase.neighborlist.NeighborList, optional
The neighbor list object. If None, then build the native neighbor list.
Warnings
--------
Expand All @@ -41,6 +43,7 @@ def __init__(
load_prefix: str = "load",
default_tf_graph: bool = False,
input_map: Optional[dict] = None,
neighbor_list=None,
) -> None:
# use this in favor of dict update to move attribute from class to
# instance namespace
Expand All @@ -58,6 +61,7 @@ def __init__(
load_prefix=load_prefix,
default_tf_graph=default_tf_graph,
input_map=input_map,
neighbor_list=neighbor_list,
)

def get_dim_fparam(self) -> int:
Expand Down
85 changes: 75 additions & 10 deletions deepmd/infer/deep_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,8 @@ class DeepTensor(DeepEval):
If uses the default tf graph, otherwise build a new tf graph for evaluation
input_map : dict, optional
The input map for tf.import_graph_def. Only work with default tf graph
neighbor_list : ase.neighborlist.NeighborList, optional
The neighbor list object. If None, then build the native neighbor list.
"""

tensors: ClassVar[Dict[str, str]] = {
Expand All @@ -63,6 +65,7 @@ def __init__(
load_prefix: str = "load",
default_tf_graph: bool = False,
input_map: Optional[dict] = None,
neighbor_list=None,
) -> None:
"""Constructor."""
DeepEval.__init__(
Expand All @@ -71,6 +74,7 @@ def __init__(
load_prefix=load_prefix,
default_tf_graph=default_tf_graph,
input_map=input_map,
neighbor_list=neighbor_list,
)
# check model type
model_type = self.tensors["t_tensor"][2:-2]
Expand Down Expand Up @@ -209,8 +213,29 @@ def eval(
)

# make natoms_vec and default_mesh
natoms_vec = self.make_natoms_vec(atom_types, mixed_type=mixed_type)
assert natoms_vec[0] == natoms
if self.neighbor_list is None:
natoms_vec = self.make_natoms_vec(atom_types, mixed_type=mixed_type)
assert natoms_vec[0] == natoms
mesh = make_default_mesh(pbc, mixed_type)
else:
if nframes > 1:
raise NotImplementedError(

Check warning on line 222 in deepmd/infer/deep_tensor.py

View check run for this annotation

Codecov / codecov/patch

deepmd/infer/deep_tensor.py#L222

Added line #L222 was not covered by tests
"neighbor_list does not support multiple frames"
)
(
natoms_vec,
coords,
atom_types,
mesh,
imap,
_,
) = self.build_neighbor_list(
coords,
cells if cells is not None else None,
atom_types,
imap,
self.neighbor_list,
)

# evaluate
feed_dict_test = {}
Expand All @@ -223,7 +248,7 @@ def eval(
)
feed_dict_test[self.t_coord] = np.reshape(coords, [-1])
feed_dict_test[self.t_box] = np.reshape(cells, [-1])
feed_dict_test[self.t_mesh] = make_default_mesh(pbc, mixed_type)
feed_dict_test[self.t_mesh] = mesh

if atomic:
assert (
Expand Down Expand Up @@ -333,8 +358,30 @@ def eval_full(
)

# make natoms_vec and default_mesh
natoms_vec = self.make_natoms_vec(atom_types, mixed_type=mixed_type)
assert natoms_vec[0] == natoms
if self.neighbor_list is None:
natoms_vec = self.make_natoms_vec(atom_types, mixed_type=mixed_type)
assert natoms_vec[0] == natoms
mesh = make_default_mesh(pbc, mixed_type)
ghost_map = None
else:
if nframes > 1:
raise NotImplementedError(

Check warning on line 368 in deepmd/infer/deep_tensor.py

View check run for this annotation

Codecov / codecov/patch

deepmd/infer/deep_tensor.py#L368

Added line #L368 was not covered by tests
"neighbor_list does not support multiple frames"
)
(
natoms_vec,
coords,
atom_types,
mesh,
imap,
ghost_map,
) = self.build_neighbor_list(
coords,
cells if cells is not None else None,
atom_types,
imap,
self.neighbor_list,
)

# evaluate
feed_dict_test = {}
Expand All @@ -347,7 +394,7 @@ def eval_full(
)
feed_dict_test[self.t_coord] = np.reshape(coords, [-1])
feed_dict_test[self.t_box] = np.reshape(cells, [-1])
feed_dict_test[self.t_mesh] = make_default_mesh(pbc, mixed_type)
feed_dict_test[self.t_mesh] = mesh

t_out = [self.t_global_tensor, self.t_force, self.t_virial]
if atomic:
Expand All @@ -361,21 +408,39 @@ def eval_full(
at = v_out[3] # atom tensor
av = v_out[4] # atom virial

nloc = natoms_vec[0]
nall = natoms_vec[1]

if ghost_map is not None:
# add the value of ghost atoms to real atoms
force = np.reshape(force, [nframes * nout, -1, 3])
# TODO: is there some way not to use for loop?
for ii in range(nframes * nout):
np.add.at(force[ii], ghost_map, force[ii, nloc:])
if atomic:
av = np.reshape(av, [nframes * nout, -1, 9])
for ii in range(nframes * nout):
np.add.at(av[ii], ghost_map, av[ii, nloc:])

# please note here the shape are wrong!
force = self.reverse_map(np.reshape(force, [nframes * nout, natoms, 3]), imap)
force = self.reverse_map(np.reshape(force, [nframes * nout, nall, 3]), imap)
if atomic:
at = self.reverse_map(
np.reshape(at, [nframes, len(sel_at), nout]), sel_imap
)
av = self.reverse_map(np.reshape(av, [nframes * nout, natoms, 9]), imap)
av = self.reverse_map(np.reshape(av, [nframes * nout, nall, 9]), imap)

# make sure the shapes are correct here
gt = np.reshape(gt, [nframes, nout])
force = np.reshape(force, [nframes, nout, natoms, 3])
force = np.reshape(force, [nframes, nout, nall, 3])
if nloc < nall:
force = force[:, :, :nloc, :]
virial = np.reshape(virial, [nframes, nout, 9])
if atomic:
at = np.reshape(at, [nframes, len(sel_at), self.output_dim])
av = np.reshape(av, [nframes, nout, natoms, 9])
av = np.reshape(av, [nframes, nout, nall, 9])
if nloc < nall:
av = av[:, :, :nloc, :]
return gt, force, virial, at, av
else:
return gt, force, virial
28 changes: 24 additions & 4 deletions source/tests/test_deepdipole.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import os
import unittest

import ase.neighborlist
import numpy as np
from common import (
finite_difference,
Expand Down Expand Up @@ -964,10 +965,6 @@ def test_1frame_full_atm(self):
gt, ff, vv, at, av = self.dp.eval_full(
self.coords, self.box, self.atype, atomic=True
)
for dd in at, ff, av:
print("\n\n")
print(", ".join(f"{ii:.18e}" for ii in dd.reshape(-1)))
print("\n\n")
# check shape of the returns
nframes = 1
natoms = len(self.atype)
Expand Down Expand Up @@ -1035,3 +1032,26 @@ def test_1frame_full_atm_shuffle(self):
np.testing.assert_almost_equal(
vv.reshape([-1]), self.expected_gv.reshape([-1]), decimal=default_places
)


class TestDeepDipoleNewPBCNeighborList(TestDeepDipoleNewPBC):
@classmethod
def setUpClass(cls):
convert_pbtxt_to_pb(
str(tests_path / os.path.join("infer", "deepdipole_new.pbtxt")),
"deepdipole_new.pb",
)
cls.dp = DeepDipole(
"deepdipole_new.pb",
neighbor_list=ase.neighborlist.NewPrimitiveNeighborList(
cutoffs=6, bothways=True
),
)

@unittest.skip("multiple frames not supported")
def test_2frame_full_atm(self):
pass

@unittest.skip("multiple frames not supported")
def test_2frame_old_atm(self):
pass

0 comments on commit e4a9e75

Please sign in to comment.