diff --git a/deepmd/common.py b/deepmd/common.py index e4613aeabb..2a48c143b1 100644 --- a/deepmd/common.py +++ b/deepmd/common.py @@ -19,8 +19,8 @@ import numpy as np import yaml -from deepmd.env import op_module, tf -from deepmd.env import GLOBAL_TF_FLOAT_PRECISION, GLOBAL_NP_FLOAT_PRECISION +from deepmd.env import op_module, tf, paddle +from deepmd.env import GLOBAL_TF_FLOAT_PRECISION, GLOBAL_PD_FLOAT_PRECISION, GLOBAL_NP_FLOAT_PRECISION if TYPE_CHECKING: _DICT_VAL = TypeVar("_DICT_VAL") @@ -34,10 +34,10 @@ # define constants PRECISION_DICT = { - "default": GLOBAL_TF_FLOAT_PRECISION, - "float16": tf.float16, - "float32": tf.float32, - "float64": tf.float64, + "default": GLOBAL_PD_FLOAT_PRECISION, + "float16": np.float16, + "float32": np.float32, + "float64": np.float64, } @@ -68,11 +68,11 @@ def gelu(x: tf.Tensor) -> tf.Tensor: data_requirement = {} ACTIVATION_FN_DICT = { - "relu": tf.nn.relu, - "relu6": tf.nn.relu6, - "softplus": tf.nn.softplus, - "sigmoid": tf.sigmoid, - "tanh": tf.nn.tanh, + "relu": paddle.nn.functional.relu, + "relu6": paddle.nn.functional.relu6, + "softplus": paddle.nn.functional.softplus, + "sigmoid": paddle.nn.functional.sigmoid, + "tanh": paddle.nn.functional.tanh, "gelu": gelu, } @@ -367,7 +367,7 @@ def j_loader(filename: Union[str, Path]) -> Dict[str, Any]: def get_activation_func( activation_fn: "_ACTIVATION", -) -> Callable[[tf.Tensor], tf.Tensor]: +): """Get activation function callable based on string name. Parameters diff --git a/deepmd/descriptor/se_a.py b/deepmd/descriptor/se_a.py index 54325ca77c..6f92f09542 100644 --- a/deepmd/descriptor/se_a.py +++ b/deepmd/descriptor/se_a.py @@ -2,18 +2,19 @@ import numpy as np from typing import Tuple, List -from deepmd.env import tf +from deepmd.env import paddle +from paddle import to_tensor from deepmd.common import get_activation_func, get_precision, ACTIVATION_FN_DICT, PRECISION_DICT, docstring_parameter, get_np_precision from deepmd.utils.argcheck import list_to_doc -from deepmd.env import GLOBAL_TF_FLOAT_PRECISION -from deepmd.env import GLOBAL_NP_FLOAT_PRECISION -from deepmd.env import op_module -from deepmd.env import default_tf_session_config -from deepmd.utils.network import embedding_net +from deepmd.env import GLOBAL_PD_FLOAT_PRECISION, GLOBAL_NP_FLOAT_PRECISION, paddle_ops +from deepmd.utils.network import EmbeddingNet from deepmd.utils.tabulate import DeepTabulate +from collections import defaultdict +import sys -class DescrptSeA (): + +class DescrptSeA(paddle.nn.Layer): @docstring_parameter(list_to_doc(ACTIVATION_FN_DICT.keys()), list_to_doc(PRECISION_DICT.keys())) def __init__ (self, rcut: float, @@ -63,6 +64,7 @@ def __init__ (self, precision The precision of the embedding net parameters. Supported options are {1} """ + super(DescrptSeA, self).__init__(name_scope="DescrptSeA") self.sel_a = sel self.rcut_r = rcut self.rcut_r_smth = rcut_smth @@ -100,32 +102,27 @@ def __init__ (self, self.dstd = None self.davg = None self.compress = False - self.place_holders = {} - avg_zero = np.zeros([self.ntypes,self.ndescrpt]).astype(GLOBAL_NP_FLOAT_PRECISION) - std_ones = np.ones ([self.ntypes,self.ndescrpt]).astype(GLOBAL_NP_FLOAT_PRECISION) - sub_graph = tf.Graph() - with sub_graph.as_default(): - name_pfx = 'd_sea_' - for ii in ['coord', 'box']: - self.place_holders[ii] = tf.placeholder(GLOBAL_NP_FLOAT_PRECISION, [None, None], name = name_pfx+'t_'+ii) - self.place_holders['type'] = tf.placeholder(tf.int32, [None, None], name=name_pfx+'t_type') - self.place_holders['natoms_vec'] = tf.placeholder(tf.int32, [self.ntypes+2], name=name_pfx+'t_natoms') - self.place_holders['default_mesh'] = tf.placeholder(tf.int32, [None], name=name_pfx+'t_mesh') - self.stat_descrpt, descrpt_deriv, rij, nlist \ - = op_module.prod_env_mat_a(self.place_holders['coord'], - self.place_holders['type'], - self.place_holders['natoms_vec'], - self.place_holders['box'], - self.place_holders['default_mesh'], - tf.constant(avg_zero), - tf.constant(std_ones), - rcut_a = self.rcut_a, - rcut_r = self.rcut_r, - rcut_r_smth = self.rcut_r_smth, - sel_a = self.sel_a, - sel_r = self.sel_r) - self.sub_sess = tf.Session(graph = sub_graph, config=default_tf_session_config) + self.avg_zero = paddle.zeros([self.ntypes, self.ndescrpt], dtype=GLOBAL_PD_FLOAT_PRECISION) + self.std_ones = paddle.ones ([self.ntypes, self.ndescrpt], dtype=GLOBAL_PD_FLOAT_PRECISION) + + nets = [] + for type_input in range(self.ntypes) : + layer = [] + for type_i in range(self.ntypes) : + layer.append(EmbeddingNet(self.filter_neuron, self.filter_precision, self.filter_activation_fn, self.filter_resnet_dt, self.seed, self.trainable, name='filter_type_'+str(type_input)+str(type_i))) + nets.append(paddle.nn.LayerList(layer)) + + self.embedding_nets = paddle.nn.LayerList(nets) + self.t_rcut = paddle.to_tensor(np.max([self.rcut_r, self.rcut_a]), dtype = GLOBAL_PD_FLOAT_PRECISION) + self.t_ntypes = paddle.to_tensor(self.ntypes, dtype = "int32") + self.t_ndescrpt = paddle.to_tensor(self.ndescrpt, dtype = "int32") + self.t_sel = paddle.to_tensor(self.sel_a, dtype = "int32") + + t_avg = paddle.to_tensor(np.zeros([self.ntypes, self.ndescrpt]), dtype = GLOBAL_PD_FLOAT_PRECISION) + t_std = paddle.to_tensor(np.ones([self.ntypes, self.ndescrpt]), dtype = GLOBAL_PD_FLOAT_PRECISION) + self.register_buffer("t_avg", t_avg) + self.register_buffer("t_std", t_std) def get_rcut (self) -> float: """ @@ -151,7 +148,7 @@ def get_dim_rot_mat_1 (self) -> int: """ return self.filter_neuron[-1] - def get_nlist (self) -> Tuple[tf.Tensor, tf.Tensor, List[int], List[int]]: + def get_nlist (self) -> Tuple[paddle.Tensor, paddle.Tensor, List[int], List[int]]: """ Returns ------- @@ -228,6 +225,9 @@ def compute_input_stats (self, if not self.set_davg_zero: self.davg = np.array(all_davg) self.dstd = np.array(all_dstd) + + self.t_avg = paddle.to_tensor(self.davg, dtype = GLOBAL_NP_FLOAT_PRECISION) + self.t_std = paddle.to_tensor(self.dstd, dtype = GLOBAL_NP_FLOAT_PRECISION) def enable_compression(self, min_nbor_dist : float, @@ -265,16 +265,15 @@ def enable_compression(self, table_stride_1, table_stride_2) - def build (self, - coord_ : tf.Tensor, - atype_ : tf.Tensor, - natoms : tf.Tensor, - box_ : tf.Tensor, - mesh : tf.Tensor, - input_dict : dict, - reuse : bool = None, - suffix : str = '' - ) -> tf.Tensor: + def forward (self, + coord_, + atype_ , + natoms , + box_ , + mesh, + input_dict, + reuse = None, + suffix = ''): """ Build the computational graph for the descriptor @@ -305,42 +304,13 @@ def build (self, descriptor The output descriptor """ - davg = self.davg - dstd = self.dstd - with tf.variable_scope('descrpt_attr' + suffix, reuse = reuse) : - if davg is None: - davg = np.zeros([self.ntypes, self.ndescrpt]) - if dstd is None: - dstd = np.ones ([self.ntypes, self.ndescrpt]) - t_rcut = tf.constant(np.max([self.rcut_r, self.rcut_a]), - name = 'rcut', - dtype = GLOBAL_TF_FLOAT_PRECISION) - t_ntypes = tf.constant(self.ntypes, - name = 'ntypes', - dtype = tf.int32) - t_ndescrpt = tf.constant(self.ndescrpt, - name = 'ndescrpt', - dtype = tf.int32) - t_sel = tf.constant(self.sel_a, - name = 'sel', - dtype = tf.int32) - self.t_avg = tf.get_variable('t_avg', - davg.shape, - dtype = GLOBAL_TF_FLOAT_PRECISION, - trainable = False, - initializer = tf.constant_initializer(davg)) - self.t_std = tf.get_variable('t_std', - dstd.shape, - dtype = GLOBAL_TF_FLOAT_PRECISION, - trainable = False, - initializer = tf.constant_initializer(dstd)) - - coord = tf.reshape (coord_, [-1, natoms[1] * 3]) - box = tf.reshape (box_, [-1, 9]) - atype = tf.reshape (atype_, [-1, natoms[1]]) + + coord = paddle.reshape(coord_, [-1, natoms[1] * 3]) + box = paddle.reshape(box_, [-1, 9]) + atype = paddle.reshape(atype_, [-1, natoms[1]]) self.descrpt, self.descrpt_deriv, self.rij, self.nlist \ - = op_module.prod_env_mat_a (coord, + = paddle_ops.prod_env_mat_a(coord, atype, natoms, box, @@ -352,16 +322,9 @@ def build (self, rcut_r_smth = self.rcut_r_smth, sel_a = self.sel_a, sel_r = self.sel_r) - # only used when tensorboard was set as true - tf.summary.histogram('descrpt', self.descrpt) - tf.summary.histogram('rij', self.rij) - tf.summary.histogram('nlist', self.nlist) - self.descrpt_reshape = tf.reshape(self.descrpt, [-1, self.ndescrpt]) - self.descrpt_reshape = tf.identity(self.descrpt_reshape, name = 'o_rmat') - self.descrpt_deriv = tf.identity(self.descrpt_deriv, name = 'o_rmat_deriv') - self.rij = tf.identity(self.rij, name = 'o_rij') - self.nlist = tf.identity(self.nlist, name = 'o_nlist') + self.descrpt_reshape = paddle.reshape(self.descrpt, [-1, self.ndescrpt]) + self.descrpt_reshape.stop_gradient = False self.dout, self.qmat = self._pass_filter(self.descrpt_reshape, atype, @@ -370,12 +333,10 @@ def build (self, suffix = suffix, reuse = reuse, trainable = self.trainable) - - # only used when tensorboard was set as true - tf.summary.histogram('embedding_net_output', self.dout) + return self.dout - def get_rot_mat(self) -> tf.Tensor: + def get_rot_mat(self) -> paddle.Tensor: """ Get rotational matrix """ @@ -383,9 +344,8 @@ def get_rot_mat(self) -> tf.Tensor: def prod_force_virial(self, - atom_ener : tf.Tensor, - natoms : tf.Tensor - ) -> Tuple[tf.Tensor, tf.Tensor, tf.Tensor]: + atom_ener, + natoms): """ Compute force and virial @@ -407,28 +367,27 @@ def prod_force_virial(self, atom_virial The atomic virial """ - [net_deriv] = tf.gradients (atom_ener, self.descrpt_reshape) - tf.summary.histogram('net_derivative', net_deriv) - net_deriv_reshape = tf.reshape (net_deriv, [-1, natoms[0] * self.ndescrpt]) + + net_deriv = paddle.grad(atom_ener, self.descrpt_reshape, create_graph=True)[0] + net_deriv_reshape = paddle.reshape (net_deriv, [-1, natoms[0] * self.ndescrpt]) + force \ - = op_module.prod_force_se_a (net_deriv_reshape, + = paddle_ops.prod_force_se_a (net_deriv_reshape, self.descrpt_deriv, self.nlist, natoms, n_a_sel = self.nnei_a, n_r_sel = self.nnei_r) + virial, atom_virial \ - = op_module.prod_virial_se_a (net_deriv_reshape, + = paddle_ops.prod_virial_se_a (net_deriv_reshape, self.descrpt_deriv, self.rij, self.nlist, natoms, n_a_sel = self.nnei_a, n_r_sel = self.nnei_r) - tf.summary.histogram('force', force) - tf.summary.histogram('virial', virial) - tf.summary.histogram('atom_virial', atom_virial) - + return force, virial, atom_virial @@ -441,32 +400,23 @@ def _pass_filter(self, suffix = '', trainable = True) : start_index = 0 - inputs = tf.reshape(inputs, [-1, self.ndescrpt * natoms[0]]) + inputs = paddle.reshape(inputs, [-1, self.ndescrpt * natoms[0]]) output = [] output_qmat = [] - if not self.type_one_side: - for type_i in range(self.ntypes): - inputs_i = tf.slice (inputs, - [ 0, start_index* self.ndescrpt], - [-1, natoms[2+type_i]* self.ndescrpt] ) - inputs_i = tf.reshape(inputs_i, [-1, self.ndescrpt]) - layer, qmat = self._filter(tf.cast(inputs_i, self.filter_precision), type_i, name='filter_type_'+str(type_i)+suffix, natoms=natoms, reuse=reuse, seed = self.seed, trainable = trainable, activation_fn = self.filter_activation_fn) - layer = tf.reshape(layer, [tf.shape(inputs)[0], natoms[2+type_i] * self.get_dim_out()]) - qmat = tf.reshape(qmat, [tf.shape(inputs)[0], natoms[2+type_i] * self.get_dim_rot_mat_1() * 3]) - output.append(layer) - output_qmat.append(qmat) - start_index += natoms[2+type_i] - else : - inputs_i = inputs - inputs_i = tf.reshape(inputs_i, [-1, self.ndescrpt]) - type_i = -1 - layer, qmat = self._filter(tf.cast(inputs_i, self.filter_precision), type_i, name='filter_type_all'+suffix, natoms=natoms, reuse=reuse, seed = self.seed, trainable = trainable, activation_fn = self.filter_activation_fn) - layer = tf.reshape(layer, [tf.shape(inputs)[0], natoms[0] * self.get_dim_out()]) - qmat = tf.reshape(qmat, [tf.shape(inputs)[0], natoms[0] * self.get_dim_rot_mat_1() * 3]) + for type_i in range(self.ntypes): + inputs_i = paddle.slice (inputs, axes=[0, 1], + starts = [0, start_index * self.ndescrpt], + ends = [inputs.shape[0], (start_index + natoms[2+type_i]) * self.ndescrpt]) + inputs_i = paddle.reshape(inputs_i, [-1, self.ndescrpt]) + layer, qmat = self._filter(paddle.cast(inputs_i, self.filter_precision), type_i, natoms=natoms, reuse=reuse, seed = self.seed, trainable = trainable, activation_fn = self.filter_activation_fn) + layer = paddle.reshape(layer, [inputs.shape[0], natoms[2+type_i] * self.get_dim_out()]) + qmat = paddle.reshape(qmat, [inputs.shape[0], natoms[2+type_i] * self.get_dim_rot_mat_1() * 3]) output.append(layer) output_qmat.append(qmat) - output = tf.concat(output, axis = 1) - output_qmat = tf.concat(output_qmat, axis = 1) + start_index += natoms[2+type_i] + + output = paddle.concat(output, axis = 1) + output_qmat = paddle.concat(output_qmat, axis = 1) return output, output_qmat @@ -475,16 +425,28 @@ def _compute_dstats_sys_smth (self, data_box, data_atype, natoms_vec, - mesh) : - dd_all \ - = self.sub_sess.run(self.stat_descrpt, - feed_dict = { - self.place_holders['coord']: data_coord, - self.place_holders['type']: data_atype, - self.place_holders['natoms_vec']: natoms_vec, - self.place_holders['box']: data_box, - self.place_holders['default_mesh']: mesh, - }) + mesh) : + input_dict = {} + input_dict['coord'] = paddle.to_tensor(data_coord, dtype=GLOBAL_NP_FLOAT_PRECISION) + input_dict['box'] = paddle.to_tensor(data_box, dtype=GLOBAL_PD_FLOAT_PRECISION) + input_dict['type'] = paddle.to_tensor(data_atype, dtype="int32") + input_dict['natoms_vec'] = paddle.to_tensor(natoms_vec, dtype="int32") + input_dict['default_mesh'] = paddle.to_tensor(mesh, dtype="int32") + + self.stat_descrpt, descrpt_deriv, rij, nlist = paddle_ops.prod_env_mat_a(input_dict['coord'], + input_dict['type'], + input_dict['natoms_vec'], + input_dict['box'], + input_dict['default_mesh'], + self.avg_zero, + self.std_ones, + rcut_a = self.rcut_a, + rcut_r = self.rcut_r, + rcut_r_smth = self.rcut_r_smth, + sel_a = self.sel_a, + sel_r = self.sel_r) + + dd_all = self.stat_descrpt.numpy() natoms = natoms_vec dd_all = np.reshape(dd_all, [-1, self.ndescrpt * natoms[0]]) start_index = 0 @@ -521,88 +483,69 @@ def _compute_std (self,sumv2, sumv, sumn) : val = 1e-2 return val - def _filter(self, - inputs, - type_input, - natoms, - activation_fn=tf.nn.tanh, - stddev=1.0, - bavg=0.0, - name='linear', - reuse=None, - seed=None, + inputs, + type_input, + natoms, + activation_fn=paddle.nn.functional.tanh, + stddev=1.0, + bavg=0.0, + reuse=None, + seed=None, trainable = True): # natom x (nei x 4) - shape = inputs.get_shape().as_list() + shape = inputs.shape outputs_size = [1] + self.filter_neuron outputs_size_2 = self.n_axis_neuron - with tf.variable_scope(name, reuse=reuse): - start_index = 0 - xyz_scatter_total = [] - for type_i in range(self.ntypes): - # cut-out inputs - # with natom x (nei_type_i x 4) - inputs_i = tf.slice (inputs, - [ 0, start_index* 4], - [-1, self.sel_a[type_i]* 4] ) - start_index += self.sel_a[type_i] - shape_i = inputs_i.get_shape().as_list() - # with (natom x nei_type_i) x 4 - inputs_reshape = tf.reshape(inputs_i, [-1, 4]) - # with (natom x nei_type_i) x 1 - xyz_scatter = tf.reshape(tf.slice(inputs_reshape, [0,0],[-1,1]),[-1,1]) - # with (natom x nei_type_i) x out_size - if self.compress and (type_input, type_i) not in self.exclude_types: - info = [self.lower, self.upper, self.upper * self.table_config[0], self.table_config[1], self.table_config[2], self.table_config[3]] - if self.type_one_side: - net = 'filter_-1_net_' + str(type_i) - else: - net = 'filter_' + str(type_input) + '_net_' + str(type_i) - if type_i == 0: - xyz_scatter_1 = op_module.tabulate_fusion(self.table.data[net].astype(self.filter_np_precision), info, xyz_scatter, tf.reshape(inputs_i, [-1, shape_i[1]//4, 4]), last_layer_size = outputs_size[-1]) - else: - xyz_scatter_1 += op_module.tabulate_fusion(self.table.data[net].astype(self.filter_np_precision), info, xyz_scatter, tf.reshape(inputs_i, [-1, shape_i[1]//4, 4]), last_layer_size = outputs_size[-1]) - else: - if (type_input, type_i) not in self.exclude_types: - xyz_scatter = embedding_net(xyz_scatter, - self.filter_neuron, - self.filter_precision, - activation_fn = activation_fn, - resnet_dt = self.filter_resnet_dt, - name_suffix = "_"+str(type_i), - stddev = stddev, - bavg = bavg, - seed = seed, - trainable = trainable) - else: - w = tf.zeros((outputs_size[0], outputs_size[-1]), dtype=GLOBAL_TF_FLOAT_PRECISION) - xyz_scatter = tf.matmul(xyz_scatter, w) - # natom x nei_type_i x out_size - xyz_scatter = tf.reshape(xyz_scatter, (-1, shape_i[1]//4, outputs_size[-1])) - # xyz_scatter_total.append(xyz_scatter) - if type_i == 0 : - xyz_scatter_1 = tf.matmul(tf.reshape(inputs_i, [-1, shape_i[1]//4, 4]), xyz_scatter, transpose_a = True) - else : - xyz_scatter_1 += tf.matmul(tf.reshape(inputs_i, [-1, shape_i[1]//4, 4]), xyz_scatter, transpose_a = True) - # natom x nei x outputs_size - # xyz_scatter = tf.concat(xyz_scatter_total, axis=1) - # natom x nei x 4 - # inputs_reshape = tf.reshape(inputs, [-1, shape[1]//4, 4]) - # natom x 4 x outputs_size - # xyz_scatter_1 = tf.matmul(inputs_reshape, xyz_scatter, transpose_a = True) - xyz_scatter_1 = xyz_scatter_1 * (4.0 / shape[1]) - # natom x 4 x outputs_size_2 - xyz_scatter_2 = tf.slice(xyz_scatter_1, [0,0,0],[-1,-1,outputs_size_2]) - # # natom x 3 x outputs_size_2 - # qmat = tf.slice(xyz_scatter_2, [0,1,0], [-1, 3, -1]) - # natom x 3 x outputs_size_1 - qmat = tf.slice(xyz_scatter_1, [0,1,0], [-1, 3, -1]) - # natom x outputs_size_1 x 3 - qmat = tf.transpose(qmat, perm = [0, 2, 1]) - # natom x outputs_size x outputs_size_2 - result = tf.matmul(xyz_scatter_1, xyz_scatter_2, transpose_a = True) - # natom x (outputs_size x outputs_size_2) - result = tf.reshape(result, [-1, outputs_size_2 * outputs_size[-1]]) + + start_index = 0 + xyz_scatter_total = [] + for type_i in range(self.ntypes): + # cut-out inputs + # with natom x (nei_type_i x 4) + inputs_i = paddle.slice (inputs, axes=[0, 1], + starts = [ 0, start_index*4], + ends = [inputs.shape[0], (start_index + self.sel_a[type_i])* 4] ) + start_index += self.sel_a[type_i] + shape_i = inputs_i.shape + # with (natom x nei_type_i) x 4 + inputs_reshape = paddle.reshape(inputs_i, [-1, 4]) + # with (natom x nei_type_i) x 1 + xyz_scatter = paddle.reshape(paddle.slice(inputs_reshape, [0, 1],[0,0],[inputs_reshape.shape[0],1]), [-1,1]) + # with (natom x nei_type_i) x out_size + + xyz_scatter = self.embedding_nets[type_input][type_i](xyz_scatter) + + # natom x nei_type_i x out_size + xyz_scatter = paddle.reshape(xyz_scatter, (-1, shape_i[1]//4, outputs_size[-1])) + + # xyz_scatter_total.append(xyz_scatter) + if type_i == 0 : + xyz_scatter_1 = paddle.fluid.layers.matmul(paddle.reshape(inputs_i, [-1, shape_i[1]//4, 4]), xyz_scatter, transpose_x = True) + else : + xyz_scatter_1 += paddle.fluid.layers.matmul(paddle.reshape(inputs_i, [-1, shape_i[1]//4, 4]), xyz_scatter, transpose_x = True) + + # natom x nei x outputs_size + # xyz_scatter = paddle.concat(xyz_scatter_total, axis=1) + # natom x nei x 4 + # inputs_reshape = paddle.reshape(inputs, [-1, shape[1]//4, 4]) + # natom x 4 x outputs_size + # xyz_scatter_1 = paddle.matmul(inputs_reshape, xyz_scatter, transpose_a = True) + xyz_scatter_1 = xyz_scatter_1 * (4.0 / shape[1]) + + # natom x 4 x outputs_size_2 + xyz_scatter_2 = paddle.slice(xyz_scatter_1, [0,1,2], [0,0,0],[xyz_scatter_1.shape[0],xyz_scatter_1.shape[1],outputs_size_2]) + + # # natom x 3 x outputs_size_2 + # qmat = paddle.slice(xyz_scatter_2, [0,1,0], [-1, 3, -1]) + # natom x 3 x outputs_size_1 + qmat = paddle.slice(xyz_scatter_1, [0,1,2], [0,1,0], [xyz_scatter_1.shape[0], 4, xyz_scatter_1.shape[2]]) + # natom x outputs_size_1 x 3 + qmat = paddle.transpose(qmat, perm = [0, 2, 1]) + # natom x outputs_size x outputs_size_2 + result = paddle.fluid.layers.matmul(xyz_scatter_1, xyz_scatter_2, transpose_x = True) + # natom x (outputs_size x outputs_size_2) + result = paddle.reshape(result, [-1, outputs_size_2 * outputs_size[-1]]) return result, qmat + diff --git a/deepmd/entrypoints/test.py b/deepmd/entrypoints/test.py index 5c91547866..767a0ed147 100644 --- a/deepmd/entrypoints/test.py +++ b/deepmd/entrypoints/test.py @@ -242,6 +242,7 @@ def test_ener( numb_test = min(nframes, numb_test) coord = test_data["coord"][:numb_test].reshape([numb_test, -1]) + box = test_data["box"][:numb_test] if dp.has_efield: efield = test_data["efield"][:numb_test].reshape([numb_test, -1]) diff --git a/deepmd/entrypoints/train.py b/deepmd/entrypoints/train.py index 685815db17..3a9858f0c4 100755 --- a/deepmd/entrypoints/train.py +++ b/deepmd/entrypoints/train.py @@ -18,6 +18,8 @@ from deepmd.utils.compat import convert_input_v0_v1 from deepmd.utils.data_system import DeepmdDataSystem +from collections import defaultdict + if TYPE_CHECKING: from deepmd.run_options import TFServerV1 @@ -260,6 +262,7 @@ def _do_work(jdata: Dict[str, Any], run_opt: RunOptions): # setup data modifier modifier: Optional[DipoleChargeModifier] modi_data = jdata["model"].get("modifier", None) + if modi_data is not None: if modi_data["type"] == "dipole_charge": modifier = DipoleChargeModifier( @@ -287,12 +290,12 @@ def _do_work(jdata: Dict[str, Any], run_opt: RunOptions): data.print_summary(run_opt, sys_probs=sys_probs, auto_prob_style=auto_prob_style) data.add_dict(data_requirement) - # build the model with stats from the first system + # # build the model with stats from the first system model.build(data, stop_batch) # train the model with the provided systems in a cyclic way start_time = time.time() - model.train(data) + model.train(data, stop_batch) end_time = time.time() log.info("finished training") log.info(f"wall time: {(end_time - start_time):.3f} s") diff --git a/deepmd/env.py b/deepmd/env.py index 8c6937b7f7..c369c99594 100644 --- a/deepmd/env.py +++ b/deepmd/env.py @@ -14,6 +14,8 @@ # import tensorflow v1 compatability try: + import paddle + import paddle_ops import tensorflow.compat.v1 as tf tf.disable_v2_behavior() @@ -181,11 +183,13 @@ def _get_package_constants( if GLOBAL_CONFIG["precision"] == "-DHIGH_PREC": GLOBAL_TF_FLOAT_PRECISION = tf.float64 + GLOBAL_PD_FLOAT_PRECISION = "float64" GLOBAL_NP_FLOAT_PRECISION = np.float64 GLOBAL_ENER_FLOAT_PRECISION = np.float64 global_float_prec = "double" else: GLOBAL_TF_FLOAT_PRECISION = tf.float32 + GLOBAL_PD_FLOAT_PRECISION = "float32" GLOBAL_NP_FLOAT_PRECISION = np.float32 GLOBAL_ENER_FLOAT_PRECISION = np.float64 global_float_prec = "float" @@ -207,19 +211,9 @@ def global_cvt_2_tf_float(xx: tf.Tensor) -> tf.Tensor: return tf.cast(xx, GLOBAL_TF_FLOAT_PRECISION) -def global_cvt_2_ener_float(xx: tf.Tensor) -> tf.Tensor: - """Cast tensor to globally set energy precision. - - Parameters - ---------- - xx : tf.Tensor - input tensor - - Returns - ------- - tf.Tensor - output tensor cast to `GLOBAL_ENER_FLOAT_PRECISION` - """ - return tf.cast(xx, GLOBAL_ENER_FLOAT_PRECISION) +def global_cvt_2_pd_float(xx: paddle.Tensor) -> paddle.Tensor: + return paddle.cast(xx, GLOBAL_PD_FLOAT_PRECISION) +def global_cvt_2_ener_float(xx: paddle.Tensor) -> paddle.Tensor: + return paddle.cast(xx, GLOBAL_ENER_FLOAT_PRECISION) diff --git a/deepmd/fit/ener.py b/deepmd/fit/ener.py index 5d60ffcc09..d91f75c21c 100644 --- a/deepmd/fit/ener.py +++ b/deepmd/fit/ener.py @@ -2,20 +2,20 @@ import numpy as np from typing import Tuple, List -from deepmd.env import tf +from deepmd.env import paddle from deepmd.common import ClassArg, add_data_requirement, get_activation_func, get_precision, ACTIVATION_FN_DICT, PRECISION_DICT, docstring_parameter from deepmd.utils.argcheck import list_to_doc -from deepmd.utils.network import one_layer +from deepmd.utils.network import OneLayer from deepmd.descriptor import DescrptLocFrame from deepmd.descriptor import DescrptSeA from deepmd.env import global_cvt_2_tf_float -from deepmd.env import GLOBAL_TF_FLOAT_PRECISION +from deepmd.env import GLOBAL_NP_FLOAT_PRECISION, GLOBAL_PD_FLOAT_PRECISION -class EnerFitting (): - @docstring_parameter(list_to_doc(ACTIVATION_FN_DICT.keys()), list_to_doc(PRECISION_DICT.keys())) + +class EnerFitting(paddle.nn.Layer): def __init__ (self, - descrpt : tf.Tensor, + descrpt, neuron : List[int] = [120,120,120], resnet_dt : bool = True, numb_fparam : int = 0, @@ -28,54 +28,9 @@ def __init__ (self, activation_function : str = 'tanh', precision : str = 'default' ) -> None: - """ - Constructor - - Parameters - ---------- - descrpt - The descrptor - neuron - Number of neurons in each hidden layer of the fitting net - resnet_dt - Time-step `dt` in the resnet construction: - y = x + dt * \phi (Wx + b) - numb_fparam - Number of frame parameter - numb_aparam - Number of atomic parameter - rcond - The condition number for the regression of atomic energy. - tot_ener_zero - Force the total energy to zero. Useful for the charge fitting. - trainable - If the weights of fitting net are trainable. - Suppose that we have N_l hidden layers in the fitting net, - this list is of length N_l + 1, specifying if the hidden layers and the output layer are trainable. - seed - Random seed for initializing the network parameters. - atom_ener - Specifying atomic energy contribution in vacuum. The `set_davg_zero` key in the descrptor should be set. - activation_function - The activation function in the embedding net. Supported options are {0} - precision - The precision of the embedding net parameters. Supported options are {1} - """ - # model param + super(EnerFitting, self).__init__(name_scope="EnerFitting") self.ntypes = descrpt.get_ntypes() self.dim_descrpt = descrpt.get_dim_out() - # args = ()\ - # .add('numb_fparam', int, default = 0)\ - # .add('numb_aparam', int, default = 0)\ - # .add('neuron', list, default = [120,120,120], alias = 'n_neuron')\ - # .add('resnet_dt', bool, default = True)\ - # .add('rcond', float, default = 1e-3) \ - # .add('tot_ener_zero', bool, default = False) \ - # .add('seed', int) \ - # .add('atom_ener', list, default = [])\ - # .add("activation_function", str, default = "tanh")\ - # .add("precision", str, default = "default")\ - # .add("trainable", [list, bool], default = True) self.numb_fparam = numb_fparam self.numb_aparam = numb_aparam self.n_neuron = neuron @@ -94,7 +49,7 @@ def __init__ (self, self.atom_ener = [] for at, ae in enumerate(atom_ener): if ae is not None: - self.atom_ener.append(tf.constant(ae, GLOBAL_TF_FLOAT_PRECISION, name = "atom_%d_ener" % at)) + self.atom_ener.append(paddle.to_tensor(ae, dtype=GLOBAL_PD_FLOAT_PRECISION)) else: self.atom_ener.append(None) self.useBN = False @@ -110,6 +65,37 @@ def __init__ (self, self.aparam_avg = None self.aparam_std = None self.aparam_inv_std = None + + emenets = [] + for type_i in range(self.ntypes): + layers = [] + for ii in range(0,len(self.n_neuron)): + if ii >= 1 and self.n_neuron[ii] == self.n_neuron[ii-1]: + layers.append(OneLayer(self.n_neuron[ii-1], self.n_neuron[ii], name='layer_'+str(ii)+'_type_'+str(type_i), seed = self.seed, use_timestep = self.resnet_dt, activation_fn = self.fitting_activation_fn, precision = self.fitting_precision, trainable = self.trainable[ii])) + else: + layers.append(OneLayer(self.dim_descrpt+self.numb_fparam+self.numb_aparam, self.n_neuron[ii], name='layer_'+str(ii)+'_type_'+str(type_i), seed = self.seed, activation_fn = self.fitting_activation_fn, precision = self.fitting_precision, trainable = self.trainable[ii])) + layers.append(OneLayer(self.n_neuron[-1], 1, name='final_layer_type_'+str(type_i), seed = self.seed, activation_fn = None, precision = self.fitting_precision, trainable = self.trainable[ii])) + + emenets.append(paddle.nn.LayerList(layers)) + self.ElementNets = paddle.nn.LayerList(emenets) + + self.t_dfparam = paddle.to_tensor(self.numb_fparam, dtype = "int32") + self.t_daparam = paddle.to_tensor(self.numb_aparam, dtype = "int32") + + # stat fparam + if self.numb_fparam > 0: + self.t_fparam_avg = paddle.to_tensor(np.zeros([1, self.numb_fparam]), + dtype = GLOBAL_PD_FLOAT_PRECISION) + self.t_fparam_istd = paddle.to_tensor(np.ones([1, self.numb_fparam]), + dtype = GLOBAL_PD_FLOAT_PRECISION) + + # stat aparam + if self.numb_aparam > 0: + self.t_aparam_avg = paddle.to_tensor(np.zeros([1, self.numb_aparam]), + dtype = GLOBAL_PD_FLOAT_PRECISION) + self.t_aparam_istd = paddle.to_tensor(np.ones([1, self.numb_aparam]), + dtype = GLOBAL_PD_FLOAT_PRECISION) + def get_numb_fparam(self) -> int: """ @@ -137,6 +123,15 @@ def compute_output_stats(self, can be prepared by model.make_stat_input """ self.bias_atom_e = self._compute_output_stats(all_stat, rcond = self.rcond) + if self.bias_atom_e is not None: + assert (len(self.bias_atom_e) == self.ntypes) + for type_i in range(self.ntypes): + type_bias_ae = self.bias_atom_e[type_i] + paddle.seed(self.seed) + normal_init_ = paddle.nn.initializer.Normal(mean=type_bias_ae, std=1.0) + final_layer = self.ElementNets[type_i][-1] + normal_init_(final_layer.bias) + @classmethod def _compute_output_stats(self, all_stat, rcond = 1e-3): @@ -175,6 +170,7 @@ def compute_input_stats(self, protection Divided-by-zero protection """ + # stat fparam if self.numb_fparam > 0: cat_data = np.concatenate(all_stat['fparam'], axis = 0) @@ -185,6 +181,12 @@ def compute_input_stats(self, if self.fparam_std[ii] < protection: self.fparam_std[ii] = protection self.fparam_inv_std = 1./self.fparam_std + + self.t_fparam_avg = paddle.to_tensor(self.fparam_avg, + dtype = GLOBAL_PD_FLOAT_PRECISION) + self.t_fparam_istd = paddle.to_tensor(self.fparam_inv_std, + dtype = GLOBAL_PD_FLOAT_PRECISION) + # stat aparam if self.numb_aparam > 0: sys_sumv = [] @@ -205,164 +207,72 @@ def compute_input_stats(self, self.aparam_std[ii] = protection self.aparam_inv_std = 1./self.aparam_std + self.t_aparam_avg = paddle.to_tensor(self.aparam_avg, + dtype = GLOBAL_PD_FLOAT_PRECISION) + self.t_aparam_istd = paddle.to_tensor(self.aparam_inv_std, + dtype = GLOBAL_PD_FLOAT_PRECISION) + def _compute_std (self, sumv2, sumv, sumn) : return np.sqrt(sumv2/sumn - np.multiply(sumv/sumn, sumv/sumn)) - - def build (self, - inputs : tf.Tensor, - natoms : tf.Tensor, - input_dict : dict = {}, - reuse : bool = None, - suffix : str = '' - ) -> tf.Tensor: - """ - Build the computational graph for fitting net - Parameters - ---------- - inputs - The input descriptor - input_dict - Additional dict for inputs. - if numb_fparam > 0, should have input_dict['fparam'] - if numb_aparam > 0, should have input_dict['aparam'] - natoms - The number of atoms. This tensor has the length of Ntypes + 2 - natoms[0]: number of local atoms - natoms[1]: total number of atoms held by this processor - natoms[i]: 2 <= i < Ntypes+2, number of type i atoms - reuse - The weights in the networks should be reused when get the variable. - suffix - Name suffix to identify this descriptor - - Return - ------ - ener - The system energy - """ - bias_atom_e = self.bias_atom_e - if self.numb_fparam > 0 and ( self.fparam_avg is None or self.fparam_inv_std is None ): + def forward(self, inputs, natoms, input_dict, reuse=None, suffix=''): + if self.numb_fparam > 0 and (self.fparam_avg is None or self.fparam_inv_std is None): raise RuntimeError('No data stat result. one should do data statisitic, before build') - if self.numb_aparam > 0 and ( self.aparam_avg is None or self.aparam_inv_std is None ): + if self.numb_aparam > 0 and (self.aparam_avg is None or self.aparam_inv_std is None): raise RuntimeError('No data stat result. one should do data statisitic, before build') - with tf.variable_scope('fitting_attr' + suffix, reuse = reuse) : - t_dfparam = tf.constant(self.numb_fparam, - name = 'dfparam', - dtype = tf.int32) - t_daparam = tf.constant(self.numb_aparam, - name = 'daparam', - dtype = tf.int32) - if self.numb_fparam > 0: - t_fparam_avg = tf.get_variable('t_fparam_avg', - self.numb_fparam, - dtype = GLOBAL_TF_FLOAT_PRECISION, - trainable = False, - initializer = tf.constant_initializer(self.fparam_avg)) - t_fparam_istd = tf.get_variable('t_fparam_istd', - self.numb_fparam, - dtype = GLOBAL_TF_FLOAT_PRECISION, - trainable = False, - initializer = tf.constant_initializer(self.fparam_inv_std)) - if self.numb_aparam > 0: - t_aparam_avg = tf.get_variable('t_aparam_avg', - self.numb_aparam, - dtype = GLOBAL_TF_FLOAT_PRECISION, - trainable = False, - initializer = tf.constant_initializer(self.aparam_avg)) - t_aparam_istd = tf.get_variable('t_aparam_istd', - self.numb_aparam, - dtype = GLOBAL_TF_FLOAT_PRECISION, - trainable = False, - initializer = tf.constant_initializer(self.aparam_inv_std)) - start_index = 0 - inputs = tf.cast(tf.reshape(inputs, [-1, self.dim_descrpt * natoms[0]]), self.fitting_precision) - - if bias_atom_e is not None : - assert(len(bias_atom_e) == self.ntypes) + inputs = paddle.cast(paddle.reshape(inputs, [-1, self.dim_descrpt * natoms[0]]), self.fitting_precision) - if self.numb_fparam > 0 : + if self.numb_fparam > 0: fparam = input_dict['fparam'] - fparam = tf.reshape(fparam, [-1, self.numb_fparam]) - fparam = (fparam - t_fparam_avg) * t_fparam_istd - if self.numb_aparam > 0 : + fparam = paddle.reshape(fparam, [-1, self.numb_fparam]) + fparam = (fparam - self.fparam_avg) * self.fparam_inv_std + if self.numb_aparam > 0: aparam = input_dict['aparam'] - aparam = tf.reshape(aparam, [-1, self.numb_aparam]) - aparam = (aparam - t_aparam_avg) * t_aparam_istd - aparam = tf.reshape(aparam, [-1, self.numb_aparam * natoms[0]]) + aparam = paddle.reshape(aparam, [-1, self.numb_aparam]) + aparam = (aparam - self.aparam_avg) * self.aparam_inv_std + aparam = paddle.reshape(aparam, [-1, self.numb_aparam * natoms[0]]) for type_i in range(self.ntypes): # cut-out inputs - inputs_i = tf.slice (inputs, - [ 0, start_index* self.dim_descrpt], - [-1, natoms[2+type_i]* self.dim_descrpt] ) - inputs_i = tf.reshape(inputs_i, [-1, self.dim_descrpt]) + inputs_i = paddle.slice(inputs, [1], + [start_index * self.dim_descrpt], + [(start_index + natoms[2 + type_i]) * self.dim_descrpt]) + inputs_i = paddle.reshape(inputs_i, [-1, self.dim_descrpt]) layer = inputs_i - if self.numb_fparam > 0 : - ext_fparam = tf.tile(fparam, [1, natoms[2+type_i]]) - ext_fparam = tf.reshape(ext_fparam, [-1, self.numb_fparam]) - ext_fparam = tf.cast(ext_fparam,self.fitting_precision) - layer = tf.concat([layer, ext_fparam], axis = 1) - if self.numb_aparam > 0 : - ext_aparam = tf.slice(aparam, - [ 0, start_index * self.numb_aparam], - [-1, natoms[2+type_i] * self.numb_aparam]) - ext_aparam = tf.reshape(ext_aparam, [-1, self.numb_aparam]) - ext_aparam = tf.cast(ext_aparam,self.fitting_precision) - layer = tf.concat([layer, ext_aparam], axis = 1) - start_index += natoms[2+type_i] - - if bias_atom_e is None : - type_bias_ae = 0.0 - else : - type_bias_ae = bias_atom_e[type_i] - - for ii in range(0,len(self.n_neuron)) : + if self.numb_fparam > 0: + ext_fparam = paddle.tile(fparam, [1, natoms[2 + type_i]]) + ext_fparam = paddle.reshape(ext_fparam, [-1, self.numb_fparam]) + ext_fparam = paddle.cast(ext_fparam, self.fitting_precision) + layer = paddle.concat([layer, ext_fparam], axis=1) + if self.numb_aparam > 0: + ext_aparam = paddle.slice(aparam, [1] + [start_index * self.numb_aparam], + [(start_index + natoms[2 + type_i]) * self.numb_aparam]) + ext_aparam = paddle.reshape(ext_aparam, [-1, self.numb_aparam]) + ext_aparam = paddle.cast(ext_aparam, self.fitting_precision) + layer = paddle.concat([layer, ext_aparam], axis=1) + start_index += natoms[2 + type_i] + + + for ii in range(0, len(self.n_neuron)) : if ii >= 1 and self.n_neuron[ii] == self.n_neuron[ii-1] : - layer+= one_layer(layer, self.n_neuron[ii], name='layer_'+str(ii)+'_type_'+str(type_i)+suffix, reuse=reuse, seed = self.seed, use_timestep = self.resnet_dt, activation_fn = self.fitting_activation_fn, precision = self.fitting_precision, trainable = self.trainable[ii]) + layer += self.ElementNets[type_i][ii](layer) else : - layer = one_layer(layer, self.n_neuron[ii], name='layer_'+str(ii)+'_type_'+str(type_i)+suffix, reuse=reuse, seed = self.seed, activation_fn = self.fitting_activation_fn, precision = self.fitting_precision, trainable = self.trainable[ii]) - final_layer = one_layer(layer, 1, activation_fn = None, bavg = type_bias_ae, name='final_layer_type_'+str(type_i)+suffix, reuse=reuse, seed = self.seed, precision = self.fitting_precision, trainable = self.trainable[-1]) - - if type_i < len(self.atom_ener) and self.atom_ener[type_i] is not None: - inputs_zero = tf.zeros_like(inputs_i, dtype=GLOBAL_TF_FLOAT_PRECISION) - layer = inputs_zero - if self.numb_fparam > 0 : - layer = tf.concat([layer, ext_fparam], axis = 1) - if self.numb_aparam > 0 : - layer = tf.concat([layer, ext_aparam], axis = 1) - for ii in range(0,len(self.n_neuron)) : - if ii >= 1 and self.n_neuron[ii] == self.n_neuron[ii-1] : - layer+= one_layer(layer, self.n_neuron[ii], name='layer_'+str(ii)+'_type_'+str(type_i)+suffix, reuse=True, seed = self.seed, use_timestep = self.resnet_dt, activation_fn = self.fitting_activation_fn, precision = self.fitting_precision, trainable = self.trainable[ii]) - else : - layer = one_layer(layer, self.n_neuron[ii], name='layer_'+str(ii)+'_type_'+str(type_i)+suffix, reuse=True, seed = self.seed, activation_fn = self.fitting_activation_fn, precision = self.fitting_precision, trainable = self.trainable[ii]) - zero_layer = one_layer(layer, 1, activation_fn = None, bavg = type_bias_ae, name='final_layer_type_'+str(type_i)+suffix, reuse=True, seed = self.seed, precision = self.fitting_precision, trainable = self.trainable[-1]) - final_layer += self.atom_ener[type_i] - zero_layer - - final_layer = tf.reshape(final_layer, [tf.shape(inputs)[0], natoms[2+type_i]]) + layer = self.ElementNets[type_i][ii](layer) + final_layer = self.ElementNets[type_i][len(self.n_neuron)](layer) + + # if type_i < len(self.atom_ener) and self.atom_ener[type_i] is not None: (Not implement) + + final_layer = paddle.reshape(final_layer, [inputs.shape[0], natoms[2 + type_i]]) # concat the results if type_i == 0: outs = final_layer else: - outs = tf.concat([outs, final_layer], axis = 1) - - if self.tot_ener_zero: - force_tot_ener = 0.0 - outs = tf.reshape(outs, [-1, natoms[0]]) - outs_mean = tf.reshape(tf.reduce_mean(outs, axis = 1), [-1, 1]) - outs_mean = outs_mean - tf.ones_like(outs_mean, dtype = GLOBAL_TF_FLOAT_PRECISION) * (force_tot_ener/global_cvt_2_tf_float(natoms[0])) - outs = outs - outs_mean - outs = tf.reshape(outs, [-1]) - - tf.summary.histogram('fitting_net_output', outs) - return tf.cast(tf.reshape(outs, [-1]), GLOBAL_TF_FLOAT_PRECISION) - - - - + outs = paddle.concat([outs, final_layer], axis=1) + return paddle.cast(paddle.reshape(outs, [-1]), GLOBAL_PD_FLOAT_PRECISION) diff --git a/deepmd/infer/deep_eval.py b/deepmd/infer/deep_eval.py index 92a45e3cd4..d6cd05dba5 100644 --- a/deepmd/infer/deep_eval.py +++ b/deepmd/infer/deep_eval.py @@ -2,8 +2,12 @@ from typing import List, Optional, TYPE_CHECKING import numpy as np -from deepmd.common import make_default_mesh -from deepmd.env import default_tf_session_config, tf, MODEL_VERSION +import json +from deepmd.common import make_default_mesh, j_must_have +from deepmd.env import MODEL_VERSION, paddle, tf +from deepmd.fit import EnerFitting +from deepmd.descriptor import DescrptSeA +from deepmd.model import EnerModel if TYPE_CHECKING: from pathlib import Path @@ -22,9 +26,35 @@ def __init__( load_prefix: str = "load", default_tf_graph: bool = False ): - self.graph = self._load_graph( - model_file, prefix=load_prefix, default_tf_graph=default_tf_graph - ) + ##### Hard code, will change to use dy2stat, avoid to build model ####### + ##### Now use paddle.load temporarily####### + with open("out.json", 'r') as load_f: + jdata = json.load(load_f) + + model_param = j_must_have(jdata, 'model') + descrpt_param = j_must_have(model_param, 'descriptor') + descrpt_param.pop('type', None) + self.descrpt = descrpt = DescrptSeA(**descrpt_param) + + fitting_param = j_must_have(model_param, 'fitting_net') + fitting_param.pop('type', None) + fitting_param['descrpt'] = self.descrpt + self.fitting = EnerFitting(**fitting_param) + + self.model = EnerModel( + self.descrpt, + self.fitting, + model_param.get('type_map'), + model_param.get('data_stat_nbatch', 10), + model_param.get('data_stat_protect', 1e-2), + model_param.get('use_srtab'), + model_param.get('smin_alpha'), + model_param.get('sw_rmin'), + model_param.get('sw_rmax') + ) + self.model.set_dict(paddle.load(model_file)) + ################################################################ + self.load_prefix = load_prefix # graph_compatable should be called after graph and prefix are set @@ -40,11 +70,7 @@ def model_type(self) -> str: :type:str """ - if not self._model_type: - t_mt = self._get_tensor("model_attr/model_type:0") - sess = tf.Session(graph=self.graph, config=default_tf_session_config) - [mt] = sess.run([t_mt], feed_dict={}) - self._model_type = mt.decode("utf-8") + self._model_type = self.model.t_mt return self._model_type @property @@ -53,16 +79,14 @@ def model_version(self) -> str: :type:str """ + if not self._model_version: try: - t_mt = self._get_tensor("model_attr/model_version:0") - sess = tf.Session(graph=self.graph, config=default_tf_session_config) - [mt] = sess.run([t_mt], feed_dict={}) - self._model_version = mt.decode("utf-8") + self._model_version = self.model.t_ver except KeyError: # For deepmd-kit version 0.x - 1.x, set model version to 0.0 self._model_version = "0.0" - return self._model_version + return self._model_version def _graph_compatable( self @@ -83,31 +107,20 @@ def _graph_compatable( else: return True - def _get_tensor( + def _get_value( self, tensor_name: str, attr_name: Optional[str] = None - ) -> tf.Tensor: - """Get TF graph tensor and assign it to class namespace. - - Parameters - ---------- - tensor_name : str - name of tensor to get - attr_name : Optional[str], optional - if specified, class attribute with this name will be created and tensor will - be assigned to it, by default None - - Returns - ------- - tf.Tensor - loaded tensor + ): + """ """ - tensor_path = os.path.join(self.load_prefix, tensor_name) - tensor = self.graph.get_tensor_by_name(tensor_path) + value = None + for name, tensor in self.model.named_buffers(): + if tensor_name in name: + value = tensor.numpy()[0] if tensor.shape == [1] else tensor.numpy() if attr_name: - setattr(self, attr_name, tensor) - return tensor + setattr(self, attr_name, value) + return value else: - return tensor + return value @staticmethod def _load_graph( diff --git a/deepmd/infer/deep_pot.py b/deepmd/infer/deep_pot.py index a8e70d5a72..4ea6147336 100644 --- a/deepmd/infer/deep_pot.py +++ b/deepmd/infer/deep_pot.py @@ -2,8 +2,9 @@ from typing import TYPE_CHECKING, List, Optional, Tuple import numpy as np +from collections import defaultdict from deepmd.common import make_default_mesh -from deepmd.env import default_tf_session_config, tf +from deepmd.env import paddle, GLOBAL_PD_FLOAT_PRECISION, GLOBAL_NP_FLOAT_PRECISION from deepmd.infer.data_modifier import DipoleChargeModifier from deepmd.infer.deep_eval import DeepEval @@ -45,94 +46,55 @@ def __init__( self.tensors = dict( { # descrpt attrs - "t_ntypes": "descrpt_attr/ntypes:0", - "t_rcut": "descrpt_attr/rcut:0", + "ntypes": "descrpt.t_ntypes", + "rcut": "descrpt.t_rcut", # fitting attrs - "t_dfparam": "fitting_attr/dfparam:0", - "t_daparam": "fitting_attr/daparam:0", - # model attrs - "t_tmap": "model_attr/tmap:0", - # inputs - "t_coord": "t_coord:0", - "t_type": "t_type:0", - "t_natoms": "t_natoms:0", - "t_box": "t_box:0", - "t_mesh": "t_mesh:0", - # add output tensors - "t_energy": "o_energy:0", - "t_force": "o_force:0", - "t_virial": "o_virial:0", - "t_ae": "o_atom_energy:0", - "t_av": "o_atom_virial:0" + "dfparam": "fitting.t_dfparam", + "daparam": "fitting.t_daparam", }, ) DeepEval.__init__( self, model_file, load_prefix=load_prefix, - default_tf_graph=default_tf_graph ) - # load optional tensors - operations = [op.name for op in self.graph.get_operations()] - # check if the graph has these operations: - # if yes add them - if 't_efield' in operations: - self._get_tensor("t_efield:0", "t_efield") + if self._get_value("t_efield") is not None: + self._get_tensor("t_efield", "t_efield") self.has_efield = True else: log.debug(f"Could not get tensor 't_efield:0'") self.t_efield = None self.has_efield = False - if 'load/t_fparam' in operations: - self.tensors.update({"t_fparam": "t_fparam:0"}) + if self._get_value("t_fparam") is not None: + self.tensors.update({"t_fparam": "t_fparam"}) self.has_fparam = True else: - log.debug(f"Could not get tensor 't_fparam:0'") + log.debug(f"Could not get tensor 't_fparam'") self.t_fparam = None self.has_fparam = False - if 'load/t_aparam' in operations: - self.tensors.update({"t_aparam": "t_aparam:0"}) + if self._get_value("t_aparam") is not None: + self.tensors.update({"t_aparam": "t_aparam"}) self.has_aparam = True else: - log.debug(f"Could not get tensor 't_aparam:0'") + log.debug(f"Could not get tensor 't_aparam'") self.t_aparam = None self.has_aparam = False # now load tensors to object attributes for attr_name, tensor_name in self.tensors.items(): - self._get_tensor(tensor_name, attr_name) + self._get_value(tensor_name, attr_name) - # start a tf session associated to the graph - self.sess = tf.Session(graph=self.graph, config=default_tf_session_config) - self._run_default_sess() - self.tmap = self.tmap.decode('UTF-8').split() + self.t_tmap = self.model.t_tmap.split() # setup modifier try: - t_modifier_type = self._get_tensor("modifier_attr/type:0") - self.modifier_type = self.sess.run(t_modifier_type).decode("UTF-8") + self.modifier_type = self._get_value("modifier_attr.type") except (ValueError, KeyError): self.modifier_type = None - if self.modifier_type == "dipole_charge": - t_mdl_name = self._get_tensor("modifier_attr/mdl_name:0") - t_mdl_charge_map = self._get_tensor("modifier_attr/mdl_charge_map:0") - t_sys_charge_map = self._get_tensor("modifier_attr/sys_charge_map:0") - t_ewald_h = self._get_tensor("modifier_attr/ewald_h:0") - t_ewald_beta = self._get_tensor("modifier_attr/ewald_beta:0") - [mdl_name, mdl_charge_map, sys_charge_map, ewald_h, ewald_beta] = self.sess.run([t_mdl_name, t_mdl_charge_map, t_sys_charge_map, t_ewald_h, t_ewald_beta]) - mdl_charge_map = [int(ii) for ii in mdl_charge_map.decode("UTF-8").split()] - sys_charge_map = [int(ii) for ii in sys_charge_map.decode("UTF-8").split()] - self.dm = DipoleChargeModifier(mdl_name, mdl_charge_map, sys_charge_map, ewald_h = ewald_h, ewald_beta = ewald_beta) - - def _run_default_sess(self): - [self.ntypes, self.rcut, self.dfparam, self.daparam, self.tmap] = self.sess.run( - [self.t_ntypes, self.t_rcut, self.t_dfparam, self.t_daparam, self.t_tmap] - ) - def get_ntypes(self) -> int: """Get the number of atom types of this model.""" return self.ntypes @@ -143,11 +105,12 @@ def get_rcut(self) -> float: def get_type_map(self) -> List[int]: """Get the type map (element name of the atom types) of this model.""" - return self.tmap + return self.t_tmap def get_sel_type(self) -> List[int]: """Unsupported in this model.""" raise NotImplementedError("This model type does not support this attribute") + def get_dim_fparam(self) -> int: """Get the number (dimension) of frame parameters of this DP.""" return self.dfparam @@ -288,35 +251,29 @@ def _eval_inner( assert(natoms_vec[0] == natoms) # evaluate - feed_dict_test = {} - feed_dict_test[self.t_natoms] = natoms_vec - feed_dict_test[self.t_type ] = np.tile(atom_types, [nframes, 1]).reshape([-1]) - t_out = [self.t_energy, - self.t_force, - self.t_virial] - if atomic : - t_out += [self.t_ae, - self.t_av] - - feed_dict_test[self.t_coord] = np.reshape(coords, [-1]) - feed_dict_test[self.t_box ] = np.reshape(cells , [-1]) - if self.has_efield: - feed_dict_test[self.t_efield]= np.reshape(efield, [-1]) - if pbc: - feed_dict_test[self.t_mesh ] = make_default_mesh(cells) - else: - feed_dict_test[self.t_mesh ] = np.array([], dtype = np.int32) + eval_inputs = {} + eval_inputs['coord'] = paddle.to_tensor(np.reshape(coords, [-1]), dtype=GLOBAL_PD_FLOAT_PRECISION) + eval_inputs['type'] = paddle.to_tensor(np.tile(atom_types, [nframes, 1]).reshape([-1]), dtype="int32") + eval_inputs['natoms_vec'] = paddle.to_tensor(natoms_vec, dtype="int32") + eval_inputs['box'] = paddle.to_tensor(np.reshape(cells , [-1]), dtype=GLOBAL_PD_FLOAT_PRECISION) + if self.has_fparam: - feed_dict_test[self.t_fparam] = np.reshape(fparam, [-1]) + eval_inputs["fparam"] = paddle.to_tensor(np.reshape(fparam, [-1], dtype=GLOBAL_PD_FLOAT_PRECISION)) if self.has_aparam: - feed_dict_test[self.t_aparam] = np.reshape(aparam, [-1]) - v_out = self.sess.run (t_out, feed_dict = feed_dict_test) - energy = v_out[0] - force = v_out[1] - virial = v_out[2] + eval_inputs["aparam"] = paddle.to_tensor(np.reshape(aparam, [-1], dtype=GLOBAL_PD_FLOAT_PRECISION)) + if pbc: + eval_inputs['default_mesh'] = paddle.to_tensor(make_default_mesh(cells), dtype="int32") + else: + eval_inputs['default_mesh'] = paddle.to_tensor(np.array([], dtype = np.int32)) + + eval_outputs = self.model(eval_inputs['coord'], eval_inputs['type'], eval_inputs['natoms_vec'], eval_inputs['box'], eval_inputs['default_mesh'], eval_inputs, suffix = "", reuse = False) + + energy = eval_outputs['energy'].numpy() + force = eval_outputs['force'].numpy() + virial = eval_outputs['virial'].numpy() if atomic: - ae = v_out[3] - av = v_out[4] + ae = eval_outputs['atom_ener'].numpy() + av = eval_outputs['atom_virial'].numpy() # reverse map of the outputs force = self.reverse_map(np.reshape(force, [nframes,-1,3]), imap) diff --git a/deepmd/load_paddle_op.py b/deepmd/load_paddle_op.py index d418d2fc90..d33d0b0914 100644 --- a/deepmd/load_paddle_op.py +++ b/deepmd/load_paddle_op.py @@ -1,17 +1,20 @@ from paddle.utils.cpp_extension import CppExtension, CUDAExtension, setup +import site + +site_package_dir = site.getsitepackages()[0] setup( name='paddle_ops', - ext_modules=CppExtension( + ext_modules=CUDAExtension( sources=['../source/op/paddle_ops/srcs/pd_prod_env_mat_multi_devices_cpu.cc', '../source/op/paddle_ops/srcs/pd_prod_env_mat_multi_devices_cuda.cc', '../source/op/paddle_ops/srcs/pd_prod_force_se_a_multi_devices_cpu.cc', '../source/op/paddle_ops/srcs/pd_prod_force_se_a_multi_devices_cuda.cc', '../source/op/paddle_ops/srcs/pd_prod_virial_se_a_multi_devices_cpu.cc', '../source/op/paddle_ops/srcs/pd_prod_virial_se_a_multi_devices_cuda.cc'], - include_dirs=["../source/lib/include/","/usr/local/cuda-10.1/targets/x86_64-linux/include/"], - library_dirs=["../build/lib/", "/usr/local/cuda-10.1/lib64"], - extra_link_args=["-ldeepmd","-lcudart"] + include_dirs=["../source/lib/include/"], + library_dirs=[site_package_dir+"/deepmd/op"], + extra_link_args=["-ldeepmd"] ) ) diff --git a/deepmd/loss/ener.py b/deepmd/loss/ener.py index f25cb42219..e5bea2f028 100644 --- a/deepmd/loss/ener.py +++ b/deepmd/loss/ener.py @@ -1,8 +1,9 @@ import numpy as np -from deepmd.env import tf +from deepmd.env import tf, paddle from deepmd.common import ClassArg, add_data_requirement from deepmd.env import global_cvt_2_tf_float +from deepmd.env import global_cvt_2_pd_float from deepmd.env import global_cvt_2_ener_float class EnerStdLoss () : @@ -47,7 +48,7 @@ def __init__ (self, add_data_requirement('atom_ener', 1, atomic=True, must=False, high_prec=False) add_data_requirement('atom_pref', 1, atomic=True, must=False, high_prec=False, repeat=3) - def build (self, + def calculate_loss (self, learning_rate, natoms, model_dict, @@ -68,36 +69,36 @@ def build (self, find_atom_ener = label_dict['find_atom_ener'] find_atom_pref = label_dict['find_atom_pref'] - l2_ener_loss = tf.reduce_mean( tf.square(energy - energy_hat), name='l2_'+suffix) + l2_ener_loss = paddle.mean(paddle.square(energy - energy_hat), name='l2_'+suffix) - force_reshape = tf.reshape (force, [-1]) - force_hat_reshape = tf.reshape (force_hat, [-1]) - atom_pref_reshape = tf.reshape (atom_pref, [-1]) + force_reshape = paddle.reshape (force, [-1]) + force_hat_reshape = paddle.reshape (force_hat, [-1]) + atom_pref_reshape = paddle.reshape (atom_pref, [-1]) diff_f = force_hat_reshape - force_reshape if self.relative_f is not None: - force_hat_3 = tf.reshape(force_hat, [-1, 3]) - norm_f = tf.reshape(tf.norm(force_hat_3, axis = 1), [-1, 1]) + self.relative_f - diff_f_3 = tf.reshape(diff_f, [-1, 3]) + force_hat_3 = paddle.reshape(force_hat, [-1, 3]) + norm_f = paddle.reshape(paddle.norm(force_hat_3, p=2, axis = 1), [-1, 1]) + self.relative_f + diff_f_3 = paddle.reshape(diff_f, [-1, 3]) diff_f_3 = diff_f_3 / norm_f - diff_f = tf.reshape(diff_f_3, [-1]) - l2_force_loss = tf.reduce_mean(tf.square(diff_f), name = "l2_force_" + suffix) - l2_pref_force_loss = tf.reduce_mean(tf.multiply(tf.square(diff_f), atom_pref_reshape), name = "l2_pref_force_" + suffix) + diff_f = paddle.reshape(diff_f_3, [-1]) + l2_force_loss = paddle.mean(paddle.square(diff_f), name = "l2_force_" + suffix) + l2_pref_force_loss = paddle.mean(paddle.multiply(paddle.square(diff_f), atom_pref_reshape), name = "l2_pref_force_" + suffix) - virial_reshape = tf.reshape (virial, [-1]) - virial_hat_reshape = tf.reshape (virial_hat, [-1]) - l2_virial_loss = tf.reduce_mean (tf.square(virial_hat_reshape - virial_reshape), name = "l2_virial_" + suffix) + virial_reshape = paddle.reshape(virial, [-1]) + virial_hat_reshape = paddle.reshape (virial_hat, [-1]) + l2_virial_loss = paddle.mean(paddle.square(virial_hat_reshape - virial_reshape), name = "l2_virial_" + suffix) - atom_ener_reshape = tf.reshape (atom_ener, [-1]) - atom_ener_hat_reshape = tf.reshape (atom_ener_hat, [-1]) - l2_atom_ener_loss = tf.reduce_mean (tf.square(atom_ener_hat_reshape - atom_ener_reshape), name = "l2_atom_ener_" + suffix) + atom_ener_reshape = paddle.reshape (atom_ener, [-1]) + atom_ener_hat_reshape = paddle.reshape (atom_ener_hat, [-1]) + l2_atom_ener_loss = paddle.mean (paddle.square(atom_ener_hat_reshape - atom_ener_reshape), name = "l2_atom_ener_" + suffix) - atom_norm = 1./ global_cvt_2_tf_float(natoms[0]) + atom_norm = 1./ global_cvt_2_pd_float(natoms[0]) atom_norm_ener = 1./ global_cvt_2_ener_float(natoms[0]) - pref_e = global_cvt_2_ener_float(find_energy * (self.limit_pref_e + (self.start_pref_e - self.limit_pref_e) * learning_rate / self.starter_learning_rate) ) - pref_f = global_cvt_2_tf_float(find_force * (self.limit_pref_f + (self.start_pref_f - self.limit_pref_f) * learning_rate / self.starter_learning_rate) ) - pref_v = global_cvt_2_tf_float(find_virial * (self.limit_pref_v + (self.start_pref_v - self.limit_pref_v) * learning_rate / self.starter_learning_rate) ) - pref_ae= global_cvt_2_tf_float(find_atom_ener * (self.limit_pref_ae+ (self.start_pref_ae-self.limit_pref_ae) * learning_rate / self.starter_learning_rate) ) - pref_pf= global_cvt_2_tf_float(find_atom_pref * (self.limit_pref_pf+ (self.start_pref_pf-self.limit_pref_pf) * learning_rate / self.starter_learning_rate) ) + pref_e = global_cvt_2_ener_float(find_energy * (self.limit_pref_e + (self.start_pref_e - self.limit_pref_e) * learning_rate / self.starter_learning_rate)) + pref_f = global_cvt_2_pd_float(find_force * (self.limit_pref_f + (self.start_pref_f - self.limit_pref_f) * learning_rate / self.starter_learning_rate) ) + pref_v = global_cvt_2_pd_float(find_virial * (self.limit_pref_v + (self.start_pref_v - self.limit_pref_v) * learning_rate / self.starter_learning_rate) ) + pref_ae= global_cvt_2_pd_float(find_atom_ener * (self.limit_pref_ae+ (self.start_pref_ae-self.limit_pref_ae) * learning_rate / self.starter_learning_rate) ) + pref_pf= global_cvt_2_pd_float(find_atom_pref * (self.limit_pref_pf+ (self.start_pref_pf-self.limit_pref_pf) * learning_rate / self.starter_learning_rate) ) l2_loss = 0 more_loss = {} @@ -105,24 +106,18 @@ def build (self, l2_loss += atom_norm_ener * (pref_e * l2_ener_loss) more_loss['l2_ener_loss'] = l2_ener_loss if self.has_f : - l2_loss += global_cvt_2_ener_float(pref_f * l2_force_loss) + l2_loss += global_cvt_2_pd_float(pref_f * l2_force_loss) more_loss['l2_force_loss'] = l2_force_loss if self.has_v : - l2_loss += global_cvt_2_ener_float(atom_norm * (pref_v * l2_virial_loss)) + l2_loss += global_cvt_2_pd_float(atom_norm * (pref_v * l2_virial_loss)) more_loss['l2_virial_loss'] = l2_virial_loss if self.has_ae : - l2_loss += global_cvt_2_ener_float(pref_ae * l2_atom_ener_loss) + l2_loss += global_cvt_2_pd_float(pref_ae * l2_atom_ener_loss) more_loss['l2_atom_ener_loss'] = l2_atom_ener_loss if self.has_pf : - l2_loss += global_cvt_2_ener_float(pref_pf * l2_pref_force_loss) + l2_loss += global_cvt_2_pd_float(pref_pf * l2_pref_force_loss) more_loss['l2_pref_force_loss'] = l2_pref_force_loss - # only used when tensorboard was set as true - self.l2_loss_summary = tf.summary.scalar('l2_loss', tf.sqrt(l2_loss)) - self.l2_loss_ener_summary = tf.summary.scalar('l2_ener_loss', global_cvt_2_tf_float(tf.sqrt(l2_ener_loss)) / global_cvt_2_tf_float(natoms[0])) - self.l2_loss_force_summary = tf.summary.scalar('l2_force_loss', tf.sqrt(l2_force_loss)) - self.l2_loss_virial_summary = tf.summary.scalar('l2_virial_loss', tf.sqrt(l2_virial_loss) / global_cvt_2_tf_float(natoms[0])) - self.l2_l = l2_loss self.l2_more = more_loss return l2_loss, more_loss @@ -141,59 +136,8 @@ def print_header(self): print_str += prop_fmt % ('rmse_v_tst', 'rmse_v_trn') if self.has_pf : print_str += prop_fmt % ('rmse_pf_tst', 'rmse_pf_trn') - return print_str - - def print_on_training(self, - tb_writer, - cur_batch, - sess, - natoms, - feed_dict_test, - feed_dict_batch): - run_data = [ - self.l2_l, - self.l2_more['l2_ener_loss'], - self.l2_more['l2_force_loss'], - self.l2_more['l2_virial_loss'], - self.l2_more['l2_atom_ener_loss'], - self.l2_more['l2_pref_force_loss'] - ] - - # first train data - train_out = sess.run(run_data, feed_dict=feed_dict_batch) - error_train, error_e_train, error_f_train, error_v_train, error_ae_train, error_pf_train = train_out - - # than test data, if tensorboard log writter is present, commpute summary - # and write tensorboard logs - if tb_writer: - summary_merged_op = tf.summary.merge([self.l2_loss_summary, self.l2_loss_ener_summary, self.l2_loss_force_summary, self.l2_loss_virial_summary]) - run_data.insert(0, summary_merged_op) - - test_out = sess.run(run_data, feed_dict=feed_dict_test) - - if tb_writer: - summary = test_out.pop(0) - tb_writer.add_summary(summary, cur_batch) - - error_test, error_e_test, error_f_test, error_v_test, error_ae_test, error_pf_test = test_out - - - print_str = "" - prop_fmt = " %11.2e %11.2e" - print_str += prop_fmt % (np.sqrt(error_test), np.sqrt(error_train)) - if self.has_e : - print_str += prop_fmt % (np.sqrt(error_e_test) / natoms[0], np.sqrt(error_e_train) / natoms[0]) - if self.has_ae : - print_str += prop_fmt % (np.sqrt(error_ae_test), np.sqrt(error_ae_train)) - if self.has_f : - print_str += prop_fmt % (np.sqrt(error_f_test), np.sqrt(error_f_train)) - if self.has_v : - print_str += prop_fmt % (np.sqrt(error_v_test) / natoms[0], np.sqrt(error_v_train) / natoms[0]) - if self.has_pf: - print_str += prop_fmt % (np.sqrt(error_pf_test), np.sqrt(error_pf_train)) - - return print_str + return print_str class EnerDipoleLoss () : @@ -318,7 +262,4 @@ def print_on_training(self, print_str += prop_fmt % (np.sqrt(error_test), np.sqrt(error_train)) print_str += prop_fmt % (np.sqrt(error_e_test) / natoms[0], np.sqrt(error_e_train) / natoms[0]) print_str += prop_fmt % (np.sqrt(error_ed_test), np.sqrt(error_ed_train)) - return print_str - - - + return print_str \ No newline at end of file diff --git a/deepmd/model/ener.py b/deepmd/model/ener.py index 3cc3d4bd8b..b0335a0aeb 100644 --- a/deepmd/model/ener.py +++ b/deepmd/model/ener.py @@ -1,14 +1,17 @@ import numpy as np from typing import Tuple, List -from deepmd.env import tf +from deepmd.env import paddle from deepmd.utils.pair_tab import PairTab from deepmd.common import ClassArg -from deepmd.env import global_cvt_2_ener_float, MODEL_VERSION -from deepmd.env import op_module +from deepmd.env import global_cvt_2_ener_float, MODEL_VERSION, GLOBAL_ENER_FLOAT_PRECISION +from deepmd.env import op_module, paddle_ops from .model_stat import make_stat_input, merge_sys_stat -class EnerModel() : +import sys + + +class EnerModel(paddle.nn.Layer) : model_type = 'ener' def __init__ ( @@ -49,6 +52,7 @@ def __init__ ( The upper boundary of the interpolation between short-range tabulated interaction and DP. It is only required when `use_srtab` is provided. """ # descriptor + super(EnerModel, self).__init__(name_scope="EnerModel") self.descrpt = descrpt self.rcut = self.descrpt.get_rcut() self.ntypes = self.descrpt.get_ntypes() @@ -70,6 +74,10 @@ def __init__ ( self.sw_rmax = sw_rmax else : self.srtab = None + + self.t_tmap = ' '.join(self.type_map) + self.t_mt = self.model_type + self.t_ver = MODEL_VERSION def get_rcut (self) : @@ -100,8 +108,8 @@ def _compute_input_stat (self, all_stat, protection = 1e-2) : def _compute_output_stat (self, all_stat) : self.fitting.compute_output_stats(all_stat) - - def build (self, + #@paddle.jit.to_static + def forward (self, coord_, atype_, natoms, @@ -110,121 +118,37 @@ def build (self, input_dict, suffix = '', reuse = None): + coord = paddle.reshape(coord_, [-1, natoms[1] * 3]) + atype = paddle.reshape(atype_, [-1, natoms[1]]) - with tf.variable_scope('model_attr' + suffix, reuse = reuse) : - t_tmap = tf.constant(' '.join(self.type_map), - name = 'tmap', - dtype = tf.string) - t_mt = tf.constant(self.model_type, - name = 'model_type', - dtype = tf.string) - t_ver = tf.constant(MODEL_VERSION, - name = 'model_version', - dtype = tf.string) - - if self.srtab is not None : - tab_info, tab_data = self.srtab.get() - self.tab_info = tf.get_variable('t_tab_info', - tab_info.shape, - dtype = tf.float64, - trainable = False, - initializer = tf.constant_initializer(tab_info, dtype = tf.float64)) - self.tab_data = tf.get_variable('t_tab_data', - tab_data.shape, - dtype = tf.float64, - trainable = False, - initializer = tf.constant_initializer(tab_data, dtype = tf.float64)) - - coord = tf.reshape (coord_, [-1, natoms[1] * 3]) - atype = tf.reshape (atype_, [-1, natoms[1]]) - - dout \ - = self.descrpt.build(coord_, - atype_, - natoms, - box, - mesh, - input_dict, - suffix = suffix, - reuse = reuse) - dout = tf.identity(dout, name='o_descriptor') - - if self.srtab is not None : - nlist, rij, sel_a, sel_r = self.descrpt.get_nlist() - nnei_a = np.cumsum(sel_a)[-1] - nnei_r = np.cumsum(sel_r)[-1] - - atom_ener = self.fitting.build (dout, - natoms, - input_dict, - reuse = reuse, - suffix = suffix) - - if self.srtab is not None : - sw_lambda, sw_deriv \ - = op_module.soft_min_switch(atype, - rij, - nlist, - natoms, - sel_a = sel_a, - sel_r = sel_r, - alpha = self.smin_alpha, - rmin = self.sw_rmin, - rmax = self.sw_rmax) - inv_sw_lambda = 1.0 - sw_lambda - # NOTICE: - # atom energy is not scaled, - # force and virial are scaled - tab_atom_ener, tab_force, tab_atom_virial \ - = op_module.pair_tab(self.tab_info, - self.tab_data, - atype, - rij, - nlist, - natoms, - sw_lambda, - sel_a = sel_a, - sel_r = sel_r) - energy_diff = tab_atom_ener - tf.reshape(atom_ener, [-1, natoms[0]]) - tab_atom_ener = tf.reshape(sw_lambda, [-1]) * tf.reshape(tab_atom_ener, [-1]) - atom_ener = tf.reshape(inv_sw_lambda, [-1]) * atom_ener - energy_raw = tab_atom_ener + atom_ener - else : - energy_raw = atom_ener - - energy_raw = tf.reshape(energy_raw, [-1, natoms[0]], name = 'o_atom_energy'+suffix) - energy = tf.reduce_sum(global_cvt_2_ener_float(energy_raw), axis=1, name='o_energy'+suffix) - - force, virial, atom_virial \ - = self.descrpt.prod_force_virial (atom_ener, natoms) - - if self.srtab is not None : - sw_force \ - = op_module.soft_min_force(energy_diff, - sw_deriv, - nlist, - natoms, - n_a_sel = nnei_a, - n_r_sel = nnei_r) - force = force + sw_force + tab_force - - force = tf.reshape (force, [-1, 3 * natoms[1]], name = "o_force"+suffix) - - if self.srtab is not None : - sw_virial, sw_atom_virial \ - = op_module.soft_min_virial (energy_diff, - sw_deriv, - rij, - nlist, - natoms, - n_a_sel = nnei_a, - n_r_sel = nnei_r) - atom_virial = atom_virial + sw_atom_virial + tab_atom_virial - virial = virial + sw_virial \ - + tf.reduce_sum(tf.reshape(tab_atom_virial, [-1, natoms[1], 9]), axis = 1) - - virial = tf.reshape (virial, [-1, 9], name = "o_virial"+suffix) - atom_virial = tf.reshape (atom_virial, [-1, 9 * natoms[1]], name = "o_atom_virial"+suffix) + dout = self.descrpt(coord_, + atype_, + natoms, + box, + mesh, + input_dict, + suffix = suffix, + reuse = reuse) + + self.dout = dout + + atom_ener = self.fitting (dout, + natoms, + input_dict, + reuse = reuse, + suffix = suffix) + + self.atom_ener = atom_ener + energy_raw = atom_ener + + energy_raw = paddle.reshape(energy_raw, [-1, natoms[0]], name = 'o_atom_energy'+suffix) + energy = paddle.sum(paddle.cast(energy_raw, GLOBAL_ENER_FLOAT_PRECISION), axis=1, name='o_energy'+suffix) + + force, virial, atom_virial = self.descrpt.prod_force_virial(atom_ener, natoms) + + force = paddle.reshape(force, [-1, 3 * natoms[1]], name = "o_force"+suffix) + virial = paddle.reshape(virial, [-1, 9], name = "o_virial"+suffix) + atom_virial = paddle.reshape(atom_virial, [-1, 9 * natoms[1]], name = "o_atom_virial"+suffix) model_dict = {} model_dict['energy'] = energy @@ -234,6 +158,5 @@ def build (self, model_dict['atom_virial'] = atom_virial model_dict['coord'] = coord model_dict['atype'] = atype - - return model_dict + return model_dict diff --git a/deepmd/train/trainer.py b/deepmd/train/trainer.py index 1ea177fd95..e8fb2cee20 100644 --- a/deepmd/train/trainer.py +++ b/deepmd/train/trainer.py @@ -4,9 +4,9 @@ import time import shutil import numpy as np -from deepmd.env import tf +from deepmd.env import tf, paddle from deepmd.env import default_tf_session_config -from deepmd.env import GLOBAL_TF_FLOAT_PRECISION +from deepmd.env import GLOBAL_TF_FLOAT_PRECISION, GLOBAL_PD_FLOAT_PRECISION from deepmd.env import GLOBAL_ENER_FLOAT_PRECISION from deepmd.fit import EnerFitting, WFCFitting, PolarFittingLocFrame, PolarFittingSeA, GlobalPolarFittingSeA, DipoleFittingSeA from deepmd.descriptor import DescrptLocFrame @@ -25,6 +25,10 @@ from tensorflow.python.client import timeline from deepmd.env import op_module +from collections import defaultdict +import sys + + # load grad of force module import deepmd.op @@ -241,7 +245,7 @@ def _init_param(self, jdata): tr_args = ClassArg()\ .add('numb_test', [int, list, str], default = 1)\ .add('disp_file', str, default = 'lcurve.out')\ - .add('disp_freq', int, default = 100)\ + .add('disp_freq', int, default = 1)\ .add('save_freq', int, default = 1000)\ .add('save_ckpt', str, default = 'model.ckpt')\ .add('display_in_training', bool, default = True)\ @@ -273,6 +277,7 @@ def _init_param(self, jdata): else : self.numb_fparam = 0 + def build (self, data, stop_batch = 0) : @@ -293,175 +298,28 @@ def build (self, self.type_map = data.get_type_map() self.model.data_stat(data) + self.lr_scheduler = self.lr.build(self.stop_batch) - if 'compress' in self.model_param and self.model_param['compress']['compress']: - assert 'rcut' in self.descrpt_param, "Error: descriptor must have attr rcut!" - self.neighbor_stat \ - = NeighborStat(self.ntypes, self.descrpt_param['rcut']) - self.min_nbor_dist, self.max_nbor_size \ - = self.neighbor_stat.get_stat(data) - self.descrpt.enable_compression(self.min_nbor_dist, self.model_param['compress']['model_file'], self.model_param['compress']['table_config'][0], self.model_param['compress']['table_config'][1], self.model_param['compress']['table_config'][2], self.model_param['compress']['table_config'][3]) - - worker_device = "/job:%s/task:%d/%s" % (self.run_opt.my_job_name, - self.run_opt.my_task_index, - self.run_opt.my_device) - - with tf.device(tf.train.replica_device_setter(worker_device = worker_device, - cluster = self.run_opt.cluster_spec)): - self._build_lr() - self._build_network(data) - self._build_training() - - - def _build_lr(self): - self._extra_train_ops = [] - self.global_step = tf.train.get_or_create_global_step() - self.learning_rate = self.lr.build(self.global_step, self.stop_batch) - log.info("built lr") - - def _build_network(self, data): - self.place_holders = {} - data_dict = data.get_data_dict() - for kk in data_dict.keys(): - if kk == 'type': - continue - prec = GLOBAL_TF_FLOAT_PRECISION - if data_dict[kk]['high_prec'] : - prec = GLOBAL_ENER_FLOAT_PRECISION - self.place_holders[kk] = tf.placeholder(prec, [None], name = 't_' + kk) - self.place_holders['find_'+kk] = tf.placeholder(tf.float32, name = 't_find_' + kk) - - self.place_holders['type'] = tf.placeholder(tf.int32, [None], name='t_type') - self.place_holders['natoms_vec'] = tf.placeholder(tf.int32, [self.ntypes+2], name='t_natoms') - self.place_holders['default_mesh'] = tf.placeholder(tf.int32, [None], name='t_mesh') - self.place_holders['is_training'] = tf.placeholder(tf.bool) - self.model_pred\ - = self.model.build (self.place_holders['coord'], - self.place_holders['type'], - self.place_holders['natoms_vec'], - self.place_holders['box'], - self.place_holders['default_mesh'], - self.place_holders, - suffix = "", - reuse = False) - - self.l2_l, self.l2_more\ - = self.loss.build (self.learning_rate, - self.place_holders['natoms_vec'], - self.model_pred, - self.place_holders, - suffix = "test") - - log.info("built network") - - def _build_training(self): - trainable_variables = tf.trainable_variables() - optimizer = tf.train.AdamOptimizer(learning_rate = self.learning_rate) - if self.run_opt.is_distrib : - optimizer = tf.train.SyncReplicasOptimizer( - optimizer, - replicas_to_aggregate = self.run_opt.cluster_spec.num_tasks("worker"), - total_num_replicas = self.run_opt.cluster_spec.num_tasks("worker"), - name = "sync_replicas") - self.sync_replicas_hook = optimizer.make_session_run_hook(self.run_opt.is_chief) - grads = tf.gradients(self.l2_l, trainable_variables) - apply_op = optimizer.apply_gradients (zip (grads, trainable_variables), - global_step=self.global_step, - name='train_step') - train_ops = [apply_op] + self._extra_train_ops - self.train_op = tf.group(*train_ops) - log.info("built training") - - def _init_sess_serial(self) : - self.sess = tf.Session(config=default_tf_session_config) - self.saver = tf.train.Saver() - saver = self.saver - if self.run_opt.init_mode == 'init_from_scratch' : - log.info("initialize model from scratch") - init_op = tf.global_variables_initializer() - self.sess.run(init_op) - fp = open(self.disp_file, "w") - fp.close () - elif self.run_opt.init_mode == 'init_from_model' : - log.info("initialize from model %s" % self.run_opt.init_model) - init_op = tf.global_variables_initializer() - self.sess.run(init_op) - saver.restore (self.sess, self.run_opt.init_model) - self.sess.run(self.global_step.assign(0)) - fp = open(self.disp_file, "w") - fp.close () - elif self.run_opt.init_mode == 'restart' : - log.info("restart from model %s" % self.run_opt.restart) - init_op = tf.global_variables_initializer() - self.sess.run(init_op) - saver.restore (self.sess, self.run_opt.restart) - else : - raise RuntimeError ("unkown init mode") - - def _init_sess_distrib(self): - ckpt_dir = os.path.join(os.getcwd(), self.save_ckpt) - assert(_is_subdir(ckpt_dir, os.getcwd())), "the checkpoint dir must be a subdir of the current dir" - if self.run_opt.init_mode == 'init_from_scratch' : - log.info("initialize model from scratch") - if self.run_opt.is_chief : - if os.path.exists(ckpt_dir): - shutil.rmtree(ckpt_dir) - if not os.path.exists(ckpt_dir) : - os.makedirs(ckpt_dir) - fp = open(self.disp_file, "w") - fp.close () - elif self.run_opt.init_mode == 'init_from_model' : - raise RuntimeError("distributed training does not support %s" % self.run_opt.init_mode) - elif self.run_opt.init_mode == 'restart' : - log.info("restart from model %s" % ckpt_dir) - if self.run_opt.is_chief : - assert(os.path.isdir(ckpt_dir)), "the checkpoint dir %s should exists" % ckpt_dir - else : - raise RuntimeError ("unkown init mode") - - saver = tf.train.Saver(max_to_keep = 1) - self.saver = None - # gpu_options = tf.GPUOptions(per_process_gpu_memory_fraction=0.5) - # config = tf.ConfigProto(allow_soft_placement=True, - # gpu_options = gpu_options, - # intra_op_parallelism_threads=self.run_opt.num_intra_threads, - # inter_op_parallelism_threads=self.run_opt.num_inter_threads) - config = tf.ConfigProto(intra_op_parallelism_threads=self.run_opt.num_intra_threads, - inter_op_parallelism_threads=self.run_opt.num_inter_threads) - # The stop_hook handles stopping after running given steps - # stop_hook = tf.train.StopAtStepHook(last_step = stop_batch) - # hooks = [self.sync_replicas_hook, stop_hook] - hooks = [self.sync_replicas_hook] - scaffold = tf.train.Scaffold(saver=saver) - # Use monitor session for distributed computation - self.sess = tf.train.MonitoredTrainingSession(master = self.run_opt.server.target, - is_chief = self.run_opt.is_chief, - config = config, - hooks = hooks, - scaffold = scaffold, - checkpoint_dir = ckpt_dir) - # , - # save_checkpoint_steps = self.save_freq) def train (self, - data) : - stop_batch = self.stop_batch - if self.run_opt.is_distrib : - self._init_sess_distrib() - else : - self._init_sess_serial() + data, + stop_batch) : + paddle.set_device("gpu") + self.stop_batch = stop_batch self.print_head() fp = None if self.run_opt.is_chief : fp = open(self.disp_file, "a") - cur_batch = self.sess.run(self.global_step) is_first_step = True - self.cur_batch = cur_batch + self.cur_batch = 0 + + adam = paddle.optimizer.Adam(learning_rate = self.lr_scheduler, parameters=self.model.parameters()) + log.info("start training at lr %.2e (== %.2e), decay_step %d, decay_rate %f, final lr will be %.2e" % - (self.sess.run(self.learning_rate), - self.lr.value(cur_batch), + (self.lr_scheduler.get_lr(), + self.lr.value(self.cur_batch), self.lr.decay_steps_, self.lr.decay_rate_, self.lr.value(stop_batch)) @@ -470,67 +328,65 @@ def train (self, prf_options = None prf_run_metadata = None if self.profiling : - prf_options = tf.RunOptions(trace_level=tf.RunOptions.FULL_TRACE) - prf_run_metadata = tf.RunMetadata() - - # set tensorboard execution environment - if self.tensorboard : - summary_merged_op = tf.summary.merge_all() - shutil.rmtree(self.tensorboard_log_dir) - tb_train_writer = tf.summary.FileWriter(self.tensorboard_log_dir + '/train', self.sess.graph) - tb_test_writer = tf.summary.FileWriter(self.tensorboard_log_dir + '/test') - else: - tb_train_writer = None - tb_test_writer = None + pass + + tb_train_writer = None + tb_test_writer = None train_time = 0 - while cur_batch < stop_batch : + + data_dict = data.get_data_dict() + while self.cur_batch < stop_batch : batch_data = data.get_batch (sys_probs = self.sys_probs, auto_prob_style = self.auto_prob_style ) - feed_dict_batch = {} + model_inputs = {} for kk in batch_data.keys(): if kk == 'find_type' or kk == 'type' : continue + prec = GLOBAL_PD_FLOAT_PRECISION if 'find_' in kk : - feed_dict_batch[self.place_holders[kk]] = batch_data[kk] + model_inputs[kk] = paddle.to_tensor(batch_data[kk], dtype="float32") else: - feed_dict_batch[self.place_holders[kk]] = np.reshape(batch_data[kk], [-1]) + model_inputs[kk] = paddle.to_tensor(np.reshape(batch_data[kk], [-1]), dtype=prec) for ii in ['type'] : - feed_dict_batch[self.place_holders[ii]] = np.reshape(batch_data[ii], [-1]) + model_inputs[ii] = paddle.to_tensor(np.reshape(batch_data[ii], [-1]), dtype="int32") for ii in ['natoms_vec', 'default_mesh'] : - feed_dict_batch[self.place_holders[ii]] = batch_data[ii] - feed_dict_batch[self.place_holders['is_training']] = True - + model_inputs[ii] = paddle.to_tensor(batch_data[ii], dtype="int32") + model_inputs['is_training'] = paddle.to_tensor(True) + if self.display_in_training and is_first_step : - self.test_on_the_fly(fp, data, feed_dict_batch, tb_test_writer) + self.test_on_the_fly(fp, data, model_inputs, tb_test_writer) is_first_step = False if self.timing_in_training : tic = time.time() - # use tensorboard to visualize the training of deepmd-kit - # it will takes some extra execution time to generate the tensorboard data - if self.tensorboard : - summary, _ = self.sess.run([summary_merged_op, self.train_op], feed_dict = feed_dict_batch, options=prf_options, run_metadata=prf_run_metadata) - tb_train_writer.add_summary(summary, cur_batch) - else : - self.sess.run([self.train_op], feed_dict = feed_dict_batch, options=prf_options, run_metadata=prf_run_metadata) + + model_pred = self.model(model_inputs['coord'], model_inputs['type'], model_inputs['natoms_vec'], model_inputs['box'], model_inputs['default_mesh'], model_inputs, suffix = "", reuse = False) + l2_l, l2_more = self.loss.calculate_loss(self.lr_scheduler.get_lr(), model_inputs['natoms_vec'], model_pred, model_inputs, suffix = "test") + + adam.clear_grad() + l2_l.backward() + adam.step() + if self.timing_in_training : toc = time.time() if self.timing_in_training : train_time += toc - tic - cur_batch = self.sess.run(self.global_step) - self.cur_batch = cur_batch + self.cur_batch += 1 + + if (self.cur_batch % self.lr.decay_steps_) == 0: + self.lr_scheduler.step() - if self.display_in_training and (cur_batch % self.disp_freq == 0) : + if self.display_in_training and (self.cur_batch % self.disp_freq == 0) : tic = time.time() - self.test_on_the_fly(fp, data, feed_dict_batch, tb_test_writer) + self.test_on_the_fly(fp, data, model_inputs, tb_test_writer) toc = time.time() test_time = toc - tic if self.timing_in_training : log.info("batch %7d training time %.2f s, testing time %.2f s" - % (cur_batch, train_time, test_time)) + % (self.cur_batch, train_time, test_time)) train_time = 0 - if self.save_freq > 0 and cur_batch % self.save_freq == 0 and self.run_opt.is_chief : - if self.saver is not None : - self.saver.save (self.sess, os.getcwd() + "/" + self.save_ckpt) - log.info("saved checkpoint %s" % self.save_ckpt) + if self.save_freq > 0 and self.cur_batch % self.save_freq == 0 and self.run_opt.is_chief: + #paddle.jit.save(self.model, os.getcwd() + "/" + self.save_ckpt) + paddle.save(self.model.state_dict(), os.getcwd() + "/" + self.save_ckpt) + log.info("saved checkpoint to %s" % (os.getcwd() + "/" + self.save_ckpt)) if self.run_opt.is_chief: fp.close () if self.profiling and self.run_opt.is_chief : @@ -540,7 +396,7 @@ def train (self, f.write(chrome_trace) def get_global_step (self) : - return self.sess.run(self.global_step) + return self.cur_batch def print_head (self) : if self.run_opt.is_chief: @@ -554,41 +410,71 @@ def print_head (self) : def test_on_the_fly (self, fp, data, - feed_dict_batch, + model_train_inputs, tb_writer) : # Do not need to pass numb_test here as data object already knows it. # Both DeepmdDataSystem and ClassArg parse the same json file + model_test_inputs = {} test_data = data.get_test(n_test=data.get_sys_ntest()) - feed_dict_test = {} + for kk in test_data.keys(): if kk == 'find_type' or kk == 'type' : continue + prec = GLOBAL_PD_FLOAT_PRECISION if 'find_' in kk: - feed_dict_test[self.place_holders[kk]] = test_data[kk] + model_test_inputs[kk] = paddle.to_tensor(test_data[kk], dtype="float32") else: # again the data object knows appropriate test data shape, # there is no need to slice again! # feed_dict_test[self.place_holders[kk]] = np.reshape(test_data[kk][:self.numb_test[data.pick_idx]], [-1]) - feed_dict_test[self.place_holders[kk]] = np.reshape(test_data[kk], [-1]) + model_test_inputs[kk] = paddle.to_tensor(np.reshape(test_data[kk], [-1]), dtype=prec) for ii in ['type'] : - feed_dict_test[self.place_holders[ii]] = np.reshape(test_data[ii], [-1]) + model_test_inputs[ii] = paddle.to_tensor(np.reshape(test_data[ii], [-1]), dtype="int32") for ii in ['natoms_vec', 'default_mesh'] : - feed_dict_test[self.place_holders[ii]] = test_data[ii] - feed_dict_test[self.place_holders['is_training']] = False + model_test_inputs[ii] = paddle.to_tensor(test_data[ii], dtype="int32") - cur_batch = self.cur_batch - current_lr = self.sess.run(self.learning_rate) - if self.run_opt.is_chief: - print_str = "%7d" % cur_batch - print_str += self.loss.print_on_training( - tb_writer, - cur_batch, - self.sess, - test_data['natoms_vec'], - feed_dict_test, - feed_dict_batch - ) + model_test_inputs['is_training'] = paddle.to_tensor(False) + current_batch = self.cur_batch + current_lr = self.lr_scheduler.get_lr() + if self.run_opt.is_chief: + print_str = "%7d" % current_batch + + model_pred = self.model(model_train_inputs['coord'], model_train_inputs['type'], model_train_inputs['natoms_vec'], model_train_inputs['box'], model_train_inputs['default_mesh'], model_train_inputs, suffix = "", reuse = False) + l2_l, l2_more = self.loss.calculate_loss(self.lr_scheduler.get_lr(), model_train_inputs['natoms_vec'], model_pred, model_train_inputs, suffix = "test") + + error_train = l2_l.numpy() + error_e_train = l2_more['l2_ener_loss'].numpy() + error_f_train = l2_more['l2_force_loss'].numpy() + error_v_train = l2_more['l2_virial_loss'].numpy() + error_ae_train = l2_more['l2_atom_ener_loss'].numpy() + error_pf_train = l2_more['l2_pref_force_loss'].numpy() + + model_pred = self.model(model_test_inputs['coord'], model_test_inputs['type'], model_test_inputs['natoms_vec'], model_test_inputs['box'], model_test_inputs['default_mesh'], model_test_inputs, suffix = "", reuse = False) + l2_l, l2_more = self.loss.calculate_loss(self.lr_scheduler.get_lr(), model_test_inputs['natoms_vec'], model_pred, model_test_inputs, suffix = "test") + + error_test = l2_l.numpy() + error_e_test = l2_more['l2_ener_loss'].numpy() + error_f_test = l2_more['l2_force_loss'].numpy() + error_v_test = l2_more['l2_virial_loss'].numpy() + error_ae_test = l2_more['l2_atom_ener_loss'].numpy() + error_pf_test = l2_more['l2_pref_force_loss'].numpy() + + prop_fmt = " %11.2e %11.2e" + natoms = test_data['natoms_vec'] + print_str += prop_fmt % (np.sqrt(error_test), np.sqrt(error_train)) + if self.loss.has_e : + print_str += prop_fmt % (np.sqrt(error_e_test) / natoms[0], np.sqrt(error_e_train) / natoms[0]) + if self.loss.has_ae : + print_str += prop_fmt % (np.sqrt(error_ae_test), np.sqrt(error_ae_train)) + if self.loss.has_f : + print_str += prop_fmt % (np.sqrt(error_f_test), np.sqrt(error_f_train)) + if self.loss.has_v : + print_str += prop_fmt % (np.sqrt(error_v_test) / natoms[0], np.sqrt(error_v_train) / natoms[0]) + if self.loss.has_pf: + print_str += prop_fmt % (np.sqrt(error_pf_test), np.sqrt(error_pf_train)) + + print("batch %7d, lr %f, l2_l %f, l2_ener_loss %f, l2_force_loss %f, l2_virial_loss %f, l2_atom_ener_loss %f, l2_pref_force_loss %f" % (current_batch, current_lr, error_train, error_e_train, error_f_train, error_v_train, error_ae_train, error_pf_train)) print_str += " %8.1e\n" % current_lr fp.write(print_str) - fp.flush () + fp.flush () \ No newline at end of file diff --git a/deepmd/utils/learning_rate.py b/deepmd/utils/learning_rate.py index 572f317a92..a10fe06f2b 100644 --- a/deepmd/utils/learning_rate.py +++ b/deepmd/utils/learning_rate.py @@ -1,5 +1,5 @@ import numpy as np -from deepmd.env import tf +from deepmd.env import paddle from deepmd.common import ClassArg class LearningRateExp (object) : @@ -44,10 +44,7 @@ def __init__ (self, self.cd['decay_rate'] = decay_rate self.start_lr_ = self.cd['start_lr'] - def build(self, - global_step : tf.Tensor, - stop_step : int = None - ) -> tf.Tensor : + def build(self, stop_step : int = None): """ Build the learning rate @@ -73,12 +70,10 @@ def build(self, if self.decay_steps_ >= stop_step: self.decay_steps_ = default_ds self.decay_rate_ = np.exp(np.log(self.stop_lr_ / self.start_lr_) / (stop_step / self.decay_steps_)) - - return tf.train.exponential_decay(self.start_lr_, - global_step, - self.decay_steps_, - self.decay_rate_, - staircase=True) + + return paddle.optimizer.lr.ExponentialDecay(learning_rate=self.start_lr_, gamma=self.decay_rate_, verbose=True) + + def start_lr(self) -> float: """ Get the start lr diff --git a/deepmd/utils/network.py b/deepmd/utils/network.py index ab32308bdd..4dc99e3ce6 100644 --- a/deepmd/utils/network.py +++ b/deepmd/utils/network.py @@ -1,7 +1,8 @@ import numpy as np -from deepmd.env import tf -from deepmd.env import GLOBAL_TF_FLOAT_PRECISION +from deepmd.env import tf, paddle +from deepmd.env import GLOBAL_TF_FLOAT_PRECISION, GLOBAL_PD_FLOAT_PRECISION + def one_layer(inputs, outputs_size, @@ -91,7 +92,7 @@ def embedding_net(xx, If the netowk is trainable """ outputs_size = [1] + network_size - + for ii in range(1, len(outputs_size)): w = tf.get_variable('matrix_'+str(ii)+name_suffix, [outputs_size[ii - 1], outputs_size[ii]], @@ -131,6 +132,8 @@ def embedding_net(xx, return xx + + def variable_summaries(var: tf.Variable, name: str): """Attach a lot of summaries to a Tensor (for TensorBoard visualization). @@ -150,4 +153,136 @@ def variable_summaries(var: tf.Variable, name: str): tf.summary.scalar('stddev', stddev) tf.summary.scalar('max', tf.reduce_max(var)) tf.summary.scalar('min', tf.reduce_min(var)) - tf.summary.histogram('histogram', var) \ No newline at end of file + tf.summary.histogram('histogram', var) + + +class OneLayer(paddle.nn.Layer): + def __init__(self, + in_features, + out_features, + activation_fn=paddle.nn.functional.tanh, + precision = GLOBAL_PD_FLOAT_PRECISION, + stddev=1.0, + bavg=0.0, + name='linear', + seed=None, + use_timestep = False, + trainable = True, + useBN = False): + super(OneLayer, self).__init__(name) + self.out_features = out_features + self.activation_fn = activation_fn + self.use_timestep = use_timestep + self.useBN = useBN + self.seed = seed + paddle.seed(seed) + + self.weight = self.create_parameter( + shape=[in_features, out_features], + dtype = precision, + is_bias= False, + default_initializer = paddle.nn.initializer.Normal(std = stddev/np.sqrt(in_features+out_features))) + self.bias = self.create_parameter( + shape=[out_features], + dtype = precision, + is_bias=True, + default_initializer = paddle.nn.initializer.Normal(mean = bavg, std = stddev)) + if self.activation_fn != None and self.use_timestep : + self.idt = self.create_parameter( + shape=[out_features], + dtype=precision, + default_initializer = paddle.nn.initializer.Normal(mean = 0.1, std = 0.001)) + + + def forward(self, input): + hidden = paddle.fluid.layers.matmul(input, self.weight) + self.bias + if self.activation_fn != None: + if self.useBN: + None + # hidden_bn = self._batch_norm(hidden, name=name+'_normalization', reuse=reuse) + # return activation_fn(hidden_bn) + else: + if self.use_timestep : + out = paddle.reshape(self.activation_fn(hidden), [-1, self.out_features]) * self.idt + else : + out = paddle.reshape(self.activation_fn(hidden), [-1, self.out_features]) + else: + if self.useBN: + None + # return self._batch_norm(hidden, name=name+'_normalization', reuse=reuse) + else: + out = hidden + return out + + + +class EmbeddingNet(paddle.nn.Layer): + """ + Parameters + ---------- + xx : Tensor + Input tensor of shape [-1,1] + network_size: list of int + Size of the embedding network. For example [16,32,64] + precision: + Precision of network weights. For example, tf.float64 + activation_fn: + Activation function + resnet_dt: boolean + Using time-step in the ResNet construction + name_suffix: str + The name suffix append to each variable. + stddev: float + Standard deviation of initializing network parameters + bavg: float + Mean of network intial bias + seed: int + Random seed for initializing network parameters + trainable: boolean + If the netowk is trainable + """ + def __init__(self, + network_size, + precision, + activation_fn = paddle.nn.functional.tanh, + resnet_dt = False, + seed = None, + trainable = True, + stddev = 1.0, + bavg = 0.0, + name=''): + super(EmbeddingNet, self).__init__(name) + self.outputs_size = [1] + network_size + self.activation_fn = activation_fn + self.seed = seed + paddle.seed(seed) + + outputs_size = self.outputs_size + weight = [] + bias = [] + for ii in range(1, len(outputs_size)): + weight.append(self.create_parameter( + shape = [outputs_size[ii-1], outputs_size[ii]], + dtype = precision, + is_bias= False, + default_initializer = paddle.nn.initializer.Normal(std = stddev/np.sqrt(outputs_size[ii]+outputs_size[ii-1])))) + bias.append(self.create_parameter( + shape = [1, outputs_size[ii]], + dtype = precision, + is_bias= True, + default_initializer = paddle.nn.initializer.Normal(mean = bavg, std = stddev))) + + self.weight = paddle.nn.ParameterList(weight) + self.bias = paddle.nn.ParameterList(bias) + + + def forward(self, xx): + outputs_size = self.outputs_size + for ii in range(1, len(outputs_size)): + hidden = paddle.reshape(self.activation_fn(paddle.fluid.layers.matmul(xx, self.weight[ii-1]) + self.bias[ii-1]), [-1, outputs_size[ii]]) + if outputs_size[ii] == outputs_size[ii-1] * 2: + xx = paddle.concat([xx,xx], axis=1) + hidden + else: + xx = hidden + + return xx