Skip to content

Commit

Permalink
Merge pull request #597 from JiabinYang/support_jit_save_load
Browse files Browse the repository at this point in the history
support jist save load
  • Loading branch information
amcadmus authored May 7, 2021
2 parents f1bccb1 + 87effc5 commit 4b24e1f
Show file tree
Hide file tree
Showing 4 changed files with 17 additions and 15 deletions.
25 changes: 13 additions & 12 deletions deepmd/infer/deep_eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,18 +41,19 @@ def __init__(
fitting_param['descrpt'] = self.descrpt
self.fitting = EnerFitting(**fitting_param)

self.model = EnerModel(
self.descrpt,
self.fitting,
model_param.get('type_map'),
model_param.get('data_stat_nbatch', 10),
model_param.get('data_stat_protect', 1e-2),
model_param.get('use_srtab'),
model_param.get('smin_alpha'),
model_param.get('sw_rmin'),
model_param.get('sw_rmax')
)
self.model.set_dict(paddle.load(model_file))
# self.model = EnerModel(
# self.descrpt,
# self.fitting,
# model_param.get('type_map'),
# model_param.get('data_stat_nbatch', 10),
# model_param.get('data_stat_protect', 1e-2),
# model_param.get('use_srtab'),
# model_param.get('smin_alpha'),
# model_param.get('sw_rmin'),
# model_param.get('sw_rmax')
# )
# self.model.set_dict(paddle.load(model_file))
self.model = paddle.jit.load(model_file)
################################################################

self.load_prefix = load_prefix
Expand Down
1 change: 1 addition & 0 deletions deepmd/infer/deep_pot.py
Original file line number Diff line number Diff line change
Expand Up @@ -266,6 +266,7 @@ def _eval_inner(
else:
eval_inputs['default_mesh'] = paddle.to_tensor(np.array([], dtype = np.int32))

self.model.eval()
eval_outputs = self.model(eval_inputs['coord'], eval_inputs['type'], eval_inputs['natoms_vec'], eval_inputs['box'], eval_inputs['default_mesh'], eval_inputs, suffix = "", reuse = False)

energy = eval_outputs['energy'].numpy()
Expand Down
2 changes: 1 addition & 1 deletion deepmd/model/ener.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,7 +108,7 @@ def _compute_input_stat (self, all_stat, protection = 1e-2) :
def _compute_output_stat (self, all_stat) :
self.fitting.compute_output_stats(all_stat)

#@paddle.jit.to_static
@paddle.jit.to_static
def forward (self,
coord_,
atype_,
Expand Down
4 changes: 2 additions & 2 deletions deepmd/train/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -384,8 +384,8 @@ def train (self,
% (self.cur_batch, train_time, test_time))
train_time = 0
if self.save_freq > 0 and self.cur_batch % self.save_freq == 0 and self.run_opt.is_chief:
#paddle.jit.save(self.model, os.getcwd() + "/" + self.save_ckpt)
paddle.save(self.model.state_dict(), os.getcwd() + "/" + self.save_ckpt)
paddle.jit.save(self.model, os.getcwd() + "/" + self.save_ckpt)
# paddle.save(self.model.state_dict(), os.getcwd() + "/" + self.save_ckpt)
log.info("saved checkpoint to %s" % (os.getcwd() + "/" + self.save_ckpt))
if self.run_opt.is_chief:
fp.close ()
Expand Down

0 comments on commit 4b24e1f

Please sign in to comment.