diff --git a/deepmd/model_format/__init__.py b/deepmd/model_format/__init__.py index 253bca3507..e15f73758e 100644 --- a/deepmd/model_format/__init__.py +++ b/deepmd/model_format/__init__.py @@ -7,6 +7,9 @@ from .env_mat import ( EnvMat, ) +from .fitting import ( + InvarFitting, +) from .network import ( EmbeddingNet, FittingNet, @@ -34,6 +37,7 @@ ) __all__ = [ + "InvarFitting", "DescrptSeA", "EnvMat", "make_multilayer_network", diff --git a/deepmd/model_format/fitting.py b/deepmd/model_format/fitting.py new file mode 100644 index 0000000000..904fb42b76 --- /dev/null +++ b/deepmd/model_format/fitting.py @@ -0,0 +1,355 @@ +# SPDX-License-Identifier: LGPL-3.0-or-later +import copy +from typing import ( + Any, + List, + Optional, +) + +import numpy as np + +from .common import ( + DEFAULT_PRECISION, + NativeOP, +) +from .network import ( + FittingNet, + NetworkCollection, +) +from .output_def import ( + FittingOutputDef, + OutputVariableDef, + fitting_check_output, +) + + +@fitting_check_output +class InvarFitting(NativeOP): + r"""Fitting the energy (or a porperty of `dim_out`) of the system. The force and the virial can also be trained. + + Lets take the energy fitting task as an example. + The potential energy :math:`E` is a fitting network function of the descriptor :math:`\mathcal{D}`: + + .. math:: + E(\mathcal{D}) = \mathcal{L}^{(n)} \circ \mathcal{L}^{(n-1)} + \circ \cdots \circ \mathcal{L}^{(1)} \circ \mathcal{L}^{(0)} + + The first :math:`n` hidden layers :math:`\mathcal{L}^{(0)}, \cdots, \mathcal{L}^{(n-1)}` are given by + + .. math:: + \mathbf{y}=\mathcal{L}(\mathbf{x};\mathbf{w},\mathbf{b})= + \boldsymbol{\phi}(\mathbf{x}^T\mathbf{w}+\mathbf{b}) + + where :math:`\mathbf{x} \in \mathbb{R}^{N_1}` is the input vector and :math:`\mathbf{y} \in \mathbb{R}^{N_2}` + is the output vector. :math:`\mathbf{w} \in \mathbb{R}^{N_1 \times N_2}` and + :math:`\mathbf{b} \in \mathbb{R}^{N_2}` are weights and biases, respectively, + both of which are trainable if `trainable[i]` is `True`. :math:`\boldsymbol{\phi}` + is the activation function. + + The output layer :math:`\mathcal{L}^{(n)}` is given by + + .. math:: + \mathbf{y}=\mathcal{L}^{(n)}(\mathbf{x};\mathbf{w},\mathbf{b})= + \mathbf{x}^T\mathbf{w}+\mathbf{b} + + where :math:`\mathbf{x} \in \mathbb{R}^{N_{n-1}}` is the input vector and :math:`\mathbf{y} \in \mathbb{R}` + is the output scalar. :math:`\mathbf{w} \in \mathbb{R}^{N_{n-1}}` and + :math:`\mathbf{b} \in \mathbb{R}` are weights and bias, respectively, + both of which are trainable if `trainable[n]` is `True`. + + Parameters + ---------- + var_name + The name of the output variable. + ntypes + The number of atom types. + dim_descrpt + The dimension of the input descriptor. + dim_out + The dimension of the output fit property. + neuron + Number of neurons :math:`N` in each hidden layer of the fitting net + resnet_dt + Time-step `dt` in the resnet construction: + :math:`y = x + dt * \phi (Wx + b)` + numb_fparam + Number of frame parameter + numb_aparam + Number of atomic parameter + rcond + The condition number for the regression of atomic energy. + tot_ener_zero + Force the total energy to zero. Useful for the charge fitting. + trainable + If the weights of fitting net are trainable. + Suppose that we have :math:`N_l` hidden layers in the fitting net, + this list is of length :math:`N_l + 1`, specifying if the hidden layers and the output layer are trainable. + atom_ener + Specifying atomic energy contribution in vacuum. The `set_davg_zero` key in the descrptor should be set. + activation_function + The activation function :math:`\boldsymbol{\phi}` in the embedding net. Supported options are |ACTIVATION_FN| + precision + The precision of the embedding net parameters. Supported options are |PRECISION| + layer_name : list[Optional[str]], optional + The name of the each layer. If two layers, either in the same fitting or different fittings, + have the same name, they will share the same neural network parameters. + use_aparam_as_mask: bool, optional + If True, the atomic parameters will be used as a mask that determines the atom is real/virtual. + And the aparam will not be used as the atomic parameters for embedding. + distinguish_types + Different atomic types uses different fitting net. + + """ + + def __init__( + self, + var_name: str, + ntypes: int, + dim_descrpt: int, + dim_out: int, + neuron: List[int] = [120, 120, 120], + resnet_dt: bool = True, + numb_fparam: int = 0, + numb_aparam: int = 0, + rcond: Optional[float] = None, + tot_ener_zero: bool = False, + trainable: Optional[List[bool]] = None, + atom_ener: Optional[List[float]] = None, + activation_function: str = "tanh", + precision: str = DEFAULT_PRECISION, + layer_name: Optional[List[Optional[str]]] = None, + use_aparam_as_mask: bool = False, + spin: Any = None, + distinguish_types: bool = False, + ): + # seed, uniform_seed are not included + if tot_ener_zero: + raise NotImplementedError("tot_ener_zero is not implemented") + if spin is not None: + raise NotImplementedError("spin is not implemented") + if use_aparam_as_mask: + raise NotImplementedError("use_aparam_as_mask is not implemented") + if use_aparam_as_mask: + raise NotImplementedError("use_aparam_as_mask is not implemented") + if layer_name is not None: + raise NotImplementedError("layer_name is not implemented") + if atom_ener is not None: + raise NotImplementedError("atom_ener is not implemented") + + self.var_name = var_name + self.ntypes = ntypes + self.dim_descrpt = dim_descrpt + self.dim_out = dim_out + self.neuron = neuron + self.resnet_dt = resnet_dt + self.numb_fparam = numb_fparam + self.numb_aparam = numb_aparam + self.rcond = rcond + self.tot_ener_zero = tot_ener_zero + self.trainable = trainable + self.atom_ener = atom_ener + self.activation_function = activation_function + self.precision = precision + self.layer_name = layer_name + self.use_aparam_as_mask = use_aparam_as_mask + self.spin = spin + self.distinguish_types = distinguish_types + if self.spin is not None: + raise NotImplementedError("spin is not supported") + + # init constants + self.bias_atom_e = np.zeros([self.ntypes, self.dim_out]) + if self.numb_fparam > 0: + self.fparam_avg = np.zeros(self.numb_fparam) + self.fparam_inv_std = np.ones(self.numb_fparam) + else: + self.fparam_avg, self.fparam_inv_std = None, None + if self.numb_aparam > 0: + self.aparam_avg = np.zeros(self.numb_aparam) + self.aparam_inv_std = np.ones(self.numb_aparam) + else: + self.aparam_avg, self.aparam_inv_std = None, None + # init networks + in_dim = self.dim_descrpt + self.numb_fparam + self.numb_aparam + out_dim = self.dim_out + self.nets = NetworkCollection( + 1 if self.distinguish_types else 0, + self.ntypes, + network_type="fitting_network", + networks=[ + FittingNet( + in_dim, + out_dim, + self.neuron, + self.activation_function, + self.resnet_dt, + self.precision, + bias_out=True, + ) + for ii in range(self.ntypes if self.distinguish_types else 1) + ], + ) + + def output_def(self): + return FittingOutputDef( + [ + OutputVariableDef( + self.var_name, [self.dim_out], reduciable=True, differentiable=True + ), + ] + ) + + def __setitem__(self, key, value): + if key in ["bias_atom_e"]: + self.bias_atom_e = value + elif key in ["fparam_avg"]: + self.fparam_avg = value + elif key in ["fparam_inv_std"]: + self.fparam_inv_std = value + elif key in ["aparam_avg"]: + self.aparam_avg = value + elif key in ["aparam_inv_std"]: + self.aparam_inv_std = value + else: + raise KeyError(key) + + def __getitem__(self, key): + if key in ["bias_atom_e"]: + return self.bias_atom_e + elif key in ["fparam_avg"]: + return self.fparam_avg + elif key in ["fparam_inv_std"]: + return self.fparam_inv_std + elif key in ["aparam_avg"]: + return self.aparam_avg + elif key in ["aparam_inv_std"]: + return self.aparam_inv_std + else: + raise KeyError(key) + + def serialize(self) -> dict: + """Serialize the fitting to dict.""" + return { + "var_name": self.var_name, + "ntypes": self.ntypes, + "dim_descrpt": self.dim_descrpt, + "dim_out": self.dim_out, + "neuron": self.neuron, + "resnet_dt": self.resnet_dt, + "numb_fparam": self.numb_fparam, + "numb_aparam": self.numb_aparam, + "rcond": self.rcond, + "activation_function": self.activation_function, + "precision": self.precision, + "distinguish_types": self.distinguish_types, + "nets": self.nets.serialize(), + "@variables": { + "bias_atom_e": self.bias_atom_e, + "fparam_avg": self.fparam_avg, + "fparam_inv_std": self.fparam_inv_std, + "aparam_avg": self.aparam_avg, + "aparam_inv_std": self.aparam_inv_std, + }, + # not supported + "tot_ener_zero": self.tot_ener_zero, + "trainable": self.trainable, + "atom_ener": self.atom_ener, + "layer_name": self.layer_name, + "use_aparam_as_mask": self.use_aparam_as_mask, + "spin": self.spin, + } + + @classmethod + def deserialize(cls, data: dict) -> "InvarFitting": + data = copy.deepcopy(data) + variables = data.pop("@variables") + nets = data.pop("nets") + obj = cls(**data) + for kk in variables.keys(): + obj[kk] = variables[kk] + obj.nets = NetworkCollection.deserialize(nets) + return obj + + def call( + self, + descriptor: np.array, + atype: np.array, + gr: Optional[np.array] = None, + g2: Optional[np.array] = None, + h2: Optional[np.array] = None, + fparam: Optional[np.array] = None, + aparam: Optional[np.array] = None, + ): + """Calculate the fitting. + + Parameters + ---------- + descriptor + input descriptor. shape: nf x nloc x nd + atype + the atom type. shape: nf x nloc + gr + The rotationally equivariant and permutationally invariant single particle + representation. shape: nf x nloc x ng x 3 + g2 + The rotationally invariant pair-partical representation. + shape: nf x nloc x nnei x ng + h2 + The rotationally equivariant pair-partical representation. + shape: nf x nloc x nnei x 3 + fparam + The frame parameter. shape: nf x nfp. nfp being `numb_fparam` + aparam + The atomic parameter. shape: nf x nloc x nap. nap being `numb_aparam` + + """ + nf, nloc, nd = descriptor.shape + # check input dim + if nd != self.dim_descrpt: + raise ValueError( + "get an input descriptor of dim {nd}," + "which is not consistent with {self.dim_descrpt}." + ) + xx = descriptor + # check fparam dim, concate to input descriptor + if self.numb_fparam > 0: + assert fparam is not None, "fparam should not be None" + if fparam.shape[-1] != self.numb_fparam: + raise ValueError( + "get an input fparam of dim {fparam.shape[-1]}, ", + "which is not consistent with {self.numb_fparam}.", + ) + fparam = (fparam - self.fparam_avg) * self.fparam_inv_std + fparam = np.tile(fparam.reshape([nf, 1, -1]), [1, nloc, 1]) + xx = np.concatenate( + [xx, fparam], + axis=-1, + ) + # check aparam dim, concate to input descriptor + if self.numb_aparam > 0: + assert aparam is not None, "aparam should not be None" + if aparam.shape[-1] != self.numb_aparam: + raise ValueError( + "get an input aparam of dim {aparam.shape[-1]}, ", + "which is not consistent with {self.numb_aparam}.", + ) + aparam = (aparam - self.aparam_avg) * self.aparam_inv_std + xx = np.concatenate( + [xx, aparam], + axis=-1, + ) + + # calcualte the prediction + if self.distinguish_types: + outs = np.zeros([nf, nloc, self.dim_out]) + for type_i in range(self.ntypes): + mask = np.tile( + (atype == type_i).reshape([nf, nloc, 1]), [1, 1, self.dim_out] + ) + atom_energy = self.nets[(type_i,)](xx) + atom_energy = atom_energy + self.bias_atom_e[type_i] + atom_energy = atom_energy * mask + outs = outs + atom_energy # Shape is [nframes, natoms[0], 1] + else: + outs = self.nets[()](xx) + self.bias_atom_e[atype] + return {self.var_name: outs} diff --git a/deepmd/model_format/network.py b/deepmd/model_format/network.py index a327d990c9..f2056c0b95 100644 --- a/deepmd/model_format/network.py +++ b/deepmd/model_format/network.py @@ -161,6 +161,8 @@ def __init__( ) -> None: prec = PRECISION_DICT[precision.lower()] self.precision = precision + # only use_timestep when skip connection is established. + use_timestep = use_timestep and (num_out == num_in or num_out == num_in * 2) rng = np.random.default_rng() self.w = rng.normal(size=(num_in, num_out)).astype(prec) self.b = rng.normal(size=(num_out,)).astype(prec) if bias else None diff --git a/deepmd/model_format/se_e2_a.py b/deepmd/model_format/se_e2_a.py index 28751cad8d..f179b10ac3 100644 --- a/deepmd/model_format/se_e2_a.py +++ b/deepmd/model_format/se_e2_a.py @@ -171,9 +171,8 @@ def __init__( ) self.env_mat = EnvMat(self.rcut, self.rcut_smth) self.nnei = np.sum(self.sel) - self.nneix4 = self.nnei * 4 - self.davg = np.zeros([self.ntypes, self.nneix4]) - self.dstd = np.ones([self.ntypes, self.nneix4]) + self.davg = np.zeros([self.ntypes, self.nnei, 4]) + self.dstd = np.ones([self.ntypes, self.nnei, 4]) self.orig_sel = self.sel def __setitem__(self, key, value): @@ -192,6 +191,11 @@ def __getitem__(self, key): else: raise KeyError(key) + @property + def dim_out(self): + """Returns the output dimension of this descriptor.""" + return self.neuron[-1] * self.axis_neuron + def cal_g( self, ss, diff --git a/deepmd/pt/model/model/dp_atomic_model.py b/deepmd/pt/model/model/dp_atomic_model.py index 853eacb875..a222c8e6f6 100644 --- a/deepmd/pt/model/model/dp_atomic_model.py +++ b/deepmd/pt/model/model/dp_atomic_model.py @@ -93,11 +93,11 @@ def __init__( ) fitting_net["type"] = fitting_net.get("type", "ener") - if self.descriptor_type not in ["se_e2_a"]: - fitting_net["ntypes"] = 1 + fitting_net["ntypes"] = self.descriptor.get_ntype() + if self.descriptor_type in ["se_e2_a"]: + fitting_net["distinguish_types"] = True else: - fitting_net["ntypes"] = self.descriptor.get_ntype() - fitting_net["use_tebd"] = False + fitting_net["distinguish_types"] = False fitting_net["embedding_width"] = self.descriptor.dim_out self.grad_force = "direct" not in fitting_net["type"] @@ -165,5 +165,5 @@ def forward_atomic( ) assert descriptor is not None # energy, force - fit_ret = self.fitting_net(descriptor, atype, atype_tebd=None, rot_mat=rot_mat) + fit_ret = self.fitting_net(descriptor, atype, gr=rot_mat) return fit_ret diff --git a/deepmd/pt/model/network/mlp.py b/deepmd/pt/model/network/mlp.py index e3ac0e7bc2..d76abd82f9 100644 --- a/deepmd/pt/model/network/mlp.py +++ b/deepmd/pt/model/network/mlp.py @@ -56,7 +56,10 @@ def __init__( precision: str = DEFAULT_PRECISION, ): super().__init__() - self.use_timestep = use_timestep + # only use_timestep when skip connection is established. + self.use_timestep = use_timestep and ( + num_out == num_in or num_out == num_in * 2 + ) self.activate_name = activation_function self.activate = ActivationFn(self.activate_name) self.precision = precision @@ -207,7 +210,7 @@ class NetworkCollection(DPNetworkCollection, nn.Module): NETWORK_TYPE_MAP: ClassVar[Dict[str, type]] = { "network": MLP, "embedding_network": EmbeddingNet, - # "fitting_network": FittingNet, + "fitting_network": FittingNet, } def __init__(self, *args, **kwargs): diff --git a/deepmd/pt/model/task/ener.py b/deepmd/pt/model/task/ener.py index 03043e2fcb..91ccd03d9a 100644 --- a/deepmd/pt/model/task/ener.py +++ b/deepmd/pt/model/task/ener.py @@ -1,10 +1,13 @@ # SPDX-License-Identifier: LGPL-3.0-or-later +import copy import logging from typing import ( + List, Optional, Tuple, ) +import numpy as np import torch from deepmd.model_format import ( @@ -12,6 +15,10 @@ OutputVariableDef, fitting_check_output, ) +from deepmd.pt.model.network.mlp import ( + FittingNet, + NetworkCollection, +) from deepmd.pt.model.network.network import ( ResidualDeep, ) @@ -21,19 +28,35 @@ from deepmd.pt.utils import ( env, ) +from deepmd.pt.utils.env import ( + DEFAULT_PRECISION, + PRECISION_DICT, +) +from deepmd.pt.utils.utils import ( + to_numpy_array, + to_torch_tensor, +) + +dtype = env.GLOBAL_PT_FLOAT_PRECISION +device = env.DEVICE -@Fitting.register("ener") @fitting_check_output -class EnergyFittingNet(Fitting): +class InvarFitting(Fitting): def __init__( self, - ntypes, - embedding_width, - neuron, - bias_atom_e, - resnet_dt=True, - use_tebd=True, + var_name: str, + ntypes: int, + dim_descrpt: int, + dim_out: int, + neuron: List[int] = [128, 128, 128], + bias_atom_e: Optional[torch.Tensor] = None, + resnet_dt: bool = True, + numb_fparam: int = 0, + numb_aparam: int = 0, + activation_function: str = "tanh", + precision: str = DEFAULT_PRECISION, + distinguish_types: bool = False, **kwargs, ): """Construct a fitting net for energy. @@ -46,67 +69,323 @@ def __init__( - resnet_dt: Using time-step in the ResNet construction. """ super().__init__() + self.var_name = var_name self.ntypes = ntypes - self.embedding_width = embedding_width - self.use_tebd = use_tebd - if not use_tebd: - assert self.ntypes == len(bias_atom_e), "Element count mismatches!" - bias_atom_e = torch.tensor(bias_atom_e) + self.dim_descrpt = dim_descrpt + self.dim_out = dim_out + self.neuron = neuron + self.distinguish_types = distinguish_types + self.use_tebd = not self.distinguish_types + self.resnet_dt = resnet_dt + self.numb_fparam = numb_fparam + self.numb_aparam = numb_aparam + self.activation_function = activation_function + self.precision = precision + self.prec = PRECISION_DICT[self.precision] + if bias_atom_e is None: + bias_atom_e = np.zeros([self.ntypes, self.dim_out]) + bias_atom_e = torch.tensor(bias_atom_e, dtype=self.prec, device=device) + bias_atom_e = bias_atom_e.view([self.ntypes, self.dim_out]) + if not self.use_tebd: + assert self.ntypes == bias_atom_e.shape[0], "Element count mismatches!" self.register_buffer("bias_atom_e", bias_atom_e) + # init constants + if self.numb_fparam > 0: + self.register_buffer( + "fparam_avg", + torch.zeros(self.numb_fparam, dtype=self.prec, device=device), + ) + self.register_buffer( + "fparam_inv_std", + torch.ones(self.numb_fparam, dtype=self.prec, device=device), + ) + else: + self.fparam_avg, self.fparam_inv_std = None, None + if self.numb_aparam > 0: + self.register_buffer( + "aparam_avg", + torch.zeros(self.numb_aparam, dtype=self.prec, device=device), + ) + self.register_buffer( + "aparam_inv_std", + torch.ones(self.numb_aparam, dtype=self.prec, device=device), + ) + else: + self.aparam_avg, self.aparam_inv_std = None, None - filter_layers = [] - for type_i in range(self.ntypes): - bias_type = 0.0 - one = ResidualDeep( - type_i, embedding_width, neuron, bias_type, resnet_dt=resnet_dt + in_dim = self.dim_descrpt + self.numb_fparam + self.numb_aparam + out_dim = 1 + + self.old_impl = kwargs.get("old_impl", False) + if self.old_impl: + filter_layers = [] + for type_i in range(self.ntypes): + bias_type = 0.0 + one = ResidualDeep( + type_i, + self.dim_descrpt, + self.neuron, + bias_type, + resnet_dt=self.resnet_dt, + ) + filter_layers.append(one) + self.filter_layers_old = torch.nn.ModuleList(filter_layers) + self.filter_layers = None + else: + self.filter_layers = NetworkCollection( + 1 if self.distinguish_types else 0, + self.ntypes, + network_type="fitting_network", + networks=[ + FittingNet( + in_dim, + out_dim, + self.neuron, + self.activation_function, + self.resnet_dt, + self.precision, + bias_out=True, + ) + for ii in range(self.ntypes if self.distinguish_types else 1) + ], ) - filter_layers.append(one) - self.filter_layers = torch.nn.ModuleList(filter_layers) + self.filter_layers_old = None + # very bad design... if "seed" in kwargs: logging.info("Set seed to %d in fitting net.", kwargs["seed"]) torch.manual_seed(kwargs["seed"]) - def output_def(self): + def output_def(self) -> FittingOutputDef: return FittingOutputDef( [ - OutputVariableDef("energy", [1], reduciable=True, differentiable=True), + OutputVariableDef( + self.var_name, [self.dim_out], reduciable=True, differentiable=True + ), ] ) + def __setitem__(self, key, value): + if key in ["bias_atom_e"]: + # correct bias_atom_e shape. user may provide stupid shape + self.bias_atom_e = value + elif key in ["fparam_avg"]: + self.fparam_avg = value + elif key in ["fparam_inv_std"]: + self.fparam_inv_std = value + elif key in ["aparam_avg"]: + self.aparam_avg = value + elif key in ["aparam_inv_std"]: + self.aparam_inv_std = value + else: + raise KeyError(key) + + def __getitem__(self, key): + if key in ["bias_atom_e"]: + return self.bias_atom_e + elif key in ["fparam_avg"]: + return self.fparam_avg + elif key in ["fparam_inv_std"]: + return self.fparam_inv_std + elif key in ["aparam_avg"]: + return self.aparam_avg + elif key in ["aparam_inv_std"]: + return self.aparam_inv_std + else: + raise KeyError(key) + + def serialize(self) -> dict: + """Serialize the fitting to dict.""" + return { + "var_name": self.var_name, + "ntypes": self.ntypes, + "dim_descrpt": self.dim_descrpt, + "dim_out": self.dim_out, + "neuron": self.neuron, + "resnet_dt": self.resnet_dt, + "numb_fparam": self.numb_fparam, + "numb_aparam": self.numb_aparam, + "activation_function": self.activation_function, + "precision": self.precision, + "distinguish_types": self.distinguish_types, + "nets": self.filter_layers.serialize(), + "@variables": { + "bias_atom_e": to_numpy_array(self.bias_atom_e), + "fparam_avg": to_numpy_array(self.fparam_avg), + "fparam_inv_std": to_numpy_array(self.fparam_inv_std), + "aparam_avg": to_numpy_array(self.aparam_avg), + "aparam_inv_std": to_numpy_array(self.aparam_inv_std), + }, + # "rcond": self.rcond , + # "tot_ener_zero": self.tot_ener_zero , + # "trainable": self.trainable , + # "atom_ener": self.atom_ener , + # "layer_name": self.layer_name , + # "use_aparam_as_mask": self.use_aparam_as_mask , + # "spin": self.spin , + ## NOTICE: not supported by far + "rcond": None, + "tot_ener_zero": False, + "trainable": True, + "atom_ener": None, + "layer_name": None, + "use_aparam_as_mask": False, + "spin": None, + } + + @classmethod + def deserialize(cls, data: dict) -> "InvarFitting": + data = copy.deepcopy(data) + variables = data.pop("@variables") + nets = data.pop("nets") + obj = cls(**data) + for kk in variables.keys(): + obj[kk] = to_torch_tensor(variables[kk]) + obj.filter_layers = NetworkCollection.deserialize(nets) + return obj + + def _extend_f_avg_std(self, xx: torch.Tensor, nb: int) -> torch.Tensor: + return torch.tile(xx.view([1, self.numb_fparam]), [nb, 1]) + + def _extend_a_avg_std(self, xx: torch.Tensor, nb: int, nloc: int) -> torch.Tensor: + return torch.tile(xx.view([1, 1, self.numb_aparam]), [nb, nloc, 1]) + def forward( self, - inputs: torch.Tensor, + descriptor: torch.Tensor, atype: torch.Tensor, - atype_tebd: Optional[torch.Tensor] = None, - rot_mat: Optional[torch.Tensor] = None, + gr: Optional[torch.Tensor] = None, + g2: Optional[torch.Tensor] = None, + h2: Optional[torch.Tensor] = None, + fparam: Optional[torch.Tensor] = None, + aparam: Optional[torch.Tensor] = None, ): """Based on embedding net output, alculate total energy. Args: - - inputs: Embedding matrix. Its shape is [nframes, natoms[0], self.embedding_width]. + - inputs: Embedding matrix. Its shape is [nframes, natoms[0], self.dim_descrpt]. - natoms: Tell atom count and element count. Its shape is [2+self.ntypes]. Returns ------- - `torch.Tensor`: Total energy with shape [nframes, natoms[0]]. """ + xx = descriptor + nf, nloc, nd = xx.shape + # NOTICE in tests/pt/test_model.py + # it happens that the user directly access the data memeber self.bias_atom_e + # and set it to a wrong shape! + self.bias_atom_e = self.bias_atom_e.view([self.ntypes, self.dim_out]) + # check input dim + if nd != self.dim_descrpt: + raise ValueError( + "get an input descriptor of dim {nd}," + "which is not consistent with {self.dim_descrpt}." + ) + # check fparam dim, concate to input descriptor + if self.numb_fparam > 0: + assert fparam is not None, "fparam should not be None" + assert self.fparam_avg is not None + assert self.fparam_inv_std is not None + if fparam.shape[-1] != self.numb_fparam: + raise ValueError( + "get an input fparam of dim {fparam.shape[-1]}, ", + "which is not consistent with {self.numb_fparam}.", + ) + nb, _ = fparam.shape + t_fparam_avg = self._extend_f_avg_std(self.fparam_avg, nb) + t_fparam_inv_std = self._extend_f_avg_std(self.fparam_inv_std, nb) + fparam = (fparam - t_fparam_avg) * t_fparam_inv_std + fparam = torch.tile(fparam.reshape([nf, 1, -1]), [1, nloc, 1]) + xx = torch.cat( + [xx, fparam], + dim=-1, + ) + # check aparam dim, concate to input descriptor + if self.numb_aparam > 0: + assert aparam is not None, "aparam should not be None" + assert self.aparam_avg is not None + assert self.aparam_inv_std is not None + if aparam.shape[-1] != self.numb_aparam: + raise ValueError( + "get an input aparam of dim {aparam.shape[-1]}, ", + "which is not consistent with {self.numb_aparam}.", + ) + nb, nloc, _ = aparam.shape + t_aparam_avg = self._extend_a_avg_std(self.aparam_avg, nb, nloc) + t_aparam_inv_std = self._extend_a_avg_std(self.aparam_inv_std, nb, nloc) + aparam = (aparam - t_aparam_avg) * t_aparam_inv_std + xx = torch.cat( + [xx, aparam], + dim=-1, + ) + outs = torch.zeros_like(atype).unsqueeze(-1) # jit assertion - if self.use_tebd: - if atype_tebd is not None: - inputs = torch.concat([inputs, atype_tebd], dim=-1) - atom_energy = self.filter_layers[0](inputs) + self.bias_atom_e[ - atype - ].unsqueeze(-1) - outs = outs + atom_energy # Shape is [nframes, natoms[0], 1] + if self.old_impl: + outs = torch.zeros_like(atype).unsqueeze(-1) # jit assertion + assert self.filter_layers_old is not None + if self.use_tebd: + atom_energy = self.filter_layers_old[0](xx) + self.bias_atom_e[ + atype + ].unsqueeze(-1) + outs = outs + atom_energy # Shape is [nframes, natoms[0], 1] + else: + for type_i, filter_layer in enumerate(self.filter_layers_old): + mask = atype == type_i + atom_energy = filter_layer(xx) + atom_energy = atom_energy + self.bias_atom_e[type_i] + atom_energy = atom_energy * mask.unsqueeze(-1) + outs = outs + atom_energy # Shape is [nframes, natoms[0], 1] + return {"energy": outs.to(env.GLOBAL_PT_FLOAT_PRECISION)} else: - for type_i, filter_layer in enumerate(self.filter_layers): - mask = atype == type_i - atom_energy = filter_layer(inputs) - atom_energy = atom_energy + self.bias_atom_e[type_i] - atom_energy = atom_energy * mask.unsqueeze(-1) + if self.use_tebd: + atom_energy = ( + self.filter_layers.networks[0](xx) + self.bias_atom_e[atype] + ) outs = outs + atom_energy # Shape is [nframes, natoms[0], 1] - return {"energy": outs.to(env.GLOBAL_PT_FLOAT_PRECISION)} + else: + for type_i, ll in enumerate(self.filter_layers.networks): + mask = (atype == type_i).unsqueeze(-1) + mask = torch.tile(mask, (1, 1, self.dim_out)) + atom_energy = ll(xx) + atom_energy = atom_energy + self.bias_atom_e[type_i] + atom_energy = atom_energy * mask + outs = outs + atom_energy # Shape is [nframes, natoms[0], 1] + return {self.var_name: outs.to(env.GLOBAL_PT_FLOAT_PRECISION)} + + +@Fitting.register("ener") +@fitting_check_output +class EnergyFittingNet(InvarFitting): + def __init__( + self, + ntypes: int, + embedding_width: int, + neuron: List[int] = [128, 128, 128], + bias_atom_e: Optional[torch.Tensor] = None, + resnet_dt: bool = True, + numb_fparam: int = 0, + numb_aparam: int = 0, + activation_function: str = "tanh", + precision: str = DEFAULT_PRECISION, + use_tebd: bool = True, + **kwargs, + ): + super().__init__( + "energy", + ntypes, + embedding_width, + 1, + neuron=neuron, + bias_atom_e=bias_atom_e, + resnet_dt=resnet_dt, + numb_fparam=numb_fparam, + numb_aparam=numb_aparam, + activation_function=activation_function, + precision=precision, + use_tebd=use_tebd, + **kwargs, + ) @Fitting.register("direct_force") @@ -136,7 +415,7 @@ def __init__( """ super().__init__() self.ntypes = ntypes - self.embedding_width = embedding_width + self.dim_descrpt = embedding_width self.use_tebd = use_tebd self.out_dim = out_dim if not use_tebd: @@ -186,13 +465,12 @@ def forward( self, inputs: torch.Tensor, atype: torch.Tensor, - atype_tebd: Optional[torch.Tensor] = None, - rot_mat: Optional[torch.Tensor] = None, + gr: Optional[torch.Tensor] = None, ) -> Tuple[torch.Tensor, None]: """Based on embedding net output, alculate total energy. Args: - - inputs: Embedding matrix. Its shape is [nframes, natoms[0], self.embedding_width]. + - inputs: Embedding matrix. Its shape is [nframes, natoms[0], self.dim_descrpt]. - natoms: Tell atom count and element count. Its shape is [2+self.ntypes]. Returns @@ -201,19 +479,19 @@ def forward( """ nframes, nloc, _ = inputs.size() if self.use_tebd: - if atype_tebd is not None: - inputs = torch.concat([inputs, atype_tebd], dim=-1) + # if atype_tebd is not None: + # inputs = torch.concat([inputs, atype_tebd], dim=-1) vec_out = self.filter_layers_dipole[0]( inputs ) # Shape is [nframes, nloc, m1] assert list(vec_out.size()) == [nframes, nloc, self.out_dim] # (nf x nloc) x 1 x od vec_out = vec_out.view(-1, 1, self.out_dim) - assert rot_mat is not None + assert gr is not None # (nf x nloc) x od x 3 - rot_mat = rot_mat.view(-1, self.out_dim, 3) + gr = gr.view(-1, self.out_dim, 3) vec_out = ( - torch.bmm(vec_out, rot_mat).squeeze(-2).view(nframes, nloc, 3) + torch.bmm(vec_out, gr).squeeze(-2).view(nframes, nloc, 3) ) # Shape is [nframes, nloc, 3] else: vec_out = torch.zeros_like(atype).unsqueeze(-1) # jit assertion diff --git a/deepmd/pt/model/task/fitting.py b/deepmd/pt/model/task/fitting.py index 16e80f9c20..c6fb6b27e1 100644 --- a/deepmd/pt/model/task/fitting.py +++ b/deepmd/pt/model/task/fitting.py @@ -7,9 +7,6 @@ import numpy as np import torch -from deepmd.model_format import ( - FittingOutputDef, -) from deepmd.pt.model.task.task import ( TaskBaseMethod, ) @@ -61,17 +58,9 @@ def __new__(cls, *args, **kwargs): if fitting_type in Fitting.__plugins.plugins: cls = Fitting.__plugins.plugins[fitting_type] else: - raise RuntimeError("Unknown descriptor type: " + fitting_type) + raise RuntimeError("Unknown fitting type: " + fitting_type) return super().__new__(cls) - def output_def(self) -> FittingOutputDef: - """Definition for the task Output.""" - raise NotImplementedError - - def forward(self, **kwargs): - """Task Output.""" - raise NotImplementedError - def share_params(self, base_class, shared_level, resume=False): assert ( self.__class__ == base_class.__class__ diff --git a/deepmd/pt/model/task/task.py b/deepmd/pt/model/task/task.py index a9b2efeb9a..b2dc03e4bd 100644 --- a/deepmd/pt/model/task/task.py +++ b/deepmd/pt/model/task/task.py @@ -1,12 +1,18 @@ # SPDX-License-Identifier: LGPL-3.0-or-later +from abc import ( + ABC, + abstractmethod, +) + import torch +from deepmd.model_format import ( + FittingOutputDef, +) -class TaskBaseMethod(torch.nn.Module): - def __init__(self, **kwargs): - """Construct a basic head for different tasks.""" - super().__init__() - def forward(self, **kwargs): - """Task Output.""" +class TaskBaseMethod(torch.nn.Module, ABC): + @abstractmethod + def output_def(self) -> FittingOutputDef: + """Definition for the task Output.""" raise NotImplementedError diff --git a/deepmd/pt/utils/utils.py b/deepmd/pt/utils/utils.py index 780dbf7e62..e83e12f608 100644 --- a/deepmd/pt/utils/utils.py +++ b/deepmd/pt/utils/utils.py @@ -4,9 +4,17 @@ Optional, ) +import numpy as np import torch import torch.nn.functional as F +from deepmd.model_format.common import PRECISION_DICT as NP_PRECISION_DICT + +from .env import ( + DEVICE, +) +from .env import PRECISION_DICT as PT_PRECISION_DICT + def get_activation_fn(activation: str) -> Callable: """Returns the activation function corresponding to `activation`.""" @@ -41,3 +49,35 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: return x else: raise RuntimeError(f"activation function {self.activation} not supported") + + +def to_numpy_array( + xx: torch.Tensor, +) -> np.ndarray: + if xx is None: + return None + assert xx is not None + # Create a reverse mapping of PT_PRECISION_DICT + reverse_precision_dict = {v: k for k, v in PT_PRECISION_DICT.items()} + # Use the reverse mapping to find keys with the desired value + prec = reverse_precision_dict.get(xx.dtype, None) + prec = NP_PRECISION_DICT.get(prec, None) + if prec is None: + raise ValueError(f"unknown precision {xx.dtype}") + return xx.detach().cpu().numpy().astype(prec) + + +def to_torch_tensor( + xx: np.ndarray, +) -> torch.Tensor: + if xx is None: + return None + assert xx is not None + # Create a reverse mapping of NP_PRECISION_DICT + reverse_precision_dict = {v: k for k, v in NP_PRECISION_DICT.items()} + # Use the reverse mapping to find keys with the desired value + prec = reverse_precision_dict.get(type(xx.flat[0]), None) + prec = PT_PRECISION_DICT.get(prec, None) + if prec is None: + raise ValueError(f"unknown precision {xx.dtype}") + return torch.tensor(xx, dtype=prec, device=DEVICE) diff --git a/source/tests/common/test_model_format_utils.py b/source/tests/common/test_model_format_utils.py index da76c53ed9..cb85fd2bb2 100644 --- a/source/tests/common/test_model_format_utils.py +++ b/source/tests/common/test_model_format_utils.py @@ -13,6 +13,7 @@ EmbeddingNet, EnvMat, FittingNet, + InvarFitting, NativeLayer, NativeNet, NetworkCollection, @@ -369,3 +370,123 @@ def test_self_consistency( mm1 = em1.call(self.coord_ext, self.atype_ext, self.nlist) for ii in [0, 1, 4]: np.testing.assert_allclose(mm0[ii], mm1[ii]) + + +class TestInvarFitting(unittest.TestCase, TestCaseSingleFrameWithNlist): + def setUp(self): + TestCaseSingleFrameWithNlist.setUp(self) + + def test_self_consistency( + self, + ): + rng = np.random.default_rng() + nf, nloc, nnei = self.nlist.shape + ds = DescrptSeA(self.rcut, self.rcut_smth, self.sel) + dd = ds.call(self.coord_ext, self.atype_ext, self.nlist) + atype = self.atype_ext[:, :nloc] + + for ( + distinguish_types, + od, + nfp, + nap, + ) in itertools.product( + [True, False], + [1, 2], + [0, 3], + [0, 4], + ): + ifn0 = InvarFitting( + "energy", + self.nt, + ds.dim_out, + od, + numb_fparam=nfp, + numb_aparam=nap, + distinguish_types=distinguish_types, + ) + ifn1 = InvarFitting.deserialize(ifn0.serialize()) + if nfp > 0: + ifp = rng.normal(size=(self.nf, nfp)) + else: + ifp = None + if nap > 0: + iap = rng.normal(size=(self.nf, self.nloc, nap)) + else: + iap = None + ret0 = ifn0(dd[0], atype, fparam=ifp, aparam=iap) + ret1 = ifn1(dd[0], atype, fparam=ifp, aparam=iap) + np.testing.assert_allclose(ret0["energy"], ret1["energy"]) + + def test_self_exception( + self, + ): + rng = np.random.default_rng() + nf, nloc, nnei = self.nlist.shape + ds = DescrptSeA(self.rcut, self.rcut_smth, self.sel) + dd = ds.call(self.coord_ext, self.atype_ext, self.nlist) + atype = self.atype_ext[:, :nloc] + + for ( + distinguish_types, + od, + nfp, + nap, + ) in itertools.product( + [True, False], + [1, 2], + [0, 3], + [0, 4], + ): + ifn0 = InvarFitting( + "energy", + self.nt, + ds.dim_out, + od, + numb_fparam=nfp, + numb_aparam=nap, + distinguish_types=distinguish_types, + ) + + if nfp > 0: + ifp = rng.normal(size=(self.nf, nfp)) + else: + ifp = None + if nap > 0: + iap = rng.normal(size=(self.nf, self.nloc, nap)) + else: + iap = None + with self.assertRaises(ValueError) as context: + ret0 = ifn0(dd[0][:, :, :-2], atype, fparam=ifp, aparam=iap) + self.assertIn("input descriptor", context.exception) + + if nfp > 0: + ifp = rng.normal(size=(self.nf, nfp - 1)) + with self.assertRaises(ValueError) as context: + ret0 = ifn0(dd[0], atype, fparam=ifp, aparam=iap) + self.assertIn("input fparam", context.exception) + + if nap > 0: + iap = rng.normal(size=(self.nf, self.nloc, nap - 1)) + with self.assertRaises(ValueError) as context: + ret0 = ifn0(dd[0], atype, fparam=ifp, aparam=iap) + self.assertIn("input aparam", context.exception) + + def test_get_set(self): + ifn0 = InvarFitting( + "energy", + self.nt, + 3, + 1, + ) + rng = np.random.default_rng() + foo = rng.normal([3, 4]) + for ii in [ + "bias_atom_e", + "fparam_avg", + "fparam_inv_std", + "aparam_avg", + "aparam_inv_std", + ]: + ifn0[ii] = foo + np.testing.assert_allclose(foo, ifn0[ii]) diff --git a/source/tests/pt/test_ener_fitting.py b/source/tests/pt/test_ener_fitting.py new file mode 100644 index 0000000000..eece8447df --- /dev/null +++ b/source/tests/pt/test_ener_fitting.py @@ -0,0 +1,181 @@ +# SPDX-License-Identifier: LGPL-3.0-or-later +import itertools +import unittest + +import numpy as np +import torch + +from deepmd.model_format import InvarFitting as DPInvarFitting +from deepmd.pt.model.descriptor.se_a import ( + DescrptSeA, +) +from deepmd.pt.model.task.ener import ( + EnergyFittingNet, + InvarFitting, +) +from deepmd.pt.utils import ( + env, +) +from deepmd.pt.utils.utils import ( + to_numpy_array, +) + +from .test_env_mat import ( + TestCaseSingleFrameWithNlist, +) + +dtype = env.GLOBAL_PT_FLOAT_PRECISION + + +class TestInvarFitting(unittest.TestCase, TestCaseSingleFrameWithNlist): + def setUp(self): + TestCaseSingleFrameWithNlist.setUp(self) + + def test_consistency( + self, + ): + rng = np.random.default_rng() + nf, nloc, nnei = self.nlist.shape + dd0 = DescrptSeA(self.rcut, self.rcut_smth, self.sel).to(env.DEVICE) + rd0, _, _, _, _ = dd0( + torch.tensor(self.coord_ext, dtype=dtype, device=env.DEVICE), + torch.tensor(self.atype_ext, dtype=int, device=env.DEVICE), + torch.tensor(self.nlist, dtype=int, device=env.DEVICE), + ) + atype = torch.tensor(self.atype_ext[:, :nloc], dtype=int, device=env.DEVICE) + + for od, distinguish_types, nfp, nap in itertools.product( + [1, 3], + [True, False], + [0, 3], + [0, 4], + ): + ft0 = InvarFitting( + "foo", + self.nt, + dd0.dim_out, + od, + numb_fparam=nfp, + numb_aparam=nap, + use_tebd=(not distinguish_types), + ).to(env.DEVICE) + ft1 = DPInvarFitting.deserialize(ft0.serialize()) + ft2 = InvarFitting.deserialize(ft0.serialize()) + + if nfp > 0: + ifp = torch.tensor( + rng.normal(size=(self.nf, nfp)), dtype=dtype, device=env.DEVICE + ) + else: + ifp = None + if nap > 0: + iap = torch.tensor( + rng.normal(size=(self.nf, self.nloc, nap)), + dtype=dtype, + device=env.DEVICE, + ) + else: + iap = None + + ret0 = ft0(rd0, atype, fparam=ifp, aparam=iap) + ret1 = ft1( + rd0.detach().cpu().numpy(), + atype.detach().cpu().numpy(), + fparam=to_numpy_array(ifp), + aparam=to_numpy_array(iap), + ) + ret2 = ft2(rd0, atype, fparam=ifp, aparam=iap) + np.testing.assert_allclose( + to_numpy_array(ret0["foo"]), + ret1["foo"], + ) + np.testing.assert_allclose( + to_numpy_array(ret0["foo"]), + to_numpy_array(ret2["foo"]), + ) + + def test_new_old( + self, + ): + rng = np.random.default_rng() + nf, nloc, nnei = self.nlist.shape + dd = DescrptSeA(self.rcut, self.rcut_smth, self.sel).to(env.DEVICE) + rd0, _, _, _, _ = dd( + torch.tensor(self.coord_ext, dtype=dtype, device=env.DEVICE), + torch.tensor(self.atype_ext, dtype=int, device=env.DEVICE), + torch.tensor(self.nlist, dtype=int, device=env.DEVICE), + ) + atype = torch.tensor(self.atype_ext[:, :nloc], dtype=int, device=env.DEVICE) + + od = 1 + for distinguish_types in itertools.product( + [True, False], + ): + ft0 = EnergyFittingNet( + self.nt, + dd.dim_out, + distinguish_types=distinguish_types, + ).to(env.DEVICE) + ft1 = EnergyFittingNet( + self.nt, + dd.dim_out, + distinguish_types=distinguish_types, + old_impl=True, + ).to(env.DEVICE) + dd0 = ft0.state_dict() + dd1 = ft1.state_dict() + for kk, vv in dd1.items(): + new_kk = kk + new_kk = new_kk.replace("filter_layers_old", "filter_layers.networks") + new_kk = new_kk.replace("deep_layers", "layers") + new_kk = new_kk.replace("final_layer", "layers.3") + dd1[kk] = dd0[new_kk] + if kk.split(".")[-1] in ["idt", "bias"]: + dd1[kk] = dd1[kk].unsqueeze(0) + dd1["bias_atom_e"] = dd0["bias_atom_e"] + ft1.load_state_dict(dd1) + ret0 = ft0(rd0, atype) + ret1 = ft1(rd0, atype) + np.testing.assert_allclose( + to_numpy_array(ret0["energy"]), + to_numpy_array(ret1["energy"]), + ) + + def test_jit( + self, + ): + for od, distinguish_types, nfp, nap in itertools.product( + [1, 3], + [True, False], + [0, 3], + [0, 4], + ): + ft0 = InvarFitting( + "foo", + self.nt, + 9, + od, + numb_fparam=nfp, + numb_aparam=nap, + use_tebd=(not distinguish_types), + ).to(env.DEVICE) + torch.jit.script(ft0) + + def test_get_set(self): + ifn0 = InvarFitting( + "energy", + self.nt, + 3, + 1, + ) + rng = np.random.default_rng() + foo = rng.normal([3, 4]) + for ii in [ + "bias_atom_e", + "fparam_avg", + "fparam_inv_std", + "aparam_avg", + "aparam_inv_std", + ]: + ifn0[ii] = torch.tensor(foo, dtype=dtype, device=env.DEVICE) + np.testing.assert_allclose(foo, ifn0[ii].detach().cpu().numpy()) diff --git a/source/tests/pt/test_fitting_net.py b/source/tests/pt/test_fitting_net.py index 3feb4f4739..ed2c428de5 100644 --- a/source/tests/pt/test_fitting_net.py +++ b/source/tests/pt/test_fitting_net.py @@ -102,25 +102,25 @@ def test_consistency(self): my_fn = EnergyFittingNet( self.ntypes, self.embedding_width, - self.n_neuron, - self.dp_fn.bias_atom_e, - use_tebd=False, + neuron=self.n_neuron, + bias_atom_e=self.dp_fn.bias_atom_e, + distinguish_types=True, ) for name, param in my_fn.named_parameters(): - matched = re.match("filter_layers\.(\d).deep_layers\.(\d)\.([a-z]+)", name) + matched = re.match( + "filter_layers\.networks\.(\d).layers\.(\d)\.([a-z]+)", name + ) key = None if matched: + if int(matched.group(2)) == len(self.n_neuron): + layer_id = -1 + else: + layer_id = matched.group(2) key = gen_key( type_id=matched.group(1), - layer_id=matched.group(2), + layer_id=layer_id, w_or_b=matched.group(3), ) - else: - matched = re.match("filter_layers\.(\d).final_layer\.([a-z]+)", name) - if matched: - key = gen_key( - type_id=matched.group(1), layer_id=-1, w_or_b=matched.group(2) - ) assert key is not None var = values[key] with torch.no_grad(): @@ -132,7 +132,7 @@ def test_consistency(self): ret = my_fn(embedding, atype) my_energy = ret["energy"] my_energy = my_energy.detach() - self.assertTrue(np.allclose(dp_energy, my_energy.numpy().reshape([-1]))) + np.testing.assert_allclose(dp_energy, my_energy.numpy().reshape([-1])) if __name__ == "__main__": diff --git a/source/tests/pt/test_model.py b/source/tests/pt/test_model.py index 5bbbc9e352..c6595e6471 100644 --- a/source/tests/pt/test_model.py +++ b/source/tests/pt/test_model.py @@ -53,23 +53,24 @@ VariableState = collections.namedtuple("VariableState", ["value", "gradient"]) -def torch2tf(torch_name): +def torch2tf(torch_name, last_layer_id=None): fields = torch_name.split(".") offset = int(fields[2] == "networks") element_id = int(fields[2 + offset]) if fields[0] == "descriptor": layer_id = int(fields[4 + offset]) + 1 weight_type = fields[5 + offset] - return "filter_type_all/%s_%d_%d:0" % (weight_type, layer_id, element_id) - elif fields[3] == "deep_layers": - layer_id = int(fields[4]) - weight_type = fields[5] - return "layer_%d_type_%d/%s:0" % (layer_id, element_id, weight_type) - elif fields[3] == "final_layer": - weight_type = fields[4] - return "final_layer_type_%d/%s:0" % (element_id, weight_type) + ret = "filter_type_all/%s_%d_%d:0" % (weight_type, layer_id, element_id) + elif fields[0] == "fitting_net": + layer_id = int(fields[4 + offset]) + weight_type = fields[5 + offset] + if layer_id != last_layer_id: + ret = "layer_%d_type_%d/%s:0" % (layer_id, element_id, weight_type) + else: + ret = "final_layer_type_%d/%s:0" % (element_id, weight_type) else: raise RuntimeError("Unexpected parameter name: %s" % torch_name) + return ret class DpTrainer: @@ -290,7 +291,7 @@ def test_consistency(self): "neuron": self.filter_neuron, "axis_neuron": self.axis_neuron, }, - "fitting_net": {"neuron": self.n_neuron}, + "fitting_net": {"neuron": self.n_neuron, "distinguish_types": True}, "data_stat_nbatch": self.data_stat_nbatch, "type_map": self.type_map, }, @@ -323,7 +324,7 @@ def test_consistency(self): # Keep parameter value consistency between 2 implentations for name, param in my_model.named_parameters(): name = name.replace("sea.", "") - var_name = torch2tf(name) + var_name = torch2tf(name, last_layer_id=len(self.n_neuron)) var = vs_dict[var_name].value with torch.no_grad(): src = torch.from_numpy(var) @@ -404,7 +405,7 @@ def step(step_id): for name, param in my_model.named_parameters(): name = name.replace("sea.", "") - var_name = torch2tf(name) + var_name = torch2tf(name, last_layer_id=len(self.n_neuron)) var_grad = vs_dict[var_name].gradient param_grad = param.grad.cpu() var_grad = torch.tensor(var_grad) diff --git a/source/tests/pt/test_se_e2_a.py b/source/tests/pt/test_se_e2_a.py index c0a106cb16..0da80ea1ea 100644 --- a/source/tests/pt/test_se_e2_a.py +++ b/source/tests/pt/test_se_e2_a.py @@ -25,6 +25,9 @@ PRECISION_DICT, ) +from .test_env_mat import ( + TestCaseSingleFrameWithNlist, +) from .test_mlp import ( get_tols, ) @@ -32,36 +35,6 @@ dtype = env.GLOBAL_PT_FLOAT_PRECISION -class TestCaseSingleFrameWithNlist: - def setUp(self): - # nloc == 3, nall == 4 - self.nloc = 3 - self.nall = 4 - self.nf, self.nt = 1, 2 - self.coord_ext = np.array( - [ - [0, 0, 0], - [0, 1, 0], - [0, 0, 1], - [0, -2, 0], - ], - dtype=np.float64, - ).reshape([1, self.nall * 3]) - self.atype_ext = np.array([0, 0, 1, 0], dtype=int).reshape([1, self.nall]) - # sel = [5, 2] - self.sel = [5, 2] - self.nlist = np.array( - [ - [1, 3, -1, -1, -1, 2, -1], - [0, -1, -1, -1, -1, 2, -1], - [0, 1, -1, -1, -1, 0, -1], - ], - dtype=int, - ).reshape([1, self.nloc, sum(self.sel)]) - self.rcut = 0.4 - self.rcut_smth = 2.2 - - # to be merged with the tf test case @unittest.skipIf(not support_se_e2_a, "EnvMat not supported") class TestDescrptSeA(unittest.TestCase, TestCaseSingleFrameWithNlist): diff --git a/source/tests/pt/test_utils.py b/source/tests/pt/test_utils.py new file mode 100644 index 0000000000..9c9a9479ad --- /dev/null +++ b/source/tests/pt/test_utils.py @@ -0,0 +1,31 @@ +# SPDX-License-Identifier: LGPL-3.0-or-later +import unittest + +import numpy as np +import torch + +from deepmd.pt.utils.utils import ( + to_numpy_array, + to_torch_tensor, +) + + +class TestCvt(unittest.TestCase): + def test_to_numpy(self): + rng = np.random.default_rng() + foo = rng.normal([3, 4]) + for ptp, npp in zip( + [torch.float16, torch.float32, torch.float64], + [np.float16, np.float32, np.float64], + ): + foo = foo.astype(npp) + bar = to_torch_tensor(foo) + self.assertEqual(bar.dtype, ptp) + onk = to_numpy_array(bar) + self.assertEqual(onk.dtype, npp) + with self.assertRaises(ValueError) as ee: + foo = foo.astype(np.int32) + bar = to_torch_tensor(foo) + with self.assertRaises(ValueError) as ee: + bar = to_torch_tensor(foo) + bar = to_numpy_array(bar.int())