Skip to content

Commit

Permalink
change pt -> pd
Browse files Browse the repository at this point in the history
  • Loading branch information
HydrogenSulfate committed Sep 24, 2024
1 parent b97571e commit 50092c6
Showing 1 changed file with 7 additions and 7 deletions.
14 changes: 7 additions & 7 deletions deepmd/pd/model/model/make_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,8 +84,8 @@ def __init__(
self.atomic_model: T_AtomicModel = T_AtomicModel(*args, **kwargs)
self.precision_dict = PRECISION_DICT
self.reverse_precision_dict = RESERVED_PRECISON_DICT
self.global_pt_float_precision = GLOBAL_PD_FLOAT_PRECISION
self.global_pt_ener_float_precision = GLOBAL_PD_ENER_FLOAT_PRECISION
self.global_pd_float_precision = GLOBAL_PD_FLOAT_PRECISION
self.global_pd_ener_float_precision = GLOBAL_PD_ENER_FLOAT_PRECISION

def model_output_def(self):
"""Get the output def for the model."""
Expand Down Expand Up @@ -312,11 +312,11 @@ def input_type_cast(
box, fparam, aparam = _lst
if (
input_prec
== self.reverse_precision_dict[self.global_pt_float_precision]
== self.reverse_precision_dict[self.global_pd_float_precision]
):
return coord, box, fparam, aparam, input_prec
else:
pp = self.global_pt_float_precision
pp = self.global_pd_float_precision
return (
coord.to(pp),
box.to(pp) if box is not None else None,
Expand All @@ -333,7 +333,7 @@ def output_type_cast(
"""Convert the model output to the input prec."""
do_cast = (
input_prec
!= self.reverse_precision_dict[self.global_pt_float_precision]
!= self.reverse_precision_dict[self.global_pd_float_precision]
)
pp = self.precision_dict[input_prec]
odef = self.model_output_def()
Expand All @@ -343,7 +343,7 @@ def output_type_cast(
continue
if check_operation_applied(odef[kk], OutputVariableOperation.REDU):
model_ret[kk] = (
model_ret[kk].to(self.global_pt_ener_float_precision)
model_ret[kk].to(self.global_pd_ener_float_precision)
if model_ret[kk] is not None
else None
)
Expand Down Expand Up @@ -424,7 +424,7 @@ def _format_nlist(
* paddle.ones(
[n_nf, n_nloc, nnei - n_nnei],
dtype=nlist.dtype,
).to(device=nlist.place),
), # .to(device=nlist.place),
],
axis=-1,
)
Expand Down

0 comments on commit 50092c6

Please sign in to comment.