Skip to content

Commit

Permalink
fix tests
Browse files Browse the repository at this point in the history
Signed-off-by: Jinzhe Zeng <[email protected]>
  • Loading branch information
njzjz committed Feb 17, 2024
1 parent bca0f44 commit 3bcdc33
Show file tree
Hide file tree
Showing 3 changed files with 4 additions and 2 deletions.
1 change: 1 addition & 0 deletions deepmd/dpmodel/fitting/invar_fitting.py
Original file line number Diff line number Diff line change
Expand Up @@ -407,6 +407,7 @@ def __init__(
use_aparam_as_mask: bool = False,
spin: Any = None,
distinguish_types: bool = False,
exclude_types: List[int] = [],
# not used
seed: Optional[int] = None,
):
Expand Down
1 change: 1 addition & 0 deletions deepmd/tf/fit/ener.py
Original file line number Diff line number Diff line change
Expand Up @@ -985,6 +985,7 @@ def serialize(self, suffix: str) -> dict:
"layer_name": self.layer_name,
"use_aparam_as_mask": self.use_aparam_as_mask,
"spin": self.spin,
"exclude_types": [],
"nets": self.serialize_network(
ntypes=self.ntypes,
# TODO: consider type embeddings
Expand Down
4 changes: 2 additions & 2 deletions source/tests/consistent/fitting/test_ener.py
Original file line number Diff line number Diff line change
Expand Up @@ -150,7 +150,7 @@ def eval_pt(self, pt_obj: Any) -> Any:
return (
pt_obj(
torch.from_numpy(self.inputs).to(device=PT_DEVICE),
torch.from_numpy(self.atype).to(device=PT_DEVICE),
torch.from_numpy(self.atype.reshape(1, -1)).to(device=PT_DEVICE),
fparam=torch.from_numpy(self.fparam).to(device=PT_DEVICE)
if numb_fparam
else None,
Expand All @@ -169,7 +169,7 @@ def eval_dp(self, dp_obj: Any) -> Any:
) = self.param
return dp_obj(
self.inputs,
self.atype,
self.atype.reshape(1, -1),
fparam=self.fparam if numb_fparam else None,
)["energy"]

Expand Down

0 comments on commit 3bcdc33

Please sign in to comment.