diff --git a/deepmd/pt/model/task/polarizability.py b/deepmd/pt/model/task/polarizability.py index 783aed2614..98c9060704 100644 --- a/deepmd/pt/model/task/polarizability.py +++ b/deepmd/pt/model/task/polarizability.py @@ -114,7 +114,7 @@ def __init__( self.scale, dtype=env.GLOBAL_PT_FLOAT_PRECISION, device=env.DEVICE ).view(ntypes, 1) self.shift_diag = shift_diag - self.constant_matrix = torch.zero(self.ntypes) + self.constant_matrix = torch.zeros(self.ntypes) super().__init__( var_name=kwargs.pop("var_name", "polar"), ntypes=ntypes,