Skip to content

Commit

Permalink
[Paddle] Fixed model save issues with Ener model
Browse files Browse the repository at this point in the history
Following issues fixed:
1. Removed @paddle.jit.to_static decorator. Model will be converted to
static graph at save time.

2. Manually set InputSpec for "Ener" model with "se_a" descriptor

3. Due to lack of support for "double" datatype at inference time,
default training precision was set to float (low precision)
  • Loading branch information
jim19930609 committed Jul 20, 2021
1 parent 55670b2 commit 75f96f4
Show file tree
Hide file tree
Showing 7 changed files with 53 additions and 16 deletions.
2 changes: 1 addition & 1 deletion deepmd/env.py
Original file line number Diff line number Diff line change
Expand Up @@ -191,7 +191,7 @@ def _get_package_constants(
GLOBAL_TF_FLOAT_PRECISION = tf.float32
GLOBAL_PD_FLOAT_PRECISION = "float32"
GLOBAL_NP_FLOAT_PRECISION = np.float32
GLOBAL_ENER_FLOAT_PRECISION = np.float64
GLOBAL_ENER_FLOAT_PRECISION = np.float32
global_float_prec = "float"


Expand Down
3 changes: 1 addition & 2 deletions deepmd/model/ener.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,14 +108,13 @@ def _compute_input_stat (self, all_stat, protection = 1e-2) :
def _compute_output_stat (self, all_stat) :
self.fitting.compute_output_stats(all_stat)

@paddle.jit.to_static
def forward (self,
coord_,
atype_,
natoms,
box,
mesh,
input_dict,
input_dict = {},
suffix = '',
reuse = None):
coord = paddle.reshape(coord_, [-1, natoms[1] * 3])
Expand Down
45 changes: 37 additions & 8 deletions deepmd/train/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@
import os
import time
import shutil
import copy
import gc
import numpy as np
from deepmd.env import tf, paddle
from deepmd.env import default_tf_session_config
Expand Down Expand Up @@ -74,7 +76,6 @@ def _generate_descrpt_from_param_dict(descrpt_param):
else :
raise RuntimeError('unknow model type ' + descrpt_type)
return descrpt


class DPTrainer (object):
def __init__(self,
Expand All @@ -91,7 +92,7 @@ def _init_param(self, jdata):
fitting_param = j_must_have(model_param, 'fitting_net')
self.model_param = model_param
self.descrpt_param = descrpt_param

# descriptor
try:
descrpt_type = descrpt_param['type']
Expand All @@ -105,12 +106,16 @@ def _init_param(self, jdata):
for ii in descrpt_param.get('list', []):
descrpt_list.append(_generate_descrpt_from_param_dict(ii))
self.descrpt = DescrptHybrid(descrpt_list)

# fitting net
try:
fitting_type = fitting_param['type']
except:
fitting_type = 'ener'

self.fitting_type = fitting_type
self.descrpt_type = descrpt_type

fitting_param.pop('type', None)
fitting_param['descrpt'] = self.descrpt
if fitting_type == 'ener':
Expand Down Expand Up @@ -384,17 +389,41 @@ def train (self,
log.info("batch %7d training time %.2f s, testing time %.2f s"
% (self.cur_batch, train_time, test_time))
train_time = 0
if self.save_freq > 0 and self.cur_batch % self.save_freq == 0 and self.run_opt.is_chief:
paddle.jit.save(self.model, os.getcwd() + "/" + self.save_ckpt)
# paddle.save(self.model.state_dict(), os.getcwd() + "/" + self.save_ckpt)
log.info("saved checkpoint to %s" % (os.getcwd() + "/" + self.save_ckpt))

if self.save_freq > 0 and self.cur_batch % self.save_freq == 0:
self.save_model(model_inputs, self.save_ckpt + "/model")

if self.run_opt.is_chief:
fp.close ()
if self.profiling and self.run_opt.is_chief :
fetched_timeline = timeline.Timeline(prf_run_metadata.step_stats)
chrome_trace = fetched_timeline.generate_chrome_trace_format()
with open(self.profiling_file, 'w') as f:
f.write(chrome_trace)

self.save_model(model_inputs, self.save_ckpt + "/model")

def save_model(self, model_inputs_, folder_name_):
# Since "paddle.jit.to_static" modifiess the model in-place
# We have to make a temporary model copy to avoid damage to the original model.
model = copy.copy(self.model)
save_path = os.getcwd() + "/" + folder_name_
if self.fitting_type == "ener" and self.descrpt_type == "se_a":
input_names = ['coord', 'type', 'natoms_vec', 'box', 'default_mesh']
input_specs = [paddle.static.InputSpec(model_inputs_[name].shape, model_inputs_[name].dtype, name=name) for name in input_names]
else:
raise NotImplementedError

try:
model = paddle.jit.to_static(model, input_spec=input_specs)
paddle.jit.save(model, save_path)
except Exception as e:
raise e
finally:
del model
gc.collect()

log.info("saved checkpoint to %s" % (save_path))

def get_global_step (self) :
return self.cur_batch
Expand Down Expand Up @@ -478,4 +507,4 @@ def test_on_the_fly (self,
print("batch %7d, lr %f, l2_l %f, l2_ener_loss %f, l2_force_loss %f, l2_virial_loss %f, l2_atom_ener_loss %f, l2_pref_force_loss %f" % (current_batch, current_lr, error_train, error_e_train, error_f_train, error_v_train, error_ae_train, error_pf_train))
print_str += " %8.1e\n" % current_lr
fp.write(print_str)
fp.flush ()
fp.flush ()
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,7 +101,7 @@
f"-DTENSORFLOW_ROOT:STRING={tf_install_dir}",
"-DBUILD_PY_IF:BOOL=TRUE",
"-DBUILD_CPP_IF:BOOL=FALSE",
"-DFLOAT_PREC:STRING=high",
"-DFLOAT_PREC:STRING=low",
],
cmake_source_dir="source",
cmake_minimum_required_version="3.0",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,7 @@ _prepare_coord_nlist_cpu(
const int &max_cpy_trial,
const int &max_nnei_trial);

#ifdef PADDLE_WITH_CUDA
std::vector<paddle::Tensor> PdProdEnvMatAOpCUDAForward(
const paddle::Tensor &coord_tensor,
const paddle::Tensor &type_tensor,
Expand All @@ -85,6 +86,7 @@ std::vector<paddle::Tensor> PdProdEnvMatAOpCUDAForward(
float rcut_r_smth,
std::vector<int> sel_a,
std::vector<int> sel_r);
#endif

template <typename data_t>
void PdProdEnvMatAOpCPUForwardKernel(
Expand Down Expand Up @@ -295,6 +297,7 @@ std::vector<paddle::Tensor> PdProdEnvMatAOpForward(
sel_a,
sel_r
);
#ifdef PADDLE_WITH_CUDA
} else if (coord_tensor.place() == paddle::PlaceType::kGPU) {
return PdProdEnvMatAOpCUDAForward(
coord_tensor,
Expand All @@ -310,6 +313,7 @@ std::vector<paddle::Tensor> PdProdEnvMatAOpForward(
sel_a,
sel_r
);
#endif
} else {
PD_THROW("Not implemented.");
}
Expand Down Expand Up @@ -507,4 +511,4 @@ PD_BUILD_OP(prod_env_mat_a)
"sel_r: std::vector<int>"})
.SetKernelFn(PD_KERNEL(PdProdEnvMatAOpForward))
.SetInferShapeFn(PD_INFER_SHAPE(PdProdEnvMatAOpInferShape))
.SetInferDtypeFn(PD_INFER_DTYPE(PdProdEnvMatAOpInferDtype));
.SetInferDtypeFn(PD_INFER_DTYPE(PdProdEnvMatAOpInferDtype));
Original file line number Diff line number Diff line change
Expand Up @@ -10,13 +10,15 @@



#ifdef PADDLE_WITH_CUDA
std::vector<paddle::Tensor> PdProdForceSeAOpCUDAForward(
const paddle::Tensor& net_deriv_tensor,
const paddle::Tensor& in_deriv_tensor,
const paddle::Tensor& nlist_tensor,
const paddle::Tensor& natoms_tensor,
int n_a_sel,
int n_r_sel);
#endif

template <typename data_t>
void PdProdForceSeAOpForwardCPUKernel(
Expand Down Expand Up @@ -199,8 +201,10 @@ int n_a_sel,
int n_r_sel){
if(net_deriv_tensor.place() == paddle::PlaceType::kCPU){
return PdProdForceSeAOpCPUForward(net_deriv_tensor, in_deriv_tensor, nlist_tensor, natoms_tensor, n_a_sel, n_r_sel);
#ifdef PADDLE_WITH_CUDA
}else if(net_deriv_tensor.place() == paddle::PlaceType::kGPU){
return PdProdForceSeAOpCUDAForward(net_deriv_tensor, in_deriv_tensor, nlist_tensor, natoms_tensor, n_a_sel, n_r_sel);
#endif
}else{
PD_THROW("No Such kernel for PdFrodForceSeAForward!");
}
Expand Down Expand Up @@ -281,4 +285,4 @@ PD_BUILD_OP(prod_force_se_a_grad2)
"n_r_sel: int"})
.SetKernelFn(PD_KERNEL(PdProdForceSeABackward))
.SetInferShapeFn(PD_INFER_SHAPE(PdProdForceSeAOpBackwardInferShape))
.SetInferDtypeFn(PD_INFER_DTYPE(PdProdForceSeAOpBackwardInferDtype));
.SetInferDtypeFn(PD_INFER_DTYPE(PdProdForceSeAOpBackwardInferDtype));
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
#define CHECK_INPUT_DIM(x, value) PD_CHECK(x.shape().size() == value, #x "'s dim should be " #value ".")



#ifdef PADDLE_WITH_CUDA
std::vector<paddle::Tensor> PdProdVirialSeAOpCUDAForward(
const paddle::Tensor& net_deriv_tensor,
const paddle::Tensor& in_deriv_tensor,
Expand All @@ -18,6 +18,7 @@ const paddle::Tensor& nlist_tensor,
const paddle::Tensor& natoms_tensor,
int n_a_sel,
int n_r_sel);
#endif

template <typename data_t>
void PdProdVirialSeAOpForwardCPUKernel(
Expand Down Expand Up @@ -328,4 +329,4 @@ PD_BUILD_OP(prod_virial_se_a_grad2)
"n_r_sel: int"})
.SetKernelFn(PD_KERNEL(PdProdVirialSeABackward))
.SetInferShapeFn(PD_INFER_SHAPE(PdProdVirialSeAOpBackwardInferShape))
.SetInferDtypeFn(PD_INFER_DTYPE(PdProdVirialSeAOpBackwardInferDtype));
.SetInferDtypeFn(PD_INFER_DTYPE(PdProdVirialSeAOpBackwardInferDtype));

0 comments on commit 75f96f4

Please sign in to comment.