Skip to content

Commit

Permalink
fix(tf): fix DeepEval degradation for virtual types (#3464)
Browse files Browse the repository at this point in the history
The energy fitting network computes `nloc` by summing over `natoms[2:]`
instead of reading `natoms[0]`.


https://github.com/deepmodeling/deepmd-kit/blob/8dab33bbe8248d9f337933f778be3e119948357e/deepmd/tf/fit/ener.py#L717-L720

This causes an issue for the virtual types after refactoring `DeepEval`.
Before, `natoms_vec` is `[nloc, nall, nloc, ...]` for mixed types. After
refactoring, we use the same `natoms_vec` for mixed_types and the normal
case, so the virtual type support is broken.

This was not detected by the test, as the test model for the virtual
types was added 12 months ago, but the energy fitting was changed 10
months ago.

Signed-off-by: Jinzhe Zeng <[email protected]>
  • Loading branch information
njzjz authored Mar 15, 2024
1 parent da866a2 commit d61b152
Show file tree
Hide file tree
Showing 2 changed files with 9 additions and 0 deletions.
4 changes: 4 additions & 0 deletions deepmd/infer/deep_eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -242,6 +242,10 @@ def _check_mixed_types(self, atom_types: np.ndarray) -> bool:
atom_types : np.ndarray
The atom types of all frames, in shape nframes * natoms.
"""
if np.count_nonzero(atom_types[0] == -1) > 0:
# assume mixed_types if there are virtual types, even when
# the atom types of all frames are the same
return False
return np.all(np.equal(atom_types, atom_types[0]))

@property
Expand Down
5 changes: 5 additions & 0 deletions deepmd/tf/infer/deep_eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -489,6 +489,11 @@ def make_natoms_vec(
natoms_vec[1] = natoms
for ii in range(self.ntypes):
natoms_vec[ii + 2] = np.count_nonzero(atom_types[0] == ii)
if np.count_nonzero(atom_types[0] == -1) > 0:
# contains virtual atoms
# energy fitting sums over natoms_vec[2:] instead of reading from natoms_vec[0]
# causing errors for shape mismatch
natoms_vec[2] += np.count_nonzero(atom_types[0] == -1)
return natoms_vec

def eval_typeebd(self) -> np.ndarray:
Expand Down

0 comments on commit d61b152

Please sign in to comment.