Skip to content

Commit

Permalink
fix(tf): fix DeepEval degradation for virtual types
Browse files Browse the repository at this point in the history
Signed-off-by: Jinzhe Zeng <[email protected]>
  • Loading branch information
njzjz committed Mar 14, 2024
1 parent 487f85c commit 8dab33b
Show file tree
Hide file tree
Showing 2 changed files with 9 additions and 0 deletions.
4 changes: 4 additions & 0 deletions deepmd/infer/deep_eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -242,6 +242,10 @@ def _check_mixed_types(self, atom_types: np.ndarray) -> bool:
atom_types : np.ndarray
The atom types of all frames, in shape nframes * natoms.
"""
if np.count_nonzero(atom_types[0] == -1) > 0:

Check warning on line 245 in deepmd/infer/deep_eval.py

View check run for this annotation

Codecov / codecov/patch

deepmd/infer/deep_eval.py#L245

Added line #L245 was not covered by tests
# assume mixed_types if there are virtual types, even when
# the atom types of all frames are the same
return False

Check warning on line 248 in deepmd/infer/deep_eval.py

View check run for this annotation

Codecov / codecov/patch

deepmd/infer/deep_eval.py#L248

Added line #L248 was not covered by tests
return np.all(np.equal(atom_types, atom_types[0]))

@property
Expand Down
5 changes: 5 additions & 0 deletions deepmd/tf/infer/deep_eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -489,6 +489,11 @@ def make_natoms_vec(
natoms_vec[1] = natoms
for ii in range(self.ntypes):
natoms_vec[ii + 2] = np.count_nonzero(atom_types[0] == ii)
if np.count_nonzero(atom_types[0] == -1) > 0:

Check warning on line 492 in deepmd/tf/infer/deep_eval.py

View check run for this annotation

Codecov / codecov/patch

deepmd/tf/infer/deep_eval.py#L492

Added line #L492 was not covered by tests
# contains virtual atoms
# energy fitting sums over natoms_vec[2:] instead of reading from natoms_vec[0]
# causing errors for shape mismatch
natoms_vec[2] += np.count_nonzero(atom_types[0] == -1)

Check warning on line 496 in deepmd/tf/infer/deep_eval.py

View check run for this annotation

Codecov / codecov/patch

deepmd/tf/infer/deep_eval.py#L496

Added line #L496 was not covered by tests
return natoms_vec

def eval_typeebd(self) -> np.ndarray:
Expand Down

0 comments on commit 8dab33b

Please sign in to comment.