Skip to content

Commit

Permalink
Merge branch 'devel' into devel
Browse files Browse the repository at this point in the history
  • Loading branch information
anyangml authored Feb 13, 2024
2 parents ca75551 + 6018e3c commit 9bc9900
Show file tree
Hide file tree
Showing 2 changed files with 27 additions and 2 deletions.
21 changes: 19 additions & 2 deletions deepmd/pt/model/model/ener.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,9 @@ def forward(
coord,
atype,
box,
fparam=fparam,
aparam=aparam,
do_atomic_virial=do_atomic_virial,
)

model_predict = {}
Expand All @@ -63,13 +66,18 @@ def forward_lower(
extended_atype,
nlist,
mapping: Optional[torch.Tensor] = None,
fparam: Optional[torch.Tensor] = None,
aparam: Optional[torch.Tensor] = None,
do_atomic_virial: bool = False,
):
model_ret = self.forward_common_lower(
extended_coord,
extended_atype,
nlist,
mapping,
mapping=mapping,
fparam=fparam,
aparam=aparam,
do_atomic_virial=do_atomic_virial,
)

model_predict = {}
Expand Down Expand Up @@ -109,7 +117,12 @@ def forward(
do_atomic_virial: bool = False,
) -> Dict[str, torch.Tensor]:
model_ret = self.forward_common(
coord, atype, box, do_atomic_virial=do_atomic_virial
coord,
atype,
box,
fparam=fparam,
aparam=aparam,
do_atomic_virial=do_atomic_virial,
)
if self.fitting_net is not None:
model_predict = {}
Expand All @@ -135,13 +148,17 @@ def forward_lower(
extended_atype,
nlist,
mapping: Optional[torch.Tensor] = None,
fparam: Optional[torch.Tensor] = None,
aparam: Optional[torch.Tensor] = None,
do_atomic_virial: bool = False,
):
model_ret = self.forward_common_lower(
extended_coord,
extended_atype,
nlist,
mapping,
fparam=fparam,
aparam=aparam,
do_atomic_virial=do_atomic_virial,
)
if self.fitting_net is not None:
Expand Down
8 changes: 8 additions & 0 deletions deepmd/pt/model/model/make_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,10 @@ def forward_common(
The type of atoms. shape: nf x nloc
box
The simulation box. shape: nf x 9
fparam
frame parameter. nf x ndf
aparam
atomic parameter. nf x nloc x nda
do_atomic_virial
If calculate the atomic virial.
Expand Down Expand Up @@ -155,6 +159,10 @@ def forward_common_lower(
neighbor list. nf x nloc x nsel.
mapping
mapps the extended indices to local indices. nf x nall.
fparam
frame parameter. nf x ndf
aparam
atomic parameter. nf x nloc x nda
do_atomic_virial
whether calculate atomic virial.
Expand Down

0 comments on commit 9bc9900

Please sign in to comment.