diff --git a/deepmd/pt/model/model/ener.py b/deepmd/pt/model/model/ener.py index ea35cf5a82..3c0b66edcd 100644 --- a/deepmd/pt/model/model/ener.py +++ b/deepmd/pt/model/model/ener.py @@ -43,6 +43,9 @@ def forward( coord, atype, box, + fparam=fparam, + aparam=aparam, + do_atomic_virial=do_atomic_virial, ) model_predict = {} @@ -63,13 +66,18 @@ def forward_lower( extended_atype, nlist, mapping: Optional[torch.Tensor] = None, + fparam: Optional[torch.Tensor] = None, + aparam: Optional[torch.Tensor] = None, do_atomic_virial: bool = False, ): model_ret = self.forward_common_lower( extended_coord, extended_atype, nlist, - mapping, + mapping=mapping, + fparam=fparam, + aparam=aparam, + do_atomic_virial=do_atomic_virial, ) model_predict = {} @@ -109,7 +117,12 @@ def forward( do_atomic_virial: bool = False, ) -> Dict[str, torch.Tensor]: model_ret = self.forward_common( - coord, atype, box, do_atomic_virial=do_atomic_virial + coord, + atype, + box, + fparam=fparam, + aparam=aparam, + do_atomic_virial=do_atomic_virial, ) if self.fitting_net is not None: model_predict = {} @@ -135,6 +148,8 @@ def forward_lower( extended_atype, nlist, mapping: Optional[torch.Tensor] = None, + fparam: Optional[torch.Tensor] = None, + aparam: Optional[torch.Tensor] = None, do_atomic_virial: bool = False, ): model_ret = self.forward_common_lower( @@ -142,6 +157,8 @@ def forward_lower( extended_atype, nlist, mapping, + fparam=fparam, + aparam=aparam, do_atomic_virial=do_atomic_virial, ) if self.fitting_net is not None: diff --git a/deepmd/pt/model/model/make_model.py b/deepmd/pt/model/model/make_model.py index 8a863b8cdc..9191f8c58f 100644 --- a/deepmd/pt/model/model/make_model.py +++ b/deepmd/pt/model/model/make_model.py @@ -83,6 +83,10 @@ def forward_common( The type of atoms. shape: nf x nloc box The simulation box. shape: nf x 9 + fparam + frame parameter. nf x ndf + aparam + atomic parameter. nf x nloc x nda do_atomic_virial If calculate the atomic virial. @@ -155,6 +159,10 @@ def forward_common_lower( neighbor list. nf x nloc x nsel. mapping mapps the extended indices to local indices. nf x nall. + fparam + frame parameter. nf x ndf + aparam + atomic parameter. nf x nloc x nda do_atomic_virial whether calculate atomic virial.