Skip to content

Commit

Permalink
feat: pt: support user specified rcond for fitting stat (#3279)
Browse files Browse the repository at this point in the history
the stats are actually not well tested, see
#3278

Co-authored-by: Han Wang <[email protected]>
  • Loading branch information
wanghan-iapcm and Han Wang authored Feb 15, 2024
1 parent 02080db commit 43f17da
Showing 1 changed file with 4 additions and 3 deletions.
7 changes: 4 additions & 3 deletions deepmd/pt/model/task/ener.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,7 @@ def __init__(
activation_function: str = "tanh",
precision: str = DEFAULT_PRECISION,
distinguish_types: bool = False,
rcond: Optional[float] = None,
**kwargs,
):
"""Construct a fitting net for energy.
Expand All @@ -87,6 +88,7 @@ def __init__(
self.activation_function = activation_function
self.precision = precision
self.prec = PRECISION_DICT[self.precision]
self.rcond = rcond
if bias_atom_e is None:
bias_atom_e = np.zeros([self.ntypes, self.dim_out])
bias_atom_e = torch.tensor(bias_atom_e, dtype=self.prec, device=device)
Expand Down Expand Up @@ -217,8 +219,7 @@ def compute_output_stats(self, merged):
input_natoms = [item["real_natoms_vec"] for item in merged]
else:
input_natoms = [item["natoms"] for item in merged]
tmp = compute_output_bias(energy, input_natoms)
bias_atom_e = tmp[:, 0]
bias_atom_e = compute_output_bias(energy, input_natoms, rcond=self.rcond)
return {"bias_atom_e": bias_atom_e}

def init_fitting_stat(self, bias_atom_e=None, **kwargs):
Expand All @@ -244,6 +245,7 @@ def serialize(self) -> dict:
"precision": self.precision,
"distinguish_types": self.distinguish_types,
"nets": self.filter_layers.serialize(),
"rcond": self.rcond,
"@variables": {
"bias_atom_e": to_numpy_array(self.bias_atom_e),
"fparam_avg": to_numpy_array(self.fparam_avg),
Expand All @@ -259,7 +261,6 @@ def serialize(self) -> dict:
# "use_aparam_as_mask": self.use_aparam_as_mask ,
# "spin": self.spin ,
## NOTICE: not supported by far
"rcond": None,
"tot_ener_zero": False,
"trainable": True,
"atom_ener": None,
Expand Down

0 comments on commit 43f17da

Please sign in to comment.