diff --git a/deepmd/descriptor/se_a.py b/deepmd/descriptor/se_a.py index e1205d8593..58f1139ffb 100644 --- a/deepmd/descriptor/se_a.py +++ b/deepmd/descriptor/se_a.py @@ -38,8 +38,6 @@ ) -# @Descriptor.register("se_e2_a") -# @Descriptor.register("se_a") class DescrptSeA(paddle.nn.Layer): r"""DeepPot-SE constructed from all information (both angular and radial) of atomic configurations. The embedding takes the distance between atoms as input. @@ -149,14 +147,14 @@ def __init__( raise RuntimeError( f"rcut_smth ({rcut_smth:f}) should be no more than rcut ({rcut:f})!" ) - self.sel_a = sel # [46(O), 92(H)] - self.rcut_r = rcut # 6.0 + self.sel_a = sel + self.rcut_r = rcut # NOTE: register 'rcut' in buffer to be accessed in inference self.register_buffer("buffer_rcut", paddle.to_tensor(rcut, dtype="float64")) - self.rcut_r_smth = rcut_smth # 0.5 - self.filter_neuron = neuron # [25, 50, 100] - self.n_axis_neuron = axis_neuron # 16 - self.filter_resnet_dt = resnet_dt # False + self.rcut_r_smth = rcut_smth + self.filter_neuron = neuron + self.n_axis_neuron = axis_neuron + self.filter_resnet_dt = resnet_dt self.seed = seed self.uniform_seed = uniform_seed self.seed_shift = embedding_net_rand_seed_shift(self.filter_neuron) @@ -164,13 +162,13 @@ def __init__( self.compress_activation_fn = get_activation_func(activation_function) self.filter_activation_fn = get_activation_func(activation_function) self.filter_precision = get_precision(precision) - self.exclude_types = set() # empty + self.exclude_types = set() for tt in exclude_types: assert len(tt) == 2 self.exclude_types.add((tt[0], tt[1])) self.exclude_types.add((tt[1], tt[0])) - self.set_davg_zero = set_davg_zero # False - # self.type_one_side = type_one_side # False + self.set_davg_zero = set_davg_zero + # self.type_one_side = type_one_side self.type_one_side = False self.spin = spin # None @@ -187,8 +185,8 @@ def __init__( ) # descrpt config - self.sel_r = [0 for ii in range(len(self.sel_a))] # [0, 0] - self.ntypes = len(self.sel_a) # 2 + self.sel_r = [0 for ii in range(len(self.sel_a))] + self.ntypes = len(self.sel_a) # NOTE: register 'ntypes' in buffer to be accessed in inference self.register_buffer( "buffer_ntypes", paddle.to_tensor(self.ntypes, dtype="int32") @@ -196,37 +194,19 @@ def __init__( assert self.ntypes == len(self.sel_r) self.rcut_a = -1 # numb of neighbors and numb of descrptors - self.nnei_a = np.cumsum(self.sel_a)[-1] # 138 邻域内原子个数 - self.nnei_r = np.cumsum(self.sel_r)[-1] # 0 - self.nnei = self.nnei_a + self.nnei_r # 138 - self.ndescrpt_a = self.nnei_a * 4 # 552 原子个数*4([s, s/x, s/y, s/z]) - self.ndescrpt_r = self.nnei_r * 1 # 0 - self.ndescrpt = self.ndescrpt_a + self.ndescrpt_r # 552 + self.nnei_a = np.cumsum(self.sel_a)[-1] + self.nnei_r = np.cumsum(self.sel_r)[-1] + self.nnei = self.nnei_a + self.nnei_r + self.ndescrpt_a = self.nnei_a * 4 + self.ndescrpt_r = self.nnei_r * 1 + self.ndescrpt = self.ndescrpt_a + self.ndescrpt_r self.useBN = False self.dstd = None self.davg = None - # self.compress = False - # self.embedding_net_variables = None - # self.mixed_prec = None - # self.place_holders = {} - # self.nei_type = np.repeat(np.arange(self.ntypes), self.sel_a) - """ - array([0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, - 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, - 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, - 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, - 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, - 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, - 1, 1, 1, 1, 1, 1]) - """ - self.avg_zero = paddle.zeros( - [self.ntypes, self.ndescrpt], dtype="float32" - ) # [2, 552] - self.std_ones = paddle.ones( - [self.ntypes, self.ndescrpt], dtype="float32" - ) # [2, 552] + self.avg_zero = paddle.zeros([self.ntypes, self.ndescrpt], dtype="float32") + self.std_ones = paddle.ones([self.ntypes, self.ndescrpt], dtype="float32") + nets = [] - # self._pass_filter => self._filter => self._filter_lower for type_input in range(self.ntypes): layer = [] for type_i in range(self.ntypes): @@ -248,46 +228,8 @@ def __init__( self.compress = False self.embedding_net_variables = None self.mixed_prec = None - # self.place_holders = {} self.nei_type = np.repeat(np.arange(self.ntypes), self.sel_a) # like a mask - # avg_zero = np.zeros([self.ntypes, self.ndescrpt]).astype( - # GLOBAL_NP_FLOAT_PRECISION - # ) - # std_ones = np.ones([self.ntypes, self.ndescrpt]).astype( - # GLOBAL_NP_FLOAT_PRECISION - # ) - # sub_graph = tf.Graph() - # with sub_graph.as_default(): - # name_pfx = "d_sea_" - # for ii in ["coord", "box"]: - # self.place_holders[ii] = tf.placeholder( - # GLOBAL_NP_FLOAT_PRECISION, [None, None], name=name_pfx + "t_" + ii - # ) - # self.place_holders["type"] = tf.placeholder( - # tf.int32, [None, None], name=name_pfx + "t_type" - # ) - # self.place_holders["natoms_vec"] = tf.placeholder( - # tf.int32, [self.ntypes + 2], name=name_pfx + "t_natoms" - # ) - # self.place_holders["default_mesh"] = tf.placeholder( - # tf.int32, [None], name=name_pfx + "t_mesh" - # ) - # self.stat_descrpt, descrpt_deriv, rij, nlist = op_module.prod_env_mat_a( - # self.place_holders["coord"], - # self.place_holders["type"], - # self.place_holders["natoms_vec"], - # self.place_holders["box"], - # self.place_holders["default_mesh"], - # self.avg_zero, - # self.std_ones, - # rcut_a=self.rcut_a, - # rcut_r=self.rcut_r, - # rcut_r_smth=self.rcut_r_smth, - # sel_a=self.sel_a, - # sel_r=self.sel_r, - # ) - # self.sub_sess = tf.Session(graph=sub_graph, config=default_tf_session_config) self.original_sel = None self.multi_task = multi_task if multi_task: @@ -589,47 +531,14 @@ def forward( """ davg = self.davg dstd = self.dstd - # if nvnmd_cfg.enable: - # if nvnmd_cfg.restore_descriptor: - # davg, dstd = build_davg_dstd() - # check_switch_range(davg, dstd) - # with tf.variable_scope("descrpt_attr" + suffix, reuse=reuse): if davg is None: davg = np.zeros([self.ntypes, self.ndescrpt]) if dstd is None: dstd = np.ones([self.ntypes, self.ndescrpt]) - # t_rcut = tf.constant( - # np.max([self.rcut_r, self.rcut_a]), - # name="rcut", - # dtype=GLOBAL_TF_FLOAT_PRECISION, - # ) - # t_ntypes = tf.constant(self.ntypes, name="ntypes", dtype=tf.int32) - # t_ndescrpt = tf.constant(self.ndescrpt, name="ndescrpt", dtype=tf.int32) - # t_sel = tf.constant(self.sel_a, name="sel", dtype=tf.int32) - # t_original_sel = paddle.to_tensor( - # self.original_sel if self.original_sel is not None else self.sel_a, - # ) - # self.t_avg = tf.get_variable( - # "t_avg", - # davg.shape, - # dtype=GLOBAL_TF_FLOAT_PRECISION, - # trainable=False, - # initializer=tf.constant_initializer(davg), - # ) - # self.t_std = tf.get_variable( - # "t_std", - # dstd.shape, - # dtype=GLOBAL_TF_FLOAT_PRECISION, - # trainable=False, - # initializer=tf.constant_initializer(dstd), - # ) coord = paddle.reshape(coord_, [-1, natoms[1] * 3]) box = paddle.reshape(box_, [-1, 9]) atype = paddle.reshape(atype_, [-1, natoms[1]]) - # op_descriptor = ( - # build_op_descriptor() if nvnmd_cfg.enable else op_module.prod_env_mat_a - # ) ( self.descrpt, self.descrpt_deriv, @@ -650,13 +559,8 @@ def forward( sel_r=self.sel_r, ) # only used when tensorboard was set as true - # tf.summary.histogram("descrpt", self.descrpt) - # tf.summary.histogram("rij", self.rij) - # tf.summary.histogram("nlist", self.nlist) self.descrpt_reshape = paddle.reshape(self.descrpt, [-1, self.ndescrpt]) - # [1, 105984] --> [192, 552] self.descrpt_reshape.stop_gradient = False - # self._identity_tensors(suffix=suffix) self.dout, self.qmat = self._pass_filter( self.descrpt_reshape, atype, @@ -665,10 +569,8 @@ def forward( suffix=suffix, reuse=reuse, trainable=self.trainable, - ) # [1, all_atom, M1*M2], output_qmat: [1, all_atom, M1*3] + ) - # only used when tensorboard was set as true - # tf.summary.histogram("embedding_net_output", self.dout) return self.dout def get_rot_mat(self) -> paddle.Tensor: @@ -722,9 +624,6 @@ def prod_force_virial( n_a_sel=self.nnei_a, n_r_sel=self.nnei_r, ) - # tf.summary.histogram("force", force) - # tf.summary.histogram("virial", virial) - # tf.summary.histogram("atom_virial", atom_virial) return force, virial, atom_virial @@ -754,19 +653,16 @@ def _pass_filter( ------- Tuple[Tensor, Tensor]: output: [1, all_atom, M1*M2], output_qmat: [1, all_atom, M1*3] """ - # natoms = [192, 192, 64 , 128] if input_dict is not None: type_embedding = input_dict.get("type_embedding", None) else: type_embedding = None start_index = 0 - # print(inputs.shape) # [192, 552(nnei*4)],每个原子和它周围nnei个原子的R矩阵(展平后) inputs = paddle.reshape(inputs, [-1, int(natoms[0].item()), int(self.ndescrpt)]) output = [] output_qmat = [] if not self.type_one_side and type_embedding is None: for type_i in range(self.ntypes): - # 按不同原子类型进行处理 inputs_i = paddle.slice( inputs, [0, 1, 2], @@ -776,12 +672,10 @@ def _pass_filter( start_index + natoms[2 + type_i], inputs.shape[2], ], - ) # [1, 某种类型原子个数64/128, 552] - inputs_i = paddle.reshape( - inputs_i, [-1, self.ndescrpt] - ) # [某种类型原子个数64/128, 552] + ) + inputs_i = paddle.reshape(inputs_i, [-1, self.ndescrpt]) filter_name = "filter_type_" + str(type_i) + suffix - layer, qmat = self._filter( # 计算某个类型的原子的 result 和 qmat + layer, qmat = self._filter( inputs_i, type_i, name=filter_name, @@ -789,10 +683,10 @@ def _pass_filter( reuse=reuse, trainable=trainable, activation_fn=self.filter_activation_fn, - ) # [natom, M1*M2], qmat: [natom, M1, 3] + ) layer = paddle.reshape( layer, [inputs.shape[0], natoms[2 + type_i], self.get_dim_out()] - ) # [1, 某种类型原子个数64/128, M1*M2] + ) qmat = paddle.reshape( qmat, [ @@ -800,12 +694,12 @@ def _pass_filter( natoms[2 + type_i], self.get_dim_rot_mat_1() * 3, ], - ) # [1, 某种类型原子个数64/128, 100*3] + ) output.append(layer) output_qmat.append(qmat) start_index += natoms[2 + type_i] else: - ... + raise NotImplementedError() # This branch will not be excecuted at current # inputs_i = inputs # inputs_i = paddle.reshape(inputs_i, [-1, self.ndescrpt]) @@ -847,25 +741,12 @@ def _pass_filter( # output_qmat.append(qmat) output = paddle.concat(output, axis=1) output_qmat = paddle.concat(output_qmat, axis=1) - # output: [1, 192, M1*M2] - # output_qmat: [1, 192, M1*3] return output, output_qmat def _compute_dstats_sys_smth( self, data_coord, data_box, data_atype, natoms_vec, mesh ): input_dict = {} - # dd_all = run_sess( - # self.sub_sess, - # self.stat_descrpt, - # feed_dict={ - # self.place_holders["coord"]: data_coord, - # self.place_holders["type"]: data_atype, - # self.place_holders["natoms_vec"]: natoms_vec, - # self.place_holders["box"]: data_box, - # self.place_holders["default_mesh"]: mesh, - # }, - # ) input_dict["coord"] = paddle.to_tensor(data_coord, dtype="float32") input_dict["box"] = paddle.to_tensor(data_box, dtype="float32") input_dict["type"] = paddle.to_tensor(data_atype, dtype="int32") @@ -875,13 +756,13 @@ def _compute_dstats_sys_smth( input_dict["default_mesh"] = paddle.to_tensor(mesh, dtype="int32") self.stat_descrpt, descrpt_deriv, rij, nlist = op_module.prod_env_mat_a( - input_dict["coord"], # fp32 - input_dict["type"], # int32 - input_dict["box"], # fp32 - input_dict["default_mesh"], # int32 + input_dict["coord"], + input_dict["type"], + input_dict["box"], + input_dict["default_mesh"], self.avg_zero, self.std_ones, - input_dict["natoms_vec"], # int32 + input_dict["natoms_vec"], rcut_a=self.rcut_a, rcut_r=self.rcut_r, rcut_r_smth=self.rcut_r_smth, @@ -987,16 +868,11 @@ def _filter_lower( type_input: int, # outer-loop start_index: int, incrs_index: int, - inputs: paddle.Tensor, # [1, 原子个数(64或128), 552(embedding_dim)] + inputs: paddle.Tensor, nframes: int, natoms: int, type_embedding=None, is_exclude=False, - # activation_fn=None, - # bavg=0.0, - # stddev=1.0, - # trainable=True, - # suffix="", ): """Input env matrix, returns R.G.""" outputs_size = [1] + self.filter_neuron @@ -1007,7 +883,7 @@ def _filter_lower( [0, 1], [0, start_index * 4], [inputs.shape[0], start_index * 4 + incrs_index * 4], - ) # 得到某个类型的原子i对邻域内类型为j的的原子关系,取出二者之间的描述矩阵R natom x nei_type_i x 4 + ) shape_i = inputs_i.shape natom = inputs_i.shape[0] @@ -1018,7 +894,7 @@ def _filter_lower( xyz_scatter = paddle.reshape( paddle.slice(inputs_reshape, [0, 1], [0, 0], [inputs_reshape.shape[0], 1]), [-1, 1], - ) # 得到某个类型的原子i对邻域内类型为j的的原子关系,取出二者之间的描述矩阵R矩阵的第一列s(rij) + ) if type_embedding is not None: xyz_scatter = self._concat_type_embedding( @@ -1029,25 +905,6 @@ def _filter_lower( "compression of type embedded descriptor is not supported at the moment" ) # natom x 4 x outputs_size - # if nvnmd_cfg.enable: - # return filter_lower_R42GR( - # type_i, - # type_input, - # inputs_i, - # is_exclude, - # activation_fn, - # bavg, - # stddev, - # trainable, - # suffix, - # self.seed, - # self.seed_shift, - # self.uniform_seed, - # self.filter_neuron, - # self.filter_precision, - # self.filter_resnet_dt, - # self.embedding_net_variables, - # ) if self.compress and (not is_exclude): if self.type_one_side: net = "filter_-1_net_" + str(type_i) @@ -1071,9 +928,7 @@ def _filter_lower( else: if not is_exclude: # excuted this branch - xyz_scatter_out = self.embedding_nets[type_input][type_i]( - xyz_scatter - ) # 对 s(rij) 进行embedding映射, (natom x nei_type_i) x 1==>(natom x nei_type_i) x 100,得到每个原子i对邻域内类型为j的的原子特征,所有该类型的原子的g_i的concat + xyz_scatter_out = self.embedding_nets[type_input][type_i](xyz_scatter) if (not self.uniform_seed) and (self.seed is not None): self.seed += self.seed_shift else: @@ -1093,23 +948,18 @@ def _filter_lower( # So we need to explicitly assign the shape to paddle.shape(inputs_i)[0] instead of -1 # natom x 4 x outputs_size - # [natom, nei_type_i, 4].T x [natom, nei_type_i, 100] - # 等价于 - # [natom, 4, nei_type_i] x [natom, nei_type_i, 100] - # ==> - # [natom, 4, 100] return paddle.matmul( paddle.reshape( inputs_i, [natom, shape_i[1] // 4, 4] ), # [natom, nei_type_i, 4] xyz_scatter_out, # [natom, nei_type_i, 100] transpose_x=True, - ) # 得到(R_i).T*g_i,即D_i表达式的右半部分 + ) # @cast_precision def _filter( self, - inputs: paddle.Tensor, # [1, 原子个数(64或128), 552(nnei*4)] + inputs: paddle.Tensor, type_input: int, natoms, type_embedding=None, @@ -1149,13 +999,9 @@ def _filter( ------- Tuple[Tensor, Tensor]: result: [64/128, M1*M2], qmat: [64/128, M1, 3] """ + # NOTE: code below is annotated as nframes computation is wrong # nframes = paddle.shape(paddle.reshape(inputs, [-1, natoms[0], self.ndescrpt]))[0] - # 上述 nframes的计算代码是错误的,reshape前后numel根本不相等,会导致程序报错,tf不会报错是因为tf计算图 - # 检测到这个变量后续不会被真正使用到,所以自动进行了优化。 - # nframes由于没有被使用到,所以这段代码没有被执行,所以tf实际运行是没有报错。 - # 复现报错很简单,只需要把这个nframes run出来,会导致这段代码被执行,然后报错。 - # 给 nframes 设置一个无用值 1 即可 nframes = 1 # natom x (nei x 4) shape = inputs.shape @@ -1192,31 +1038,23 @@ def _filter( rets = [] # execute this branch for type_i in range(self.ntypes): - # 计算type_input和type_i的原子之间的特征 ret = self._filter_lower( type_i, type_input, start_index, self.sel_a[type_i], # 46(O)/92(H) - inputs, # [1, 原子个数(64或128), 552(nnei*4)] + inputs, nframes, natoms, type_embedding=type_embedding, is_exclude=(type_input, type_i) in self.exclude_types, - # activation_fn=activation_fn, - # stddev=stddev, - # bavg=bavg, - # trainable=trainable, - # suffix="_" + str(type_i), - ) # ==> [natom_i, 4, 100] + ) if (type_input, type_i) not in self.exclude_types: # add zero is meaningless; skip rets.append(ret) start_index += self.sel_a[type_i] # faster to use accumulate_n than multiple add - xyz_scatter_1 = paddle.add_n( - rets - ) # 得到所有(R_i).T*g_i: [当前类型原子个数64/128, 4, embedding维度M1] + xyz_scatter_1 = paddle.add_n(rets) else: xyz_scatter_1 = self._filter_lower( type_i, @@ -1228,13 +1066,7 @@ def _filter( natoms, type_embedding=type_embedding, is_exclude=False, - # activation_fn=activation_fn, - # stddev=stddev, - # bavg=bavg, - # trainable=trainable, ) - # if nvnmd_cfg.enable: - # return filter_GR2D(xyz_scatter_1) # natom x nei x outputs_size # xyz_scatter = tf.concat(xyz_scatter_total, axis=1) # natom x nei x 4 @@ -1253,17 +1085,14 @@ def _filter( ), self.filter_precision, ) - xyz_scatter_1 = ( - xyz_scatter_1 / nnei - ) # (R_i).T*g_i: [当前类型原子个数64/128, 4, embedding维度M1] + xyz_scatter_1 = xyz_scatter_1 / nnei # natom x 4 x outputs_size_2 xyz_scatter_2 = paddle.slice( xyz_scatter_1, [0, 1, 2], [0, 0, 0], [xyz_scatter_1.shape[0], xyz_scatter_1.shape[1], outputs_size_2], - ) # [当前类型原子个数, R矩阵描述特征数4, 隐层特征数里的前16维特征(M2)], [64, 4, 16] - # (g_i<).T*(R_i): [当前类型原子个数64/128, 4, embedding前M2列] + ) # natom x 3 x outputs_size_2 # qmat = tf.slice(xyz_scatter_2, [0,1,0], [-1, 3, -1]) # natom x 3 x outputs_size_1 @@ -1276,16 +1105,10 @@ def _filter( # natom x outputs_size_1 x 3 qmat = paddle.transpose(qmat, perm=[0, 2, 1]) # [64/128, M1, 3] # natom x outputs_size x outputs_size_2 - result = paddle.matmul( - xyz_scatter_1, xyz_scatter_2, transpose_x=True - ) # [64/128,M1,4]x[64/128,4,M2]==>[64/128,M1,M2] + result = paddle.matmul(xyz_scatter_1, xyz_scatter_2, transpose_x=True) # natom x (outputs_size x outputs_size_2) - result = paddle.reshape( - result, [-1, outputs_size_2 * outputs_size[-1]] - ) # [64,M1*M2] + result = paddle.reshape(result, [-1, outputs_size_2 * outputs_size[-1]]) - # result: [64/128, M1*M2] - # qmat: [64/128, M1, 3] return result, qmat def init_variables( diff --git a/deepmd/entrypoints/freeze.py b/deepmd/entrypoints/freeze.py index 566cff19b1..57efbc09ce 100755 --- a/deepmd/entrypoints/freeze.py +++ b/deepmd/entrypoints/freeze.py @@ -354,23 +354,7 @@ def freeze_graph( InputSpec(shape=[None], dtype="float64"), # box InputSpec(shape=[6], dtype="int32"), # mesh { - # "coord": InputSpec( - # shape=[2880], - # dtype="float64" - # ), - # "type": InputSpec( - # shape=[960], - # dtype="int32" - # ), - # "natoms_vec": InputSpec( - # shape=[4], - # dtype="int32" - # ), "box": InputSpec(shape=[None], dtype="float64"), - # "default_mesh": InputSpec( - # shape=[6], - # dtype="int32" - # ), }, "", False, @@ -380,17 +364,7 @@ def freeze_graph( print( f"[{name}, {param.shape}] generated name in static_model is: {param.name}" ) - # print(f"st_model.descrpt.buffer_rcut.name = {st_model.descrpt.buffer_rcut.name}") - # print( - # f"st_model.descrpt.buffer_ntypes.name = {st_model.descrpt.buffer_ntypes.name}" - # ) - # print( - # f"st_model.fitting.buffer_dfparam.name = {st_model.fitting.buffer_dfparam.name}" - # ) - # print( - # f"st_model.fitting.buffer_daparam.name = {st_model.fitting.buffer_daparam.name}" - # ) - # 跳过对program的裁剪,从而保留rcut、ntypes等不参与前向的参数,从而在C++端可以获取这些参数 + # skip pruning for program so as to keep buffers into files skip_prune_program = True print(f"==>> Set skip_prune_program = {skip_prune_program}") paddle.jit.save(st_model, output, skip_prune_program=skip_prune_program) @@ -475,12 +449,8 @@ def freeze_graph_multi( def freeze( *, - # checkpoint_folder: str, input_file: str, output: str, - # node_names: Optional[str] = None, - # nvnmd_weight: Optional[str] = None, - # united_model: bool = False, **kwargs, ): """Freeze the graph in supplied folder. @@ -494,78 +464,7 @@ def freeze( **kwargs other arguments """ - # We retrieve our checkpoint fullpath - # checkpoint = tf.train.get_checkpoint_state(checkpoint_folder) - # input_checkpoint = checkpoint.model_checkpoint_path - - # # expand the output file to full path - # output_graph = abspath(output) - - # # Before exporting our graph, we need to precise what is our output node - # # This is how TF decides what part of the Graph he has to keep - # # and what part it can dump - # # NOTE: this variable is plural, because you can have multiple output nodes - # # node_names = "energy_test,force_test,virial_test,t_rcut" - - # # We clear devices to allow TensorFlow to control - # # on which device it will load operations - # clear_devices = True - - # # We import the meta graph and retrieve a Saver - # try: - # # In case paralle training - # import horovod.tensorflow as _ # noqa: F401 - # except ImportError: - # pass - # saver = tf.train.import_meta_graph( - # f"{input_checkpoint}.meta", clear_devices=clear_devices - # ) - - # # We retrieve the protobuf graph definition - # graph = tf.get_default_graph() - # try: - # input_graph_def = graph.as_graph_def() - # except google.protobuf.message.DecodeError as e: - # raise GraphTooLargeError( - # "The graph size exceeds 2 GB, the hard limitation of protobuf." - # " Then a DecodeError was raised by protobuf. You should " - # "reduce the size of your model." - # ) from e - # nodes = [n.name for n in input_graph_def.node] - - # # We start a session and restore the graph weights - # with tf.Session() as sess: - # saver.restore(sess, input_checkpoint) - # model_type = run_sess(sess, "model_attr/model_type:0", feed_dict={}).decode( - # "utf-8" - # ) - # if "modifier_attr/type" in nodes: - # modifier_type = run_sess(sess, "modifier_attr/type:0", feed_dict={}).decode( - # "utf-8" - # ) - # else: - # modifier_type = None - # if nvnmd_weight is not None: - # save_weight(sess, nvnmd_weight) # nvnmd - # if model_type != "multi_task": freeze_graph( input_file, output, - # sess, - # input_graph_def, - # nodes, - # model_type, - # modifier_type, - # output_graph, - # node_names, ) - # else: - # freeze_graph_multi( - # sess, - # input_graph_def, - # nodes, - # modifier_type, - # output_graph, - # node_names, - # united_model=united_model, - # ) diff --git a/deepmd/entrypoints/test.py b/deepmd/entrypoints/test.py index ecab32aefb..2ecc52ebe4 100644 --- a/deepmd/entrypoints/test.py +++ b/deepmd/entrypoints/test.py @@ -260,7 +260,7 @@ def test_ener( data.add("energy", 1, atomic=False, must=False, high_prec=True) data.add("force", 3, atomic=True, must=False, high_prec=False) data.add("virial", 9, atomic=False, must=False, high_prec=False) - if dp.has_efield: # False + if dp.has_efield: data.add("efield", 3, atomic=True, must=True, high_prec=False) if has_atom_ener: data.add("atom_ener", 1, atomic=True, must=True, high_prec=False) @@ -278,7 +278,6 @@ def test_ener( numb_test = min(nframes, numb_test) coord = test_data["coord"][:numb_test].reshape([numb_test, -1]) - box = test_data["box"][:numb_test] if dp.has_efield: efield = test_data["efield"][:numb_test].reshape([numb_test, -1]) diff --git a/deepmd/entrypoints/train.py b/deepmd/entrypoints/train.py index 472136a08e..1d295e091d 100755 --- a/deepmd/entrypoints/train.py +++ b/deepmd/entrypoints/train.py @@ -207,7 +207,7 @@ def _do_work(jdata: Dict[str, Any], run_opt: RunOptions, is_compress: bool = Fal dp_random.seed(seed) # setup data modifier - modifier = get_modifier(jdata["model"].get("modifier", None)) # None + modifier = get_modifier(jdata["model"].get("modifier", None)) # check the multi-task mode multi_task_mode = "fitting_net_dict" in jdata["model"] @@ -275,7 +275,6 @@ def _do_work(jdata: Dict[str, Any], run_opt: RunOptions, is_compress: bool = Fal origin_type_map = get_data( jdata["training"]["training_data"], rcut, None, modifier ).get_type_map() - print("model.build") model.build(train_data, stop_batch, origin_type_map=origin_type_map) if not is_compress: @@ -377,7 +376,7 @@ def get_nbor_stat(jdata, rcut, one_type: bool = False): if type_map and len(type_map) == 0: type_map = None multi_task_mode = "data_dict" in jdata["training"] - if not multi_task_mode: # here + if not multi_task_mode: train_data = get_data( jdata["training"]["training_data"], max_rcut, type_map, None ) @@ -419,15 +418,6 @@ def get_nbor_stat(jdata, rcut, one_type: bool = False): min_nbor_dist, max_nbor_size = neistat.get_stat(train_data) - # moved from traier.py as duplicated - # TODO: this is a simple fix but we should have a clear - # architecture to call neighbor stat - # tf.constant( - # min_nbor_dist, - # name="train_attr/min_nbor_dist", - # dtype=GLOBAL_ENER_FLOAT_PRECISION, - # ) - # tf.constant(max_nbor_size, name="train_attr/max_nbor_size", dtype=tf.int32) return min_nbor_dist, max_nbor_size @@ -473,9 +463,7 @@ def update_one_sel(jdata, descriptor): if descriptor["type"] == "loc_frame": return descriptor rcut = descriptor["rcut"] - tmp_sel = get_sel( - jdata, rcut, one_type=descriptor["type"] in ("se_atten",) - ) # [38 72],每个原子截断半径内,最多的邻域原子个数 + tmp_sel = get_sel(jdata, rcut, one_type=descriptor["type"] in ("se_atten",)) sel = descriptor["sel"] # [46, 92] if isinstance(sel, int): # convert to list and finnally convert back to int @@ -507,7 +495,7 @@ def update_sel(jdata): if descrpt_data["type"] == "hybrid": for ii in range(len(descrpt_data["list"])): descrpt_data["list"][ii] = update_one_sel(jdata, descrpt_data["list"][ii]) - else: # here + else: descrpt_data = update_one_sel(jdata, descrpt_data) jdata["model"]["descriptor"] = descrpt_data return jdata diff --git a/deepmd/env.py b/deepmd/env.py index 9eb7e1e6a8..044301c628 100644 --- a/deepmd/env.py +++ b/deepmd/env.py @@ -372,10 +372,7 @@ def get_module(module_name: str) -> "ModuleType": raise FileNotFoundError(f"module {module_name} does not exist") else: try: - # module = tf.load_op_library(str(module_file)) - import paddle_deepmd_lib - - module = paddle_deepmd_lib + import paddle_deepmd_lib as module except tf.errors.NotFoundError as e: # check CXX11_ABI_FLAG is compatiblity diff --git a/deepmd/fit/ener.py b/deepmd/fit/ener.py index eaac525c04..27bf6a2105 100644 --- a/deepmd/fit/ener.py +++ b/deepmd/fit/ener.py @@ -431,7 +431,6 @@ def _build_lower( bias_atom_e=0.0, type_suffix="", suffix="", - # reuse=None, type_i=None, ): # cut-out inputs @@ -462,51 +461,19 @@ def _build_lower( ext_aparam = paddle.cast(ext_aparam, self.fitting_precision) layer = paddle.concat([layer, ext_aparam], axis=1) - # if nvnmd_cfg.enable: - # one_layer = one_layer_nvnmd - # else: - # one_layer = one_layer_deepmd for ii in range(0, len(self.n_neuron)): - # if self.layer_name is not None and self.layer_name[ii] is not None: - # layer_suffix = "share_" + self.layer_name[ii] + type_suffix - # layer_reuse = tf.AUTO_REUSE - # else: - # layer_suffix = "layer_" + str(ii) + type_suffix + suffix - # layer_reuse = reuse if ii >= 1 and self.n_neuron[ii] == self.n_neuron[ii - 1]: layer += self.one_layers[type_i][ii](layer) else: layer = self.one_layers[type_i][ii](layer) - # print(f"use {ii} of {len(self.one_layers)}_{type_i}") - # if (not self.uniform_seed) and (self.seed is not None): - # self.seed += self.seed_shift - # if self.layer_name is not None and self.layer_name[-1] is not None: - # layer_suffix = "share_" + self.layer_name[-1] + type_suffix - # layer_reuse = tf.AUTO_REUSE - # else: - # layer_suffix = "final_layer" + type_suffix + suffix - # layer_reuse = reuse if (not self.uniform_seed) and (self.seed is not None): self.seed += self.seed_shift - final_layer = self.final_layers[type_i]( - layer, - # 1, - # activation_fn=None, - # bavg=bias_atom_e, - # name=layer_suffix, - # reuse=layer_reuse, - # seed=self.seed, - # precision=self.fitting_precision, - # trainable=self.trainable[-1], - # uniform_seed=self.uniform_seed, - # initial_variables=self.fitting_net_variables, - # mixed_prec=self.mixed_prec, - # final_layer=True, - ) + + final_layer = self.final_layers[type_i](layer) if (not self.uniform_seed) and (self.seed is not None): self.seed += self.seed_shift - return final_layer # [natoms, 1] + return final_layer def forward( self, @@ -577,9 +544,7 @@ def forward( self.bias_atom_e[type_i] = self.bias_atom_e[type_i] self.bias_atom_e = self.bias_atom_e[:ntypes_atom] - inputs = paddle.reshape( - inputs, [-1, natoms[0], self.dim_descrpt] - ) # [1, all_atoms, M1*M2] + inputs = paddle.reshape(inputs, [-1, natoms[0], self.dim_descrpt]) if len(self.atom_ener): # only for atom_ener nframes = input_dict.get("nframes") @@ -643,18 +608,6 @@ def forward( start_index = 0 outs_list = [] for type_i in range(ntypes_atom): - # final_layer = inputs - # for layer_j in range(type_i * ntypes_atom, (type_i + 1) * ntypes_atom): - # final_layer = self.one_layers[layer_j](final_layer) - # final_layer = self.final_layers[type_i](final_layer) - # print(final_layer.shape) - - # # concat the results - # if type_i < len(self.atom_ener) and self.atom_ener[type_i] is not None: - # zero_layer = inputs_zero - # for layer_j in range(type_i * ntypes_atom, (type_i + 1) * ntypes_atom): - # zero_layer = self.one_layers[layer_j](zero_layer) - # zero_layer = self.final_layers[type_i](zero_layer) final_layer = self._build_lower( start_index, natoms[2 + type_i], @@ -664,7 +617,6 @@ def forward( bias_atom_e=0.0, type_suffix="_type_" + str(type_i), suffix=suffix, - # reuse=reuse, type_i=type_i, ) # concat the results @@ -678,13 +630,12 @@ def forward( bias_atom_e=0.0, type_suffix="_type_" + str(type_i), suffix=suffix, - # reuse=True, type_i=type_i, ) final_layer -= zero_layer final_layer = paddle.reshape( final_layer, [paddle.shape(inputs)[0], natoms[2 + type_i]] - ) # [1, natoms] + ) outs_list.append(final_layer) start_index += natoms[2 + type_i] # concat the results @@ -731,7 +682,7 @@ def forward( ), [paddle.shape(inputs)[0], paddle.sum(natoms[2 : 2 + ntypes_atom]).item()], ) - outs = outs + self.add_type # 类型编码(类似于transformer的位置编码,每种类型自己有一个特征,加到原特征上) + outs = outs + self.add_type outs *= atype_filter self.atom_ener_after = outs diff --git a/deepmd/fit/ener_tf.py b/deepmd/fit/ener_tf.py deleted file mode 100644 index aacdf5b67f..0000000000 --- a/deepmd/fit/ener_tf.py +++ /dev/null @@ -1,914 +0,0 @@ -# SPDX-License-Identifier: LGPL-3.0-or-later -import logging -from typing import ( - List, - Optional, -) - -import numpy as np - -from deepmd.common import ( - add_data_requirement, - cast_precision, - get_activation_func, - get_precision, -) -from deepmd.env import ( - GLOBAL_TF_FLOAT_PRECISION, - global_cvt_2_tf_float, - tf, -) -from deepmd.fit.fitting import ( - Fitting, -) -from deepmd.infer import ( - DeepPotential, -) -from deepmd.loss.ener import ( - EnerDipoleLoss, - EnerSpinLoss, - EnerStdLoss, -) -from deepmd.loss.loss import ( - Loss, -) -from deepmd.nvnmd.fit.ener import ( - one_layer_nvnmd, -) -from deepmd.nvnmd.utils.config import ( - nvnmd_cfg, -) -from deepmd.utils.errors import ( - GraphWithoutTensorError, -) -from deepmd.utils.graph import ( - get_fitting_net_variables_from_graph_def, - get_tensor_by_name_from_graph, -) -from deepmd.utils.network import one_layer as one_layer_deepmd -from deepmd.utils.network import ( - one_layer_rand_seed_shift, -) -from deepmd.utils.spin import ( - Spin, -) - -log = logging.getLogger(__name__) - - -@Fitting.register("ener") -class EnerFitting(Fitting): - r"""Fitting the energy of the system. The force and the virial can also be trained. - - The potential energy :math:`E` is a fitting network function of the descriptor :math:`\mathcal{D}`: - - .. math:: - E(\mathcal{D}) = \mathcal{L}^{(n)} \circ \mathcal{L}^{(n-1)} - \circ \cdots \circ \mathcal{L}^{(1)} \circ \mathcal{L}^{(0)} - - The first :math:`n` hidden layers :math:`\mathcal{L}^{(0)}, \cdots, \mathcal{L}^{(n-1)}` are given by - - .. math:: - \mathbf{y}=\mathcal{L}(\mathbf{x};\mathbf{w},\mathbf{b})= - \boldsymbol{\phi}(\mathbf{x}^T\mathbf{w}+\mathbf{b}) - - where :math:`\mathbf{x} \in \mathbb{R}^{N_1}` is the input vector and :math:`\mathbf{y} \in \mathbb{R}^{N_2}` - is the output vector. :math:`\mathbf{w} \in \mathbb{R}^{N_1 \times N_2}` and - :math:`\mathbf{b} \in \mathbb{R}^{N_2}` are weights and biases, respectively, - both of which are trainable if `trainable[i]` is `True`. :math:`\boldsymbol{\phi}` - is the activation function. - - The output layer :math:`\mathcal{L}^{(n)}` is given by - - .. math:: - \mathbf{y}=\mathcal{L}^{(n)}(\mathbf{x};\mathbf{w},\mathbf{b})= - \mathbf{x}^T\mathbf{w}+\mathbf{b} - - where :math:`\mathbf{x} \in \mathbb{R}^{N_{n-1}}` is the input vector and :math:`\mathbf{y} \in \mathbb{R}` - is the output scalar. :math:`\mathbf{w} \in \mathbb{R}^{N_{n-1}}` and - :math:`\mathbf{b} \in \mathbb{R}` are weights and bias, respectively, - both of which are trainable if `trainable[n]` is `True`. - - Parameters - ---------- - descrpt - The descrptor :math:`\mathcal{D}` - neuron - Number of neurons :math:`N` in each hidden layer of the fitting net - resnet_dt - Time-step `dt` in the resnet construction: - :math:`y = x + dt * \phi (Wx + b)` - numb_fparam - Number of frame parameter - numb_aparam - Number of atomic parameter - rcond - The condition number for the regression of atomic energy. - tot_ener_zero - Force the total energy to zero. Useful for the charge fitting. - trainable - If the weights of fitting net are trainable. - Suppose that we have :math:`N_l` hidden layers in the fitting net, - this list is of length :math:`N_l + 1`, specifying if the hidden layers and the output layer are trainable. - seed - Random seed for initializing the network parameters. - atom_ener - Specifying atomic energy contribution in vacuum. The `set_davg_zero` key in the descrptor should be set. - activation_function - The activation function :math:`\boldsymbol{\phi}` in the embedding net. Supported options are |ACTIVATION_FN| - precision - The precision of the embedding net parameters. Supported options are |PRECISION| - uniform_seed - Only for the purpose of backward compatibility, retrieves the old behavior of using the random seed - layer_name : list[Optional[str]], optional - The name of the each layer. If two layers, either in the same fitting or different fittings, - have the same name, they will share the same neural network parameters. - use_aparam_as_mask: bool, optional - If True, the atomic parameters will be used as a mask that determines the atom is real/virtual. - And the aparam will not be used as the atomic parameters for embedding. - """ - - def __init__( - self, - descrpt: tf.Tensor, - neuron: List[int] = [120, 120, 120], - resnet_dt: bool = True, - numb_fparam: int = 0, - numb_aparam: int = 0, - rcond: Optional[float] = None, - tot_ener_zero: bool = False, - trainable: Optional[List[bool]] = None, - seed: Optional[int] = None, - atom_ener: List[float] = [], - activation_function: str = "tanh", - precision: str = "default", - uniform_seed: bool = False, - layer_name: Optional[List[Optional[str]]] = None, - use_aparam_as_mask: bool = False, - spin: Optional[Spin] = None, - **kwargs, - ) -> None: - """Constructor.""" - # model param - self.ntypes = descrpt.get_ntypes() - self.dim_descrpt = descrpt.get_dim_out() - self.use_aparam_as_mask = use_aparam_as_mask - # args = ()\ - # .add('numb_fparam', int, default = 0)\ - # .add('numb_aparam', int, default = 0)\ - # .add('neuron', list, default = [120,120,120], alias = 'n_neuron')\ - # .add('resnet_dt', bool, default = True)\ - # .add('rcond', float, default = 1e-3) \ - # .add('tot_ener_zero', bool, default = False) \ - # .add('seed', int) \ - # .add('atom_ener', list, default = [])\ - # .add("activation_function", str, default = "tanh")\ - # .add("precision", str, default = "default")\ - # .add("trainable", [list, bool], default = True) - self.numb_fparam = numb_fparam - self.numb_aparam = numb_aparam - self.n_neuron = neuron - self.resnet_dt = resnet_dt - self.rcond = rcond - self.seed = seed - self.uniform_seed = uniform_seed - self.spin = spin - self.ntypes_spin = self.spin.get_ntypes_spin() if self.spin is not None else 0 - self.seed_shift = one_layer_rand_seed_shift() - self.tot_ener_zero = tot_ener_zero - self.fitting_activation_fn = get_activation_func(activation_function) - self.fitting_precision = get_precision(precision) - self.trainable = trainable - if self.trainable is None: - self.trainable = [True for ii in range(len(self.n_neuron) + 1)] - if isinstance(self.trainable, bool): - self.trainable = [self.trainable] * (len(self.n_neuron) + 1) - assert ( - len(self.trainable) == len(self.n_neuron) + 1 - ), "length of trainable should be that of n_neuron + 1" - self.atom_ener = [] - self.atom_ener_v = atom_ener - for at, ae in enumerate(atom_ener): - if ae is not None: - self.atom_ener.append( - tf.constant(ae, GLOBAL_TF_FLOAT_PRECISION, name="atom_%d_ener" % at) - ) - else: - self.atom_ener.append(None) - self.useBN = False - self.bias_atom_e = np.zeros(self.ntypes, dtype=np.float64) - # data requirement - if self.numb_fparam > 0: - add_data_requirement( - "fparam", self.numb_fparam, atomic=False, must=True, high_prec=False - ) - self.fparam_avg = None - self.fparam_std = None - self.fparam_inv_std = None - if self.numb_aparam > 0: - add_data_requirement( - "aparam", self.numb_aparam, atomic=True, must=True, high_prec=False - ) - self.aparam_avg = None - self.aparam_std = None - self.aparam_inv_std = None - - self.fitting_net_variables = None - self.mixed_prec = None - self.layer_name = layer_name - if self.layer_name is not None: - assert isinstance(self.layer_name, list), "layer_name should be a list" - assert ( - len(self.layer_name) == len(self.n_neuron) + 1 - ), "length of layer_name should be that of n_neuron + 1" - - def get_numb_fparam(self) -> int: - """Get the number of frame parameters.""" - return self.numb_fparam - - def get_numb_aparam(self) -> int: - """Get the number of atomic parameters.""" - return self.numb_fparam - - def compute_output_stats(self, all_stat: dict, mixed_type: bool = False) -> None: - """Compute the ouput statistics. - - Parameters - ---------- - all_stat - must have the following components: - all_stat['energy'] of shape n_sys x n_batch x n_frame - can be prepared by model.make_stat_input - mixed_type - Whether to perform the mixed_type mode. - If True, the input data has the mixed_type format (see doc/model/train_se_atten.md), - in which frames in a system may have different natoms_vec(s), with the same nloc. - """ - self.bias_atom_e = self._compute_output_stats( - all_stat, rcond=self.rcond, mixed_type=mixed_type - ) - - def _compute_output_stats(self, all_stat, rcond=1e-3, mixed_type=False): - data = all_stat["energy"] - # data[sys_idx][batch_idx][frame_idx] - sys_ener = [] - for ss in range(len(data)): - sys_data = [] - for ii in range(len(data[ss])): - for jj in range(len(data[ss][ii])): - sys_data.append(data[ss][ii][jj]) - sys_data = np.concatenate(sys_data) - sys_ener.append(np.average(sys_data)) - sys_ener = np.array(sys_ener) - sys_tynatom = [] - if mixed_type: - data = all_stat["real_natoms_vec"] - nsys = len(data) - for ss in range(len(data)): - tmp_tynatom = [] - for ii in range(len(data[ss])): - for jj in range(len(data[ss][ii])): - tmp_tynatom.append(data[ss][ii][jj].astype(np.float64)) - tmp_tynatom = np.average(np.array(tmp_tynatom), axis=0) - sys_tynatom.append(tmp_tynatom) - else: - data = all_stat["natoms_vec"] - nsys = len(data) - for ss in range(len(data)): - sys_tynatom.append(data[ss][0].astype(np.float64)) - sys_tynatom = np.array(sys_tynatom) - sys_tynatom = np.reshape(sys_tynatom, [nsys, -1]) - sys_tynatom = sys_tynatom[:, 2:] - if len(self.atom_ener) > 0: - # Atomic energies stats are incorrect if atomic energies are assigned. - # In this situation, we directly use these assigned energies instead of computing stats. - # This will make the loss decrease quickly - assigned_atom_ener = np.array( - [ee for ee in self.atom_ener_v if ee is not None] - ) - assigned_ener_idx = [ - ii for ii, ee in enumerate(self.atom_ener_v) if ee is not None - ] - # np.dot out size: nframe - sys_ener -= np.dot(sys_tynatom[:, assigned_ener_idx], assigned_atom_ener) - sys_tynatom[:, assigned_ener_idx] = 0.0 - energy_shift, resd, rank, s_value = np.linalg.lstsq( - sys_tynatom, sys_ener, rcond=rcond - ) - if len(self.atom_ener) > 0: - for ii in assigned_ener_idx: - energy_shift[ii] = self.atom_ener_v[ii] - return energy_shift - - def compute_input_stats(self, all_stat: dict, protection: float = 1e-2) -> None: - """Compute the input statistics. - - Parameters - ---------- - all_stat - if numb_fparam > 0 must have all_stat['fparam'] - if numb_aparam > 0 must have all_stat['aparam'] - can be prepared by model.make_stat_input - protection - Divided-by-zero protection - """ - # stat fparam - if self.numb_fparam > 0: - cat_data = np.concatenate(all_stat["fparam"], axis=0) - cat_data = np.reshape(cat_data, [-1, self.numb_fparam]) - self.fparam_avg = np.average(cat_data, axis=0) - self.fparam_std = np.std(cat_data, axis=0) - for ii in range(self.fparam_std.size): - if self.fparam_std[ii] < protection: - self.fparam_std[ii] = protection - self.fparam_inv_std = 1.0 / self.fparam_std - # stat aparam - if self.numb_aparam > 0: - sys_sumv = [] - sys_sumv2 = [] - sys_sumn = [] - for ss_ in all_stat["aparam"]: - ss = np.reshape(ss_, [-1, self.numb_aparam]) - sys_sumv.append(np.sum(ss, axis=0)) - sys_sumv2.append(np.sum(np.multiply(ss, ss), axis=0)) - sys_sumn.append(ss.shape[0]) - sumv = np.sum(sys_sumv, axis=0) - sumv2 = np.sum(sys_sumv2, axis=0) - sumn = np.sum(sys_sumn) - self.aparam_avg = (sumv) / sumn - self.aparam_std = self._compute_std(sumv2, sumv, sumn) - for ii in range(self.aparam_std.size): - if self.aparam_std[ii] < protection: - self.aparam_std[ii] = protection - self.aparam_inv_std = 1.0 / self.aparam_std - - def _compute_std(self, sumv2, sumv, sumn): - return np.sqrt(sumv2 / sumn - np.multiply(sumv / sumn, sumv / sumn)) - - @cast_precision - def _build_lower( - self, - start_index, - natoms, - inputs, - fparam=None, - aparam=None, - bias_atom_e=0.0, - type_suffix="", - suffix="", - reuse=None, - ): - # cut-out inputs - inputs_i = tf.slice(inputs, [0, start_index, 0], [-1, natoms, -1]) - inputs_i = tf.reshape(inputs_i, [-1, self.dim_descrpt]) - layer = inputs_i - if fparam is not None: - ext_fparam = tf.tile(fparam, [1, natoms]) - ext_fparam = tf.reshape(ext_fparam, [-1, self.numb_fparam]) - ext_fparam = tf.cast(ext_fparam, self.fitting_precision) - layer = tf.concat([layer, ext_fparam], axis=1) - if aparam is not None: - ext_aparam = tf.slice( - aparam, - [0, start_index * self.numb_aparam], - [-1, natoms * self.numb_aparam], - ) - ext_aparam = tf.reshape(ext_aparam, [-1, self.numb_aparam]) - ext_aparam = tf.cast(ext_aparam, self.fitting_precision) - layer = tf.concat([layer, ext_aparam], axis=1) - - if nvnmd_cfg.enable: - one_layer = one_layer_nvnmd - else: - one_layer = one_layer_deepmd - for ii in range(0, len(self.n_neuron)): - if self.layer_name is not None and self.layer_name[ii] is not None: - layer_suffix = "share_" + self.layer_name[ii] + type_suffix - layer_reuse = tf.AUTO_REUSE - else: - layer_suffix = "layer_" + str(ii) + type_suffix + suffix - layer_reuse = reuse - if ii >= 1 and self.n_neuron[ii] == self.n_neuron[ii - 1]: - layer += one_layer( - layer, - self.n_neuron[ii], - name=layer_suffix, - reuse=layer_reuse, - seed=self.seed, - use_timestep=self.resnet_dt, - activation_fn=self.fitting_activation_fn, - precision=self.fitting_precision, - trainable=self.trainable[ii], - uniform_seed=self.uniform_seed, - initial_variables=self.fitting_net_variables, - mixed_prec=self.mixed_prec, - ) - else: - layer = one_layer( - layer, - self.n_neuron[ii], - name=layer_suffix, - reuse=layer_reuse, - seed=self.seed, - activation_fn=self.fitting_activation_fn, - precision=self.fitting_precision, - trainable=self.trainable[ii], - uniform_seed=self.uniform_seed, - initial_variables=self.fitting_net_variables, - mixed_prec=self.mixed_prec, - ) - if (not self.uniform_seed) and (self.seed is not None): - self.seed += self.seed_shift - if self.layer_name is not None and self.layer_name[-1] is not None: - layer_suffix = "share_" + self.layer_name[-1] + type_suffix - layer_reuse = tf.AUTO_REUSE - else: - layer_suffix = "final_layer" + type_suffix + suffix - layer_reuse = reuse - final_layer = one_layer( - layer, - 1, - activation_fn=None, - bavg=bias_atom_e, - name=layer_suffix, - reuse=layer_reuse, - seed=self.seed, - precision=self.fitting_precision, - trainable=self.trainable[-1], - uniform_seed=self.uniform_seed, - initial_variables=self.fitting_net_variables, - mixed_prec=self.mixed_prec, - final_layer=True, - ) - if (not self.uniform_seed) and (self.seed is not None): - self.seed += self.seed_shift - - return final_layer - - def build( - self, - inputs: tf.Tensor, - natoms: tf.Tensor, - input_dict: Optional[dict] = None, - reuse: Optional[bool] = None, - suffix: str = "", - ) -> tf.Tensor: - """Build the computational graph for fitting net. - - Parameters - ---------- - inputs - The input descriptor - input_dict - Additional dict for inputs. - if numb_fparam > 0, should have input_dict['fparam'] - if numb_aparam > 0, should have input_dict['aparam'] - natoms - The number of atoms. This tensor has the length of Ntypes + 2 - natoms[0]: number of local atoms - natoms[1]: total number of atoms held by this processor - natoms[i]: 2 <= i < Ntypes+2, number of type i atoms - reuse - The weights in the networks should be reused when get the variable. - suffix - Name suffix to identify this descriptor - - Returns - ------- - ener - The system energy - """ - if input_dict is None: - input_dict = {} - bias_atom_e = self.bias_atom_e - type_embedding = input_dict.get("type_embedding", None) - atype = input_dict.get("atype", None) - if self.numb_fparam > 0: - if self.fparam_avg is None: - self.fparam_avg = 0.0 - if self.fparam_inv_std is None: - self.fparam_inv_std = 1.0 - if self.numb_aparam > 0: - if self.aparam_avg is None: - self.aparam_avg = 0.0 - if self.aparam_inv_std is None: - self.aparam_inv_std = 1.0 - - ntypes_atom = self.ntypes - self.ntypes_spin - if self.spin is not None: - for type_i in range(ntypes_atom): - if self.bias_atom_e.shape[0] != self.ntypes: - self.bias_atom_e = np.pad( - self.bias_atom_e, - (0, self.ntypes_spin), - "constant", - constant_values=(0, 0), - ) - bias_atom_e = self.bias_atom_e - if self.spin.use_spin[type_i]: - self.bias_atom_e[type_i] = ( - self.bias_atom_e[type_i] - + self.bias_atom_e[type_i + ntypes_atom] - ) - else: - self.bias_atom_e[type_i] = self.bias_atom_e[type_i] - self.bias_atom_e = self.bias_atom_e[:ntypes_atom] - - with tf.variable_scope("fitting_attr" + suffix, reuse=reuse): - # t_dfparam = tf.constant(self.numb_fparam, name="dfparam", dtype=tf.int32) - # t_daparam = tf.constant(self.numb_aparam, name="daparam", dtype=tf.int32) - self.t_bias_atom_e = tf.get_variable( - "t_bias_atom_e", - self.bias_atom_e.shape, - dtype=GLOBAL_TF_FLOAT_PRECISION, - trainable=False, - initializer=tf.constant_initializer(self.bias_atom_e), - ) - if self.numb_fparam > 0: - t_fparam_avg = tf.get_variable( - "t_fparam_avg", - self.numb_fparam, - dtype=GLOBAL_TF_FLOAT_PRECISION, - trainable=False, - initializer=tf.constant_initializer(self.fparam_avg), - ) - t_fparam_istd = tf.get_variable( - "t_fparam_istd", - self.numb_fparam, - dtype=GLOBAL_TF_FLOAT_PRECISION, - trainable=False, - initializer=tf.constant_initializer(self.fparam_inv_std), - ) - if self.numb_aparam > 0: - t_aparam_avg = tf.get_variable( - "t_aparam_avg", - self.numb_aparam, - dtype=GLOBAL_TF_FLOAT_PRECISION, - trainable=False, - initializer=tf.constant_initializer(self.aparam_avg), - ) - t_aparam_istd = tf.get_variable( - "t_aparam_istd", - self.numb_aparam, - dtype=GLOBAL_TF_FLOAT_PRECISION, - trainable=False, - initializer=tf.constant_initializer(self.aparam_inv_std), - ) - - inputs = tf.reshape(inputs, [-1, natoms[0], self.dim_descrpt]) - if len(self.atom_ener): - # only for atom_ener - nframes = input_dict.get("nframes") - if nframes is not None: - # like inputs, but we don't want to add a dependency on inputs - inputs_zero = tf.zeros( - (nframes, natoms[0], self.dim_descrpt), - dtype=GLOBAL_TF_FLOAT_PRECISION, - ) - else: - inputs_zero = tf.zeros_like(inputs, dtype=GLOBAL_TF_FLOAT_PRECISION) - - if bias_atom_e is not None: - assert len(bias_atom_e) == self.ntypes - - fparam = None - if self.numb_fparam > 0: - fparam = input_dict["fparam"] - fparam = tf.reshape(fparam, [-1, self.numb_fparam]) - fparam = (fparam - t_fparam_avg) * t_fparam_istd - - aparam = None - if not self.use_aparam_as_mask: - if self.numb_aparam > 0: - aparam = input_dict["aparam"] - aparam = tf.reshape(aparam, [-1, self.numb_aparam]) - aparam = (aparam - t_aparam_avg) * t_aparam_istd - aparam = tf.reshape(aparam, [-1, self.numb_aparam * natoms[0]]) - - atype_nall = tf.reshape(atype, [-1, natoms[1]]) - self.atype_nloc = tf.slice( - atype_nall, [0, 0], [-1, natoms[0]] - ) ## lammps will make error - atype_filter = tf.cast(self.atype_nloc >= 0, GLOBAL_TF_FLOAT_PRECISION) - self.atype_nloc = tf.reshape(self.atype_nloc, [-1]) - # prevent embedding_lookup error, - # but the filter will be applied anyway - self.atype_nloc = tf.clip_by_value(self.atype_nloc, 0, self.ntypes - 1) - - ## if spin is used - if self.spin is not None: - self.atype_nloc = tf.slice( - atype_nall, [0, 0], [-1, tf.reduce_sum(natoms[2 : 2 + ntypes_atom])] - ) - atype_filter = tf.cast(self.atype_nloc >= 0, GLOBAL_TF_FLOAT_PRECISION) - self.atype_nloc = tf.reshape(self.atype_nloc, [-1]) - if ( - nvnmd_cfg.enable - and nvnmd_cfg.quantize_descriptor - and nvnmd_cfg.restore_descriptor - and (nvnmd_cfg.version == 1) - ): - type_embedding = nvnmd_cfg.map["t_ebd"] - if type_embedding is not None: - atype_embed = tf.nn.embedding_lookup(type_embedding, self.atype_nloc) - else: - atype_embed = None - - self.atype_embed = atype_embed - - if atype_embed is None: - start_index = 0 - outs_list = [] - for type_i in range(ntypes_atom): - final_layer = self._build_lower( - start_index, - natoms[2 + type_i], - inputs, - fparam, - aparam, - bias_atom_e=0.0, - type_suffix="_type_" + str(type_i), - suffix=suffix, - reuse=reuse, - ) - # concat the results - if type_i < len(self.atom_ener) and self.atom_ener[type_i] is not None: - zero_layer = self._build_lower( - start_index, - natoms[2 + type_i], - inputs_zero, - fparam, - aparam, - bias_atom_e=0.0, - type_suffix="_type_" + str(type_i), - suffix=suffix, - reuse=True, - ) - final_layer -= zero_layer - final_layer = tf.reshape( - final_layer, [tf.shape(inputs)[0], natoms[2 + type_i]] - ) - outs_list.append(final_layer) - start_index += natoms[2 + type_i] - # concat the results - # concat once may be faster than multiple concat - outs = tf.concat(outs_list, axis=1) - # with type embedding - else: - atype_embed = tf.cast(atype_embed, GLOBAL_TF_FLOAT_PRECISION) - type_shape = atype_embed.get_shape().as_list() - inputs = tf.concat( - [tf.reshape(inputs, [-1, self.dim_descrpt]), atype_embed], axis=1 - ) - original_dim_descrpt = self.dim_descrpt - self.dim_descrpt = self.dim_descrpt + type_shape[1] - inputs = tf.reshape(inputs, [-1, natoms[0], self.dim_descrpt]) - final_layer = self._build_lower( - 0, - natoms[0], - inputs, - fparam, - aparam, - bias_atom_e=0.0, - suffix=suffix, - reuse=reuse, - ) - if len(self.atom_ener): - # remove contribution in vacuum - inputs_zero = tf.concat( - [tf.reshape(inputs_zero, [-1, original_dim_descrpt]), atype_embed], - axis=1, - ) - inputs_zero = tf.reshape(inputs_zero, [-1, natoms[0], self.dim_descrpt]) - zero_layer = self._build_lower( - 0, - natoms[0], - inputs_zero, - fparam, - aparam, - bias_atom_e=0.0, - suffix=suffix, - reuse=True, - ) - # atomic energy will be stored in `self.t_bias_atom_e` which is not trainable - final_layer -= zero_layer - outs = tf.reshape(final_layer, [tf.shape(inputs)[0], natoms[0]]) - # add bias - self.atom_ener_before = outs * atype_filter - # atomic bias energy from data statistics - self.atom_bias_ener = tf.reshape( - tf.nn.embedding_lookup(self.t_bias_atom_e, self.atype_nloc), - [tf.shape(inputs)[0], tf.reduce_sum(natoms[2 : 2 + ntypes_atom])], - ) - outs = outs + self.atom_bias_ener - outs *= atype_filter - self.atom_bias_ener *= atype_filter - self.atom_ener_after = outs - - if self.tot_ener_zero: - force_tot_ener = 0.0 - outs = tf.reshape(outs, [-1, tf.reduce_sum(natoms[2 : 2 + ntypes_atom])]) - outs_mean = tf.reshape(tf.reduce_mean(outs, axis=1), [-1, 1]) - outs_mean = outs_mean - tf.ones_like( - outs_mean, dtype=GLOBAL_TF_FLOAT_PRECISION - ) * ( - force_tot_ener - / global_cvt_2_tf_float(tf.reduce_sum(natoms[2 : 2 + ntypes_atom])) - ) - outs = outs - outs_mean - outs = tf.reshape(outs, [-1]) - - tf.summary.histogram("fitting_net_output", outs) - return tf.reshape(outs, [-1]) - - def init_variables( - self, - graph: tf.Graph, - graph_def: tf.GraphDef, - suffix: str = "", - ) -> None: - """Init the fitting net variables with the given dict. - - Parameters - ---------- - graph : tf.Graph - The input frozen model graph - graph_def : tf.GraphDef - The input frozen model graph_def - suffix : str - suffix to name scope - """ - self.fitting_net_variables = get_fitting_net_variables_from_graph_def( - graph_def, suffix=suffix - ) - if self.layer_name is not None: - # shared variables have no suffix - shared_variables = get_fitting_net_variables_from_graph_def( - graph_def, suffix="" - ) - self.fitting_net_variables.update(shared_variables) - if self.numb_fparam > 0: - self.fparam_avg = get_tensor_by_name_from_graph( - graph, "fitting_attr%s/t_fparam_avg" % suffix - ) - self.fparam_inv_std = get_tensor_by_name_from_graph( - graph, "fitting_attr%s/t_fparam_istd" % suffix - ) - if self.numb_aparam > 0: - self.aparam_avg = get_tensor_by_name_from_graph( - graph, "fitting_attr%s/t_aparam_avg" % suffix - ) - self.aparam_inv_std = get_tensor_by_name_from_graph( - graph, "fitting_attr%s/t_aparam_istd" % suffix - ) - try: - self.bias_atom_e = get_tensor_by_name_from_graph( - graph, "fitting_attr%s/t_bias_atom_e" % suffix - ) - except GraphWithoutTensorError: - # for compatibility, old models has no t_bias_atom_e - pass - - def change_energy_bias( - self, - data, - frozen_model, - origin_type_map, - full_type_map, - bias_shift="delta", - ntest=10, - ) -> None: - """Change the energy bias according to the input data and the pretrained model. - - Parameters - ---------- - data : DeepmdDataSystem - The training data. - frozen_model : str - The path file of frozen model. - origin_type_map : list - The original type_map in dataset, they are targets to change the energy bias. - full_type_map : str - The full type_map in pretrained model - bias_shift : str - The mode for changing energy bias : ['delta', 'statistic'] - 'delta' : perform predictions on energies of target dataset, - and do least sqaure on the errors to obtain the target shift as bias. - 'statistic' : directly use the statistic energy bias in the target dataset. - ntest : int - The number of test samples in a system to change the energy bias. - """ - type_numbs = [] - energy_ground_truth = [] - energy_predict = [] - sorter = np.argsort(full_type_map) - idx_type_map = sorter[ - np.searchsorted(full_type_map, origin_type_map, sorter=sorter) - ] - mixed_type = data.mixed_type - numb_type = len(full_type_map) - dp = None - if bias_shift == "delta": - # init model - dp = DeepPotential(frozen_model) - for sys in data.data_systems: - test_data = sys.get_test() - nframes = test_data["box"].shape[0] - numb_test = min(nframes, ntest) - if mixed_type: - atype = test_data["type"][:numb_test].reshape([numb_test, -1]) - else: - atype = test_data["type"][0] - assert np.array( - [i in idx_type_map for i in list(set(atype.reshape(-1)))] - ).all(), "Some types are not in 'type_map'!" - energy_ground_truth.append( - test_data["energy"][:numb_test].reshape([numb_test, 1]) - ) - if mixed_type: - type_numbs.append( - np.array( - [(atype == i).sum(axis=-1) for i in idx_type_map], - dtype=np.int32, - ).T - ) - else: - type_numbs.append( - np.tile( - np.bincount(atype, minlength=numb_type)[idx_type_map], - (numb_test, 1), - ) - ) - if bias_shift == "delta": - coord = test_data["coord"][:numb_test].reshape([numb_test, -1]) - if sys.pbc: - box = test_data["box"][:numb_test] - else: - box = None - ret = dp.eval(coord, box, atype, mixed_type=mixed_type) - energy_predict.append(ret[0].reshape([numb_test, 1])) - type_numbs = np.concatenate(type_numbs) - energy_ground_truth = np.concatenate(energy_ground_truth) - old_bias = self.bias_atom_e[idx_type_map] - if bias_shift == "delta": - energy_predict = np.concatenate(energy_predict) - bias_diff = energy_ground_truth - energy_predict - delta_bias = np.linalg.lstsq(type_numbs, bias_diff, rcond=None)[0] - unbias_e = energy_predict + type_numbs @ delta_bias - atom_numbs = type_numbs.sum(-1) - rmse_ae = ( - np.sqrt(np.square(unbias_e - energy_ground_truth)) / atom_numbs - ).mean() - self.bias_atom_e[idx_type_map] += delta_bias.reshape(-1) - log.info( - f"RMSE of atomic energy after linear regression is: {rmse_ae} eV/atom." - ) - elif bias_shift == "statistic": - statistic_bias = np.linalg.lstsq( - type_numbs, energy_ground_truth, rcond=None - )[0] - self.bias_atom_e[idx_type_map] = statistic_bias.reshape(-1) - else: - raise RuntimeError("Unknown bias_shift mode: " + bias_shift) - log.info( - "Change energy bias of {} from {} to {}.".format( - str(origin_type_map), str(old_bias), str(self.bias_atom_e[idx_type_map]) - ) - ) - - def enable_mixed_precision(self, mixed_prec: Optional[dict] = None) -> None: - """Reveive the mixed precision setting. - - Parameters - ---------- - mixed_prec - The mixed precision setting used in the embedding net - """ - self.mixed_prec = mixed_prec - self.fitting_precision = get_precision(mixed_prec["output_prec"]) - - def get_loss(self, loss: dict, lr) -> Loss: - """Get the loss function. - - Parameters - ---------- - loss : dict - The loss function parameters. - lr : LearningRateExp - The learning rate. - - Returns - ------- - Loss - The loss function. - """ - _loss_type = loss.pop("type", "ener") - loss["starter_learning_rate"] = lr.start_lr() - if _loss_type == "ener": - return EnerStdLoss(**loss) - elif _loss_type == "ener_dipole": - return EnerDipoleLoss(**loss) - elif _loss_type == "ener_spin": - return EnerSpinLoss(**loss, use_spin=self.spin.use_spin) - else: - raise RuntimeError("unknown loss type") diff --git a/deepmd/infer/deep_eval.py b/deepmd/infer/deep_eval.py index 5b9b024a7d..27c6af754f 100644 --- a/deepmd/infer/deep_eval.py +++ b/deepmd/infer/deep_eval.py @@ -8,9 +8,6 @@ Union, ) -# from deepmd.descriptor.descriptor import ( -# Descriptor, -# ) import numpy as np from deepmd.common import ( @@ -130,18 +127,18 @@ def __init__( for k, v in load_state_dict.items(): if k in self.model.state_dict(): if load_state_dict[k].dtype != self.model.state_dict()[k].dtype: - print( - f"convert {k}'s dtype from {load_state_dict[k].dtype} to {self.model.state_dict()[k].dtype}" - ) + # print( + # f"convert {k}'s dtype from {load_state_dict[k].dtype} to {self.model.state_dict()[k].dtype}" + # ) load_state_dict[k] = load_state_dict[k].astype( self.model.state_dict()[k].dtype ) if list(load_state_dict[k].shape) != list( self.model.state_dict()[k].shape ): - print( - f"convert {k}'s shape from {load_state_dict[k].shape} to {self.model.state_dict()[k].shape}" - ) + # print( + # f"convert {k}'s shape from {load_state_dict[k].shape} to {self.model.state_dict()[k].shape}" + # ) load_state_dict[k] = load_state_dict[k].reshape( self.model.state_dict()[k].shape ) @@ -295,7 +292,6 @@ def _load_graph( # # print(param.shape) # if param.shape == (2,): # print(constant_op.outputs[0], param) - return graph @staticmethod diff --git a/deepmd/infer/deep_pot.py b/deepmd/infer/deep_pot.py index efeaf55bc5..377c776320 100644 --- a/deepmd/infer/deep_pot.py +++ b/deepmd/infer/deep_pot.py @@ -86,24 +86,6 @@ def __init__( # fitting attrs "dfparam": "fitting.t_dfparam", "daparam": "fitting.t_daparam", - # # fitting attrs - # "t_dfparam": "fitting_attr/dfparam:0", - # "t_daparam": "fitting_attr/daparam:0", - # # model attrs - # "t_tmap": "model_attr/tmap:0", - # # inputs - # "t_coord": "t_coord:0", - # "t_type": "t_type:0", - # "t_natoms": "t_natoms:0", - # "t_box": "t_box:0", - # "t_mesh": "t_mesh:0", - # # add output tensors - # "t_energy": "o_energy:0", - # "t_force": "o_force:0", - # "t_virial": "o_virial:0", - # "t_ae": "o_atom_energy:0", - # "t_av": "o_atom_virial:0", - # "t_descriptor": "o_descriptor:0", }, ) DeepEval.__init__( @@ -115,39 +97,11 @@ def __init__( ) # # load optional tensors - # operations = [op.name for op in self.graph.get_operations()] - # # check if the graph has these operations: - # # if yes add them - # if "t_efield" in operations: - # # self._get_tensor("t_efield:0", "t_efield") - # if self._get_value("t_efield") is not None: - # self._get_value("t_efield", "t_efield") - # self.has_efield = True - # else: - # log.debug("Could not get tensor 't_efield'") - # self.t_efield = None self.has_efield = False - # if self._get_value("load/t_fparam") is not None: - # self.tensors.update({"t_fparam": "t_fparam"}) - # self.has_fparam = True - # else: - # log.debug("Could not get tensor 't_fparam'") - # self.t_fparam = None self.has_fparam = False - # if self._get_value("load/t_aparam") is not None: - # self.tensors.update({"t_aparam": "t_aparam"}) - # self.has_aparam = True - # else: - # log.debug("Could not get tensor 't_aparam'") - # self.t_aparam = None self.has_aparam = False - - # if self._get_value("load/spin_attr/ntypes_spin") is not None: - # self.tensors.update({"t_ntypes_spin": "spin_attr/ntypes_spin"}) - # self.has_spin = True - # else: self.ntypes_spin = 0 self.has_spin = False @@ -159,62 +113,20 @@ def __init__( if attr_name != "t_descriptor": raise - # self._run_default_sess() - # self.tmap = self.tmap.decode("UTF-8").split() self.ntypes = 2 self.rcut = 6.0 self.dfparam = 0 self.daparam = 0 - # self.t_tmap = self.model.t_tmap.split() self.t_tmap = ["O", "H"] # setup modifier try: - # t_modifier_type = self._get_tensor("modifier_attr/type:0") - # self.modifier_type = run_sess(self.sess, t_modifier_type).decode("UTF-8") self.modifier_type = self._get_value("modifier_attr.type") except (ValueError, KeyError): self.modifier_type = None self.modifier_type = None self.descriptor_type = "se_e2_a" - # try: - # t_jdata = self._get_tensor("train_attr/training_script") - # jdata = run_sess(self.sess, t_jdata).decode("UTF-8") - # import json - - # jdata = json.loads(jdata) - # self.descriptor_type = jdata["model"]["descriptor"]["type"] - # except (ValueError, KeyError): - # self.descriptor_type = None - - # if self.modifier_type == "dipole_charge": - # t_mdl_name = self._get_tensor("modifier_attr/mdl_name:0") - # t_mdl_charge_map = self._get_tensor("modifier_attr/mdl_charge_map:0") - # t_sys_charge_map = self._get_tensor("modifier_attr/sys_charge_map:0") - # t_ewald_h = self._get_tensor("modifier_attr/ewald_h:0") - # t_ewald_beta = self._get_tensor("modifier_attr/ewald_beta:0") - # [mdl_name, mdl_charge_map, sys_charge_map, ewald_h, ewald_beta] = run_sess( - # self.sess, - # [ - # t_mdl_name, - # t_mdl_charge_map, - # t_sys_charge_map, - # t_ewald_h, - # t_ewald_beta, - # ], - # ) - # mdl_name = mdl_name.decode("UTF-8") - # mdl_charge_map = [int(ii) for ii in mdl_charge_map.decode("UTF-8").split()] - # sys_charge_map = [int(ii) for ii in sys_charge_map.decode("UTF-8").split()] - # self.dm = DipoleChargeModifier( - # mdl_name, - # mdl_charge_map, - # sys_charge_map, - # ewald_h=ewald_h, - # ewald_beta=ewald_beta, - # ) - def _run_default_sess(self): if self.has_spin is True: [ @@ -395,7 +307,7 @@ def eval( mixed_type=mixed_type, ) - if self.modifier_type is not None: # 这里不会运行 + if self.modifier_type is not None: if atomic: raise RuntimeError("modifier does not support atomic modification") me, mf, mv = self.dm.eval(coords, cells, atom_types) @@ -487,32 +399,6 @@ def _prepare_feed_dict( assert natoms_vec[0] == natoms # evaluate - # feed_dict_test = {} - # feed_dict_test[self.t_natoms] = natoms_vec - # if mixed_type: - # feed_dict_test[self.t_type] = atom_types.reshape([-1]) - # else: - # feed_dict_test[self.t_type] = np.tile(atom_types, [nframes, 1]).reshape( - # [-1] - # ) - # feed_dict_test[self.t_coord] = np.reshape(coords, [-1]) - - # if len(self.t_box.shape) == 1: - # feed_dict_test[self.t_box] = np.reshape(cells, [-1]) - # elif len(self.t_box.shape) == 2: - # feed_dict_test[self.t_box] = cells - # else: - # raise RuntimeError - # if self.has_efield: - # feed_dict_test[self.t_efield] = np.reshape(efield, [-1]) - # if pbc: - # feed_dict_test[self.t_mesh] = make_default_mesh(cells) - # else: - # feed_dict_test[self.t_mesh] = np.array([], dtype=np.int32) - # if self.has_fparam: - # feed_dict_test[self.t_fparam] = np.reshape(fparam, [-1]) - # if self.has_aparam: - # feed_dict_test[self.t_aparam] = np.reshape(aparam, [-1]) return None, None, natoms_vec def _eval_inner( @@ -533,48 +419,6 @@ def _eval_inner( coords, cells, atom_types, fparam, aparam, efield, mixed_type=mixed_type ) - # t_out = [self.t_energy, self.t_force, self.t_virial] - # if atomic: - # t_out += [self.t_ae, self.t_av] - - # v_out = run_sess(self.sess, t_out, feed_dict=feed_dict_test) - # energy = v_out[0] - # force = v_out[1] - # virial = v_out[2] - # if atomic: - # ae = v_out[3] - # av = v_out[4] - - # if self.has_spin: - # ntypes_real = self.ntypes - self.ntypes_spin - # natoms_real = sum( - # [ - # np.count_nonzero(np.array(atom_types) == ii) - # for ii in range(ntypes_real) - # ] - # ) - # else: - # natoms_real = natoms - - # # reverse map of the outputs - # force = self.reverse_map(np.reshape(force, [nframes, -1, 3]), imap) - # if atomic: - # ae = self.reverse_map(np.reshape(ae, [nframes, -1, 1]), imap[:natoms_real]) - # av = self.reverse_map(np.reshape(av, [nframes, -1, 9]), imap) - - # energy = np.reshape(energy, [nframes, 1]) - # force = np.reshape(force, [nframes, natoms, 3]) - # virial = np.reshape(virial, [nframes, 9]) - # if atomic: - # ae = np.reshape(ae, [nframes, natoms_real, 1]) - # av = np.reshape(av, [nframes, natoms, 9]) - # return energy, force, virial, ae, av - # else: - # atom_types = np.array(atom_types, dtype=int).reshape([-1]) - # natoms = atom_types.size - # coords = np.reshape(np.array(coords), [-1, natoms * 3]) - # nframes = coords.shape[0] - eval_inputs = {} eval_inputs["coord"] = paddle.to_tensor( np.reshape(coords, [-1]), dtype="float64" @@ -586,10 +430,6 @@ def _eval_inner( natoms_vec, dtype="int32", place="cpu" ) eval_inputs["box"] = paddle.to_tensor(np.reshape(cells, [-1]), dtype="float64") - # print(eval_inputs['coord'].shape) # [2880] - # print(eval_inputs['type'].shape) # [960] - # print(eval_inputs['natoms_vec'].shape) # [4] - # print(eval_inputs['box'].shape) # [45] if self.has_fparam: eval_inputs["fparam"] = paddle.to_tensor( @@ -599,21 +439,18 @@ def _eval_inner( eval_inputs["aparam"] = paddle.to_tensor( np.reshape(aparam, [-1], dtype="float64") ) - # if se.pbc: eval_inputs["default_mesh"] = paddle.to_tensor( make_default_mesh(cells), dtype="int32" ) - # else: - # eval_inputs['default_mesh'] = paddle.to_tensor(np.array([], dtype = np.int32)) if hasattr(self, "st_model"): # NOTE: 使用静态图模型推理 eval_outputs = self.st_model( - eval_inputs["coord"], # [2880] paddle.float64 - eval_inputs["type"], # [960] paddle.int32 - eval_inputs["natoms_vec"], # [4] paddle.int32 - eval_inputs["box"], # [45] paddle.float64 - eval_inputs["default_mesh"], # [6] paddle.int32 + eval_inputs["coord"], + eval_inputs["type"], + eval_inputs["natoms_vec"], + eval_inputs["box"], + eval_inputs["default_mesh"], ) eval_outputs = { "atom_ener": eval_outputs[0], @@ -627,11 +464,11 @@ def _eval_inner( else: # NOTE: 使用动态图模型推理 eval_outputs = self.model( - eval_inputs["coord"], # [2880] paddle.float64 - eval_inputs["type"], # [960] paddle.int32 - eval_inputs["natoms_vec"], # [4] paddle.int32 - eval_inputs["box"], # [45] paddle.float64 - eval_inputs["default_mesh"], # [6] paddle.int32 + eval_inputs["coord"], + eval_inputs["type"], + eval_inputs["natoms_vec"], + eval_inputs["box"], + eval_inputs["default_mesh"], eval_inputs, suffix="", reuse=False, diff --git a/deepmd/loss/ener.py b/deepmd/loss/ener.py index b0d453919d..d11177ee3a 100644 --- a/deepmd/loss/ener.py +++ b/deepmd/loss/ener.py @@ -195,11 +195,6 @@ def compute_loss(self, learning_rate, natoms, model_dict, label_dict, suffix): l2_loss = 0 more_loss = {} - # print(self.has_e) - # print(self.has_f) - # print(self.has_v) - # print(self.has_ae) - # print(self.has_pf) if self.has_e: # true l2_loss += atom_norm_ener * (pref_e * l2_ener_loss) more_loss["l2_ener_loss"] = l2_ener_loss @@ -216,24 +211,6 @@ def compute_loss(self, learning_rate, natoms, model_dict, label_dict, suffix): l2_loss += pref_pf * l2_pref_force_loss more_loss["l2_pref_force_loss"] = l2_pref_force_loss - # only used when tensorboard was set as true - # self.l2_loss_summary = paddle.summary.scalar("l2_loss_" + suffix, paddle.sqrt(l2_loss)) - # if self.has_e: - # self.l2_loss_ener_summary = paddle.summary.scalar( - # "l2_ener_loss_" + suffix, - # global_cvt_2_tf_float(paddle.sqrt(l2_ener_loss)) - # / global_cvt_2_tf_float(natoms[0]), - # ) - # if self.has_f: - # self.l2_loss_force_summary = paddle.summary.scalar( - # "l2_force_loss_" + suffix, paddle.sqrt(l2_force_loss) - # ) - # if self.has_v: - # self.l2_loss_virial_summary = paddle.summary.scalar( - # "l2_virial_loss_" + suffix, - # paddle.sqrt(l2_virial_loss) / global_cvt_2_tf_float(natoms[0]), - # ) - self.l2_l = l2_loss self.l2_more = more_loss return l2_loss, more_loss @@ -275,7 +252,6 @@ def eval(self, model, batch_data, natoms): reuse=False, ) l2_l, l2_more = self.compute_loss( - # 0.0, natoms, model_dict, batch_data 0.0, model_inputs["natoms_vec"], model_pred, diff --git a/deepmd/model/ener.py b/deepmd/model/ener.py index 10f26f597e..21c5ec6ee0 100644 --- a/deepmd/model/ener.py +++ b/deepmd/model/ener.py @@ -72,9 +72,8 @@ def __init__( sw_rmax: Optional[float] = None, spin: Optional[Spin] = None, ) -> None: - super().__init__() - # super(EnerModel, self).__init__(name_scope="EnerModel") """Constructor.""" + super().__init__() # descriptor self.descrpt = descrpt self.rcut = self.descrpt.get_rcut() @@ -102,9 +101,7 @@ def __init__( else: self.srtab = None - # self.type_map = " ".join(self.type_map) self.t_tmap = " ".join(self.type_map) - print(self.t_tmap) self.t_mt = self.model_type self.t_ver = str(MODEL_VERSION) # NOTE: workaround for string type is not supported in Paddle @@ -136,7 +133,6 @@ def data_stat(self, data): m_all_stat, protection=self.data_stat_protect, mixed_type=data.mixed_type ) self._compute_output_stat(all_stat, mixed_type=data.mixed_type) - # self.bias_atom_e = data.compute_energy_shift(self.rcond) def _compute_input_stat(self, all_stat, protection=1e-2, mixed_type=False): if mixed_type: @@ -180,51 +176,11 @@ def forward( suffix="", reuse=None, ): - # print(__file__, coord_.shape) - # print(__file__, atype_.shape) - # print(__file__, natoms.shape) - # print(__file__, box.shape) - # print(__file__, mesh.shape) - # for k, v in input_dict.items(): - # print(f"{__file__} {k} {v.shape}") - if input_dict is None: input_dict = {} - # if self.srtab is not None: - # tab_info, tab_data = self.srtab.get() - # self.tab_info = tf.get_variable( - # "t_tab_info", - # tab_info.shape, - # dtype=tf.float64, - # trainable=False, - # initializer=tf.constant_initializer(tab_info, dtype=tf.float64), - # ) - # self.tab_data = tf.get_variable( - # "t_tab_data", - # tab_data.shape, - # dtype=tf.float64, - # trainable=False, - # initializer=tf.constant_initializer(tab_data, dtype=tf.float64), - # ) coord = paddle.reshape(coord_, [-1, natoms[1] * 3]) atype = paddle.reshape(atype_, [-1, natoms[1]]) - # input_dict["nframes"] = paddle.shape(coord)[0] # 推理模型导出的时候注释掉这里,否则会报错 - - # type embedding if any - # if self.typeebd is not None: - # type_embedding = self.typeebd.build( - # self.ntypes, - # reuse=reuse, - # suffix=suffix, - # ) - # input_dict["type_embedding"] = type_embedding - # spin if any - # if self.spin is not None: - # type_spin = self.spin.build( - # reuse=reuse, - # suffix=suffix, - # ) input_dict["atype"] = atype_ dout = self.descrpt( @@ -234,56 +190,14 @@ def forward( box, mesh, input_dict, - # frz_model=frz_model, - # ckpt_meta=ckpt_meta, suffix=suffix, reuse=reuse, - ) # [1, all_atom, M1*M2] - # self.dout = dout - - # if self.srtab is not None: - # nlist, rij, sel_a, sel_r = self.descrpt.get_nlist() - # nnei_a = np.cumsum(sel_a)[-1] - # nnei_r = np.cumsum(sel_r)[-1] + ) atom_ener = self.fitting(dout, natoms, input_dict, reuse=reuse, suffix=suffix) self.atom_ener = atom_ener - # if self.srtab is not None: - # sw_lambda, sw_deriv = op_module.soft_min_switch( - # atype, - # rij, - # nlist, - # natoms, - # sel_a=sel_a, - # sel_r=sel_r, - # alpha=self.smin_alpha, - # rmin=self.sw_rmin, - # rmax=self.sw_rmax, - # ) - # inv_sw_lambda = 1.0 - sw_lambda - # # NOTICE: - # # atom energy is not scaled, - # # force and virial are scaled - # tab_atom_ener, tab_force, tab_atom_virial = op_module.pair_tab( - # self.tab_info, - # self.tab_data, - # atype, - # rij, - # nlist, - # natoms, - # sw_lambda, - # sel_a=sel_a, - # sel_r=sel_r, - # ) - # energy_diff = tab_atom_ener - tf.reshape(atom_ener, [-1, natoms[0]]) - # tab_atom_ener = tf.reshape(sw_lambda, [-1]) * tf.reshape( - # tab_atom_ener, [-1] - # ) - # atom_ener = tf.reshape(inv_sw_lambda, [-1]) * atom_ener - # energy_raw = tab_atom_ener + atom_ener - # else: - energy_raw = atom_ener # [1, all_atoms] + energy_raw = atom_ener nloc_atom = ( natoms[0] @@ -298,15 +212,9 @@ def forward( force, virial, atom_virial = self.descrpt.prod_force_virial(atom_ener, natoms) # force: [1, all_atoms*3] # virial: [1, 9] - # force: [1, all_atoms*9] - - # if self.srtab is not None: - # sw_force = op_module.soft_min_force( - # energy_diff, sw_deriv, nlist, natoms, n_a_sel=nnei_a, n_r_sel=nnei_r - # ) - # force = force + sw_force + tab_force + # atom_virial: [1, all_atoms*9] - force = paddle.reshape(force, [-1, 3 * natoms[1]]) # [1, all_atoms*3] + force = paddle.reshape(force, [-1, 3 * natoms[1]]) if self.spin is not None: # split and concatenate force to compute local atom force and magnetic force judge = paddle.equal(natoms[0], natoms[1]) @@ -318,36 +226,19 @@ def forward( force = paddle.reshape(force, [-1, 3 * natoms[1]], name="o_force" + suffix) - # if self.srtab is not None: - # sw_virial, sw_atom_virial = op_module.soft_min_virial( - # energy_diff, - # sw_deriv, - # rij, - # nlist, - # natoms, - # n_a_sel=nnei_a, - # n_r_sel=nnei_r, - # ) - # atom_virial = atom_virial + sw_atom_virial + tab_atom_virial - # virial = ( - # virial - # + sw_virial - # + tf.sum(tf.reshape(tab_atom_virial, [-1, natoms[1], 9]), axis=1) - # ) - virial = paddle.reshape(virial, [-1, 9], name="o_virial" + suffix) atom_virial = paddle.reshape( atom_virial, [-1, 9 * natoms[1]], name="o_atom_virial" + suffix ) model_dict = {} - model_dict["energy"] = energy # [batch_size] - model_dict["force"] = force # [batch_size, 576] - model_dict["virial"] = virial # [batch_size, 9] - model_dict["atom_ener"] = energy_raw # [batch_size, 192] - model_dict["atom_virial"] = atom_virial # [batch_size, 1728] - model_dict["coord"] = coord # [batch_size, 576] - model_dict["atype"] = atype # [batch_size, 192] + model_dict["energy"] = energy + model_dict["force"] = force + model_dict["virial"] = virial + model_dict["atom_ener"] = energy_raw + model_dict["atom_virial"] = atom_virial + model_dict["coord"] = coord + model_dict["atype"] = atype return model_dict def init_variables( diff --git a/deepmd/train/trainer.py b/deepmd/train/trainer.py index 44ab53fabe..cc46dc9801 100644 --- a/deepmd/train/trainer.py +++ b/deepmd/train/trainer.py @@ -848,39 +848,6 @@ def train(self, train_data=None, valid_data=None, stop_batch: int = 10): ) ) - # prf_options = None - # prf_run_metadata = None - # if self.profiling: - # prf_options = tf.RunOptions(trace_level=tf.RunOptions.FULL_TRACE) - # prf_run_metadata = tf.RunMetadata() - - # set tensorboard execution environment - # if self.tensorboard: - # summary_merged_op = tf.summary.merge_all() - # # Remove TB old logging directory from previous run - # try: - # shutil.rmtree(self.tensorboard_log_dir) - # except FileNotFoundError: - # pass # directory does not exist, this is OK - # except Exception as e: - # # general error when removing directory, warn user - # log.exception( - # f"Could not remove old tensorboard logging directory: " - # f"{self.tensorboard_log_dir}. Error: {e}" - # ) - # else: - # log.debug("Removing old tensorboard log directory.") - # tb_train_writer = tf.summary.FileWriter( - # self.tensorboard_log_dir + "/train", self.sess.graph - # ) - # tb_valid_writer = tf.summary.FileWriter(self.tensorboard_log_dir + "/test") - # else: - # tb_train_writer = None - # tb_valid_writer = None - # if self.enable_profiler: - # # https://www.tensorflow.org/guide/profiler - # tfv2.profiler.experimental.start(self.tensorboard_log_dir) - train_time = 0 total_train_time = 0.0 wall_time_tic = time.time() @@ -898,35 +865,6 @@ def train(self, train_data=None, valid_data=None, stop_batch: int = 10): while cur_batch < stop_batch: train_batch = datasetloader.get_data_dict() - # first round validation: - # if is_first_step: - # if not self.multi_task_mode: - # train_batch = train_data.get_batch() - # # batch_train_op = self.train_op - # else: - # fitting_idx = dp_random.choice( - # np.arange(self.nfitting), p=np.array(self.fitting_prob) - # ) - # fitting_key = self.fitting_key_list[fitting_idx] - # train_batch = train_data[fitting_key].get_batch() - # # batch_train_op = self.train_op[fitting_key] - # else: - # train_batch = next_datasetloader.get_data_dict(next_train_batch_list) - # # batch_train_op = next_batch_train_op - # fitting_key = next_fitting_key - # for next round - # if not self.multi_task_mode: - # next_datasetloader = datasetloader - # next_batch_train_op = self.train_op - # next_train_batch_op = data_op - # else: - # fitting_idx = dp_random.choice( - # np.arange(self.nfitting), p=np.array(self.fitting_prob) - # ) - # next_fitting_key = self.fitting_key_list[fitting_idx] - # next_datasetloader = datasetloader[next_fitting_key] - # next_batch_train_op = self.train_op[fitting_key] - # next_train_batch_op = data_op[fitting_key] if self.display_in_training and is_first_step: if self.run_opt.is_chief: @@ -972,18 +910,9 @@ def train(self, train_data=None, valid_data=None, stop_batch: int = 10): if self.timing_in_training: tic = time.time() - # train_feed_dict = self.get_feed_dict(train_batch, is_training=True) # use tensorboard to visualize the training of deepmd-kit # it will takes some extra execution time to generate the tensorboard data if self.tensorboard and (cur_batch % self.tensorboard_freq == 0): - # summary, _, next_train_batch_list = run_sess( - # self.sess, - # [summary_merged_op, batch_train_op, next_train_batch_op], - # feed_dict=train_feed_dict, - # options=prf_options, - # run_metadata=prf_run_metadata, - # ) - # tb_train_writer.add_summary(summary, cur_batch) model_pred = self.model( paddle.to_tensor(train_batch["coord"], "float32"), paddle.to_tensor(train_batch["type"], "int32"), @@ -995,30 +924,6 @@ def train(self, train_data=None, valid_data=None, stop_batch: int = 10): reuse=False, ) else: - # for k, v in train_feed_dict.items(): - # print(f"{k} {v.shape if hasattr(v, 'shape') else v}") - """ - find_box:0", dtype=float32) () - find_coord:0", dtype=float32) () - find_numb_copy:0", dtype=float32) () - find_energy:0", dtype=float32) () - find_force:0", dtype=float32) () - find_virial:0", dtype=float32) () - find_atom_ener:0", dtype=float32) () - find_atom_pref:0", dtype=float32) () - box:0", shape=(?,), dtype=float64) (9,) - coord:0", shape=(?,), dtype=float64) (576,) - numb_copy:0", shape=(?,), dtype=float64) (1,) - energy:0", shape=(?,), dtype=float64) (1,) - force:0", shape=(?,), dtype=float64) (576,) - virial:0", shape=(?,), dtype=float64) (9,) - atom_ener:0", shape=(?,), dtype=float64) (192,) - atom_pref:0", shape=(?,), dtype=float64) (576,) - natoms:0", shape=(4,), dtype=int32) (4,) - mesh:0", shape=(?,), dtype=int32) (6,) - type:0", shape=(?,), dtype=int32) (192,) - aceholder:0", dtype=bool) True - """ model_inputs = {} for kk in train_batch.keys(): if kk == "find_type" or kk == "type": @@ -1054,12 +959,6 @@ def train(self, train_data=None, valid_data=None, stop_batch: int = 10): reuse=False, ) - # loss = ( - # model_pred["force"].sum() - # + model_pred["virial"].sum() - # + model_pred["energy"].sum() - # + model_pred["atom_ener"].sum() - # ) # print(f"{self.cur_batch} {self.learning_rate.get_lr():.10f}") l2_l, l2_more = self.loss.compute_loss( self.learning_rate.get_lr(), @@ -1074,34 +973,6 @@ def train(self, train_data=None, valid_data=None, stop_batch: int = 10): self.optimizer.step() self.global_step += 1 - # _, next_train_batch_list = run_sess( - # self.sess, - # [batch_train_op, next_train_batch_op], - # feed_dict=train_feed_dict, - # options=prf_options, - # run_metadata=prf_run_metadata, - # ) - """next_train_batch_list - find_box (): none - box (1, 9): (1, 9) - find_coord (): none - coord (1, 576): (1, 576) - find_numb_copy (): none - numb_copy (1, 1): (1, 1) - find_energy (): none - energy (1, 1): (1, 1) - find_force (): none - force (1, 576): (1, 576) - find_virial (): none - virial (1, 9): (1, 9) - find_atom_ener (): none - atom_ener (1, 192): (1, 192) - find_atom_pref (): none - atom_pref (1, 576): (1, 576) - type (1, 192): (1, 192) - natoms_vec (4,): (4,) - default_mesh (6,): (6,) - """ if self.timing_in_training: toc = time.time() if self.timing_in_training: @@ -1183,42 +1054,7 @@ def train(self, train_data=None, valid_data=None, stop_batch: int = 10): total_train_time / (stop_batch // self.disp_freq * self.disp_freq), ) - # if self.profiling and self.run_opt.is_chief: - # fetched_timeline = timeline.Timeline(prf_run_metadata.step_stats) - # chrome_trace = fetched_timeline.generate_chrome_trace_format() - # with open(self.profiling_file, "w") as f: - # f.write(chrome_trace) - # if self.enable_profiler and self.run_opt.is_chief: - # tfv2.profiler.experimental.stop() - def save_checkpoint(self, cur_batch: int): - # try: - # ckpt_prefix = self.saver.save( - # self.sess, - # os.path.join(os.getcwd(), self.save_ckpt), - # global_step=cur_batch, - # ) - # except google.protobuf.message.DecodeError as e: - # raise GraphTooLargeError( - # "The graph size exceeds 2 GB, the hard limitation of protobuf." - # " Then a DecodeError was raised by protobuf. You should " - # "reduce the size of your model." - # ) from e - # # make symlinks from prefix with step to that without step to break nothing - # # get all checkpoint files - # original_files = glob.glob(ckpt_prefix + ".*") - # for ori_ff in original_files: - # new_ff = self.save_ckpt + ori_ff[len(ckpt_prefix) :] - # try: - # # remove old one - # os.remove(new_ff) - # except OSError: - # pass - # if platform.system() != "Windows": - # # by default one does not have access to create symlink on Windows - # os.symlink(ori_ff, new_ff) - # else: - # shutil.copyfile(ori_ff, new_ff) paddle.save(self.model.state_dict(), f"Model_{cur_batch}.pdparams") paddle.save(self.optimizer.state_dict(), f"Optimier_{cur_batch}.pdopt") log.info("saved checkpoint %s" % self.save_ckpt) diff --git a/deepmd/utils/batch_size.py b/deepmd/utils/batch_size.py index 53ad84e5d8..f393618cb1 100644 --- a/deepmd/utils/batch_size.py +++ b/deepmd/utils/batch_size.py @@ -196,7 +196,6 @@ def execute_with_batch_size( for rr in result: rr.reshape((n_batch, -1)) results.append(result) - # print(__file__, "here") r = tuple([np.concatenate(r, axis=0) for r in zip(*results)]) if len(r) == 1: diff --git a/deepmd/utils/learning_rate.py b/deepmd/utils/learning_rate.py index d17e99dd1b..0f1ccdf5cf 100644 --- a/deepmd/utils/learning_rate.py +++ b/deepmd/utils/learning_rate.py @@ -92,16 +92,10 @@ def build( np.log(self.stop_lr_ / self.start_lr_) / (stop_step / self.decay_steps_) ) - # print("decay_steps_ = ", self.decay_steps_) return lr.ExponentialDecay( self.start_lr_, gamma=self.decay_rate_, ) - # return paddle.optimizer.lr.ExponentialDecay( - # learning_rate=self.start_lr_, - # gamma=self.decay_rate_ ** (1 / self.decay_steps_), - # # verbose=True, - # ) def start_lr(self) -> float: """Get the start lr.""" diff --git a/deepmd/utils/neighbor_stat.py b/deepmd/utils/neighbor_stat.py index fe5fe04bc1..966645996a 100644 --- a/deepmd/utils/neighbor_stat.py +++ b/deepmd/utils/neighbor_stat.py @@ -6,6 +6,7 @@ ) import numpy as np +import paddle from deepmd.env import ( op_module, @@ -42,44 +43,6 @@ def __init__( self.rcut = rcut self.ntypes = ntypes self.one_type = one_type - # sub_graph = tf.Graph() - - # def builder(): - # place_holders = {} - # for ii in ["coord", "box"]: - # place_holders[ii] = tf.placeholder( - # GLOBAL_NP_FLOAT_PRECISION, [None, None], name="t_" + ii - # ) - # place_holders["type"] = tf.placeholder( - # tf.int32, [None, None], name="t_type" - # ) - # place_holders["natoms_vec"] = tf.placeholder( - # tf.int32, [self.ntypes + 2], name="t_natoms" - # ) - # place_holders["default_mesh"] = tf.placeholder( - # tf.int32, [None], name="t_mesh" - # ) - # t_type = place_holders["type"] - # t_natoms = place_holders["natoms_vec"] - # if self.one_type: - # # all types = 0, natoms_vec = [natoms, natoms, natoms] - # t_type = tf.clip_by_value(t_type, -1, 0) - # t_natoms = tf.tile(t_natoms[0:1], [3]) - # _max_nbor_size, _min_nbor_dist = op_module.neighbor_stat( # 这里只计算一次 - # place_holders["coord"], - # t_type, - # t_natoms, - # place_holders["box"], - # place_holders["default_mesh"], - # rcut=self.rcut, - # ) - # place_holders["dir"] = tf.placeholder(tf.string) - # return place_holders, (_max_nbor_size, _min_nbor_dist, place_holders["dir"]) - - # with sub_graph.as_default(): - # self.p = ParallelOp(builder, config=default_tf_session_config) - - # self.sub_sess = tf.Session(graph=sub_graph, config=default_tf_session_config) def get_stat(self, data: DeepmdDataSystem) -> Tuple[float, List[int]]: """Get the data statistics of the training data, including nearest nbor distance between atoms, max nbor size of atoms. @@ -101,37 +64,6 @@ def get_stat(self, data: DeepmdDataSystem) -> Tuple[float, List[int]]: if not self.one_type: self.max_nbor_size *= self.ntypes - # def feed(): - # for ii in range(len(data.system_dirs)): - # for jj in data.data_systems[ii].dirs: - # data_set = data.data_systems[ii]._load_set(jj) - # for kk in range(np.array(data_set["type"]).shape[0]): - # ret = { - # "coord": np.array(data_set["coord"])[kk].reshape( - # [-1, data.natoms[ii] * 3] - # ), # (1, 576) - # "type": np.array(data_set["type"])[kk].reshape( - # [-1, data.natoms[ii]] - # ), # (1, 192) - # "natoms_vec": np.array(data.natoms_vec[ii]), # (4,) - # "box": np.array(data_set["box"])[kk].reshape([-1, 9]), # (1, 9) - # "default_mesh": np.array(data.default_mesh[ii]), # (6,) - # "dir": str(jj), # ../data/data_0/set.xxx - # } - # print(str(jj)) - # print("coord", ret["coord"].shape, ret["coord"].dtype) - # print("type", ret["type"].shape, ret["type"].dtype) - # print("natoms_vec", ret["natoms_vec"].shape, ret["natoms_vec"].dtype) - # print("box", ret["box"].shape, ret["box"].dtype) - # print("default_mesh", ret["default_mesh"].shape, ret["default_mesh"].dtype) - # # np.save("/workspace/hesensen/deepmd-kit/cuda_ext/coord.npy", ret["coord"]) - # # np.save("/workspace/hesensen/deepmd-kit/cuda_ext/type.npy", ret["type"]) - # # np.save("/workspace/hesensen/deepmd-kit/cuda_ext/natoms_vec.npy", ret["natoms_vec"]) - # # np.save("/workspace/hesensen/deepmd-kit/cuda_ext/box.npy", ret["box"]) - # # np.save("/workspace/hesensen/deepmd-kit/cuda_ext/default_mesh.npy", ret["default_mesh"]) - # yield ret - import paddle - for ii in range(len(data.system_dirs)): for jj in data.data_systems[ii].dirs: data_set = data.data_systems[ii]._load_set(jj) @@ -139,29 +71,25 @@ def get_stat(self, data: DeepmdDataSystem) -> Tuple[float, List[int]]: coord = np.array(data_set["coord"])[kk].reshape( [-1, data.natoms[ii] * 3] ) - coord = paddle.to_tensor( - coord, dtype="float32", place="cpu" - ) # [1, 576] + coord = paddle.to_tensor(coord, dtype="float32", place="cpu") _type = np.array(data_set["type"])[kk].reshape( [-1, data.natoms[ii]] ) - _type = paddle.to_tensor( - _type, dtype="int32", place="cpu" - ) # [1, 192] + _type = paddle.to_tensor(_type, dtype="int32", place="cpu") natoms_vec = np.array(data.natoms_vec[ii]) natoms_vec = paddle.to_tensor( natoms_vec, dtype="int64", place="cpu" - ) # [4] + ) box = np.array(data_set["box"])[kk].reshape([-1, 9]) - box = paddle.to_tensor(box, dtype="float32", place="cpu") # [1, 9] + box = paddle.to_tensor(box, dtype="float32", place="cpu") default_mesh = np.array(data.default_mesh[ii]) default_mesh = paddle.to_tensor( default_mesh, dtype="int32", place="cpu" - ) # [6] + ) rcut = self.rcut mn, dt = op_module.neighbor_stat( @@ -192,30 +120,6 @@ def get_stat(self, data: DeepmdDataSystem) -> Tuple[float, List[int]]: var = paddle.max(mn, axis=0).numpy() self.max_nbor_size = np.maximum(var, self.max_nbor_size) - # for mn, dt, jj in self.p.generate(self.sub_sess, feed()): # _max_nbor_size, _min_nbor_dist, dir - # # print(mn.shape, dt.shape, jj) - # # np.save("/workspace/hesensen/deepmd-kit/cuda_ext/max_nbor_size.npy", mn) - # # np.save("/workspace/hesensen/deepmd-kit/cuda_ext/min_nbor_dist.npy", dt) - # if dt.size != 0: - # dt = np.min(dt) - # else: - # dt = self.rcut - # log.warning( - # "Atoms with no neighbors found in %s. Please make sure it's what you expected." - # % jj - # ) - # if dt < self.min_nbor_dist: - # if math.isclose(dt, 0.0, rel_tol=1e-6): - # # it's unexpected that the distance between two atoms is zero - # # zero distance will cause nan (#874) - # raise RuntimeError( - # "Some atoms are overlapping in %s. Please check your" - # " training data to remove duplicated atoms." % jj - # ) - # self.min_nbor_dist = dt - # var = np.max(mn, axis=0) - # self.max_nbor_size = np.maximum(var, self.max_nbor_size) - log.info("training data with min nbor dist: " + str(self.min_nbor_dist)) log.info("training data with max nbor size: " + str(self.max_nbor_size)) return self.min_nbor_dist, self.max_nbor_size diff --git a/deepmd/utils/network.py b/deepmd/utils/network.py index c25c5f0589..58e6378215 100644 --- a/deepmd/utils/network.py +++ b/deepmd/utils/network.py @@ -489,27 +489,3 @@ def forward(self, xx): xx = hidden return xx - - # == debug code below ==# - # hidden = nn.functional.tanh( - # nn.functional.linear(xx, self.weight[0], self.bias[0]) - # ).reshape( - # [-1, 25] - # ) # 1 - # xx = hidden # 7 - - # hidden = nn.functional.tanh( - # nn.functional.linear(xx, self.weight[1], self.bias[1]) - # ).reshape( - # [-1, 50] - # ) # 1 - # xx = paddle.concat([xx, xx], axis=1) + hidden # 6 - - # hidden = nn.functional.tanh( - # nn.functional.linear(xx, self.weight[2], self.bias[2]) - # ).reshape( - # [-1, 100] - # ) # 1 - # xx = paddle.concat([xx, xx], axis=1) + hidden # 6 - - # return xx diff --git a/examples/water/lmp/Model_1000000_with_buffer.pdiparams b/examples/water/lmp/Model_1000000_with_buffer.pdiparams deleted file mode 120000 index 569305c94f..0000000000 --- a/examples/water/lmp/Model_1000000_with_buffer.pdiparams +++ /dev/null @@ -1 +0,0 @@ -/workspace/hesensen/deepmd_backend/deepmd-kit/examples/water/se_e2_a/Model_1000000_with_buffer.pdiparams \ No newline at end of file diff --git a/examples/water/lmp/Model_1000000_with_buffer.pdmodel b/examples/water/lmp/Model_1000000_with_buffer.pdmodel deleted file mode 120000 index 175c7e7080..0000000000 --- a/examples/water/lmp/Model_1000000_with_buffer.pdmodel +++ /dev/null @@ -1 +0,0 @@ -/workspace/hesensen/deepmd_backend/deepmd-kit/examples/water/se_e2_a/Model_1000000_with_buffer.pdmodel \ No newline at end of file