Skip to content

Commit

Permalink
pt: make jit happy with torch 2.0.0 (#3443)
Browse files Browse the repository at this point in the history
Signed-off-by: Jinzhe Zeng <[email protected]>
Co-authored-by: Han Wang <[email protected]>
  • Loading branch information
njzjz and wanghan-iapcm authored Mar 11, 2024
1 parent 804848a commit b544885
Show file tree
Hide file tree
Showing 3 changed files with 12 additions and 0 deletions.
3 changes: 3 additions & 0 deletions deepmd/pt/model/task/dipole.py
Original file line number Diff line number Diff line change
Expand Up @@ -192,3 +192,6 @@ def forward(
# (nframes, nloc, 3)
out = torch.bmm(out, gr).squeeze(-2).view(nframes, nloc, 3)
return {self.var_name: out.to(env.GLOBAL_PT_FLOAT_PRECISION)}

# make jit happy with torch 2.0.0
exclude_types: List[int]
6 changes: 6 additions & 0 deletions deepmd/pt/model/task/ener.py
Original file line number Diff line number Diff line change
Expand Up @@ -214,6 +214,9 @@ def forward(
"""
return self._forward_common(descriptor, atype, gr, g2, h2, fparam, aparam)

# make jit happy with torch 2.0.0
exclude_types: List[int]


@Fitting.register("ener")
class EnergyFittingNet(InvarFitting):
Expand Down Expand Up @@ -262,6 +265,9 @@ def serialize(self) -> dict:
"type": "ener",
}

# make jit happy with torch 2.0.0
exclude_types: List[int]


@Fitting.register("direct_force")
@Fitting.register("direct_force_ener")
Expand Down
3 changes: 3 additions & 0 deletions deepmd/pt/model/task/polarizability.py
Original file line number Diff line number Diff line change
Expand Up @@ -316,3 +316,6 @@ def forward(
out = out + bias

return {self.var_name: out.to(env.GLOBAL_PT_FLOAT_PRECISION)}

# make jit happy with torch 2.0.0
exclude_types: List[int]

0 comments on commit b544885

Please sign in to comment.