Skip to content

Commit

Permalink
fix bugs
Browse files Browse the repository at this point in the history
  • Loading branch information
Han Wang committed Apr 6, 2024
1 parent 4c038d8 commit 0dae3c9
Show file tree
Hide file tree
Showing 2 changed files with 14 additions and 9 deletions.
13 changes: 9 additions & 4 deletions deepmd/pt/model/task/invar_fitting.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,8 +78,11 @@ class InvarFitting(GeneralFitting):
Random seed.
exclude_types: List[int]
Atomic contributions of the excluded atom types are set zero.
atom_ener: List[float], optional
Specifying atomic energy contribution in vacuum. The `set_davg_zero` key in the descrptor should be set.
atom_ener: List[Optional[torch.Tensor]], optional
Specifying atomic energy contribution in vacuum.
The value is a list specifying the bias. the elements can be None or np.array of output shape.
For example: [None, [2.]] means type 0 is not set, type 1 is set to [2.]
The `set_davg_zero` key in the descrptor should be set.
"""

Expand All @@ -100,7 +103,7 @@ def __init__(
rcond: Optional[float] = None,
seed: Optional[int] = None,
exclude_types: List[int] = [],
atom_ener: Optional[List[float]] = None,
atom_ener: Optional[List[Optional[torch.Tensor]]] = None,
**kwargs,
):
self.dim_out = dim_out
Expand Down Expand Up @@ -171,7 +174,9 @@ def compute_output_stats(
keys=[self.var_name],
stat_file_path=stat_file_path,
rcond=self.rcond,
atom_ener=self.atom_ener,
preset_bias={self.var_name: self.atom_ener}
if self.atom_ener is not None
else None,
)[0][self.var_name]
self.bias_atom_e.copy_(bias_atom_e.view([self.ntypes, self.dim_out]))

Expand Down
10 changes: 5 additions & 5 deletions source/tests/pt/test_stat.py
Original file line number Diff line number Diff line change
Expand Up @@ -371,7 +371,7 @@ def test_calc_and_load(self):
len(type_map),
keys=["energy"],
stat_file_path=stat_file_path,
atom_ener=None,
preset_bias=None,
model_forward=None,
)
# ground truth
Expand Down Expand Up @@ -399,15 +399,15 @@ def raise_error():
len(type_map),
keys=["energy"],
stat_file_path=stat_file_path,
atom_ener=None,
preset_bias=None,
model_forward=None,
)
np.testing.assert_almost_equal(
to_numpy_array(ret0["energy"]), to_numpy_array(ret1["energy"]), decimal=10
)

def test_assigned(self):
atom_ener = np.array([3.0, 5.0]).reshape(2, 1)
atom_ener = {"energy": np.array([3.0, 5.0]).reshape(2, 1)}
stat_file_path = self.stat_file_path
type_map = self.type_map

Expand All @@ -417,11 +417,11 @@ def test_assigned(self):
len(type_map),
keys=["energy"],
stat_file_path=stat_file_path,
atom_ener=atom_ener,
preset_bias=atom_ener,
model_forward=None,
)
np.testing.assert_almost_equal(
to_numpy_array(ret2["energy"]), atom_ener, decimal=10
to_numpy_array(ret2["energy"]), atom_ener["energy"], decimal=10
)


Expand Down

0 comments on commit 0dae3c9

Please sign in to comment.