From 6018e3c57e70e68d8d7c106b75e55bdca5576f36 Mon Sep 17 00:00:00 2001 From: Han Wang <92130845+wanghan-iapcm@users.noreply.github.com> Date: Tue, 13 Feb 2024 17:20:57 +0800 Subject: [PATCH] fix bug of not passing params in model (#3260) Co-authored-by: Han Wang Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> --- deepmd/pt/model/model/ener.py | 21 +++++++++++++++++++-- deepmd/pt/model/model/make_model.py | 8 ++++++++ 2 files changed, 27 insertions(+), 2 deletions(-) diff --git a/deepmd/pt/model/model/ener.py b/deepmd/pt/model/model/ener.py index ea35cf5a82..3c0b66edcd 100644 --- a/deepmd/pt/model/model/ener.py +++ b/deepmd/pt/model/model/ener.py @@ -43,6 +43,9 @@ def forward( coord, atype, box, + fparam=fparam, + aparam=aparam, + do_atomic_virial=do_atomic_virial, ) model_predict = {} @@ -63,13 +66,18 @@ def forward_lower( extended_atype, nlist, mapping: Optional[torch.Tensor] = None, + fparam: Optional[torch.Tensor] = None, + aparam: Optional[torch.Tensor] = None, do_atomic_virial: bool = False, ): model_ret = self.forward_common_lower( extended_coord, extended_atype, nlist, - mapping, + mapping=mapping, + fparam=fparam, + aparam=aparam, + do_atomic_virial=do_atomic_virial, ) model_predict = {} @@ -109,7 +117,12 @@ def forward( do_atomic_virial: bool = False, ) -> Dict[str, torch.Tensor]: model_ret = self.forward_common( - coord, atype, box, do_atomic_virial=do_atomic_virial + coord, + atype, + box, + fparam=fparam, + aparam=aparam, + do_atomic_virial=do_atomic_virial, ) if self.fitting_net is not None: model_predict = {} @@ -135,6 +148,8 @@ def forward_lower( extended_atype, nlist, mapping: Optional[torch.Tensor] = None, + fparam: Optional[torch.Tensor] = None, + aparam: Optional[torch.Tensor] = None, do_atomic_virial: bool = False, ): model_ret = self.forward_common_lower( @@ -142,6 +157,8 @@ def forward_lower( extended_atype, nlist, mapping, + fparam=fparam, + aparam=aparam, do_atomic_virial=do_atomic_virial, ) if self.fitting_net is not None: diff --git a/deepmd/pt/model/model/make_model.py b/deepmd/pt/model/model/make_model.py index 8a863b8cdc..9191f8c58f 100644 --- a/deepmd/pt/model/model/make_model.py +++ b/deepmd/pt/model/model/make_model.py @@ -83,6 +83,10 @@ def forward_common( The type of atoms. shape: nf x nloc box The simulation box. shape: nf x 9 + fparam + frame parameter. nf x ndf + aparam + atomic parameter. nf x nloc x nda do_atomic_virial If calculate the atomic virial. @@ -155,6 +159,10 @@ def forward_common_lower( neighbor list. nf x nloc x nsel. mapping mapps the extended indices to local indices. nf x nall. + fparam + frame parameter. nf x ndf + aparam + atomic parameter. nf x nloc x nda do_atomic_virial whether calculate atomic virial.