Skip to content

Commit

Permalink
[pre-commit.ci] auto fixes from pre-commit.com hooks
Browse files Browse the repository at this point in the history
for more information, see https://pre-commit.ci
  • Loading branch information
pre-commit-ci[bot] committed Dec 23, 2024
1 parent 15846fd commit 34cc7ce
Show file tree
Hide file tree
Showing 2 changed files with 21 additions and 7 deletions.
12 changes: 9 additions & 3 deletions deepmd/infer/deep_property.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,8 +56,12 @@ def change_output_def(self) -> None:
)
)
self.deep_eval.output_def = self.output_def
self.deep_eval._OUTDEF_DP2BACKEND[self.get_property_name()] = f"atom_{self.get_property_name()}"
self.deep_eval._OUTDEF_DP2BACKEND[f"{self.get_property_name()}_redu"] = self.get_property_name()
self.deep_eval._OUTDEF_DP2BACKEND[self.get_property_name()] = (
f"atom_{self.get_property_name()}"
)
self.deep_eval._OUTDEF_DP2BACKEND[f"{self.get_property_name()}_redu"] = (
self.get_property_name()
)

@property
def task_dim(self) -> int:
Expand Down Expand Up @@ -125,7 +129,9 @@ def eval(
atomic_property = results[self.get_property_name()].reshape(
nframes, natoms, self.get_task_dim()
)
property = results[f"{self.get_property_name()}_redu"].reshape(nframes, self.get_task_dim())
property = results[f"{self.get_property_name()}_redu"].reshape(
nframes, self.get_task_dim()
)

if atomic:
return (
Expand Down
16 changes: 12 additions & 4 deletions deepmd/pt/model/model/property_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,8 +62,12 @@ def forward(
do_atomic_virial=do_atomic_virial,
)
model_predict = {}
model_predict[f"atom_{self.get_property_name()}"] = model_ret[self.get_property_name()]
model_predict[self.get_property_name()] = model_ret[f"{self.get_property_name()}_redu"]
model_predict[f"atom_{self.get_property_name()}"] = model_ret[
self.get_property_name()
]
model_predict[self.get_property_name()] = model_ret[
f"{self.get_property_name()}_redu"
]
if "mask" in model_ret:
model_predict["mask"] = model_ret["mask"]
return model_predict
Expand Down Expand Up @@ -107,8 +111,12 @@ def forward_lower(
extra_nlist_sort=self.need_sorted_nlist_for_lower(),
)
model_predict = {}
model_predict[f"atom_{self.get_property_name()}"] = model_ret[self.get_property_name()]
model_predict[self.get_property_name()] = model_ret[f"{self.get_property_name()}_redu"]
model_predict[f"atom_{self.get_property_name()}"] = model_ret[
self.get_property_name()
]
model_predict[self.get_property_name()] = model_ret[
f"{self.get_property_name()}_redu"
]
if "mask" in model_ret:
model_predict["mask"] = model_ret["mask"]
return model_predict

0 comments on commit 34cc7ce

Please sign in to comment.