Skip to content

Commit

Permalink
fix uts
Browse files Browse the repository at this point in the history
  • Loading branch information
iProzd committed Nov 13, 2024
1 parent 2aa2f73 commit f5f7370
Show file tree
Hide file tree
Showing 7 changed files with 14 additions and 15 deletions.
2 changes: 0 additions & 2 deletions deepmd/pt/model/task/dipole.py
Original file line number Diff line number Diff line change
Expand Up @@ -186,8 +186,6 @@ def forward(
]
# (nframes * nloc, 1, m1)
out = out.view(-1, 1, self.embedding_width)
# cast from global to gr precision again
out = out.to(dtype=gr.dtype)
# (nframes * nloc, m1, 3)
gr = gr.view(nframes * nloc, self.embedding_width, 3)
# (nframes, nloc, 3)
Expand Down
8 changes: 4 additions & 4 deletions deepmd/pt/model/task/fitting.py
Original file line number Diff line number Diff line change
Expand Up @@ -477,15 +477,15 @@ def _forward_common(

outs = torch.zeros(
(nf, nloc, net_dim_out),
dtype=env.GLOBAL_PT_FLOAT_PRECISION,
dtype=self.prec,
device=descriptor.device,
) # jit assertion
if self.mixed_types:
atom_property = self.filter_layers.networks[0](xx)
if xx_zeros is not None:
atom_property -= self.filter_layers.networks[0](xx_zeros)
outs = (
outs + atom_property + self.bias_atom_e[atype]
outs + atom_property + self.bias_atom_e[atype].to(self.prec)
) # Shape is [nframes, natoms[0], net_dim_out]
else:
for type_i, ll in enumerate(self.filter_layers.networks):
Expand All @@ -500,7 +500,7 @@ def _forward_common(
and not self.remove_vaccum_contribution[type_i]
):
atom_property -= ll(xx_zeros)
atom_property = atom_property + self.bias_atom_e[type_i]
atom_property = atom_property + self.bias_atom_e[type_i].to(self.prec)
atom_property = torch.where(mask, atom_property, 0.0)
outs = (
outs + atom_property
Expand All @@ -509,4 +509,4 @@ def _forward_common(
mask = self.emask(atype).to(torch.bool)
# nf x nloc x nod
outs = torch.where(mask[:, :, None], outs, 0.0)
return {self.var_name: outs.to(env.GLOBAL_PT_FLOAT_PRECISION)}
return {self.var_name: outs}
5 changes: 4 additions & 1 deletion deepmd/pt/model/task/invar_fitting.py
Original file line number Diff line number Diff line change
Expand Up @@ -177,7 +177,10 @@ def forward(
-------
- `torch.Tensor`: Total energy with shape [nframes, natoms[0]].
"""
return self._forward_common(descriptor, atype, gr, g2, h2, fparam, aparam)
out = self._forward_common(descriptor, atype, gr, g2, h2, fparam, aparam)[
self.var_name
]
return {self.var_name: out.to(env.GLOBAL_PT_FLOAT_PRECISION)}

# make jit happy with torch 2.0.0
exclude_types: list[int]
4 changes: 1 addition & 3 deletions deepmd/pt/model/task/polarizability.py
Original file line number Diff line number Diff line change
Expand Up @@ -237,9 +237,7 @@ def forward(
out = self._forward_common(descriptor, atype, gr, g2, h2, fparam, aparam)[
self.var_name
]
out = out * (self.scale.to(atype.device))[atype]
# cast from global to gr precision again
out = out.to(dtype=gr.dtype)
out = out * (self.scale.to(atype.device).to(self.prec))[atype]

gr = gr.view(nframes * nloc, self.embedding_width, 3) # (nframes * nloc, m1, 3)

Expand Down
4 changes: 2 additions & 2 deletions source/tests/pt/model/test_dipole_fitting.py
Original file line number Diff line number Diff line change
Expand Up @@ -262,7 +262,7 @@ def test_permu(self):
nlist,
)

ret0 = ft0(rd0, atype, gr0, fparam=0, aparam=0)
ret0 = ft0(rd0, atype, gr0, fparam=None, aparam=None)
res.append(ret0["dipole"])

np.testing.assert_allclose(
Expand Down Expand Up @@ -303,7 +303,7 @@ def test_trans(self):
nlist,
)

ret0 = ft0(rd0, atype, gr0, fparam=0, aparam=0)
ret0 = ft0(rd0, atype, gr0, fparam=None, aparam=None)
res.append(ret0["dipole"])

np.testing.assert_allclose(to_numpy_array(res[0]), to_numpy_array(res[1]))
Expand Down
2 changes: 1 addition & 1 deletion source/tests/pt/model/test_polarizability_fitting.py
Original file line number Diff line number Diff line change
Expand Up @@ -326,7 +326,7 @@ def test_trans(self):
nlist,
)

ret0 = ft0(rd0, atype, gr0, fparam=0, aparam=0)
ret0 = ft0(rd0, atype, gr0, fparam=None, aparam=None)
res.append(ret0["polarizability"])

np.testing.assert_allclose(to_numpy_array(res[0]), to_numpy_array(res[1]))
Expand Down
4 changes: 2 additions & 2 deletions source/tests/pt/model/test_property_fitting.py
Original file line number Diff line number Diff line change
Expand Up @@ -228,7 +228,7 @@ def test_trans(self):
nlist,
)

ret0 = ft0(rd0, atype, gr0, fparam=0, aparam=0)
ret0 = ft0(rd0, atype, gr0, fparam=None, aparam=None)
res.append(ret0["property"])

np.testing.assert_allclose(to_numpy_array(res[0]), to_numpy_array(res[1]))
Expand Down Expand Up @@ -399,7 +399,7 @@ def test_trans(self):
nlist,
)

ret0 = ft0(rd0, atype, gr0, fparam=0, aparam=0)
ret0 = ft0(rd0, atype, gr0, fparam=None, aparam=None)
res.append(ret0["property"])

np.testing.assert_allclose(to_numpy_array(res[0]), to_numpy_array(res[1]))
Expand Down

0 comments on commit f5f7370

Please sign in to comment.