From 50092c6ab33444f22fabd1e67c6c7de4fb38fb26 Mon Sep 17 00:00:00 2001 From: HydrogenSulfate <490868991@qq.com> Date: Tue, 24 Sep 2024 14:11:56 +0800 Subject: [PATCH] change pt -> pd --- deepmd/pd/model/model/make_model.py | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/deepmd/pd/model/model/make_model.py b/deepmd/pd/model/model/make_model.py index 3a35589458..1b634f7c8c 100644 --- a/deepmd/pd/model/model/make_model.py +++ b/deepmd/pd/model/model/make_model.py @@ -84,8 +84,8 @@ def __init__( self.atomic_model: T_AtomicModel = T_AtomicModel(*args, **kwargs) self.precision_dict = PRECISION_DICT self.reverse_precision_dict = RESERVED_PRECISON_DICT - self.global_pt_float_precision = GLOBAL_PD_FLOAT_PRECISION - self.global_pt_ener_float_precision = GLOBAL_PD_ENER_FLOAT_PRECISION + self.global_pd_float_precision = GLOBAL_PD_FLOAT_PRECISION + self.global_pd_ener_float_precision = GLOBAL_PD_ENER_FLOAT_PRECISION def model_output_def(self): """Get the output def for the model.""" @@ -312,11 +312,11 @@ def input_type_cast( box, fparam, aparam = _lst if ( input_prec - == self.reverse_precision_dict[self.global_pt_float_precision] + == self.reverse_precision_dict[self.global_pd_float_precision] ): return coord, box, fparam, aparam, input_prec else: - pp = self.global_pt_float_precision + pp = self.global_pd_float_precision return ( coord.to(pp), box.to(pp) if box is not None else None, @@ -333,7 +333,7 @@ def output_type_cast( """Convert the model output to the input prec.""" do_cast = ( input_prec - != self.reverse_precision_dict[self.global_pt_float_precision] + != self.reverse_precision_dict[self.global_pd_float_precision] ) pp = self.precision_dict[input_prec] odef = self.model_output_def() @@ -343,7 +343,7 @@ def output_type_cast( continue if check_operation_applied(odef[kk], OutputVariableOperation.REDU): model_ret[kk] = ( - model_ret[kk].to(self.global_pt_ener_float_precision) + model_ret[kk].to(self.global_pd_ener_float_precision) if model_ret[kk] is not None else None ) @@ -424,7 +424,7 @@ def _format_nlist( * paddle.ones( [n_nf, n_nloc, nnei - n_nnei], dtype=nlist.dtype, - ).to(device=nlist.place), + ), # .to(device=nlist.place), ], axis=-1, )