From 87effc5a98af8c96610cb7cd8790275a988ed6e1 Mon Sep 17 00:00:00 2001 From: JiabinYang <360788950@qq.com> Date: Fri, 7 May 2021 10:02:59 +0000 Subject: [PATCH] support jist save load --- deepmd/infer/deep_eval.py | 25 +++++++++++++------------ deepmd/infer/deep_pot.py | 1 + deepmd/model/ener.py | 2 +- deepmd/train/trainer.py | 4 ++-- 4 files changed, 17 insertions(+), 15 deletions(-) diff --git a/deepmd/infer/deep_eval.py b/deepmd/infer/deep_eval.py index d6cd05dba5..9dd811e0d0 100644 --- a/deepmd/infer/deep_eval.py +++ b/deepmd/infer/deep_eval.py @@ -41,18 +41,19 @@ def __init__( fitting_param['descrpt'] = self.descrpt self.fitting = EnerFitting(**fitting_param) - self.model = EnerModel( - self.descrpt, - self.fitting, - model_param.get('type_map'), - model_param.get('data_stat_nbatch', 10), - model_param.get('data_stat_protect', 1e-2), - model_param.get('use_srtab'), - model_param.get('smin_alpha'), - model_param.get('sw_rmin'), - model_param.get('sw_rmax') - ) - self.model.set_dict(paddle.load(model_file)) + # self.model = EnerModel( + # self.descrpt, + # self.fitting, + # model_param.get('type_map'), + # model_param.get('data_stat_nbatch', 10), + # model_param.get('data_stat_protect', 1e-2), + # model_param.get('use_srtab'), + # model_param.get('smin_alpha'), + # model_param.get('sw_rmin'), + # model_param.get('sw_rmax') + # ) + # self.model.set_dict(paddle.load(model_file)) + self.model = paddle.jit.load(model_file) ################################################################ self.load_prefix = load_prefix diff --git a/deepmd/infer/deep_pot.py b/deepmd/infer/deep_pot.py index 4ea6147336..62dadea2af 100644 --- a/deepmd/infer/deep_pot.py +++ b/deepmd/infer/deep_pot.py @@ -266,6 +266,7 @@ def _eval_inner( else: eval_inputs['default_mesh'] = paddle.to_tensor(np.array([], dtype = np.int32)) + self.model.eval() eval_outputs = self.model(eval_inputs['coord'], eval_inputs['type'], eval_inputs['natoms_vec'], eval_inputs['box'], eval_inputs['default_mesh'], eval_inputs, suffix = "", reuse = False) energy = eval_outputs['energy'].numpy() diff --git a/deepmd/model/ener.py b/deepmd/model/ener.py index b0335a0aeb..f0f5db3fa0 100644 --- a/deepmd/model/ener.py +++ b/deepmd/model/ener.py @@ -108,7 +108,7 @@ def _compute_input_stat (self, all_stat, protection = 1e-2) : def _compute_output_stat (self, all_stat) : self.fitting.compute_output_stats(all_stat) - #@paddle.jit.to_static + @paddle.jit.to_static def forward (self, coord_, atype_, diff --git a/deepmd/train/trainer.py b/deepmd/train/trainer.py index e8fb2cee20..f483bcc696 100644 --- a/deepmd/train/trainer.py +++ b/deepmd/train/trainer.py @@ -384,8 +384,8 @@ def train (self, % (self.cur_batch, train_time, test_time)) train_time = 0 if self.save_freq > 0 and self.cur_batch % self.save_freq == 0 and self.run_opt.is_chief: - #paddle.jit.save(self.model, os.getcwd() + "/" + self.save_ckpt) - paddle.save(self.model.state_dict(), os.getcwd() + "/" + self.save_ckpt) + paddle.jit.save(self.model, os.getcwd() + "/" + self.save_ckpt) + # paddle.save(self.model.state_dict(), os.getcwd() + "/" + self.save_ckpt) log.info("saved checkpoint to %s" % (os.getcwd() + "/" + self.save_ckpt)) if self.run_opt.is_chief: fp.close ()