From 3a0dca0ba597d7993789898e1c8d2a651834b277 Mon Sep 17 00:00:00 2001 From: Jinzhe Zeng Date: Mon, 11 Mar 2024 00:42:43 -0400 Subject: [PATCH] pt: make jit happy with torch 2.0.0 Signed-off-by: Jinzhe Zeng --- deepmd/pt/model/task/dipole.py | 3 +++ deepmd/pt/model/task/ener.py | 6 ++++++ deepmd/pt/model/task/polarizability.py | 3 +++ 3 files changed, 12 insertions(+) diff --git a/deepmd/pt/model/task/dipole.py b/deepmd/pt/model/task/dipole.py index 21372888d6..3356fee16e 100644 --- a/deepmd/pt/model/task/dipole.py +++ b/deepmd/pt/model/task/dipole.py @@ -182,3 +182,6 @@ def forward( # (nframes, nloc, 3) out = torch.bmm(out, gr).squeeze(-2).view(nframes, nloc, 3) return {self.var_name: out.to(env.GLOBAL_PT_FLOAT_PRECISION)} + + # make jit happy with torch 2.0.0 + exclude_types: List[int] diff --git a/deepmd/pt/model/task/ener.py b/deepmd/pt/model/task/ener.py index b593ddc3cc..6e0649eff2 100644 --- a/deepmd/pt/model/task/ener.py +++ b/deepmd/pt/model/task/ener.py @@ -205,6 +205,9 @@ def forward( """ return self._forward_common(descriptor, atype, gr, g2, h2, fparam, aparam) + # make jit happy with torch 2.0.0 + exclude_types: List[int] + @Fitting.register("ener") class EnergyFittingNet(InvarFitting): @@ -252,6 +255,9 @@ def serialize(self) -> dict: "type": "ener", } + # make jit happy with torch 2.0.0 + exclude_types: List[int] + @Fitting.register("direct_force") @Fitting.register("direct_force_ener") diff --git a/deepmd/pt/model/task/polarizability.py b/deepmd/pt/model/task/polarizability.py index fa4f6d7f37..8595e702cf 100644 --- a/deepmd/pt/model/task/polarizability.py +++ b/deepmd/pt/model/task/polarizability.py @@ -220,3 +220,6 @@ def forward( out = out.view(nframes, nloc, 3, 3) return {self.var_name: out.to(env.GLOBAL_PT_FLOAT_PRECISION)} + + # make jit happy with torch 2.0.0 + exclude_types: List[int]