From 13b8e6f45acdc3ce97d9876451879890835f6f80 Mon Sep 17 00:00:00 2001 From: zhouwei25 Date: Thu, 15 Apr 2021 15:15:15 +0000 Subject: [PATCH] Add Ener Model for Paddle --- deepmd/common.py | 15 ++++------- deepmd/descriptor/se_a.py | 55 +++++++++------------------------------ deepmd/fit/ener.py | 41 ++++++++++++++--------------- deepmd/loss/ener.py | 9 ++++--- deepmd/model/ener.py | 14 +++++----- deepmd/train/trainer.py | 49 ++++++++-------------------------- deepmd/utils/network.py | 33 +++++++---------------- 7 files changed, 70 insertions(+), 146 deletions(-) diff --git a/deepmd/common.py b/deepmd/common.py index 3e84affa16..2a48c143b1 100644 --- a/deepmd/common.py +++ b/deepmd/common.py @@ -68,11 +68,11 @@ def gelu(x: tf.Tensor) -> tf.Tensor: data_requirement = {} ACTIVATION_FN_DICT = { - "relu": tf.nn.relu, - "relu6": tf.nn.relu6, - "softplus": tf.nn.softplus, - "sigmoid": tf.sigmoid, - "tanh": tf.tanh, + "relu": paddle.nn.functional.relu, + "relu6": paddle.nn.functional.relu6, + "softplus": paddle.nn.functional.softplus, + "sigmoid": paddle.nn.functional.sigmoid, + "tanh": paddle.nn.functional.tanh, "gelu": gelu, } @@ -385,11 +385,6 @@ def get_activation_func( RuntimeError if unknown activation function is specified """ - #return paddle.nn.functional.tanh - def fun(x): - return paddle.clip(x, min=-1.0, max=1.0) - return fun - if activation_fn not in ACTIVATION_FN_DICT: raise RuntimeError(f"{activation_fn} is not a valid activation function") return ACTIVATION_FN_DICT[activation_fn] diff --git a/deepmd/descriptor/se_a.py b/deepmd/descriptor/se_a.py index 8ba9439722..1553ad8814 100644 --- a/deepmd/descriptor/se_a.py +++ b/deepmd/descriptor/se_a.py @@ -14,7 +14,7 @@ import sys -class DescrptSeA (paddle.nn.Layer): +class DescrptSeA(paddle.nn.Layer): @docstring_parameter(list_to_doc(ACTIVATION_FN_DICT.keys()), list_to_doc(PRECISION_DICT.keys())) def __init__ (self, rcut: float, @@ -308,12 +308,6 @@ def forward (self, box = paddle.reshape(box_, [-1, 9]) atype = paddle.reshape(atype_, [-1, natoms[1]]) - #print("coord= ", coord.shape) - #print("box= ", box.shape) - #print("atype= ", atype.shape) - #print("natoms= ", natoms.shape) - #print("mesh= ", mesh.shape) - self.descrpt, self.descrpt_deriv, self.rij, self.nlist \ = paddle_ops.prod_env_mat_a(coord, atype, @@ -328,16 +322,6 @@ def forward (self, sel_a = self.sel_a, sel_r = self.sel_r) - #self.descrpt = to_tensor(np.load('/workspace/deepmd-kit/examples/water/train/descrpt.npy'), stop_gradient=False) - #self.descrpt_deriv = to_tensor(np.load('/workspace/deepmd-kit/examples/water/train/descrpt_deriv.npy')) - #self.rij = to_tensor(np.load('/workspace/deepmd-kit/examples/water/train/rij.npy')) - #self.nlist = to_tensor(np.load('/workspace/deepmd-kit/examples/water/train/nlist.npy')) - - #print("self.descrpt= ", self.descrpt) - #print("self.descrpt_deriv= ", self.descrpt_deriv) - #print("self.rij= ", self.rij) - #print("self.nlist= ", self.nlist) - self.descrpt_reshape = paddle.reshape(self.descrpt, [-1, self.ndescrpt]) self.descrpt_reshape.stop_gradient = False @@ -386,10 +370,6 @@ def prod_force_virial(self, net_deriv = paddle.grad(atom_ener, self.descrpt_reshape, create_graph=True)[0] net_deriv_reshape = paddle.reshape (net_deriv, [-1, natoms[0] * self.ndescrpt]) - self.net_deriv_reshape = net_deriv_reshape - - paddle.set_device("cpu") - force \ = paddle_ops.prod_force_se_a (net_deriv_reshape, self.descrpt_deriv, @@ -397,6 +377,7 @@ def prod_force_virial(self, natoms, n_a_sel = self.nnei_a, n_r_sel = self.nnei_r) + virial, atom_virial \ = paddle_ops.prod_virial_se_a (net_deriv_reshape, self.descrpt_deriv, @@ -405,8 +386,7 @@ def prod_force_virial(self, natoms, n_a_sel = self.nnei_a, n_r_sel = self.nnei_r) - - paddle.set_device("gpu") + return force, virial, atom_virial @@ -445,14 +425,6 @@ def _compute_dstats_sys_smth (self, data_atype, natoms_vec, mesh) : - - #print("pbefore sub_sess run========") - #print("data_coord= ", data_coord) - #print("data_atype= ", data_atype) - #print("natoms_vec= ", natoms_vec) - #print("data_box= ", data_box) - #print("mesh= ", mesh) - input_dict = {} input_dict['coord'] = paddle.to_tensor(data_coord, dtype=GLOBAL_NP_FLOAT_PRECISION) input_dict['box'] = paddle.to_tensor(data_box, dtype=GLOBAL_PD_FLOAT_PRECISION) @@ -472,9 +444,7 @@ def _compute_dstats_sys_smth (self, rcut_r_smth = self.rcut_r_smth, sel_a = self.sel_a, sel_r = self.sel_r) - - #print("self.stat_descrpt ", stat_descrpt) - #print("==========after sub_sess run=========") + dd_all = self.stat_descrpt.numpy() natoms = natoms_vec dd_all = np.reshape(dd_all, [-1, self.ndescrpt * natoms[0]]) @@ -513,14 +483,14 @@ def _compute_std (self,sumv2, sumv, sumn) : return val def _filter(self, - inputs, - type_input, - natoms, - activation_fn=paddle.nn.functional.relu, - stddev=1.0, - bavg=0.0, - reuse=None, - seed=None, + inputs, + type_input, + natoms, + activation_fn=paddle.nn.functional.tanh, + stddev=1.0, + bavg=0.0, + reuse=None, + seed=None, trainable = True): # natom x (nei x 4) shape = inputs.shape @@ -577,3 +547,4 @@ def _filter(self, result = paddle.reshape(result, [-1, outputs_size_2 * outputs_size[-1]]) return result, qmat + diff --git a/deepmd/fit/ener.py b/deepmd/fit/ener.py index 471f4684d9..d91f75c21c 100644 --- a/deepmd/fit/ener.py +++ b/deepmd/fit/ener.py @@ -84,16 +84,16 @@ def __init__ (self, # stat fparam if self.numb_fparam > 0: - self.t_fparam_avg = paddl.to_tensor(np.zeros([1, self.numb_fparam]), + self.t_fparam_avg = paddle.to_tensor(np.zeros([1, self.numb_fparam]), dtype = GLOBAL_PD_FLOAT_PRECISION) - self.t_fparam_istd = paddl.to_tensor(np.ones([1, self.numb_fparam]), + self.t_fparam_istd = paddle.to_tensor(np.ones([1, self.numb_fparam]), dtype = GLOBAL_PD_FLOAT_PRECISION) # stat aparam if self.numb_aparam > 0: - self.t_aparam_avg = paddl.to_tensor(np.zeros([1, self.numb_aparam]), + self.t_aparam_avg = paddle.to_tensor(np.zeros([1, self.numb_aparam]), dtype = GLOBAL_PD_FLOAT_PRECISION) - self.t_aparam_istd = tf.get_variable(np.ones([1, self.numb_aparam]), + self.t_aparam_istd = paddle.to_tensor(np.ones([1, self.numb_aparam]), dtype = GLOBAL_PD_FLOAT_PRECISION) @@ -123,6 +123,15 @@ def compute_output_stats(self, can be prepared by model.make_stat_input """ self.bias_atom_e = self._compute_output_stats(all_stat, rcond = self.rcond) + if self.bias_atom_e is not None: + assert (len(self.bias_atom_e) == self.ntypes) + for type_i in range(self.ntypes): + type_bias_ae = self.bias_atom_e[type_i] + paddle.seed(self.seed) + normal_init_ = paddle.nn.initializer.Normal(mean=type_bias_ae, std=1.0) + final_layer = self.ElementNets[type_i][-1] + normal_init_(final_layer.bias) + @classmethod def _compute_output_stats(self, all_stat, rcond = 1e-3): @@ -173,9 +182,9 @@ def compute_input_stats(self, self.fparam_std[ii] = protection self.fparam_inv_std = 1./self.fparam_std - self.t_fparam_avg = paddl.to_tensor(self.fparam_avg, + self.t_fparam_avg = paddle.to_tensor(self.fparam_avg, dtype = GLOBAL_PD_FLOAT_PRECISION) - self.t_fparam_istd = paddl.to_tensor(self.fparam_inv_std, + self.t_fparam_istd = paddle.to_tensor(self.fparam_inv_std, dtype = GLOBAL_PD_FLOAT_PRECISION) # stat aparam @@ -198,9 +207,9 @@ def compute_input_stats(self, self.aparam_std[ii] = protection self.aparam_inv_std = 1./self.aparam_std - self.t_aparam_avg = paddl.to_tensor(self.aparam_avg, + self.t_aparam_avg = paddle.to_tensor(self.aparam_avg, dtype = GLOBAL_PD_FLOAT_PRECISION) - self.t_aparam_istd = tf.get_variable(self.aparam_inv_std, + self.t_aparam_istd = paddle.to_tensor(self.aparam_inv_std, dtype = GLOBAL_PD_FLOAT_PRECISION) @@ -209,7 +218,6 @@ def _compute_std (self, sumv2, sumv, sumn) : def forward(self, inputs, natoms, input_dict, reuse=None, suffix=''): - bias_atom_e = self.bias_atom_e if self.numb_fparam > 0 and (self.fparam_avg is None or self.fparam_inv_std is None): raise RuntimeError('No data stat result. one should do data statisitic, before build') if self.numb_aparam > 0 and (self.aparam_avg is None or self.aparam_inv_std is None): @@ -218,9 +226,6 @@ def forward(self, inputs, natoms, input_dict, reuse=None, suffix=''): start_index = 0 inputs = paddle.cast(paddle.reshape(inputs, [-1, self.dim_descrpt * natoms[0]]), self.fitting_precision) - if bias_atom_e is not None: - assert (len(bias_atom_e) == self.ntypes) - if self.numb_fparam > 0: fparam = input_dict['fparam'] fparam = paddle.reshape(fparam, [-1, self.numb_fparam]) @@ -252,10 +257,6 @@ def forward(self, inputs, natoms, input_dict, reuse=None, suffix=''): layer = paddle.concat([layer, ext_aparam], axis=1) start_index += natoms[2 + type_i] - if bias_atom_e is None: - type_bias_ae = 0.0 - else: - type_bias_ae = bias_atom_e[type_i] for ii in range(0, len(self.n_neuron)) : if ii >= 1 and self.n_neuron[ii] == self.n_neuron[ii-1] : @@ -264,11 +265,7 @@ def forward(self, inputs, natoms, input_dict, reuse=None, suffix=''): layer = self.ElementNets[type_i][ii](layer) final_layer = self.ElementNets[type_i][len(self.n_neuron)](layer) - if type_i < len(self.atom_ener) and self.atom_ener[type_i] is not None: - zero_inputs = paddle.cast(layer, self.fitting_precision) - zero_inputs[:, :self.dim_descrpt] = 0. - zero_layer = net_i(zero_inputs) - final_layer += self.atom_ener[type_i] - zero_layer + # if type_i < len(self.atom_ener) and self.atom_ener[type_i] is not None: (Not implement) final_layer = paddle.reshape(final_layer, [inputs.shape[0], natoms[2 + type_i]]) @@ -278,4 +275,4 @@ def forward(self, inputs, natoms, input_dict, reuse=None, suffix=''): else: outs = paddle.concat([outs, final_layer], axis=1) - return paddle.cast(paddle.reshape(outs, [-1]), GLOBAL_PD_FLOAT_PRECISION) \ No newline at end of file + return paddle.cast(paddle.reshape(outs, [-1]), GLOBAL_PD_FLOAT_PRECISION) diff --git a/deepmd/loss/ener.py b/deepmd/loss/ener.py index a1eca0405a..e5bea2f028 100644 --- a/deepmd/loss/ener.py +++ b/deepmd/loss/ener.py @@ -84,17 +84,17 @@ def calculate_loss (self, l2_force_loss = paddle.mean(paddle.square(diff_f), name = "l2_force_" + suffix) l2_pref_force_loss = paddle.mean(paddle.multiply(paddle.square(diff_f), atom_pref_reshape), name = "l2_pref_force_" + suffix) - virial_reshape = paddle.reshape (virial, [-1]) + virial_reshape = paddle.reshape(virial, [-1]) virial_hat_reshape = paddle.reshape (virial_hat, [-1]) - l2_virial_loss = paddle.mean (paddle.square(virial_hat_reshape - virial_reshape), name = "l2_virial_" + suffix) + l2_virial_loss = paddle.mean(paddle.square(virial_hat_reshape - virial_reshape), name = "l2_virial_" + suffix) atom_ener_reshape = paddle.reshape (atom_ener, [-1]) atom_ener_hat_reshape = paddle.reshape (atom_ener_hat, [-1]) l2_atom_ener_loss = paddle.mean (paddle.square(atom_ener_hat_reshape - atom_ener_reshape), name = "l2_atom_ener_" + suffix) atom_norm = 1./ global_cvt_2_pd_float(natoms[0]) - atom_norm_ener = 1./ global_cvt_2_pd_float(natoms[0]) - pref_e = global_cvt_2_pd_float(find_energy * (self.limit_pref_e + (self.start_pref_e - self.limit_pref_e) * learning_rate / self.starter_learning_rate) ) + atom_norm_ener = 1./ global_cvt_2_ener_float(natoms[0]) + pref_e = global_cvt_2_ener_float(find_energy * (self.limit_pref_e + (self.start_pref_e - self.limit_pref_e) * learning_rate / self.starter_learning_rate)) pref_f = global_cvt_2_pd_float(find_force * (self.limit_pref_f + (self.start_pref_f - self.limit_pref_f) * learning_rate / self.starter_learning_rate) ) pref_v = global_cvt_2_pd_float(find_virial * (self.limit_pref_v + (self.start_pref_v - self.limit_pref_v) * learning_rate / self.starter_learning_rate) ) pref_ae= global_cvt_2_pd_float(find_atom_ener * (self.limit_pref_ae+ (self.start_pref_ae-self.limit_pref_ae) * learning_rate / self.starter_learning_rate) ) @@ -136,6 +136,7 @@ def print_header(self): print_str += prop_fmt % ('rmse_v_tst', 'rmse_v_trn') if self.has_pf : print_str += prop_fmt % ('rmse_pf_tst', 'rmse_pf_trn') + return print_str diff --git a/deepmd/model/ener.py b/deepmd/model/ener.py index 38cebc7a4a..b0335a0aeb 100644 --- a/deepmd/model/ener.py +++ b/deepmd/model/ener.py @@ -10,6 +10,7 @@ import sys + class EnerModel(paddle.nn.Layer) : model_type = 'ener' @@ -130,7 +131,7 @@ def forward (self, reuse = reuse) self.dout = dout - + atom_ener = self.fitting (dout, natoms, input_dict, @@ -143,11 +144,11 @@ def forward (self, energy_raw = paddle.reshape(energy_raw, [-1, natoms[0]], name = 'o_atom_energy'+suffix) energy = paddle.sum(paddle.cast(energy_raw, GLOBAL_ENER_FLOAT_PRECISION), axis=1, name='o_energy'+suffix) - force, virial, atom_virial = self.descrpt.prod_force_virial (atom_ener, natoms) - - force = paddle.reshape (force, [-1, 3 * natoms[1]], name = "o_force"+suffix) - virial = paddle.reshape (virial, [-1, 9], name = "o_virial"+suffix) - atom_virial = paddle.reshape (atom_virial, [-1, 9 * natoms[1]], name = "o_atom_virial"+suffix) + force, virial, atom_virial = self.descrpt.prod_force_virial(atom_ener, natoms) + + force = paddle.reshape(force, [-1, 3 * natoms[1]], name = "o_force"+suffix) + virial = paddle.reshape(virial, [-1, 9], name = "o_virial"+suffix) + atom_virial = paddle.reshape(atom_virial, [-1, 9 * natoms[1]], name = "o_atom_virial"+suffix) model_dict = {} model_dict['energy'] = energy @@ -159,4 +160,3 @@ def forward (self, model_dict['atype'] = atype return model_dict - diff --git a/deepmd/train/trainer.py b/deepmd/train/trainer.py index 7d784ab2a0..e8fb2cee20 100644 --- a/deepmd/train/trainer.py +++ b/deepmd/train/trainer.py @@ -238,7 +238,6 @@ def _init_param(self, jdata): else : raise RuntimeError('get unknown fitting type when building loss function') - print(self.model) # training training_param = j_must_have(jdata, 'training') @@ -305,6 +304,7 @@ def build (self, def train (self, data, stop_batch) : + paddle.set_device("gpu") self.stop_batch = stop_batch self.print_head() @@ -353,11 +353,10 @@ def train (self, model_inputs[ii] = paddle.to_tensor(np.reshape(batch_data[ii], [-1]), dtype="int32") for ii in ['natoms_vec', 'default_mesh'] : model_inputs[ii] = paddle.to_tensor(batch_data[ii], dtype="int32") - model_inputs['is_training'] = paddle.to_tensor(True) - + if self.display_in_training and is_first_step : - #self.test_on_the_fly(fp, data, model_inputs, tb_test_writer) + self.test_on_the_fly(fp, data, model_inputs, tb_test_writer) is_first_step = False if self.timing_in_training : tic = time.time() @@ -366,44 +365,18 @@ def train (self, adam.clear_grad() l2_l.backward() - - #print([[name, p._grad_ivar() is None] for name, p in self.model.named_parameters()]) - #print("\n ", [p for p, g in adam.backward(l2_l)]) - #print("\n ", [g for p, g in adam.backward(l2_l)]) - - #print(self.model.descrpt.dout.grad) - - #print(l2_l.grad) - #print(self.model.descrpt.embedding_nets[0].weight[0].grad) - #print(self.model.descrpt.embedding_nets[0].bias[0].grad) - #print(self.model.descrpt.embedding_nets[0].weight[1].grad) - #print(self.model.descrpt.embedding_nets[0].bias[1].grad) - #print(self.model.descrpt.embedding_nets[0].weight[2].grad) - #print(self.model.descrpt.embedding_nets[0].bias[2].grad) - adam.step() - #print("self.atom_ener= ", self.model.atom_ener) - #print("self.dout= ", self.model.dout) - #print("self.net_deriv_reshape= ", self.descrpt.net_deriv_reshape) - - if self.timing_in_training : toc = time.time() if self.timing_in_training : train_time += toc - tic - self.cur_batch += 1 - if self.cur_batch == 1: - exit(0) - - print("batch %7d training time %.2f s, l2_l %f" % (self.cur_batch, train_time, l2_l.numpy())) - if (self.cur_batch % self.lr.decay_steps_) == 0: self.lr_scheduler.step() if self.display_in_training and (self.cur_batch % self.disp_freq == 0) : tic = time.time() - #self.test_on_the_fly(fp, data, model_inputs, tb_test_writer) + self.test_on_the_fly(fp, data, model_inputs, tb_test_writer) toc = time.time() test_time = toc - tic if self.timing_in_training : @@ -487,21 +460,21 @@ def test_on_the_fly (self, error_ae_test = l2_more['l2_atom_ener_loss'].numpy() error_pf_test = l2_more['l2_pref_force_loss'].numpy() - print_str = "" prop_fmt = " %11.2e %11.2e" + natoms = test_data['natoms_vec'] print_str += prop_fmt % (np.sqrt(error_test), np.sqrt(error_train)) - if self.has_e : + if self.loss.has_e : print_str += prop_fmt % (np.sqrt(error_e_test) / natoms[0], np.sqrt(error_e_train) / natoms[0]) - if self.has_ae : + if self.loss.has_ae : print_str += prop_fmt % (np.sqrt(error_ae_test), np.sqrt(error_ae_train)) - if self.has_f : + if self.loss.has_f : print_str += prop_fmt % (np.sqrt(error_f_test), np.sqrt(error_f_train)) - if self.has_v : + if self.loss.has_v : print_str += prop_fmt % (np.sqrt(error_v_test) / natoms[0], np.sqrt(error_v_train) / natoms[0]) - if self.has_pf: + if self.loss.has_pf: print_str += prop_fmt % (np.sqrt(error_pf_test), np.sqrt(error_pf_train)) + print("batch %7d, lr %f, l2_l %f, l2_ener_loss %f, l2_force_loss %f, l2_virial_loss %f, l2_atom_ener_loss %f, l2_pref_force_loss %f" % (current_batch, current_lr, error_train, error_e_train, error_f_train, error_v_train, error_ae_train, error_pf_train)) print_str += " %8.1e\n" % current_lr - print(print_str) fp.write(print_str) fp.flush () \ No newline at end of file diff --git a/deepmd/utils/network.py b/deepmd/utils/network.py index f9ba216ac8..4dc99e3ce6 100644 --- a/deepmd/utils/network.py +++ b/deepmd/utils/network.py @@ -3,11 +3,6 @@ from deepmd.env import tf, paddle from deepmd.env import GLOBAL_TF_FLOAT_PRECISION, GLOBAL_PD_FLOAT_PRECISION -w1 = 0.001 -b1 = -0.05 - -w2 = -0.002 -b2 = 0.03 def one_layer(inputs, outputs_size, @@ -99,7 +94,6 @@ def embedding_net(xx, outputs_size = [1] + network_size for ii in range(1, len(outputs_size)): - print("for ii in range(1, len(outputs_size)):") w = tf.get_variable('matrix_'+str(ii)+name_suffix, [outputs_size[ii - 1], outputs_size[ii]], precision, @@ -116,7 +110,6 @@ def embedding_net(xx, hidden = tf.reshape(activation_fn(tf.matmul(xx, w) + b), [-1, outputs_size[ii]]) if resnet_dt : - print("resnet_dt") idt = tf.get_variable('idt_'+str(ii)+name_suffix, [1, outputs_size[ii]], precision, @@ -125,7 +118,6 @@ def embedding_net(xx, variable_summaries(idt, 'idt_'+str(ii)+name_suffix) if outputs_size[ii] == outputs_size[ii-1]: - print("outputs_size[ii] == outputs_size[ii-1]") if resnet_dt : xx += hidden * idt else : @@ -168,7 +160,7 @@ class OneLayer(paddle.nn.Layer): def __init__(self, in_features, out_features, - activation_fn=paddle.nn.functional.relu, + activation_fn=paddle.nn.functional.tanh, precision = GLOBAL_PD_FLOAT_PRECISION, stddev=1.0, bavg=0.0, @@ -184,25 +176,23 @@ def __init__(self, self.useBN = useBN self.seed = seed paddle.seed(seed) - + self.weight = self.create_parameter( shape=[in_features, out_features], dtype = precision, is_bias= False, - default_initializer = paddle.fluid.initializer.Constant(w1)) - #default_initializer = paddle.nn.initializer.Normal(std = stddev/np.sqrt(in_features+out_features))) + default_initializer = paddle.nn.initializer.Normal(std = stddev/np.sqrt(in_features+out_features))) self.bias = self.create_parameter( shape=[out_features], dtype = precision, is_bias=True, - default_initializer = paddle.fluid.initializer.Constant(b1)) - #default_initializer = paddle.nn.initializer.Normal(mean = bavg, std = stddev)) + default_initializer = paddle.nn.initializer.Normal(mean = bavg, std = stddev)) if self.activation_fn != None and self.use_timestep : self.idt = self.create_parameter( shape=[out_features], dtype=precision, - default_initializer = paddle.fluid.initializer.Constant(b1)) - #default_initializer = paddle.nn.initializer.Normal(mean = 0.1, std = 0.001)) + default_initializer = paddle.nn.initializer.Normal(mean = 0.1, std = 0.001)) + def forward(self, input): hidden = paddle.fluid.layers.matmul(input, self.weight) + self.bias @@ -254,7 +244,7 @@ class EmbeddingNet(paddle.nn.Layer): def __init__(self, network_size, precision, - activation_fn = paddle.nn.functional.relu, + activation_fn = paddle.nn.functional.tanh, resnet_dt = False, seed = None, trainable = True, @@ -275,20 +265,17 @@ def __init__(self, shape = [outputs_size[ii-1], outputs_size[ii]], dtype = precision, is_bias= False, - default_initializer = paddle.fluid.initializer.Constant(w2))) - #default_initializer = paddle.nn.initializer.Normal(std = stddev/np.sqrt(outputs_size[ii]+outputs_size[ii-1])))) - + default_initializer = paddle.nn.initializer.Normal(std = stddev/np.sqrt(outputs_size[ii]+outputs_size[ii-1])))) bias.append(self.create_parameter( shape = [1, outputs_size[ii]], dtype = precision, is_bias= True, - default_initializer = paddle.fluid.initializer.Constant(b2))) - #default_initializer = paddle.nn.initializer.Normal(mean = bavg, std = stddev))) + default_initializer = paddle.nn.initializer.Normal(mean = bavg, std = stddev))) self.weight = paddle.nn.ParameterList(weight) self.bias = paddle.nn.ParameterList(bias) - + def forward(self, xx): outputs_size = self.outputs_size for ii in range(1, len(outputs_size)):