From 2da5a35bc3dd6c845e9aaf932d0ecf2d98184ce7 Mon Sep 17 00:00:00 2001 From: Anyang Peng <137014849+anyangml@users.noreply.github.com> Date: Wed, 6 Mar 2024 19:51:11 +0800 Subject: [PATCH] feat: add output --- deepmd/pt/model/task/polarizability.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/deepmd/pt/model/task/polarizability.py b/deepmd/pt/model/task/polarizability.py index 90a8719a22..32637926b8 100644 --- a/deepmd/pt/model/task/polarizability.py +++ b/deepmd/pt/model/task/polarizability.py @@ -198,7 +198,7 @@ def compute_output_stats( sampled = merged sys_matrix, polar_bias = [], [] - for sys in len(sampled): + for sys in range(len(sampled)): if sampled[sys]["find_atomic_polarizability"] > 0.0: for itype in range(self.ntypes): # this is a tensor of shape nframes, nall @@ -275,5 +275,7 @@ def forward( "bim,bmj->bij", gr.transpose(1, 2), out ) # (nframes * nloc, 3, 3) out = out.view(nframes, nloc, 3, 3) + if self.shift_diag: + out = out + self.constant_matrix[atype]*torch.eye(3, device=env.DEVICE)* self.scale[atype] return {self.var_name: out.to(env.GLOBAL_PT_FLOAT_PRECISION)}