From 2ee2c6a518e1e24830a750fbaa1b1fab2cb399e4 Mon Sep 17 00:00:00 2001 From: HydrogenSulfate <490868991@qq.com> Date: Wed, 28 Feb 2024 07:45:23 +0000 Subject: [PATCH] add tensor-dipole example --- deepmd/fit/dipole.py | 209 +++++++++++++-------- deepmd/infer/deep_eval.py | 17 +- deepmd/infer/deep_tensor.py | 364 ++++++++++++++++++++++++++++-------- deepmd/loss/tensor.py | 106 ++++++++--- deepmd/model/tensor.py | 92 ++++----- deepmd/train/trainer.py | 8 +- 6 files changed, 562 insertions(+), 234 deletions(-) diff --git a/deepmd/fit/dipole.py b/deepmd/fit/dipole.py index c2cc1e09fb..be949a046e 100644 --- a/deepmd/fit/dipole.py +++ b/deepmd/fit/dipole.py @@ -4,29 +4,31 @@ ) import numpy as np +from paddle import ( + nn, +) from deepmd.common import ( - cast_precision, get_activation_func, get_precision, ) from deepmd.env import ( + paddle, tf, ) -from deepmd.fit.fitting import ( - Fitting, -) + +# from deepmd.infer import DeepPotential from deepmd.utils.graph import ( get_fitting_net_variables_from_graph_def, ) +from deepmd.utils.network import OneLayer as OneLayer_deepmd from deepmd.utils.network import ( - one_layer, one_layer_rand_seed_shift, ) -@Fitting.register("dipole") -class DipoleFittingSeA(Fitting): +# @Fitting.register("dipole") +class DipoleFittingSeA(nn.Layer): r"""Fit the atomic dipole with descriptor se_a. Parameters @@ -52,7 +54,7 @@ class DipoleFittingSeA(Fitting): def __init__( self, - descrpt: tf.Tensor, + descrpt: paddle.Tensor, neuron: List[int] = [120, 120, 120], resnet_dt: bool = True, sel_type: Optional[List[int]] = None, @@ -61,6 +63,7 @@ def __init__( precision: str = "default", uniform_seed: bool = False, ) -> None: + super().__init__(name_scope="DipoleFittingSeA") """Constructor.""" self.ntypes = descrpt.get_ntypes() self.dim_descrpt = descrpt.get_dim_out() @@ -74,6 +77,7 @@ def __init__( ) self.seed = seed self.uniform_seed = uniform_seed + self.ntypes_spin = 0 self.seed_shift = one_layer_rand_seed_shift() self.fitting_activation_fn = get_activation_func(activation_function) self.fitting_precision = get_precision(precision) @@ -83,6 +87,54 @@ def __init__( self.fitting_net_variables = None self.mixed_prec = None + type_suffix = "" + suffix = "" + self.one_layers = nn.LayerList() + self.final_layers = nn.LayerList() + ntypes_atom = self.ntypes - self.ntypes_spin + for type_i in range(0, ntypes_atom): + type_i_layers = nn.LayerList() + for ii in range(0, len(self.n_neuron)): + layer_suffix = "layer_" + str(ii) + type_suffix + suffix + + if ii >= 1 and self.n_neuron[ii] == self.n_neuron[ii - 1]: + type_i_layers.append( + OneLayer_deepmd( + self.n_neuron[ii - 1], + self.n_neuron[ii], + activation_fn=self.fitting_activation_fn, + precision=self.fitting_precision, + name=layer_suffix, + seed=self.seed, + use_timestep=self.resnet_dt, + ) + ) + else: + type_i_layers.append( + OneLayer_deepmd( + self.dim_descrpt, + self.n_neuron[ii], + activation_fn=self.fitting_activation_fn, + precision=self.fitting_precision, + name=layer_suffix, + seed=self.seed, + ) + ) + if (not self.uniform_seed) and (self.seed is not None): + self.seed += self.seed_shift + + self.one_layers.append(type_i_layers) + self.final_layers.append( + OneLayer_deepmd( + self.n_neuron[-1], + self.dim_rot_mat_1, + activation_fn=None, + precision=self.fitting_precision, + name=layer_suffix, + seed=self.seed, + ) + ) + def get_sel_type(self) -> int: """Get selected type.""" return self.sel_type @@ -91,79 +143,66 @@ def get_out_size(self) -> int: """Get the output size. Should be 3.""" return 3 - def _build_lower(self, start_index, natoms, inputs, rot_mat, suffix="", reuse=None): + def _build_lower( + self, + start_index, + natoms, + inputs, + rot_mat, + suffix="", + reuse=None, + type_i=None, + ): # cut-out inputs - inputs_i = tf.slice(inputs, [0, start_index, 0], [-1, natoms, -1]) - inputs_i = tf.reshape(inputs_i, [-1, self.dim_descrpt]) - rot_mat_i = tf.slice(rot_mat, [0, start_index, 0], [-1, natoms, -1]) - rot_mat_i = tf.reshape(rot_mat_i, [-1, self.dim_rot_mat_1, 3]) + inputs_i = paddle.slice( + inputs, + [0, 1, 2], + [0, start_index, 0], + [inputs.shape[0], start_index + natoms, inputs.shape[2]], + ) + inputs_i = paddle.reshape(inputs_i, [-1, self.dim_descrpt]) + rot_mat_i = paddle.slice( + rot_mat, + [0, 1, 2], + [0, start_index, 0], + [rot_mat.shape[0], start_index + natoms, rot_mat.shape[2]], + ) + rot_mat_i = paddle.reshape(rot_mat_i, [-1, self.dim_rot_mat_1, 3]) layer = inputs_i for ii in range(0, len(self.n_neuron)): if ii >= 1 and self.n_neuron[ii] == self.n_neuron[ii - 1]: - layer += one_layer( - layer, - self.n_neuron[ii], - name="layer_" + str(ii) + suffix, - reuse=reuse, - seed=self.seed, - use_timestep=self.resnet_dt, - activation_fn=self.fitting_activation_fn, - precision=self.fitting_precision, - uniform_seed=self.uniform_seed, - initial_variables=self.fitting_net_variables, - mixed_prec=self.mixed_prec, - ) + layer += self.one_layers[type_i][ii](layer) else: - layer = one_layer( - layer, - self.n_neuron[ii], - name="layer_" + str(ii) + suffix, - reuse=reuse, - seed=self.seed, - activation_fn=self.fitting_activation_fn, - precision=self.fitting_precision, - uniform_seed=self.uniform_seed, - initial_variables=self.fitting_net_variables, - mixed_prec=self.mixed_prec, - ) + layer = self.one_layers[type_i][ii](layer) + if (not self.uniform_seed) and (self.seed is not None): self.seed += self.seed_shift # (nframes x natoms) x naxis - final_layer = one_layer( + final_layer = self.final_layers[type_i]( layer, - self.dim_rot_mat_1, - activation_fn=None, - name="final_layer" + suffix, - reuse=reuse, - seed=self.seed, - precision=self.fitting_precision, - uniform_seed=self.uniform_seed, - initial_variables=self.fitting_net_variables, - mixed_prec=self.mixed_prec, - final_layer=True, ) + if (not self.uniform_seed) and (self.seed is not None): self.seed += self.seed_shift # (nframes x natoms) x 1 * naxis - final_layer = tf.reshape( - final_layer, [tf.shape(inputs)[0] * natoms, 1, self.dim_rot_mat_1] + final_layer = paddle.reshape( + final_layer, [paddle.shape(inputs)[0] * natoms, 1, self.dim_rot_mat_1] ) # (nframes x natoms) x 1 x 3(coord) - final_layer = tf.matmul(final_layer, rot_mat_i) + final_layer = paddle.matmul(final_layer, rot_mat_i) # nframes x natoms x 3 - final_layer = tf.reshape(final_layer, [tf.shape(inputs)[0], natoms, 3]) + final_layer = paddle.reshape(final_layer, [paddle.shape(inputs)[0], natoms, 3]) return final_layer - @cast_precision - def build( + def forward( self, - input_d: tf.Tensor, - rot_mat: tf.Tensor, - natoms: tf.Tensor, + input_d: paddle.Tensor, + rot_mat: paddle.Tensor, + natoms: paddle.Tensor, input_dict: Optional[dict] = None, reuse: Optional[bool] = None, suffix: str = "", - ) -> tf.Tensor: + ) -> paddle.Tensor: """Build the computational graph for fitting net. Parameters @@ -195,22 +234,25 @@ def build( atype = input_dict.get("atype", None) nframes = input_dict.get("nframes") start_index = 0 - inputs = tf.reshape(input_d, [-1, natoms[0], self.dim_descrpt]) - rot_mat = tf.reshape(rot_mat, [-1, natoms[0], self.dim_rot_mat]) + inputs = paddle.reshape(input_d, [-1, natoms[0], self.dim_descrpt]) + rot_mat = paddle.reshape(rot_mat, [-1, natoms[0], self.dim_rot_mat]) if type_embedding is not None: - nloc_mask = tf.reshape( - tf.tile(tf.repeat(self.sel_mask, natoms[2:]), [nframes]), [nframes, -1] + nloc_mask = paddle.reshape( + paddle.tile( + paddle.repeat_interleave(self.sel_mask, natoms[2:]), [nframes] + ), + [nframes, -1], ) - atype_nall = tf.reshape(atype, [-1, natoms[1]]) + atype_nall = paddle.reshape(atype, [-1, natoms[1]]) # (nframes x nloc_masked) - self.atype_nloc_masked = tf.reshape( - tf.slice(atype_nall, [0, 0], [-1, natoms[0]])[nloc_mask], [-1] + self.atype_nloc_masked = paddle.reshape( + paddle.slice(atype_nall, [0, 0], [-1, natoms[0]])[nloc_mask], [-1] ) ## lammps will make error - self.nloc_masked = tf.shape( - tf.reshape(self.atype_nloc_masked, [nframes, -1]) + self.nloc_masked = paddle.shape( + paddle.reshape(self.atype_nloc_masked, [nframes, -1]) )[1] - atype_embed = tf.nn.embedding_lookup(type_embedding, self.atype_nloc_masked) + atype_embed = nn.embedding_lookup(type_embedding, self.atype_nloc_masked) else: atype_embed = None @@ -230,40 +272,43 @@ def build( rot_mat, suffix="_type_" + str(type_i) + suffix, reuse=reuse, + type_i=type_i, ) start_index += natoms[2 + type_i] # concat the results outs_list.append(final_layer) count += 1 - outs = tf.concat(outs_list, axis=1) + outs = paddle.concat(outs_list, axis=1) else: - inputs = tf.reshape( - tf.reshape(inputs, [nframes, natoms[0], self.dim_descrpt])[nloc_mask], + inputs = paddle.reshape( + paddle.reshape(inputs, [nframes, natoms[0], self.dim_descrpt])[ + nloc_mask + ], [-1, self.dim_descrpt], ) - rot_mat = tf.reshape( - tf.reshape(rot_mat, [nframes, natoms[0], self.dim_rot_mat_1 * 3])[ + rot_mat = paddle.reshape( + paddle.reshape(rot_mat, [nframes, natoms[0], self.dim_rot_mat_1 * 3])[ nloc_mask ], [-1, self.dim_rot_mat_1, 3], ) - atype_embed = tf.cast(atype_embed, self.fitting_precision) + atype_embed = paddle.cast(atype_embed, self.fitting_precision) type_shape = atype_embed.get_shape().as_list() - inputs = tf.concat([inputs, atype_embed], axis=1) + inputs = paddle.concat([inputs, atype_embed], axis=1) self.dim_descrpt = self.dim_descrpt + type_shape[1] - inputs = tf.reshape(inputs, [nframes, self.nloc_masked, self.dim_descrpt]) - rot_mat = tf.reshape( + inputs = paddle.reshape( + inputs, [nframes, self.nloc_masked, self.dim_descrpt] + ) + rot_mat = paddle.reshape( rot_mat, [nframes, self.nloc_masked, self.dim_rot_mat_1 * 3] ) final_layer = self._build_lower( 0, self.nloc_masked, inputs, rot_mat, suffix=suffix, reuse=reuse ) # nframes x natoms x 3 - outs = tf.reshape(final_layer, [nframes, self.nloc_masked, 3]) + outs = paddle.reshape(final_layer, [nframes, self.nloc_masked, 3]) - tf.summary.histogram("fitting_net_output", outs) - return tf.reshape(outs, [-1]) - # return tf.reshape(outs, [tf.shape(inputs)[0] * natoms[0] * 3 // 3]) + return paddle.reshape(outs, [-1]) def init_variables( self, diff --git a/deepmd/infer/deep_eval.py b/deepmd/infer/deep_eval.py index 77c81dc202..8c9dd588ba 100644 --- a/deepmd/infer/deep_eval.py +++ b/deepmd/infer/deep_eval.py @@ -1,3 +1,4 @@ +import os from functools import ( lru_cache, ) @@ -22,6 +23,7 @@ tf, ) from deepmd.fit import ( + dipole, ener, ) from deepmd.model import ( @@ -92,7 +94,9 @@ def __init__( default_tf_graph: bool = False, auto_batch_size: Union[bool, int, AutoBatchSize] = False, ): - jdata = j_loader("input.json") + jdata = j_loader( + "input.json" if os.path.exists("input.json") else "dipole_input.json" + ) remove_comment_in_json(jdata) model_param = j_must_have(jdata, "model") self.multi_task_mode = "fitting_net_dict" in model_param @@ -147,7 +151,12 @@ def __init__( if fitting_type == "ener": fitting_param["spin"] = spin fitting_param.pop("type", None) - fitting = ener.EnerFitting(**fitting_param) + fitting = ener.EnerFitting(**fitting_param) + elif fitting_type == "dipole": + fitting_param.pop("type", None) + fitting = dipole.DipoleFittingSeA(**fitting_param) + else: + raise NotImplementedError() else: self.fitting_dict = {} self.fitting_type_dict = {} @@ -359,7 +368,7 @@ def __init__( @property @lru_cache(maxsize=None) def model_type(self) -> str: - return "ener" + return self.model.model_type """Get type of model. :type:str @@ -418,7 +427,7 @@ def _graph_compatable(self) -> bool: def _get_value( self, tensor_name: str, attr_name: Optional[str] = None - ) -> tf.Tensor: + ) -> paddle.Tensor: """Get TF graph tensor and assign it to class namespace. Parameters diff --git a/deepmd/infer/deep_tensor.py b/deepmd/infer/deep_tensor.py index ca62b97385..065f6fd2a1 100644 --- a/deepmd/infer/deep_tensor.py +++ b/deepmd/infer/deep_tensor.py @@ -1,8 +1,10 @@ from typing import ( TYPE_CHECKING, + Callable, List, Optional, Tuple, + Union, ) import numpy as np @@ -10,6 +12,9 @@ from deepmd.common import ( make_default_mesh, ) +from deepmd.env import ( + paddle, +) from deepmd.infer.deep_eval import ( DeepEval, ) @@ -70,7 +75,7 @@ def __init__( # now load tensors to object attributes for attr_name, tensor_name in self.tensors.items(): - self._get_tensor(tensor_name, attr_name) + self._get_value(tensor_name, attr_name) # load optional tensors if possible optional_tensors = { @@ -82,18 +87,20 @@ def __init__( try: # first make sure these tensor all exists (but do not modify self attr) for attr_name, tensor_name in optional_tensors.items(): - self._get_tensor(tensor_name) + self._get_value(tensor_name) # then put those into self.attrs for attr_name, tensor_name in optional_tensors.items(): - self._get_tensor(tensor_name, attr_name) + self._get_value(tensor_name, attr_name) except KeyError: self._support_gfv = False else: self.tensors.update(optional_tensors) self._support_gfv = True - self._run_default_sess() - self.tmap = self.tmap.decode("UTF-8").split() + # self._run_default_sess() + # self.tmap = self.tmap.decode("UTF-8").split() + self.ntypes = int(self.model.descrpt.buffer_ntypes) + self.tselt = self.model.fitting.sel_type def _run_default_sess(self): [self.ntypes, self.rcut, self.tmap, self.tselt, self.output_dim] = run_sess( @@ -107,6 +114,247 @@ def _run_default_sess(self): ], ) + def _eval_func(self, inner_func: Callable, numb_test: int, natoms: int) -> Callable: + """Wrapper method with auto batch size. + + Parameters + ---------- + inner_func : Callable + the method to be wrapped + numb_test : int + number of tests + natoms : int + number of atoms + + Returns + ------- + Callable + the wrapper + """ + if self.auto_batch_size is not None: + + def eval_func(*args, **kwargs): + return self.auto_batch_size.execute_all( + inner_func, numb_test, natoms, *args, **kwargs + ) + + else: + eval_func = inner_func + return eval_func + + def _get_natoms_and_nframes( + self, + coords: np.ndarray, + atom_types: Union[List[int], np.ndarray], + mixed_type: bool = False, + ) -> Tuple[int, int]: + if mixed_type: + natoms = len(atom_types[0]) + else: + natoms = len(atom_types) + coords = np.reshape(np.array(coords), [-1, natoms * 3]) + nframes = coords.shape[0] + return natoms, nframes + + def _prepare_feed_dict( + self, + coords, + cells, + atom_types, + fparam=None, + aparam=None, + atomic=False, + efield=None, + mixed_type=False, + ): + # standarize the shape of inputs + natoms, nframes = self._get_natoms_and_nframes( + coords, atom_types, mixed_type=mixed_type + ) + if mixed_type: + atom_types = np.array(atom_types, dtype=int).reshape([-1, natoms]) + else: + atom_types = np.array(atom_types, dtype=int).reshape([-1]) + coords = np.reshape(np.array(coords), [-1, natoms * 3]) + if cells is None: + pbc = False + # make cells to work around the requirement of pbc + cells = np.tile(np.eye(3), [nframes, 1]).reshape([nframes, 9]) + else: + pbc = True + cells = np.array(cells).reshape([nframes, 9]) + + # if self.has_fparam: + # assert fparam is not None + # fparam = np.array(fparam) + # if self.has_aparam: + # assert aparam is not None + # aparam = np.array(aparam) + # if self.has_efield: + # assert ( + # efield is not None + # ), "you are using a model with external field, parameter efield should be provided" + # efield = np.array(efield) + + # reshape the inputs + # if self.has_fparam: + # fdim = self.get_dim_fparam() + # if fparam.size == nframes * fdim: + # fparam = np.reshape(fparam, [nframes, fdim]) + # elif fparam.size == fdim: + # fparam = np.tile(fparam.reshape([-1]), [nframes, 1]) + # else: + # raise RuntimeError( + # "got wrong size of frame param, should be either %d x %d or %d" + # % (nframes, fdim, fdim) + # ) + # if self.has_aparam: + # fdim = self.get_dim_aparam() + # if aparam.size == nframes * natoms * fdim: + # aparam = np.reshape(aparam, [nframes, natoms * fdim]) + # elif aparam.size == natoms * fdim: + # aparam = np.tile(aparam.reshape([-1]), [nframes, 1]) + # elif aparam.size == fdim: + # aparam = np.tile(aparam.reshape([-1]), [nframes, natoms]) + # else: + # raise RuntimeError( + # "got wrong size of frame param, should be either %d x %d x %d or %d x %d or %d" + # % (nframes, natoms, fdim, natoms, fdim, fdim) + # ) + + # sort inputs + # coords, atom_types, imap = self.sort_input( + # coords, atom_types, mixed_type=mixed_type + # ) + # if self.has_efield: + # efield = np.reshape(efield, [nframes, natoms, 3]) + # efield = efield[:, imap, :] + # efield = np.reshape(efield, [nframes, natoms * 3]) + + # make natoms_vec and default_mesh + natoms_vec = self.make_natoms_vec(atom_types, mixed_type=mixed_type) + assert natoms_vec[0] == natoms + + # evaluate + # feed_dict_test = {} + # feed_dict_test[self.t_natoms] = natoms_vec + # if mixed_type: + # feed_dict_test[self.t_type] = atom_types.reshape([-1]) + # else: + # feed_dict_test[self.t_type] = np.tile(atom_types, [nframes, 1]).reshape( + # [-1] + # ) + # feed_dict_test[self.t_coord] = np.reshape(coords, [-1]) + + # if len(self.t_box.shape) == 1: + # feed_dict_test[self.t_box] = np.reshape(cells, [-1]) + # elif len(self.t_box.shape) == 2: + # feed_dict_test[self.t_box] = cells + # else: + # raise RuntimeError + # if self.has_efield: + # feed_dict_test[self.t_efield] = np.reshape(efield, [-1]) + # if pbc: + # feed_dict_test[self.t_mesh] = make_default_mesh(cells) + # else: + # feed_dict_test[self.t_mesh] = np.array([], dtype=np.int32) + # if self.has_fparam: + # feed_dict_test[self.t_fparam] = np.reshape(fparam, [-1]) + # if self.has_aparam: + # feed_dict_test[self.t_aparam] = np.reshape(aparam, [-1]) + return None, None, natoms_vec + + def _eval_inner( + self, + coords, + cells, + atom_types, + fparam=None, + aparam=None, + atomic=False, + efield=None, + mixed_type=False, + ): + natoms, nframes = self._get_natoms_and_nframes( + coords, atom_types, mixed_type=mixed_type + ) + feed_dict_test, imap, natoms_vec = self._prepare_feed_dict( + coords, cells, atom_types, fparam, aparam, efield, mixed_type=mixed_type + ) + if cells is None: + pbc = False + cells = np.tile(np.eye(3), [nframes, 1]).reshape([nframes, 9]) + else: + pbc = True + cells = np.array(cells).reshape([nframes, 9]) + eval_inputs = {} + eval_inputs["coord"] = paddle.to_tensor( + np.reshape(coords, [-1]), dtype="float64" + ) + eval_inputs["type"] = paddle.to_tensor( + np.tile(atom_types, [nframes, 1]).reshape([-1]), dtype="int32" + ) + eval_inputs["natoms_vec"] = paddle.to_tensor( + natoms_vec, dtype="int32", place="cpu" + ) + eval_inputs["box"] = paddle.to_tensor(np.reshape(cells, [-1]), dtype="float64") + + # if self.has_fparam: + # eval_inputs["fparam"] = paddle.to_tensor( + # np.reshape(fparam, [-1], dtype="float64") + # ) + # if self.has_aparam: + # eval_inputs["aparam"] = paddle.to_tensor( + # np.reshape(aparam, [-1], dtype="float64") + # ) + # if se.pbc: + # eval_inputs["default_mesh"] = paddle.to_tensor( + # make_default_mesh(cells), dtype="int32" + # ) + # else: + eval_inputs["default_mesh"] = paddle.to_tensor(np.array([], dtype=np.int32)) + + if hasattr(self, "st_model"): + # NOTE: 使用静态图模型推理 + eval_outputs = self.st_model( + eval_inputs["coord"], + eval_inputs["type"], + eval_inputs["natoms_vec"], + eval_inputs["box"], + eval_inputs["default_mesh"], + ) + eval_outputs = { + "atom_ener": eval_outputs[0], + "atom_virial": eval_outputs[1], + "atype": eval_outputs[2], + "coord": eval_outputs[3], + "energy": eval_outputs[4], + "force": eval_outputs[5], + "virial": eval_outputs[6], + } + else: + # NOTE: 使用动态图模型推理 + eval_outputs = self.model( + eval_inputs["coord"], + eval_inputs["type"], + eval_inputs["natoms_vec"], + eval_inputs["box"], + eval_inputs["default_mesh"], + eval_inputs, + suffix="", + reuse=False, + ) + dipole = eval_outputs["dipole"].numpy() + + return dipole + + # if atomic: + # ae = eval_outputs["atom_ener"].numpy() + # av = eval_outputs["atom_virial"].numpy() + # return energy, force, virial, ae, av + # else: + # return energy, force, virial + def get_ntypes(self) -> int: """Get the number of atom types of this model.""" return self.ntypes @@ -136,12 +384,12 @@ def eval( coords: np.ndarray, cells: np.ndarray, atom_types: List[int], - atomic: bool = True, + atomic: bool = False, fparam: Optional[np.ndarray] = None, aparam: Optional[np.ndarray] = None, efield: Optional[np.ndarray] = None, mixed_type: bool = False, - ) -> np.ndarray: + ) -> Tuple[np.ndarray, ...]: """Evaluate the model. Parameters @@ -172,76 +420,44 @@ def eval( Returns ------- - tensor - The returned tensor - If atomic == False then of size nframes x output_dim - else of size nframes x natoms x output_dim + energy + The system energy. + force + The force on each atom + virial + The virial + atom_energy + The atomic energy. Only returned when atomic == True + atom_virial + The atomic virial. Only returned when atomic == True """ # standarize the shape of inputs - if mixed_type: - natoms = atom_types[0].size - atom_types = np.array(atom_types, dtype=int).reshape([-1, natoms]) - else: - atom_types = np.array(atom_types, dtype=int).reshape([-1]) - natoms = atom_types.size - coords = np.reshape(np.array(coords), [-1, natoms * 3]) - nframes = coords.shape[0] - if cells is None: - pbc = False - cells = np.tile(np.eye(3), [nframes, 1]).reshape([nframes, 9]) - else: - pbc = True - cells = np.array(cells).reshape([nframes, 9]) - - # sort inputs - coords, atom_types, imap, sel_at, sel_imap = self.sort_input( - coords, atom_types, sel_atoms=self.get_sel_type(), mixed_type=mixed_type + # reshape coords before getting shape + natoms, numb_test = self._get_natoms_and_nframes( + coords, atom_types, mixed_type=mixed_type + ) # 192, 30 + output = self._eval_func(self._eval_inner, numb_test, natoms)( + coords, + cells, + atom_types, + fparam=fparam, + aparam=aparam, + atomic=atomic, + efield=efield, + mixed_type=mixed_type, ) - - # make natoms_vec and default_mesh - natoms_vec = self.make_natoms_vec(atom_types, mixed_type=mixed_type) - assert natoms_vec[0] == natoms - - # evaluate - feed_dict_test = {} - feed_dict_test[self.t_natoms] = natoms_vec - if mixed_type: - feed_dict_test[self.t_type] = atom_types.reshape([-1]) - else: - feed_dict_test[self.t_type] = np.tile(atom_types, [nframes, 1]).reshape( - [-1] - ) - feed_dict_test[self.t_coord] = np.reshape(coords, [-1]) - feed_dict_test[self.t_box] = np.reshape(cells, [-1]) - if pbc: - feed_dict_test[self.t_mesh] = make_default_mesh(cells) - else: - feed_dict_test[self.t_mesh] = np.array([], dtype=np.int32) - - if atomic: - assert ( - "global" not in self.model_type - ), f"cannot do atomic evaluation with model type {self.model_type}" - t_out = [self.t_tensor] - else: - assert ( - self._support_gfv or "global" in self.model_type - ), f"do not support global tensor evaluation with old {self.model_type} model" - t_out = [self.t_global_tensor if self._support_gfv else self.t_tensor] - v_out = self.sess.run(t_out, feed_dict=feed_dict_test) - tensor = v_out[0] - - # reverse map of the outputs - if atomic: - tensor = np.array(tensor) - tensor = self.reverse_map( - np.reshape(tensor, [nframes, -1, self.output_dim]), sel_imap - ) - tensor = np.reshape(tensor, [nframes, len(sel_at), self.output_dim]) - else: - tensor = np.reshape(tensor, [nframes, self.output_dim]) - - return tensor + # if self.modifier_type is not None: + # if atomic: + # raise RuntimeError("modifier does not support atomic modification") + # me, mf, mv = self.dm.eval(coords, cells, atom_types) + # output = list(output) # tuple to list + # e, f, v = output[:3] + # output[0] += me.reshape(e.shape) + # output[1] += mf.reshape(f.shape) + # output[2] += mv.reshape(v.shape) + # output = tuple(output) + + return output def eval_full( self, diff --git a/deepmd/loss/tensor.py b/deepmd/loss/tensor.py index e261be2bb1..be3fdf9161 100644 --- a/deepmd/loss/tensor.py +++ b/deepmd/loss/tensor.py @@ -4,11 +4,7 @@ add_data_requirement, ) from deepmd.env import ( - global_cvt_2_tf_float, - tf, -) -from deepmd.utils.sess import ( - run_sess, + paddle, ) from .loss import ( @@ -67,30 +63,31 @@ def __init__(self, jdata, **kwarg): type_sel=self.type_sel, ) - def build(self, learning_rate, natoms, model_dict, label_dict, suffix): + def compute_loss(self, learning_rate, natoms, model_dict, label_dict, suffix): polar_hat = label_dict[self.label_name] atomic_polar_hat = label_dict["atomic_" + self.label_name] - polar = tf.reshape(model_dict[self.tensor_name], [-1]) + polar = paddle.reshape(model_dict[self.tensor_name], [-1]) find_global = label_dict["find_" + self.label_name] find_atomic = label_dict["find_atomic_" + self.label_name] # YHT: added for global / local dipole combination - l2_loss = global_cvt_2_tf_float(0.0) + l2_loss = 0.0 more_loss = { - "local_loss": global_cvt_2_tf_float(0.0), - "global_loss": global_cvt_2_tf_float(0.0), + "local_loss": 0.0, + "global_loss": 0.0, } if self.local_weight > 0.0: - local_loss = global_cvt_2_tf_float(find_atomic) * tf.reduce_mean( - tf.square(self.scale * (polar - atomic_polar_hat)), name="l2_" + suffix + local_loss = find_atomic * paddle.mean( + paddle.square(self.scale * (polar - atomic_polar_hat)), + name="l2_" + suffix, ) more_loss["local_loss"] = local_loss l2_loss += self.local_weight * local_loss - self.l2_loss_local_summary = tf.summary.scalar( - "l2_local_loss_" + suffix, tf.sqrt(more_loss["local_loss"]) - ) + # self.l2_loss_local_summary = paddle.summary.scalar( + # "l2_local_loss_" + suffix, paddle.sqrt(more_loss["local_loss"]) + # ) if self.global_weight > 0.0: # Need global loss atoms = 0 @@ -99,33 +96,34 @@ def build(self, learning_rate, natoms, model_dict, label_dict, suffix): atoms += natoms[2 + w] else: atoms = natoms[0] - nframes = tf.shape(polar)[0] // self.tensor_size // atoms + nframes = paddle.shape(polar)[0] // self.tensor_size // atoms # get global results - global_polar = tf.reshape( - tf.reduce_sum( - tf.reshape(polar, [nframes, -1, self.tensor_size]), axis=1 + global_polar = paddle.reshape( + paddle.sum( + paddle.reshape(polar, [nframes, -1, self.tensor_size]), axis=1 ), [-1], ) # if self.atomic: # If label is local, however - # global_polar_hat = tf.reshape(tf.reduce_sum(tf.reshape( + # global_polar_hat = paddle.reshape(paddle.sum(paddle.reshape( # polar_hat, [nframes, -1, self.tensor_size]), axis=1),[-1]) # else: # global_polar_hat = polar_hat - global_loss = global_cvt_2_tf_float(find_global) * tf.reduce_mean( - tf.square(self.scale * (global_polar - polar_hat)), name="l2_" + suffix + global_loss = find_global * paddle.mean( + paddle.square(self.scale * (global_polar - polar_hat)), + name="l2_" + suffix, ) more_loss["global_loss"] = global_loss - self.l2_loss_global_summary = tf.summary.scalar( - "l2_global_loss_" + suffix, - tf.sqrt(more_loss["global_loss"]) / global_cvt_2_tf_float(atoms), - ) + # self.l2_loss_global_summary = paddle.summary.scalar( + # "l2_global_loss_" + suffix, + # paddle.sqrt(more_loss["global_loss"]) / global_cvt_2_tf_float(atoms), + # ) # YWolfeee: should only consider atoms with dipole, i.e. atoms # atom_norm = 1./ global_cvt_2_tf_float(natoms[0]) - atom_norm = 1.0 / global_cvt_2_tf_float(atoms) + atom_norm = 1.0 / atoms global_loss *= atom_norm l2_loss += self.global_weight * global_loss @@ -133,10 +131,10 @@ def build(self, learning_rate, natoms, model_dict, label_dict, suffix): self.l2_more = more_loss self.l2_l = l2_loss - self.l2_loss_summary = tf.summary.scalar("l2_loss_" + suffix, tf.sqrt(l2_loss)) + # self.l2_loss_summary = paddle.summary.scalar("l2_loss_" + suffix, tf.sqrt(l2_loss)) return l2_loss, more_loss - def eval(self, sess, feed_dict, natoms): + def eval(self, model, batch_data, natoms): atoms = 0 if self.type_sel is not None: for w in self.type_sel: @@ -144,8 +142,54 @@ def eval(self, sess, feed_dict, natoms): else: atoms = natoms[0] - run_data = [self.l2_l, self.l2_more["local_loss"], self.l2_more["global_loss"]] - error, error_lc, error_gl = run_sess(sess, run_data, feed_dict=feed_dict) + model_inputs = {} + for kk in batch_data.keys(): + if kk == "find_type" or kk == "type": + continue + prec = "float64" + if "find_" in kk: + model_inputs[kk] = paddle.to_tensor(batch_data[kk], dtype="float64") + else: + model_inputs[kk] = paddle.to_tensor( + np.reshape(batch_data[kk], [-1]), dtype=prec + ) + + for ii in ["type"]: + model_inputs[ii] = paddle.to_tensor( + np.reshape(batch_data[ii], [-1]), dtype="int32" + ) + for ii in ["natoms_vec", "default_mesh"]: + model_inputs[ii] = paddle.to_tensor(batch_data[ii], dtype="int32") + model_inputs["is_training"] = paddle.to_tensor(False) + model_inputs["natoms_vec"] = paddle.to_tensor( + model_inputs["natoms_vec"], place="cpu" + ) + + model_pred = model( + model_inputs["coord"], + model_inputs["type"], + model_inputs["natoms_vec"], + model_inputs["box"], + model_inputs["default_mesh"], + model_inputs, + suffix="", + reuse=False, + ) + + l2_l, l2_more = self.compute_loss( + 0.0, + model_inputs["natoms_vec"], + model_pred, + model_inputs, + suffix="test", + ) + + run_data = [ + (float(l2_l)), + (float(l2_more["local_loss"])), + (float(l2_more["global_loss"])), + ] + error, error_lc, error_gl = run_data results = {"natoms": atoms, "rmse": np.sqrt(error)} if self.local_weight > 0.0: diff --git a/deepmd/model/tensor.py b/deepmd/model/tensor.py index eb32fa88cf..5992e15e3d 100644 --- a/deepmd/model/tensor.py +++ b/deepmd/model/tensor.py @@ -4,7 +4,7 @@ ) from deepmd.env import ( - MODEL_VERSION, + paddle, tf, ) @@ -17,7 +17,7 @@ ) -class TensorModel(Model): +class TensorModel(Model, paddle.nn.Layer): """Tensor model. Parameters @@ -49,6 +49,7 @@ def __init__( data_stat_nbatch: int = 10, data_stat_protect: float = 1e-2, ) -> None: + super().__init__() """Constructor.""" self.model_type = tensor_name # descriptor @@ -104,7 +105,7 @@ def _compute_output_stat(self, all_stat): if hasattr(self.fitting, "compute_output_stats"): self.fitting.compute_output_stats(all_stat) - def build( + def forward( self, coord_, atype_, @@ -119,63 +120,68 @@ def build( ): if input_dict is None: input_dict = {} - with tf.variable_scope("model_attr" + suffix, reuse=reuse): - t_tmap = tf.constant(" ".join(self.type_map), name="tmap", dtype=tf.string) - t_st = tf.constant(self.get_sel_type(), name="sel_type", dtype=tf.int32) - t_mt = tf.constant(self.model_type, name="model_type", dtype=tf.string) - t_ver = tf.constant(MODEL_VERSION, name="model_version", dtype=tf.string) - t_od = tf.constant(self.get_out_size(), name="output_dim", dtype=tf.int32) - - natomsel = sum(natoms[2 + type_i] for type_i in self.get_sel_type()) - nout = self.get_out_size() - - coord = tf.reshape(coord_, [-1, natoms[1] * 3]) - atype = tf.reshape(atype_, [-1, natoms[1]]) - input_dict["nframes"] = tf.shape(coord)[0] + # with tf.variable_scope("model_attr" + suffix, reuse=reuse): + # t_tmap = tf.constant(" ".join(self.type_map), name="tmap", dtype=tf.string) + # t_st = tf.constant(self.get_sel_type(), name="sel_type", dtype=tf.int32) + # t_mt = tf.constant(self.model_type, name="model_type", dtype=tf.string) + # t_ver = tf.constant(MODEL_VERSION, name="model_version", dtype=tf.string) + # t_od = tf.constant(self.get_out_size(), name="output_dim", dtype=tf.int32) + + natomsel = sum( + natoms[2 + type_i] for type_i in self.get_sel_type() + ) # n_atom_selected + nout = self.get_out_size() # 3 + + coord = paddle.reshape(coord_, [-1, natoms[1] * 3]) + atype = paddle.reshape(atype_, [-1, natoms[1]]) + # input_dict["nframes"] = paddle.shape(coord)[0]# 推理模型导出的时候注释掉这里,否则会报错 # type embedding if any - if self.typeebd is not None: - type_embedding = self.typeebd.build( - self.ntypes, - reuse=reuse, - suffix=suffix, - ) - input_dict["type_embedding"] = type_embedding - input_dict["atype"] = atype_ - - dout = self.build_descrpt( + # if self.typeebd is not None: + # type_embedding = self.typeebd.build( + # self.ntypes, + # reuse=reuse, + # suffix=suffix, + # ) + # input_dict["type_embedding"] = type_embedding + input_dict["atype"] = atype_ + + dout = self.descrpt( coord, atype, natoms, box, mesh, input_dict, - frz_model=frz_model, - ckpt_meta=ckpt_meta, + # frz_model=frz_model, + # ckpt_meta=ckpt_meta, suffix=suffix, reuse=reuse, ) rot_mat = self.descrpt.get_rot_mat() - rot_mat = tf.identity(rot_mat, name="o_rot_mat" + suffix) + rot_mat = paddle.clone(rot_mat, name="o_rot_mat" + suffix) - output = self.fitting.build( + output = self.fitting( dout, rot_mat, natoms, input_dict, reuse=reuse, suffix=suffix ) + framesize = nout if "global" in self.model_type else natomsel * nout - output = tf.reshape( + output = paddle.reshape( output, [-1, framesize], name="o_" + self.model_type + suffix ) model_dict = {self.model_type: output} if "global" not in self.model_type: - gname = "global_" + self.model_type - atom_out = tf.reshape(output, [-1, natomsel, nout]) - global_out = tf.reduce_sum(atom_out, axis=1) - global_out = tf.reshape(global_out, [-1, nout], name="o_" + gname + suffix) + gname = "global_" + self.model_type # "global_dipole" + atom_out = paddle.reshape(output, [-1, natomsel, nout]) # nout=3 + global_out = paddle.sum(atom_out, axis=1) + global_out = paddle.reshape( + global_out, [-1, nout], name="o_" + gname + suffix + ) - out_cpnts = tf.split(atom_out, nout, axis=-1) + out_cpnts = paddle.split(atom_out, nout, axis=-1) force_cpnts = [] virial_cpnts = [] atom_virial_cpnts = [] @@ -184,16 +190,18 @@ def build( force_i, virial_i, atom_virial_i = self.descrpt.prod_force_virial( out_i, natoms ) - force_cpnts.append(tf.reshape(force_i, [-1, 3 * natoms[1]])) - virial_cpnts.append(tf.reshape(virial_i, [-1, 9])) - atom_virial_cpnts.append(tf.reshape(atom_virial_i, [-1, 9 * natoms[1]])) + force_cpnts.append(paddle.reshape(force_i, [-1, 3 * natoms[1]])) + virial_cpnts.append(paddle.reshape(virial_i, [-1, 9])) + atom_virial_cpnts.append( + paddle.reshape(atom_virial_i, [-1, 9 * natoms[1]]) + ) # [nframe x nout x (natom x 3)] - force = tf.concat(force_cpnts, axis=1, name="o_force" + suffix) + force = paddle.concat(force_cpnts, axis=1, name="o_force" + suffix) # [nframe x nout x 9] - virial = tf.concat(virial_cpnts, axis=1, name="o_virial" + suffix) + virial = paddle.concat(virial_cpnts, axis=1, name="o_virial" + suffix) # [nframe x nout x (natom x 9)] - atom_virial = tf.concat( + atom_virial = paddle.concat( atom_virial_cpnts, axis=1, name="o_atom_virial" + suffix ) diff --git a/deepmd/train/trainer.py b/deepmd/train/trainer.py index 8f5051570b..d8a4c811e9 100644 --- a/deepmd/train/trainer.py +++ b/deepmd/train/trainer.py @@ -29,6 +29,7 @@ ) from deepmd.fit import ( Fitting, + dipole, ener, ) from deepmd.loss import ( @@ -169,7 +170,12 @@ def _init_param(self, jdata): fitting_param["mixed_prec"] = self.mixed_prec if fitting_param["mixed_prec"] is not None: fitting_param["precision"]: str = self.mixed_prec["output_prec"] - self.fitting = ener.EnerFitting(**fitting_param) + self.fitting = ener.EnerFitting(**fitting_param) + elif fitting_type == "dipole": + fitting_param.pop("type") + self.fitting = dipole.DipoleFittingSeA(**fitting_param) + else: + raise NotImplementedError else: raise NotImplementedError("multi-task mode is not supported") self.fitting_dict = {}