Skip to content

Commit

Permalink
hack the jit script
Browse files Browse the repository at this point in the history
  • Loading branch information
Han Wang committed Feb 28, 2024
1 parent cd16e7b commit e69e1b4
Showing 1 changed file with 5 additions and 4 deletions.
9 changes: 5 additions & 4 deletions deepmd/pt/model/model/transform_output.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,9 @@
get_deriv_name,
get_reduce_name,
)
from deepmd.pt.utils import (
env,
)


def atomic_virial_corr(
Expand Down Expand Up @@ -148,8 +151,7 @@ def fit_output_to_model_output(
the model output.
"""
## should have been GLOBAL_PT_ENER_FLOAT_PRECISION, but does not pass jit!!!
redu_prec = torch.float64
redu_prec = env.GLOBAL_PT_ENER_FLOAT_PRECISION
model_ret = dict(fit_ret.items())
for kk, vv in fit_ret.items():
vdef = fit_output_def[kk]
Expand Down Expand Up @@ -188,8 +190,7 @@ def communicate_extended_output(
local and ghost (extended) atoms to local atoms.
"""
## should have been GLOBAL_PT_ENER_FLOAT_PRECISION, but does not pass jit!!!
redu_prec = torch.float64
redu_prec = env.GLOBAL_PT_ENER_FLOAT_PRECISION
new_ret = {}
for kk in model_output_def.keys_outp():
vv = model_ret[kk]
Expand Down

0 comments on commit e69e1b4

Please sign in to comment.