From 05921072bfc78e5b001f5022772e699302cb20c0 Mon Sep 17 00:00:00 2001 From: Anyang Peng <137014849+anyangml@users.noreply.github.com> Date: Wed, 13 Mar 2024 09:12:08 +0800 Subject: [PATCH] fix: add dtype --- deepmd/pt/model/atomic_model/linear_atomic_model.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/deepmd/pt/model/atomic_model/linear_atomic_model.py b/deepmd/pt/model/atomic_model/linear_atomic_model.py index a8c5b91791..98fd9b40c4 100644 --- a/deepmd/pt/model/atomic_model/linear_atomic_model.py +++ b/deepmd/pt/model/atomic_model/linear_atomic_model.py @@ -284,7 +284,7 @@ def _compute_weight( ) -> List[torch.Tensor]: """This should be a list of user defined weights that matches the number of models to be combined.""" nmodels = len(self.models) - return [torch.ones(1) / nmodels for _ in range(nmodels)] + return [torch.ones(1, dtype=torch.float64, device=env.DEVICE) / nmodels for _ in range(nmodels)] def get_dim_fparam(self) -> int: """Get the number (dimension) of frame parameters of this atomic model."""