From 4c618d7b01fce48ee0b24c7348fc3bddeccb2fad Mon Sep 17 00:00:00 2001 From: HydrogenSulfate <490868991@qq.com> Date: Sat, 13 Jan 2024 04:22:41 +0800 Subject: [PATCH] [Paddle Backend] Add zinc protein (#3135) Add zinc protein example and fix typos in zinc_protein config. --------- Signed-off-by: HydrogenSulfate <490868991@qq.com> Co-authored-by: xusuyong <2209245477@qq.com> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> --- deepmd/descriptor/se_a_mask.py | 573 ++++++++++++++---- deepmd/fit/ener.py | 11 +- deepmd/train/trainer.py | 11 +- examples/zinc_protein/zinc_se_a_mask.json | 6 +- source/lib/paddle_src/custom_op_install.py | 3 + source/lib/paddle_src/custom_op_test.py | 103 ++++ .../paddle_src/paddle_descrpt_se_a_mask.cc | 342 +++++++++++ .../paddle_src/paddle_prod_force_se_a_mask.cc | 157 +++++ .../paddle_prod_force_se_a_mask_grad.cc | 149 +++++ 9 files changed, 1241 insertions(+), 114 deletions(-) create mode 100644 source/lib/paddle_src/paddle_descrpt_se_a_mask.cc create mode 100644 source/lib/paddle_src/paddle_prod_force_se_a_mask.cc create mode 100644 source/lib/paddle_src/paddle_prod_force_se_a_mask_grad.cc diff --git a/deepmd/descriptor/se_a_mask.py b/deepmd/descriptor/se_a_mask.py index b47389d7e7..b2c1a91d0e 100644 --- a/deepmd/descriptor/se_a_mask.py +++ b/deepmd/descriptor/se_a_mask.py @@ -14,26 +14,17 @@ get_precision, ) from deepmd.env import ( - GLOBAL_NP_FLOAT_PRECISION, - GLOBAL_TF_FLOAT_PRECISION, - default_tf_session_config, + GLOBAL_PD_FLOAT_PRECISION, op_module, - tf, + paddle, ) from deepmd.utils.network import ( + EmbeddingNet, embedding_net_rand_seed_shift, ) -from .descriptor import ( - Descriptor, -) -from .se_a import ( - DescrptSeA, -) - -@Descriptor.register("se_a_mask") -class DescrptSeAMask(DescrptSeA): +class DescrptSeAMask(paddle.nn.Layer): r"""DeepPot-SE constructed from all information (both angular and radial) of atomic configurations. The embedding takes the distance between atoms as input. @@ -127,6 +118,7 @@ def __init__( precision: str = "default", uniform_seed: bool = False, ) -> None: + super().__init__() """Constructor.""" self.sel_a = sel self.total_atom_num = np.cumsum(self.sel_a)[-1] @@ -166,48 +158,34 @@ def __init__( self.compress = False self.embedding_net_variables = None self.mixed_prec = None - self.place_holders = {} + # self.place_holders = {} nei_type = np.array([]) for ii in range(self.ntypes): nei_type = np.append(nei_type, ii * np.ones(self.sel_a[ii])) # like a mask - self.nei_type = tf.constant(nei_type, dtype=tf.int32) - - avg_zero = np.zeros([self.ntypes, self.ndescrpt]).astype( - GLOBAL_NP_FLOAT_PRECISION + # self.nei_type = tf.constant(nei_type, dtype=tf.int32) + # self.nei_type = paddle.to_tensor(nei_type, dtype="int32") + self.register_buffer( + "buffer_ntypes_spin", paddle.to_tensor(nei_type, dtype="int32") ) - std_ones = np.ones([self.ntypes, self.ndescrpt]).astype( - GLOBAL_NP_FLOAT_PRECISION - ) - sub_graph = tf.Graph() - with sub_graph.as_default(): - name_pfx = "d_sea_mask_" - for ii in ["coord", "box"]: - self.place_holders[ii] = tf.placeholder( - GLOBAL_NP_FLOAT_PRECISION, [None, None], name=name_pfx + "t_" + ii + + nets = [] + for type_input in range(self.ntypes): + layer = [] + for type_i in range(self.ntypes): + layer.append( + EmbeddingNet( + self.filter_neuron, + self.filter_precision, + self.filter_activation_fn, + self.filter_resnet_dt, + self.seed, + self.trainable, + name="filter_type_" + str(type_input) + str(type_i), + ) ) - self.place_holders["type"] = tf.placeholder( - tf.int32, [None, None], name=name_pfx + "t_type" - ) - self.place_holders["mask"] = tf.placeholder( - tf.int32, [None, None], name=name_pfx + "t_aparam" - ) # named aparam for inference compatibility in c++ interface. - - self.place_holders["natoms_vec"] = tf.placeholder( - tf.int32, [self.ntypes + 2], name=name_pfx + "t_natoms" - ) # Not used in se_a_mask. For compatibility with c++ interface. - self.place_holders["default_mesh"] = tf.placeholder( - tf.int32, [None], name=name_pfx + "t_mesh" - ) # Not used in se_a_mask. For compatibility with c++ interface. - - self.stat_descrpt, descrpt_deriv, rij, nlist = op_module.descrpt_se_a_mask( - self.place_holders["coord"], - self.place_holders["type"], - self.place_holders["mask"], - self.place_holders["box"], - self.place_holders["natoms_vec"], - self.place_holders["default_mesh"], - ) - self.sub_sess = tf.Session(graph=sub_graph, config=default_tf_session_config) + nets.append(paddle.nn.LayerList(layer)) + + self.embedding_nets = paddle.nn.LayerList(nets) self.original_sel = None def get_rcut(self) -> float: @@ -215,6 +193,18 @@ def get_rcut(self) -> float: warnings.warn("The cutoff radius is not used for this descriptor") return -1.0 + def get_ntypes(self) -> int: + """Returns the number of atom types.""" + return self.ntypes + + def get_dim_out(self) -> int: + """Returns the output dimension of this descriptor.""" + return self.filter_neuron[-1] * self.n_axis_neuron + + def get_dim_rot_mat_1(self) -> int: + """Returns the first dimension of the rotation matrix. The rotation is of shape dim_1 x 3.""" + return self.filter_neuron[-1] + def compute_input_stats( self, data_coord: list, @@ -249,17 +239,17 @@ def compute_input_stats( self.davg = None self.dstd = None - def build( + def forward( self, - coord_: tf.Tensor, - atype_: tf.Tensor, - natoms: tf.Tensor, - box_: tf.Tensor, - mesh: tf.Tensor, + coord_: paddle.Tensor, + atype_: paddle.Tensor, + natoms: paddle.Tensor, + box_: paddle.Tensor, + mesh: paddle.Tensor, input_dict: Dict[str, Any], reuse: Optional[bool] = None, suffix: str = "", - ) -> tf.Tensor: + ) -> paddle.Tensor: """Build the computational graph for the descriptor. Parameters @@ -299,40 +289,49 @@ def build( aparam[:, :] is the real/virtual sign for each atom. """ aparam = input_dict["aparam"] - self.mask = tf.cast(aparam, tf.int32) - self.mask = tf.reshape(self.mask, [-1, natoms[1]]) - - with tf.variable_scope("descrpt_attr" + suffix, reuse=reuse): - if davg is None: - davg = np.zeros([self.ntypes, self.ndescrpt]) - if dstd is None: - dstd = np.ones([self.ntypes, self.ndescrpt]) - t_rcut = tf.constant( - self.rcut, - name="rcut", - dtype=GLOBAL_TF_FLOAT_PRECISION, - ) - t_ntypes = tf.constant(self.ntypes, name="ntypes", dtype=tf.int32) - t_ndescrpt = tf.constant(self.ndescrpt, name="ndescrpt", dtype=tf.int32) - t_sel = tf.constant(self.sel_a, name="sel", dtype=tf.int32) - """ - self.t_avg = tf.get_variable('t_avg', - davg.shape, - dtype = GLOBAL_TF_FLOAT_PRECISION, - trainable = False, - initializer = tf.constant_initializer(davg)) - self.t_std = tf.get_variable('t_std', - dstd.shape, - dtype = GLOBAL_TF_FLOAT_PRECISION, - trainable = False, - initializer = tf.constant_initializer(dstd)) - """ - - coord = tf.reshape(coord_, [-1, natoms[1] * 3]) - box_ = tf.reshape( + + self.mask = paddle.cast(aparam, paddle.int32) + self.mask = paddle.reshape(self.mask, [-1, natoms[1]]) + # with tf.variable_scope("descrpt_attr" + suffix, reuse=reuse): + if davg is None: + davg = np.zeros([self.ntypes, self.ndescrpt]) + if dstd is None: + dstd = np.ones([self.ntypes, self.ndescrpt]) + # t_rcut = tf.constant( + # self.rcut, + # name="rcut", + # dtype=GLOBAL_TF_FLOAT_PRECISION, + # ) + # t_ntypes = tf.constant(self.ntypes, name="ntypes", dtype=tf.int32) + # t_ndescrpt = tf.constant(self.ndescrpt, name="ndescrpt", dtype=tf.int32) + # t_sel = tf.constant(self.sel_a, name="sel", dtype=tf.int32) + # """ + # self.t_avg = tf.get_variable('t_avg', + # davg.shape, + # dtype = GLOBAL_TF_FLOAT_PRECISION, + # trainable = False, + # initializer = tf.constant_initializer(davg)) + # self.t_std = tf.get_variable('t_std', + # dstd.shape, + # dtype = GLOBAL_TF_FLOAT_PRECISION, + # trainable = False, + # initializer = tf.constant_initializer(dstd)) + # """ + + coord = paddle.reshape(coord_, [-1, natoms[1] * 3]) + + box_ = paddle.reshape( box_, [-1, 9] ) # Not used in se_a_mask descriptor. For compatibility in c++ inference. - atype = tf.reshape(atype_, [-1, natoms[1]]) + + atype = paddle.reshape(atype_, [-1, natoms[1]]) + + coord = paddle.to_tensor(coord, place="cpu") + atype = paddle.to_tensor(atype, place="cpu") + self.mask = paddle.to_tensor(self.mask, place="cpu") + box_ = paddle.to_tensor(box_, place="cpu") + natoms = paddle.to_tensor(natoms, place="cpu") + mesh = paddle.to_tensor(mesh, place="cpu") ( self.descrpt, @@ -341,12 +340,13 @@ def build( self.nlist, ) = op_module.descrpt_se_a_mask(coord, atype, self.mask, box_, natoms, mesh) # only used when tensorboard was set as true - tf.summary.histogram("descrpt", self.descrpt) - tf.summary.histogram("rij", self.rij) - tf.summary.histogram("nlist", self.nlist) + # tf.summary.histogram("descrpt", self.descrpt) + # tf.summary.histogram("rij", self.rij) + # tf.summary.histogram("nlist", self.nlist) - self.descrpt_reshape = tf.reshape(self.descrpt, [-1, self.ndescrpt]) - self._identity_tensors(suffix=suffix) + self.descrpt_reshape = paddle.reshape(self.descrpt, [-1, self.ndescrpt]) + # self._identity_tensors(suffix=suffix) + self.descrpt_reshape.stop_gradient = False self.dout, self.qmat = self._pass_filter( self.descrpt_reshape, @@ -359,14 +359,14 @@ def build( ) # only used when tensorboard was set as true - tf.summary.histogram("embedding_net_output", self.dout) + # tf.summary.histogram("embedding_net_output", self.dout) return self.dout def prod_force_virial( self, - atom_ener: tf.Tensor, - natoms: tf.Tensor, - ) -> Tuple[tf.Tensor, tf.Tensor, tf.Tensor]: + atom_ener: paddle.Tensor, + natoms: paddle.Tensor, + ) -> Tuple[paddle.Tensor, paddle.Tensor, paddle.Tensor]: """Compute force and virial. Parameters @@ -388,9 +388,10 @@ def prod_force_virial( atom_virial None for se_a_mask op """ - [net_deriv] = tf.gradients(atom_ener, self.descrpt_reshape) - tf.summary.histogram("net_derivative", net_deriv) - net_deriv_reshape = tf.reshape(net_deriv, [-1, natoms[0] * self.ndescrpt]) + net_deriv = paddle.grad(atom_ener, self.descrpt_reshape, create_graph=True)[0] + # tf.summary.histogram("net_derivative", net_deriv) + net_deriv_reshape = paddle.reshape(net_deriv, [-1, natoms[0] * self.ndescrpt]) + net_deriv_reshape = paddle.to_tensor(net_deriv_reshape, place="cpu") force = op_module.prod_force_se_a_mask( net_deriv_reshape, self.descrpt_deriv, @@ -399,11 +400,375 @@ def prod_force_virial( total_atom_num=self.total_atom_num, ) - tf.summary.histogram("force", force) + # tf.summary.histogram("force", force) # Construct virial and atom virial tensors to avoid reshape errors in model/ener.py # They are not used in se_a_mask op - virial = tf.zeros([1, 9], dtype=force.dtype) - atom_virial = tf.zeros([1, natoms[1], 9], dtype=force.dtype) + virial = paddle.zeros([1, 9], dtype=force.dtype) + atom_virial = paddle.zeros([1, natoms[1], 9], dtype=force.dtype) return force, virial, atom_virial + + def _pass_filter( + self, inputs, atype, natoms, input_dict, reuse=None, suffix="", trainable=True + ): + """pass_filter. + + Parameters + ---------- + inputs : paddle.Tensor + Inputs tensor. + atype : paddle.Tensor + Atom type Tensor. + natoms : paddle.Tensor + Number of atoms vector + input_dict : Dict[str, paddle.Tensor] + Input data dict. + reuse : bool, optional + Whether reuse variables. Defaults to None. + suffix : str, optional + Variable suffix. Defaults to "". + trainable : bool, optional + Whether make subnetwork traninable. Defaults to True. + + Returns + ------- + Tuple[Tensor, Tensor]: output: [1, all_atom, M1*M2], output_qmat: [1, all_atom, M1*3] + """ + if input_dict is not None: + type_embedding = input_dict.get("type_embedding", None) + else: + type_embedding = None + start_index = 0 + inputs = paddle.reshape(inputs, [-1, int(natoms[0].item()), int(self.ndescrpt)]) + output = [] + output_qmat = [] + if not self.type_one_side and type_embedding is None: + for type_i in range(self.ntypes): + inputs_i = paddle.slice( + inputs, + [0, 1, 2], + [0, start_index, 0], + [ + inputs.shape[0], + start_index + natoms[2 + type_i].item(), + inputs.shape[2], + ], + ) + inputs_i = paddle.reshape(inputs_i, [-1, self.ndescrpt]) + filter_name = "filter_type_" + str(type_i) + suffix + layer, qmat = self._filter( + inputs_i, + type_i, + name=filter_name, + natoms=natoms, + reuse=reuse, + trainable=trainable, + activation_fn=self.filter_activation_fn, + ) + layer = paddle.reshape( + layer, [inputs.shape[0], natoms[2 + type_i], self.get_dim_out()] + ) + qmat = paddle.reshape( + qmat, + [ + inputs.shape[0], + natoms[2 + type_i], + self.get_dim_rot_mat_1() * 3, + ], + ) + output.append(layer) + output_qmat.append(qmat) + start_index += natoms[2 + type_i].item() + else: + raise NotImplementedError() + # This branch will not be excecuted at current + # inputs_i = inputs + # inputs_i = paddle.reshape(inputs_i, [-1, self.ndescrpt]) + # type_i = -1 + # # if nvnmd_cfg.enable and nvnmd_cfg.quantize_descriptor: + # # inputs_i = descrpt2r4(inputs_i, natoms) + # if len(self.exclude_types): + # atype_nloc = paddle.reshape( + # paddle.slice(atype, [0, 1], [0, 0], [atype.shape[0], natoms[0]]), + # [-1], + # ) # when nloc != nall, pass nloc to mask + # mask = self.build_type_exclude_mask( + # self.exclude_types, + # self.ntypes, + # self.sel_a, + # self.ndescrpt, + # atype_nloc, + # paddle.shape(inputs_i)[0], + # ) + # inputs_i *= mask + + # layer, qmat = self._filter( + # inputs_i, + # type_i, + # name="filter_type_all" + suffix, + # natoms=natoms, + # reuse=reuse, + # trainable=trainable, + # activation_fn=self.filter_activation_fn, + # type_embedding=type_embedding, + # ) + # layer = paddle.reshape( + # layer, [inputs.shape[0], natoms[0], self.get_dim_out()] + # ) + # qmat = paddle.reshape( + # qmat, [inputs.shape[0], natoms[0], self.get_dim_rot_mat_1() * 3] + # ) + # output.append(layer) + # output_qmat.append(qmat) + output = paddle.concat(output, axis=1) + output_qmat = paddle.concat(output_qmat, axis=1) + return output, output_qmat + + def _filter_lower( + self, + type_i: int, # inner-loop + type_input: int, # outer-loop + start_index: int, + incrs_index: int, + inputs: paddle.Tensor, + nframes: int, + natoms: int, + type_embedding=None, + is_exclude=False, + ): + """Input env matrix, returns R.G.""" + outputs_size = [1, *self.filter_neuron] + # cut-out inputs + # with natom x (nei_type_i x 4) + inputs_i = paddle.slice( + inputs, + [0, 1], + [0, start_index * 4], + [inputs.shape[0], start_index * 4 + incrs_index * 4], + ) + + shape_i = inputs_i.shape + natom = inputs_i.shape[0] + + # with (natom x nei_type_i) x 4 + inputs_reshape = paddle.reshape(inputs_i, [-1, 4]) + # with (natom x nei_type_i) x 1 + xyz_scatter = paddle.reshape( + paddle.slice(inputs_reshape, [0, 1], [0, 0], [inputs_reshape.shape[0], 1]), + [-1, 1], + ) + + if type_embedding is not None: + xyz_scatter = self._concat_type_embedding( + xyz_scatter, nframes, natoms, type_embedding + ) # + if self.compress: + raise RuntimeError( + "compression of type embedded descriptor is not supported at the moment" + ) + # natom x 4 x outputs_size + if self.compress and (not is_exclude): + if self.type_one_side: + net = "filter_-1_net_" + str(type_i) + else: + net = "filter_" + str(type_input) + "_net_" + str(type_i) + info = [ + self.lower[net], + self.upper[net], + self.upper[net] * self.table_config[0], + self.table_config[1], + self.table_config[2], + self.table_config[3], + ] + return op_module.tabulate_fusion_se_a( + paddle.cast(self.table.data[net], self.filter_precision), + info, + xyz_scatter, + paddle.reshape(inputs_i, [natom, shape_i[1] // 4, 4]), + last_layer_size=outputs_size[-1], + ) + else: + if not is_exclude: + # excuted this branch + xyz_scatter_out = self.embedding_nets[type_input][type_i](xyz_scatter) + if (not self.uniform_seed) and (self.seed is not None): + self.seed += self.seed_shift + else: + # we can safely return the final xyz_scatter filled with zero directly + return paddle.cast( + paddle.fill((natom, 4, outputs_size[-1]), 0.0), + self.filter_precision, + ) + # natom x nei_type_i x out_size + xyz_scatter_out = paddle.reshape( + xyz_scatter_out, (-1, shape_i[1] // 4, outputs_size[-1]) + ) # (natom x nei_type_i) x 100 ==> natom x nei_type_i x 100 + # When using paddle.reshape(inputs_i, [-1, shape_i[1]//4, 4]) below + # [588 24] -> [588 6 4] correct + # but if sel is zero + # [588 0] -> [147 0 4] incorrect; the correct one is [588 0 4] + # So we need to explicitly assign the shape to paddle.shape(inputs_i)[0] instead of -1 + # natom x 4 x outputs_size + + return paddle.matmul( + paddle.reshape(inputs_i, [natom, shape_i[1] // 4, 4]), + xyz_scatter_out, + transpose_x=True, + ) + + def _filter( + self, + inputs: paddle.Tensor, + type_input: int, + natoms, + type_embedding=None, + activation_fn=paddle.nn.functional.tanh, + stddev=1.0, + bavg=0.0, + name="linear", + reuse=None, + trainable=True, + ): + """_filter. + + Parameters + ---------- + inputs : paddle.Tensor + Inputs tensor. + type_input : int + Type of input. + natoms : paddle.Tensor + Number of atoms, a vector. + type_embedding : paddle.Tensor + Type embedding. Defaults to None. + activation_fn : Callable + Activation function. Defaults to paddle.nn.functional.tanh. + stddev : float, optional + Stddev for parameters initialization. Defaults to 1.0. + bavg : float, optional + Bavg for parameters initialization . Defaults to 0.0. + name : str, optional + Name for subnetwork. Defaults to "linear". + reuse : bool, optional + Whether reuse variables. Defaults to None. + trainable : bool, optional + Whether make subnetwork trainable. Defaults to True. + + Returns + ------- + Tuple[Tensor, Tensor]: result: [64/128, M1*M2], qmat: [64/128, M1, 3] + """ + # NOTE: code below is annotated as nframes computation is wrong + # nframes = paddle.shape(paddle.reshape(inputs, [-1, natoms[0], self.ndescrpt]))[0] + + nframes = 1 + # natom x (nei x 4) + shape = inputs.shape + outputs_size = [1, *self.filter_neuron] + outputs_size_2 = self.n_axis_neuron # 16 + all_excluded = all( + # FIXME: the bracket '[]' is needed when convert to static model, will be + # removed when fixed. + [ # noqa + (type_input, type_i) in self.exclude_types # set() + for type_i in range(self.ntypes) + ] + ) + if all_excluded: + # all types are excluded so result and qmat should be zeros + # we can safaly return a zero matrix... + # See also https://stackoverflow.com/a/34725458/9567349 + # result: natom x outputs_size x outputs_size_2 + # qmat: natom x outputs_size x 3 + natom = paddle.shape(inputs)[0] + result = paddle.cast( + paddle.full((natom, outputs_size_2, outputs_size[-1]), 0.0), + GLOBAL_PD_FLOAT_PRECISION, + ) + qmat = paddle.cast( + paddle.full((natom, outputs_size[-1], 3), 0.0), + GLOBAL_PD_FLOAT_PRECISION, + ) + return result, qmat + + # with tf.variable_scope(name, reuse=reuse): + start_index = 0 + type_i = 0 + # natom x 4 x outputs_size + if type_embedding is None: + rets = [] + # execute this branch + for type_i in range(self.ntypes): + ret = self._filter_lower( + type_i, + type_input, + start_index, + self.sel_a[type_i], # 46(O)/92(H) + inputs, + nframes, + natoms, + type_embedding=type_embedding, + is_exclude=(type_input, type_i) in self.exclude_types, + ) + if (type_input, type_i) not in self.exclude_types: + # add zero is meaningless; skip + rets.append(ret) + start_index += self.sel_a[type_i] + # faster to use accumulate_n than multiple add + xyz_scatter_1 = paddle.add_n(rets) + else: + xyz_scatter_1 = self._filter_lower( + type_i, + type_input, + start_index, + np.cumsum(self.sel_a)[-1], + inputs, + nframes, + natoms, + type_embedding=type_embedding, + is_exclude=False, + ) + # natom x nei x outputs_size + # xyz_scatter = tf.concat(xyz_scatter_total, axis=1) + # natom x nei x 4 + # inputs_reshape = tf.reshape(inputs, [-1, shape[1]//4, 4]) + # natom x 4 x outputs_size + # xyz_scatter_1 = tf.matmul(inputs_reshape, xyz_scatter, transpose_a = True) + if self.original_sel is None: + # shape[1] = nnei * 4 + nnei = shape[1] / 4 + else: + nnei = paddle.cast( + paddle.to_tensor( + np.sum(self.original_sel), + dtype=paddle.int32, + stop_gradient=True, + ), + self.filter_precision, + ) + xyz_scatter_1 = xyz_scatter_1 / nnei + # natom x 4 x outputs_size_2 + xyz_scatter_2 = paddle.slice( + xyz_scatter_1, + [0, 1, 2], + [0, 0, 0], + [xyz_scatter_1.shape[0], xyz_scatter_1.shape[1], outputs_size_2], + ) + # natom x 3 x outputs_size_2 + # qmat = tf.slice(xyz_scatter_2, [0,1,0], [-1, 3, -1]) + # natom x 3 x outputs_size_1 + qmat = paddle.slice( + xyz_scatter_1, + [0, 1, 2], + [0, 1, 0], + [xyz_scatter_1.shape[0], 1 + 3, xyz_scatter_1.shape[2]], + ) + # natom x outputs_size_1 x 3 + qmat = paddle.transpose(qmat, perm=[0, 2, 1]) # [64/128, M1, 3] + # natom x outputs_size x outputs_size_2 + result = paddle.matmul(xyz_scatter_1, xyz_scatter_2, transpose_x=True) + # natom x (outputs_size x outputs_size_2) + result = paddle.reshape(result, [-1, outputs_size_2 * outputs_size[-1]]) + + return result, qmat diff --git a/deepmd/fit/ener.py b/deepmd/fit/ener.py index 4a218fdc1b..d0ca9fda0c 100644 --- a/deepmd/fit/ener.py +++ b/deepmd/fit/ener.py @@ -247,7 +247,13 @@ def __init__( else: type_i_layers.append( OneLayer_deepmd( - self.dim_descrpt + self.numb_fparam + self.numb_aparam, + self.dim_descrpt + + self.numb_fparam + + ( + self.numb_aparam + if self.use_aparam_as_mask is False + else 0 + ), self.n_neuron[ii], activation_fn=self.fitting_activation_fn, precision=self.fitting_precision, @@ -691,7 +697,8 @@ def forward( self.atom_ener_before = outs * atype_filter self.add_type = paddle.reshape( paddle.nn.functional.embedding( - self.atype_nloc, self.t_bias_atom_e.reshape([2, -1]) + self.atype_nloc, + self.t_bias_atom_e.reshape([self.t_bias_atom_e.shape[0], -1]), ), [paddle.shape(inputs)[0], paddle.sum(natoms[2 : 2 + ntypes_atom]).item()], ) diff --git a/deepmd/train/trainer.py b/deepmd/train/trainer.py index 2ad6a385e9..8f5051570b 100644 --- a/deepmd/train/trainer.py +++ b/deepmd/train/trainer.py @@ -151,11 +151,12 @@ def _init_param(self, jdata): descrpt_param["multi_task"] = True if descrpt_param["type"] in ["se_e2_a", "se_a", "se_e2_r", "se_r", "hybrid"]: descrpt_param["spin"] = self.spin - descrpt_param.pop("type") - descrpt_param["mixed_prec"] = self.mixed_prec - if descrpt_param["mixed_prec"] is not None: - descrpt_param["precision"]: str = self.mixed_prec["output_prec"] - self.descrpt = deepmd.descriptor.se_a.DescrptSeA(**descrpt_param) + elif descrpt_param["type"] == "se_a_mask": + descrpt_param.pop("type") + self.descrpt = deepmd.descriptor.se_a_mask.DescrptSeAMask(**descrpt_param) + else: + descrpt_param.pop("type") + self.descrpt = deepmd.descriptor.se_a.DescrptSeA(**descrpt_param) # fitting net if not self.multi_task_mode: diff --git a/examples/zinc_protein/zinc_se_a_mask.json b/examples/zinc_protein/zinc_se_a_mask.json index 80a070d726..d6376c346a 100644 --- a/examples/zinc_protein/zinc_se_a_mask.json +++ b/examples/zinc_protein/zinc_se_a_mask.json @@ -66,14 +66,14 @@ "training": { "training_data": { "systems": [ - "example/zinc_protein/train_data_dp_mask/" + "examples/zinc_protein/train_data_dp_mask/" ], "batch_size": 2, "_comment": "that's all" }, "validation_data": { "systems": [ - "example/zinc_protein/val_data_dp_mask/" + "examples/zinc_protein/val_data_dp_mask/" ], "batch_size": 2, "_comment": "that's all" @@ -82,7 +82,7 @@ "seed": 10, "disp_freq": 100, "save_freq": 1000, - "tensorboard": true, + "tensorboard": false, "tensorboard_log_dir": "log4tensorboard", "tensorboard_freq": 100, "_comment": "that's all" diff --git a/source/lib/paddle_src/custom_op_install.py b/source/lib/paddle_src/custom_op_install.py index 6aa4a490da..5d42b71385 100644 --- a/source/lib/paddle_src/custom_op_install.py +++ b/source/lib/paddle_src/custom_op_install.py @@ -52,6 +52,9 @@ "./paddle_prod_force_grad.cu", "./paddle_prod_force_grad.cc", "./paddle_neighbor_stat.cc", + "./paddle_descrpt_se_a_mask.cc", + "./paddle_prod_force_se_a_mask.cc", + "./paddle_prod_force_se_a_mask_grad.cc", ], include_dirs=[ "../../lib/include/", diff --git a/source/lib/paddle_src/custom_op_test.py b/source/lib/paddle_src/custom_op_test.py index 29ac501935..106e398b7b 100644 --- a/source/lib/paddle_src/custom_op_test.py +++ b/source/lib/paddle_src/custom_op_test.py @@ -256,6 +256,107 @@ def test_prod_virial_se_a(place="cpu"): print(np.allclose(atom_virial.numpy(), atom_virial_load)) +def test_prod_force_se_a_mask(place="cpu"): + print("=" * 10, f"test_prod_force_se_a_mask [place={place}]", "=" * 10) + import numpy as np + + net_deriv_reshape = np.ascontiguousarray( + np.load(osp.join(unitest_dir, "prod_force_se_a_mask/net_deriv_reshape.npy")) + ) + descrpt_deriv = np.ascontiguousarray( + np.load(osp.join(unitest_dir, "prod_force_se_a_mask/descrpt_deriv.npy")) + ) + mask = np.ascontiguousarray( + np.load(osp.join(unitest_dir, "prod_force_se_a_mask/mask.npy")) + ) + nlist = np.ascontiguousarray( + np.load(osp.join(unitest_dir, "prod_force_se_a_mask/nlist.npy")) + ) + total_atom_num = np.ascontiguousarray( + np.load(osp.join(unitest_dir, "prod_force_se_a_mask/total_atom_num.npy")) + ) + + net_deriv_reshape = paddle.to_tensor( + net_deriv_reshape, stop_gradient=False, place=place + ) + descrpt_deriv = paddle.to_tensor(descrpt_deriv, "float64", place=place) + mask = paddle.to_tensor(mask, "int32", place=place) + nlist = paddle.to_tensor(nlist, "int32", place=place) + + force = paddle_deepmd_lib.prod_force_se_a_mask( + net_deriv_reshape, + descrpt_deriv, + mask, + nlist, + total_atom_num, + ) + force.sum().backward() + # print(f"net_deriv_reshape.grad.shape = {net_deriv_reshape.grad.shape}") + + # print(mn.shape, mn.min().item()); print(mn.max().item()); print(mn.mean().item()); print(mn.var().item()) + # print(mn_load.shape); print(mn_load.min().item()); print(mn_load.max().item()); print(mn_load.mean().item()); print(mn_load.var().item()) + # print(dt.shape, dt.min().item(), dt.max().item(), dt.mean().item(), dt.var().item()) + # print(dt_load.shape, dt_load.min().item(), dt_load.max().item(), dt_load.mean().item(), dt_load.var().item()) + force_load = np.load(osp.join(unitest_dir, "prod_force_se_a_mask/force.npy")) + grad_load = np.load( + osp.join(unitest_dir, "prod_force_se_a_mask/net_deriv_reshape_grad.npy") + ) + + print(np.allclose(force.numpy(), force_load)) + print(np.allclose(net_deriv_reshape.grad.numpy(), grad_load)) + + +def test_descrpt_se_a_mask(place="cpu"): + print("=" * 10, f"test_descrpt_se_a_mask [place={place}]", "=" * 10) + import numpy as np + + coord = np.ascontiguousarray( + np.load(osp.join(unitest_dir, "descrpt_se_a_mask/coord.npy")) + ) # float64 (2,441) + atype = np.ascontiguousarray( + np.load(osp.join(unitest_dir, "descrpt_se_a_mask/atype.npy")) + ) # int32 (2,147) + mask = np.ascontiguousarray( + np.load(osp.join(unitest_dir, "descrpt_se_a_mask/mask.npy")) + ) # int32 (2, 147) + box = np.ascontiguousarray( + np.load(osp.join(unitest_dir, "descrpt_se_a_mask/box.npy")) + ) # float64 (2, 9) + natoms = np.ascontiguousarray( + np.load(osp.join(unitest_dir, "descrpt_se_a_mask/natoms.npy")) + ) # int32 (8,) + mesh = np.ascontiguousarray( + np.load(osp.join(unitest_dir, "descrpt_se_a_mask/mesh.npy")) + ) # int32 (0,) + + coord = paddle.to_tensor(coord, "float64", place=place) + atype = paddle.to_tensor(atype, "int32", place=place) + mask = paddle.to_tensor(mask, "int32", place=place) + box = paddle.to_tensor(box, "float64", place=place) + natoms = paddle.to_tensor(natoms, "int32", place=place) + mesh = paddle.to_tensor(mesh, "int32", place=place) + + descrpt, descrpt_deriv, rij, nlist = paddle_deepmd_lib.descrpt_se_a_mask( + coord, atype, mask, box, natoms, mesh + ) + + # print(mn.shape, mn.min().item()); print(mn.max().item()); print(mn.mean().item()); print(mn.var().item()) + # print(mn_load.shape); print(mn_load.min().item()); print(mn_load.max().item()); print(mn_load.mean().item()); print(mn_load.var().item()) + # print(dt.shape, dt.min().item(), dt.max().item(), dt.mean().item(), dt.var().item()) + # print(dt_load.shape, dt_load.min().item(), dt_load.max().item(), dt_load.mean().item(), dt_load.var().item()) + descrpt_load = np.load(osp.join(unitest_dir, "descrpt_se_a_mask/descrpt.npy")) + descrpt_deriv_load = np.load( + osp.join(unitest_dir, "descrpt_se_a_mask/descrpt_deriv.npy") + ) + rij_load = np.load(osp.join(unitest_dir, "descrpt_se_a_mask/rij.npy")) + nlist_load = np.load(osp.join(unitest_dir, "descrpt_se_a_mask/nlist.npy")) + + print(np.allclose(descrpt.numpy(), descrpt_load)) + print(np.allclose(descrpt_deriv.numpy(), descrpt_deriv_load)) + print(np.allclose(rij.numpy(), rij_load)) + print(np.allclose(nlist.numpy(), nlist_load)) + + if __name__ == "__main__": test_neighbor_stat() @@ -266,3 +367,5 @@ def test_prod_virial_se_a(place="cpu"): test_prod_env_mat_a("cpu") test_prod_force_se_a("cpu") test_prod_virial_se_a("cpu") + test_prod_force_se_a_mask() + test_descrpt_se_a_mask() diff --git a/source/lib/paddle_src/paddle_descrpt_se_a_mask.cc b/source/lib/paddle_src/paddle_descrpt_se_a_mask.cc new file mode 100644 index 0000000000..9a1bbb63be --- /dev/null +++ b/source/lib/paddle_src/paddle_descrpt_se_a_mask.cc @@ -0,0 +1,342 @@ +#include + +#include "ComputeDescriptor.h" +#include "errors.h" +#include "fmt_nlist.h" +#include "neighbor_list.h" +#include "paddle/extension.h" + +typedef double boxtensor_t; +typedef double compute_t; + +#define CHECK_INPUT(x) PD_CHECK(x.is_cpu(), #x " must be a CPU Tensor.") +#define CHECK_INPUT_DIM(x, value) \ + PD_CHECK(x.shape().size() == value, #x "'s dim should be " #value ".") +#define CHECK_INPUT_READY(x) \ + PD_CHECK(x.initialized(), #x " must be initialized before usage.") +#define CHECK_INPUT_INT32(x) \ + PD_CHECK(x.dtype() == paddle::DataType::INT32, #x " dtype should be INT32.") +template +struct NeighborInfo { + int type; + FPTYPE dist; + int index; + NeighborInfo() : type(0), dist(0), index(0) {} + NeighborInfo(int tt, FPTYPE dd, int ii) : type(tt), dist(dd), index(ii) {} + bool operator<(const NeighborInfo& b) const { + return (type < b.type || + (type == b.type && + (dist < b.dist || (dist == b.dist && index < b.index)))); + } +}; + +compute_t max_distance = 10000.0; + +void buildAndSortNeighborList(int i_idx, + const std::vector d_coord3, + std::vector& d_type, + std::vector& d_mask, + std::vector& sorted_nlist, + int total_atom_num) { + // sorted_nlist.resize(total_atom_num); + std::vector> sel_nei; + for (int jj = 0; jj < total_atom_num; jj++) { + compute_t diff[3]; + const int j_idx = jj; + for (int dd = 0; dd < 3; ++dd) { + diff[dd] = d_coord3[j_idx * 3 + dd] - d_coord3[i_idx * 3 + dd]; + } + // Check if j_idx atom is virtual particle or not. + compute_t rr = 0.0; + if (d_mask[j_idx] == 0 || j_idx == i_idx) { + rr = max_distance; + } else { + rr = sqrt(deepmd::dot3(diff, diff)); + } + sel_nei.push_back(NeighborInfo(d_type[j_idx], rr, j_idx)); + } + std::sort(sel_nei.begin(), sel_nei.end()); + // Save the sorted atom index. + for (int jj = 0; jj < sel_nei.size(); jj++) { + int atom_idx = sel_nei[jj].index; + sorted_nlist[jj] = atom_idx; + } +} + +template +void DescrptSeAMaskCPUKernel(int n_descrpt, + int total_atom_num, + int nsamples, + const data_t* coord, + const int* type, + const int* mask_matrix, + data_t* descrpt, + data_t* descrpt_deriv, + data_t* rij, + int* nlist) { +#pragma omp parallel for + for (int kk = 0; kk < nsamples; ++kk) { + // Iterate for each frame. + int nloc = total_atom_num; + int natoms = total_atom_num; + + std::vector d_coord3(natoms * 3); + for (int ii = 0; ii < natoms; ++ii) { + for (int dd = 0; dd < 3; ++dd) { + d_coord3[ii * 3 + dd] = coord[kk * total_atom_num * 3 + ii * 3 + dd]; + } + } + + std::vector d_type(natoms); + for (int ii = 0; ii < natoms; ++ii) { + d_type[ii] = type[kk * total_atom_num + ii]; + } + + std::vector d_mask(natoms); + for (int ii = 0; ii < natoms; ++ii) { + d_mask[ii] = mask_matrix[kk * total_atom_num + ii]; + } + std::vector sorted_nlist(total_atom_num); + + for (int ii = 0; ii < nloc; ii++) { + // Check this atom is virtual atom or not. If it is, set the virtual + // atom's environment descriptor and derivation on descriptor to be zero + // directly. + if (mask_matrix[kk * total_atom_num + ii] == 0) { + for (int jj = 0; jj < natoms * 4; ++jj) { + descrpt[kk * total_atom_num * total_atom_num * n_descrpt + + ii * total_atom_num * 4 + jj] = 0.; + } + for (int jj = 0; jj < natoms * 4 * 3; ++jj) { + descrpt_deriv[kk * total_atom_num * total_atom_num * n_descrpt * 3 + + ii * total_atom_num * 4 * 3 + jj] = 0.; + } + // Save the neighbor list relative coordinates with center atom ii. + for (int jj = 0; jj < natoms * 3; ++jj) { + rij[kk * total_atom_num * total_atom_num * 3 + ii * natoms * 3 + jj] = + 0.; + } + // Save the neighbor atoms indicies. + for (int jj = 0; jj < natoms; jj++) { + nlist[kk * total_atom_num * total_atom_num + ii * natoms + jj] = -1; + } + continue; + } + + // Build the neighbor list for atom ii. + std::fill(sorted_nlist.begin(), sorted_nlist.end(), -1); + buildAndSortNeighborList(ii, d_coord3, d_type, d_mask, sorted_nlist, + total_atom_num); + + // Set the center atom coordinates. + std::vector rloc(3); + for (int dd = 0; dd < 3; ++dd) { + rloc[dd] = coord[kk * total_atom_num * 3 + ii * 3 + dd]; + } + + // Compute the descriptor and derive for the descriptor for each atom. + std::vector descrpt_atom(natoms * 4); + std::vector descrpt_deriv_atom(natoms * 12); + std::vector rij_atom(natoms * 3); + + std::fill(descrpt_deriv_atom.begin(), descrpt_deriv_atom.end(), 0.0); + std::fill(descrpt_atom.begin(), descrpt_atom.end(), 0.0); + std::fill(rij_atom.begin(), rij_atom.end(), 0.0); + + // Compute the each environment std::vector for each atom. + for (int jj = 0; jj < natoms; jj++) { + int j_idx = sorted_nlist[jj]; + + compute_t temp_rr; + compute_t temp_diff[3]; + temp_rr = 0.; + + // Once ii == j_idx, the descriptor and derivation should be set to + // zero. Or if the atom jj is an virtual atom. The descriptor and + // derivation should be zero also. + if (ii == j_idx || mask_matrix[kk * total_atom_num + j_idx] == 0) { + // 1./rr, cos(theta), cos(phi), sin(phi) + descrpt_atom[jj * 4 + 0] = 0.; + descrpt_atom[jj * 4 + 1] = 0.; + descrpt_atom[jj * 4 + 2] = 0.; + descrpt_atom[jj * 4 + 3] = 0.; + // derive of the component 1/r + descrpt_deriv_atom[jj * 12 + 0] = 0.; + descrpt_deriv_atom[jj * 12 + 1] = 0.; + descrpt_deriv_atom[jj * 12 + 2] = 0.; + // derive of the component x/r2 + descrpt_deriv_atom[jj * 12 + 3] = 0.; // on x. + descrpt_deriv_atom[jj * 12 + 4] = 0.; // on y. + descrpt_deriv_atom[jj * 12 + 5] = 0.; // on z. + // derive of the component y/r2 + descrpt_deriv_atom[jj * 12 + 6] = 0.; // on x. + descrpt_deriv_atom[jj * 12 + 7] = 0.; // on y. + descrpt_deriv_atom[jj * 12 + 8] = 0.; // on z. + // derive of the component z/r2 + descrpt_deriv_atom[jj * 12 + 9] = 0.; // on x. + descrpt_deriv_atom[jj * 12 + 10] = 0.; // on y. + descrpt_deriv_atom[jj * 12 + 11] = 0.; // on z. + rij_atom[jj * 3 + 0] = 0.; + rij_atom[jj * 3 + 1] = 0.; + rij_atom[jj * 3 + 2] = 0.; + continue; + } + + for (int dd = 0; dd < 3; dd++) { + temp_diff[dd] = d_coord3[j_idx * 3 + dd] - rloc[dd]; + rij_atom[jj * 3 + dd] = temp_diff[dd]; + } + + temp_rr = deepmd::dot3(temp_diff, temp_diff); + + compute_t x = temp_diff[0]; + compute_t y = temp_diff[1]; + compute_t z = temp_diff[2]; + + // r^2 + compute_t nr2 = temp_rr; + // 1/r + compute_t inr = 1. / sqrt(nr2); + // r + compute_t nr = nr2 * inr; + // 1/r^2 + compute_t inr2 = inr * inr; + // 1/r^4 + compute_t inr4 = inr2 * inr2; + // 1/r^3 + compute_t inr3 = inr * inr2; + // 1./rr, cos(theta), cos(phi), sin(phi) + descrpt_atom[jj * 4 + 0] = 1. / nr; + descrpt_atom[jj * 4 + 1] = x / nr2; + descrpt_atom[jj * 4 + 2] = y / nr2; + descrpt_atom[jj * 4 + 3] = z / nr2; + // derive of the component 1/r + descrpt_deriv_atom[jj * 12 + 0] = x * inr3; + descrpt_deriv_atom[jj * 12 + 1] = y * inr3; + descrpt_deriv_atom[jj * 12 + 2] = z * inr3; + // derive of the component x/r2 + descrpt_deriv_atom[jj * 12 + 3] = 2. * x * x * inr4 - inr2; // on x. + descrpt_deriv_atom[jj * 12 + 4] = 2. * x * y * inr4; // on y. + descrpt_deriv_atom[jj * 12 + 5] = 2. * x * z * inr4; // on z. + // derive of the component y/r2 + descrpt_deriv_atom[jj * 12 + 6] = 2. * y * x * inr4; // on x. + descrpt_deriv_atom[jj * 12 + 7] = 2. * y * y * inr4 - inr2; // on y. + descrpt_deriv_atom[jj * 12 + 8] = 2. * y * z * inr4; // on z. + // derive of the component z/r2 + descrpt_deriv_atom[jj * 12 + 9] = 2. * z * x * inr4; // on x. + descrpt_deriv_atom[jj * 12 + 10] = 2. * z * y * inr4; // on y. + descrpt_deriv_atom[jj * 12 + 11] = 2. * z * z * inr4 - inr2; // on z. + } + + for (int jj = 0; jj < natoms * 4; ++jj) { + descrpt[kk * total_atom_num * total_atom_num * n_descrpt + + ii * total_atom_num * 4 + jj] = descrpt_atom[jj]; + } + for (int jj = 0; jj < natoms * 4 * 3; ++jj) { + descrpt_deriv[kk * total_atom_num * total_atom_num * n_descrpt * 3 + + ii * total_atom_num * 4 * 3 + jj] = + descrpt_deriv_atom[jj]; + } + // Save the neighbor list relative coordinates with center atom ii. + for (int jj = 0; jj < natoms * 3; ++jj) { + rij[kk * total_atom_num * total_atom_num * 3 + ii * natoms * 3 + jj] = + rij_atom[jj]; + } + // Save the neighbor atoms indicies. + for (int jj = 0; jj < natoms; ++jj) { + nlist[kk * total_atom_num * total_atom_num + ii * natoms + jj] = + sorted_nlist[jj]; + } + } + } +} + +std::vector DescrptSeAMaskCPU( + const paddle::Tensor& coord_tensor, + const paddle::Tensor& type_tensor, + const paddle::Tensor& mask_matrix_tensor, + const paddle::Tensor& box_tensor, + const paddle::Tensor& natoms_tensor, + const paddle::Tensor& mesh_tensor) { + CHECK_INPUT(coord_tensor); + CHECK_INPUT(type_tensor); + CHECK_INPUT(mask_matrix_tensor); + CHECK_INPUT(box_tensor); + CHECK_INPUT(natoms_tensor); + CHECK_INPUT(mesh_tensor); + + CHECK_INPUT_INT32(type_tensor); + CHECK_INPUT_INT32(mask_matrix_tensor); + CHECK_INPUT_INT32(natoms_tensor); + CHECK_INPUT_INT32(mesh_tensor); + // set size of the sample + CHECK_INPUT_DIM(coord_tensor, 2); + CHECK_INPUT_DIM(type_tensor, 2); + CHECK_INPUT_DIM(mask_matrix_tensor, 2); + + int nsamples = coord_tensor.shape()[0]; + + // check the sizes + PD_CHECK(nsamples == type_tensor.shape()[0], + "number of samples should match"); + PD_CHECK(nsamples == mask_matrix_tensor.shape()[0], + "number of samples should match"); + + // Set n_descrpt for each atom. Include 1/rr, cos(theta), cos(phi), sin(phi) + int n_descrpt = 4; + + // Calculate the total_atom_num + const int* natoms = natoms_tensor.data(); + int total_atom_num = natoms[1]; + // check the sizes + PD_CHECK(total_atom_num * 3 == coord_tensor.shape()[1], + "number of samples should match"); + PD_CHECK(total_atom_num == mask_matrix_tensor.shape()[1], + "number of samples should match"); + + // create output tensor + std::vector descrpt_shape{ + nsamples, total_atom_num * total_atom_num * n_descrpt}; + std::vector descrpt_deriv_shape{ + nsamples, total_atom_num * total_atom_num * n_descrpt * 3}; + std::vector rij_shape{nsamples, total_atom_num * total_atom_num * 3}; + std::vector nlist_shape{nsamples, total_atom_num * total_atom_num}; + paddle::Tensor descrpt_tensor = + paddle::empty(descrpt_shape, coord_tensor.dtype(), coord_tensor.place()); + paddle::Tensor descrpt_deriv_tensor = paddle::empty( + descrpt_deriv_shape, coord_tensor.dtype(), coord_tensor.place()); + paddle::Tensor rij_tensor = + paddle::empty(rij_shape, coord_tensor.dtype(), coord_tensor.place()); + paddle::Tensor nlist_tensor = + paddle::empty(nlist_shape, type_tensor.dtype(), type_tensor.place()); + + PD_DISPATCH_FLOATING_TYPES( + coord_tensor.type(), "descrpt_se_a_mask_kernel", ([&] { + DescrptSeAMaskCPUKernel( + n_descrpt, total_atom_num, nsamples, coord_tensor.data(), + type_tensor.data(), mask_matrix_tensor.data(), + descrpt_tensor.data(), descrpt_deriv_tensor.data(), + rij_tensor.data(), nlist_tensor.data()); + })); + return {descrpt_tensor, descrpt_deriv_tensor, rij_tensor, nlist_tensor}; +} + +std::vector DescrptSeAMask( + const paddle::Tensor& coord_tensor, + const paddle::Tensor& type_tensor, + const paddle::Tensor& mask_matrix_tensor, + const paddle::Tensor& box_tensor, + const paddle::Tensor& natoms_tensor, + const paddle::Tensor& mesh_tensor) { + if (coord_tensor.is_cpu()) { + return DescrptSeAMaskCPU(coord_tensor, type_tensor, mask_matrix_tensor, + box_tensor, natoms_tensor, mesh_tensor); + } else { + PD_THROW("DescrptSeAMask only support CPU device."); + } +} + +PD_BUILD_OP(descrpt_se_a_mask) + .Inputs({"coord", "type", "mask", "box", "natoms", "mesh"}) + .Outputs({"descrpt", "descrpt_deriv", "rij", "nlist"}) + .SetKernelFn(PD_KERNEL(DescrptSeAMask)); diff --git a/source/lib/paddle_src/paddle_prod_force_se_a_mask.cc b/source/lib/paddle_src/paddle_prod_force_se_a_mask.cc new file mode 100644 index 0000000000..66c6665f72 --- /dev/null +++ b/source/lib/paddle_src/paddle_prod_force_se_a_mask.cc @@ -0,0 +1,157 @@ +#include "paddle/extension.h" + +#define CHECK_INPUT(x) PD_CHECK(x.is_cpu(), #x " must be a CPU Tensor.") +#define CHECK_INPUT_DIM(x, value) \ + PD_CHECK(x.shape().size() == value, #x "'s dim should be " #value ".") +#define CHECK_INPUT_READY(x) \ + PD_CHECK(x.initialized(), #x " must be initialized before usage.") + +template +void ProdForceSeAMaskOpForwardCPUKernel(int nframes, + int total_atom_num, + const data_t* net_deriv, + const data_t* in_deriv, + const int* mask, + const int* nlist, + data_t* force) { + int nloc = total_atom_num; + int nall = total_atom_num; + int ndescrpt = nall * 4; + +#pragma omp parallel for + for (int kk = 0; kk < nframes; ++kk) { + int force_iter = kk * nall * 3; + int net_iter = kk * nall * ndescrpt; + int in_iter = kk * nall * ndescrpt * 3; + int mask_iter = kk * nall; + int nlist_iter = kk * nall * nall; + + for (int ii = 0; ii < nall; ii++) { + int i_idx = ii; + force[force_iter + i_idx * 3 + 0] = 0.0; + force[force_iter + i_idx * 3 + 1] = 0.0; + force[force_iter + i_idx * 3 + 2] = 0.0; + } + + for (int ii = 0; ii < nall; ii++) { + int i_idx = ii; + // Check if the atom ii is a virtual particle or not. + if (mask[mask_iter + i_idx] == 0) { + continue; + } + // Derivation with center atom. + for (int aa = 0; aa < nall * 4; ++aa) { + force[force_iter + i_idx * 3 + 0] -= + net_deriv[net_iter + i_idx * ndescrpt + aa] * + in_deriv[in_iter + i_idx * ndescrpt * 3 + aa * 3 + 0]; + force[force_iter + i_idx * 3 + 1] -= + net_deriv[net_iter + i_idx * ndescrpt + aa] * + in_deriv[in_iter + i_idx * ndescrpt * 3 + aa * 3 + 1]; + force[force_iter + i_idx * 3 + 2] -= + net_deriv[net_iter + i_idx * ndescrpt + aa] * + in_deriv[in_iter + i_idx * ndescrpt * 3 + aa * 3 + 2]; + } + // Derivation with other atoms. + for (int jj = 0; jj < nall; jj++) { + // Get the neighbor index from nlist tensor. + int j_idx = nlist[nlist_iter + i_idx * nall + jj]; + + if (j_idx == i_idx) { + continue; + } + int aa_start, aa_end; + aa_start = jj * 4; + aa_end = jj * 4 + 4; + for (int aa = aa_start; aa < aa_end; aa++) { + force[force_iter + j_idx * 3 + 0] += + net_deriv[net_iter + i_idx * ndescrpt + aa] * + in_deriv[in_iter + i_idx * ndescrpt * 3 + aa * 3 + 0]; + force[force_iter + j_idx * 3 + 1] += + net_deriv[net_iter + i_idx * ndescrpt + aa] * + in_deriv[in_iter + i_idx * ndescrpt * 3 + aa * 3 + 1]; + force[force_iter + j_idx * 3 + 2] += + net_deriv[net_iter + i_idx * ndescrpt + aa] * + in_deriv[in_iter + i_idx * ndescrpt * 3 + aa * 3 + 2]; + } + } + } + } +} + +std::vector ProdForceSeAMaskForward( + const paddle::Tensor& net_deriv_tensor, + const paddle::Tensor& in_deriv_tensor, + const paddle::Tensor& mask_tensor, + const paddle::Tensor& nlist_tensor, + int total_atom_num) { + CHECK_INPUT(net_deriv_tensor); + CHECK_INPUT(in_deriv_tensor); + CHECK_INPUT(mask_tensor); + CHECK_INPUT(nlist_tensor); + + CHECK_INPUT_DIM(net_deriv_tensor, 2); + CHECK_INPUT_DIM(in_deriv_tensor, 2); + CHECK_INPUT_DIM(mask_tensor, 2); + CHECK_INPUT_DIM(nlist_tensor, 2); + + PD_CHECK(total_atom_num >= 3, + "Number of atoms should be larger than (or equal to) 3"); + + int nframes = net_deriv_tensor.shape()[0]; + int nloc = total_atom_num; + int nall = total_atom_num; + int ndescrpt = nall * 4; + int nnei = nloc > 0 ? nlist_tensor.shape()[1] / nloc : 0; + + PD_CHECK(nframes == in_deriv_tensor.shape()[0], + "Number of samples should match"); + PD_CHECK(nframes == nlist_tensor.shape()[0], + "Number of samples should match"); + PD_CHECK(nloc * ndescrpt * 3 == in_deriv_tensor.shape()[1], + "Number of descriptors should match"); + + // Create output tensor + std::vector force_shape{nframes, 3 * nall}; + paddle::Tensor force_tensor = paddle::empty( + force_shape, net_deriv_tensor.dtype(), net_deriv_tensor.place()); + + PD_DISPATCH_FLOATING_TYPES( + net_deriv_tensor.type(), "prod_force_se_a_mask_cpu_forward_kernel", ([&] { + ProdForceSeAMaskOpForwardCPUKernel( + nframes, total_atom_num, net_deriv_tensor.data(), + in_deriv_tensor.data(), mask_tensor.data(), + nlist_tensor.data(), force_tensor.data()); + })); + + return {force_tensor}; +} + +std::vector> ProdForceSeAMaskOpInferShape( + std::vector net_deriv_shape, + std::vector in_deriv_shape, + std::vector mask_shape, + std::vector nlist_shape, + const int& total_atom_num) { + int64_t nall = total_atom_num; + int64_t nframes = net_deriv_shape[0]; + + std::vector force_shape = {nframes, 3 * nall}; + + return {force_shape}; +} + +std::vector ProdForceSeAMaskOpInferDtype( + paddle::DataType net_deriv_dtype, + paddle::DataType in_deriv_dtype, + paddle::DataType mask_dtype, + paddle::DataType nlist_dtype) { + return {net_deriv_dtype}; +} + +PD_BUILD_OP(prod_force_se_a_mask) + .Inputs({"net_deriv", "in_deriv", "mask", "nlist"}) + .Attrs({"total_atom_num: int"}) + .Outputs({"force"}) + .SetKernelFn(PD_KERNEL(ProdForceSeAMaskForward)) + .SetInferShapeFn(PD_INFER_SHAPE(ProdForceSeAMaskOpInferShape)) + .SetInferDtypeFn(PD_INFER_DTYPE(ProdForceSeAMaskOpInferDtype)); diff --git a/source/lib/paddle_src/paddle_prod_force_se_a_mask_grad.cc b/source/lib/paddle_src/paddle_prod_force_se_a_mask_grad.cc new file mode 100644 index 0000000000..e761a5c24a --- /dev/null +++ b/source/lib/paddle_src/paddle_prod_force_se_a_mask_grad.cc @@ -0,0 +1,149 @@ +#include "paddle/extension.h" + +#define CHECK_INPUT(x) PD_CHECK(x.is_cpu(), #x " must be a CPU Tensor.") +#define CHECK_INPUT_DIM(x, value) \ + PD_CHECK(x.shape().size() == value, #x "'s dim should be " #value ".") +#define CHECK_INPUT_READY(x) \ + PD_CHECK(x.initialized(), #x " must be initialized before usage.") + +template +void ProdForceSeAMaskOpCPUBackwardKernel(int nloc, + int nframes, + int ndescrpt, + int nnei, + const data_t* grad, + const data_t* net_deriv, + const data_t* in_deriv, + const int* mask, + const int* nlist, + data_t* grad_net) { +#pragma omp parallel for + for (int kk = 0; kk < nframes; ++kk) { + int grad_iter = kk * nloc * 3; + int net_iter = kk * nloc * ndescrpt; + int in_iter = kk * nloc * ndescrpt * 3; + int nlist_iter = kk * nloc * nnei; + int mask_iter = kk * nloc; + int grad_net_iter = kk * nloc * ndescrpt; + + // reset the frame to 0 + for (int ii = 0; ii < nloc; ++ii) { + for (int aa = 0; aa < ndescrpt; ++aa) { + grad_net[grad_net_iter + ii * ndescrpt + aa] = 0.0; + } + } + + // compute grad of one frame + for (int ii = 0; ii < nloc; ++ii) { + int i_idx = ii; + + // deriv wrt center atom + for (int aa = 0; aa < ndescrpt; ++aa) { + for (int dd = 0; dd < 3; ++dd) { + grad_net[grad_net_iter + i_idx * ndescrpt + aa] -= + grad[grad_iter + i_idx * 3 + dd] * + in_deriv[in_iter + i_idx * ndescrpt * 3 + aa * 3 + dd]; + } + } + + // loop over neighbors + for (int jj = 0; jj < nnei; ++jj) { + int j_idx = nlist[nlist_iter + i_idx * nnei + jj]; + // Check if atom j_idx is virtual or if the i_idx is virtual. + if (j_idx == i_idx || j_idx < 0) { + continue; + } + /* + if (j_idx > nloc) + j_idx = j_idx % nloc; + if (j_idx < 0) + continue; + */ + int aa_start, aa_end; + aa_start = jj * 4; + aa_end = jj * 4 + 4; + // make_descript_range (aa_start, aa_end, jj); + for (int aa = aa_start; aa < aa_end; ++aa) { + for (int dd = 0; dd < 3; ++dd) { + grad_net[grad_net_iter + i_idx * ndescrpt + aa] += + grad[grad_iter + j_idx * 3 + dd] * + in_deriv[in_iter + i_idx * ndescrpt * 3 + aa * 3 + dd]; + } + } + } + } + } +} + +std::vector ProdForceSeAMaskOpCPUBackward( + const paddle::Tensor& grad_tensor, + const paddle::Tensor& net_deriv_tensor, + const paddle::Tensor& in_deriv_tensor, + const paddle::Tensor& mask_tensor, + const paddle::Tensor& nlist_tensor, + int total_atom_num) { + CHECK_INPUT(grad_tensor); + CHECK_INPUT(net_deriv_tensor); + CHECK_INPUT(in_deriv_tensor); + CHECK_INPUT(mask_tensor); + CHECK_INPUT(nlist_tensor); + + CHECK_INPUT_DIM(grad_tensor, 2); + CHECK_INPUT_DIM(net_deriv_tensor, 2); + CHECK_INPUT_DIM(in_deriv_tensor, 2); + CHECK_INPUT_DIM(mask_tensor, 2); + CHECK_INPUT_DIM(nlist_tensor, 2); + + PD_CHECK(total_atom_num >= 3, + "Number of atoms should be larger than (or equal to) 3"); + + int nframes = net_deriv_tensor.shape()[0]; + int nloc = total_atom_num; + int ndescrpt = nloc > 0 ? net_deriv_tensor.shape()[1] / nloc : 0; + int nnei = total_atom_num; + + PD_CHECK(nframes == grad_tensor.shape()[0], "Number of frames should match"); + PD_CHECK(nframes == in_deriv_tensor.shape()[0], + "Number of frames should match"); + PD_CHECK(nframes == nlist_tensor.shape()[0], "Number of frames should match"); + PD_CHECK(nframes == mask_tensor.shape()[0], "Number of frames should match"); + + PD_CHECK(nloc * 3 == grad_tensor.shape()[1], + "input grad shape should be 3 x natoms"); + PD_CHECK(nloc * ndescrpt * 3 == in_deriv_tensor.shape()[1], + "Number of descriptors should match"); + + // Create an output tensor + std::vector grad_net_shape{nframes, nloc * ndescrpt}; + paddle::Tensor grad_net_tensor = + paddle::empty(grad_net_shape, grad_tensor.dtype(), grad_tensor.place()); + + PD_DISPATCH_FLOATING_TYPES( + grad_tensor.type(), "prod_force_se_a_mask_cpu_backward_kernel", ([&] { + ProdForceSeAMaskOpCPUBackwardKernel( + nloc, nframes, ndescrpt, nnei, grad_tensor.data(), + net_deriv_tensor.data(), in_deriv_tensor.data(), + mask_tensor.data(), nlist_tensor.data(), + grad_net_tensor.data()); + })); + + return {grad_net_tensor}; +} + +std::vector ProdForceSeAMaskBackward( + const paddle::Tensor& grad_tensor, + const paddle::Tensor& net_deriv_tensor, + const paddle::Tensor& in_deriv_tensor, + const paddle::Tensor& mask_tensor, + const paddle::Tensor& nlist_tensor, + int total_atom_num) { + return ProdForceSeAMaskOpCPUBackward(grad_tensor, net_deriv_tensor, + in_deriv_tensor, mask_tensor, + nlist_tensor, total_atom_num); +} + +PD_BUILD_GRAD_OP(prod_force_se_a_mask) + .Inputs({paddle::Grad("force"), "net_deriv", "in_deriv", "mask", "nlist"}) + .Attrs({"total_atom_num: int"}) + .Outputs({paddle::Grad("net_deriv")}) + .SetKernelFn(PD_KERNEL(ProdForceSeAMaskBackward));