diff --git a/deepmd/backend/pytorch.py b/deepmd/backend/pytorch.py index 676694172b..fb7d30e994 100644 --- a/deepmd/backend/pytorch.py +++ b/deepmd/backend/pytorch.py @@ -29,8 +29,8 @@ @Backend.register("pt") @Backend.register("pytorch") -class TensorFlowBackend(Backend): - """TensorFlow backend.""" +class PyTorchBackend(Backend): + """PyTorch backend.""" name = "PyTorch" """The formal name of the backend.""" diff --git a/deepmd/dpmodel/descriptor/se_e2_a.py b/deepmd/dpmodel/descriptor/se_e2_a.py index a068a2e366..531aa09f3a 100644 --- a/deepmd/dpmodel/descriptor/se_e2_a.py +++ b/deepmd/dpmodel/descriptor/se_e2_a.py @@ -111,6 +111,8 @@ class DescrptSeA(NativeOP, BaseDescriptor): exclude_types : List[List[int]] The excluded pairs of types which have no interaction with each other. For example, `[[0, 1]]` means no interaction between type 0 and type 1. + env_protection: float + Protection parameter to prevent division by zero errors during environment matrix calculations. set_davg_zero Set the shift of embedding net input to zero. activation_function @@ -149,6 +151,7 @@ def __init__( trainable: bool = True, type_one_side: bool = True, exclude_types: List[List[int]] = [], + env_protection: float = 0.0, set_davg_zero: bool = False, activation_function: str = "tanh", precision: str = DEFAULT_PRECISION, @@ -169,6 +172,7 @@ def __init__( self.resnet_dt = resnet_dt self.trainable = trainable self.type_one_side = type_one_side + self.env_protection = env_protection self.set_davg_zero = set_davg_zero self.activation_function = activation_function self.precision = precision @@ -192,7 +196,7 @@ def __init__( self.resnet_dt, self.precision, ) - self.env_mat = EnvMat(self.rcut, self.rcut_smth) + self.env_mat = EnvMat(self.rcut, self.rcut_smth, protection=self.env_protection) self.nnei = np.sum(self.sel) self.davg = np.zeros( [self.ntypes, self.nnei, 4], dtype=PRECISION_DICT[self.precision] @@ -378,6 +382,7 @@ def serialize(self) -> dict: "trainable": self.trainable, "type_one_side": self.type_one_side, "exclude_types": self.exclude_types, + "env_protection": self.env_protection, "set_davg_zero": self.set_davg_zero, "activation_function": self.activation_function, # make deterministic @@ -406,7 +411,6 @@ def deserialize(cls, data: dict) -> "DescrptSeA": obj["davg"] = variables["davg"] obj["dstd"] = variables["dstd"] obj.embeddings = NetworkCollection.deserialize(embeddings) - obj.env_mat = EnvMat.deserialize(env_mat) return obj @classmethod diff --git a/deepmd/dpmodel/descriptor/se_r.py b/deepmd/dpmodel/descriptor/se_r.py index 2dbf495d14..3128a28493 100644 --- a/deepmd/dpmodel/descriptor/se_r.py +++ b/deepmd/dpmodel/descriptor/se_r.py @@ -106,6 +106,7 @@ def __init__( trainable: bool = True, type_one_side: bool = True, exclude_types: List[List[int]] = [], + env_protection: float = 0.0, set_davg_zero: bool = False, activation_function: str = "tanh", precision: str = DEFAULT_PRECISION, @@ -133,6 +134,7 @@ def __init__( self.precision = precision self.spin = spin self.emask = PairExcludeMask(self.ntypes, self.exclude_types) + self.env_protection = env_protection in_dim = 1 # not considiering type embedding self.embeddings = NetworkCollection( @@ -150,7 +152,7 @@ def __init__( self.resnet_dt, self.precision, ) - self.env_mat = EnvMat(self.rcut, self.rcut_smth) + self.env_mat = EnvMat(self.rcut, self.rcut_smth, protection=self.env_protection) self.nnei = np.sum(self.sel) self.davg = np.zeros( [self.ntypes, self.nnei, 1], dtype=PRECISION_DICT[self.precision] @@ -305,6 +307,7 @@ def serialize(self) -> dict: "trainable": self.trainable, "type_one_side": self.type_one_side, "exclude_types": self.exclude_types, + "env_protection": self.env_protection, "set_davg_zero": self.set_davg_zero, "activation_function": self.activation_function, # make deterministic @@ -333,7 +336,6 @@ def deserialize(cls, data: dict) -> "DescrptSeR": obj["davg"] = variables["davg"] obj["dstd"] = variables["dstd"] obj.embeddings = NetworkCollection.deserialize(embeddings) - obj.env_mat = EnvMat.deserialize(env_mat) return obj @classmethod diff --git a/deepmd/dpmodel/fitting/ener_fitting.py b/deepmd/dpmodel/fitting/ener_fitting.py index de41bebf6d..3a0e9909b9 100644 --- a/deepmd/dpmodel/fitting/ener_fitting.py +++ b/deepmd/dpmodel/fitting/ener_fitting.py @@ -63,6 +63,7 @@ def __init__( use_aparam_as_mask=use_aparam_as_mask, spin=spin, mixed_types=mixed_types, + exclude_types=exclude_types, ) @classmethod diff --git a/deepmd/dpmodel/model/__init__.py b/deepmd/dpmodel/model/__init__.py index dda174fa4e..cb796e6d35 100644 --- a/deepmd/dpmodel/model/__init__.py +++ b/deepmd/dpmodel/model/__init__.py @@ -16,8 +16,12 @@ from .make_model import ( make_model, ) +from .spin_model import ( + SpinModel, +) __all__ = [ "DPModel", + "SpinModel", "make_model", ] diff --git a/deepmd/dpmodel/model/model.py b/deepmd/dpmodel/model/model.py index 6f06785c56..3fdf5b802b 100644 --- a/deepmd/dpmodel/model/model.py +++ b/deepmd/dpmodel/model/model.py @@ -8,10 +8,16 @@ from deepmd.dpmodel.model.dp_model import ( DPModel, ) +from deepmd.dpmodel.model.spin_model import ( + SpinModel, +) +from deepmd.utils.spin import ( + Spin, +) -def get_model(data: dict) -> DPModel: - """Get a DPModel from a dictionary. +def get_standard_model(data: dict) -> DPModel: + """Get a standard DPModel from a dictionary. Parameters ---------- @@ -30,6 +36,7 @@ def get_model(data: dict) -> DPModel: fitting = EnergyFittingNet( ntypes=descriptor.get_ntypes(), dim_descrpt=descriptor.get_dim_out(), + mixed_types=descriptor.mixed_types(), **data["fitting_net"], ) else: @@ -41,3 +48,50 @@ def get_model(data: dict) -> DPModel: atom_exclude_types=data.get("atom_exclude_types", []), pair_exclude_types=data.get("pair_exclude_types", []), ) + + +def get_spin_model(data: dict) -> SpinModel: + """Get a spin model from a dictionary. + + Parameters + ---------- + data : dict + The data to construct the model. + """ + # include virtual spin and placeholder types + data["type_map"] += [item + "_spin" for item in data["type_map"]] + spin = Spin( + use_spin=data["spin"]["use_spin"], + virtual_scale=data["spin"]["virtual_scale"], + ) + pair_exclude_types = spin.get_pair_exclude_types( + exclude_types=data.get("pair_exclude_types", None) + ) + data["pair_exclude_types"] = pair_exclude_types + # for descriptor data stat + data["descriptor"]["exclude_types"] = pair_exclude_types + atom_exclude_types = spin.get_atom_exclude_types( + exclude_types=data.get("atom_exclude_types", None) + ) + data["atom_exclude_types"] = atom_exclude_types + if "env_protection" not in data["descriptor"]: + data["descriptor"]["env_protection"] = 1e-6 + if data["descriptor"]["type"] in ["se_e2_a"]: + # only expand sel for se_e2_a + data["descriptor"]["sel"] += data["descriptor"]["sel"] + backbone_model = get_standard_model(data) + return SpinModel(backbone_model=backbone_model, spin=spin) + + +def get_model(data: dict): + """Get a model from a dictionary. + + Parameters + ---------- + data : dict + The data to construct the model. + """ + if "spin" in data: + return get_spin_model(data) + else: + return get_standard_model(data) diff --git a/deepmd/dpmodel/model/spin_model.py b/deepmd/dpmodel/model/spin_model.py new file mode 100644 index 0000000000..5b31b64fdf --- /dev/null +++ b/deepmd/dpmodel/model/spin_model.py @@ -0,0 +1,394 @@ +# SPDX-License-Identifier: LGPL-3.0-or-later +from typing import ( + Dict, + List, + Optional, +) + +import numpy as np + +from deepmd.dpmodel.model.dp_model import ( + DPModel, +) +from deepmd.utils.spin import ( + Spin, +) + + +class SpinModel: + """A spin model wrapper, with spin input preprocess and output split.""" + + def __init__( + self, + backbone_model, + spin: Spin, + ): + super().__init__() + self.backbone_model = backbone_model + self.spin = spin + self.ntypes_real = self.spin.ntypes_real + self.virtual_scale_mask = self.spin.get_virtual_scale_mask() + self.spin_mask = self.spin.get_spin_mask() + + def process_spin_input(self, coord, atype, spin): + """Generate virtual coordinates and types, concat into the input.""" + nframes, nloc = coord.shape[:-1] + atype_spin = np.concatenate([atype, atype + self.ntypes_real], axis=-1) + virtual_coord = coord + spin * self.virtual_scale_mask[atype].reshape( + [nframes, nloc, 1] + ) + coord_spin = np.concatenate([coord, virtual_coord], axis=-2) + return coord_spin, atype_spin + + def process_spin_input_lower( + self, + extended_coord: np.ndarray, + extended_atype: np.ndarray, + extended_spin: np.ndarray, + nlist: np.ndarray, + mapping: Optional[np.ndarray] = None, + ): + """ + Add `extended_spin` into `extended_coord` to generate virtual atoms, and extend `nlist` and `mapping`. + Note that the final `extended_coord_updated` with shape [nframes, nall + nall, 3] has the following order: + - [:, :nloc]: original nloc real atoms. + - [:, nloc: nloc + nloc]: virtual atoms corresponding to nloc real atoms. + - [:, nloc + nloc: nloc + nall]: ghost real atoms. + - [:, nloc + nall: nall + nall]: virtual atoms corresponding to ghost real atoms. + """ + nframes, nall = extended_coord.shape[:2] + nloc = nlist.shape[1] + virtual_extended_coord = ( + extended_coord + + extended_spin + * self.virtual_scale_mask[extended_atype].reshape([nframes, nall, 1]) + ) + virtual_extended_atype = extended_atype + self.ntypes_real + extended_coord_updated = self.concat_switch_virtual( + extended_coord, virtual_extended_coord, nloc + ) + extended_atype_updated = self.concat_switch_virtual( + extended_atype, virtual_extended_atype, nloc + ) + if mapping is not None: + virtual_mapping = mapping + nloc + mapping_updated = self.concat_switch_virtual(mapping, virtual_mapping, nloc) + else: + mapping_updated = None + # extend the nlist + nlist_updated = self.extend_nlist(extended_atype, nlist) + return ( + extended_coord_updated, + extended_atype_updated, + nlist_updated, + mapping_updated, + ) + + def process_spin_output( + self, atype, out_tensor, add_mag: bool = True, virtual_scale: bool = True + ): + """Split the output both real and virtual atoms, and scale the latter.""" + nframes, nloc_double = out_tensor.shape[:2] + nloc = nloc_double // 2 + if virtual_scale: + virtual_scale_mask = self.virtual_scale_mask + else: + virtual_scale_mask = self.spin_mask + atomic_mask = virtual_scale_mask[atype].reshape([nframes, nloc, 1]) + out_real, out_mag = np.split(out_tensor, [nloc], axis=1) + if add_mag: + out_real = out_real + out_mag + out_mag = (out_mag.reshape([nframes, nloc, -1]) * atomic_mask).reshape( + out_mag.shape + ) + return out_real, out_mag, atomic_mask > 0.0 + + def process_spin_output_lower( + self, + extended_atype, + extended_out_tensor, + nloc: int, + add_mag: bool = True, + virtual_scale: bool = True, + ): + """Split the extended output of both real and virtual atoms with switch, and scale the latter.""" + nframes, nall_double = extended_out_tensor.shape[:2] + nall = nall_double // 2 + if virtual_scale: + virtual_scale_mask = self.virtual_scale_mask + else: + virtual_scale_mask = self.spin_mask + atomic_mask = virtual_scale_mask[extended_atype].reshape([nframes, nall, 1]) + extended_out_real = np.concatenate( + [ + extended_out_tensor[:, :nloc], + extended_out_tensor[:, nloc + nloc : nloc + nall], + ], + axis=1, + ) + extended_out_mag = np.concatenate( + [ + extended_out_tensor[:, nloc : nloc + nloc], + extended_out_tensor[:, nloc + nall :], + ], + axis=1, + ) + if add_mag: + extended_out_real = extended_out_real + extended_out_mag + extended_out_mag = ( + extended_out_mag.reshape([nframes, nall, -1]) * atomic_mask + ).reshape(extended_out_mag.shape) + return extended_out_real, extended_out_mag, atomic_mask > 0.0 + + @staticmethod + def extend_nlist(extended_atype, nlist): + nframes, nloc, nnei = nlist.shape + nall = extended_atype.shape[1] + nlist_mask = nlist != -1 + nlist[nlist == -1] = 0 + nlist_shift = nlist + nall + nlist[~nlist_mask] = -1 + nlist_shift[~nlist_mask] = -1 + self_spin = np.arange(0, nloc, dtype=nlist.dtype) + nall + self_spin = self_spin.reshape(1, -1, 1).repeat(nframes, axis=0) + # self spin + real neighbor + virtual neighbor + # nf x nloc x (1 + nnei + nnei) + extended_nlist = np.concatenate([self_spin, nlist, nlist_shift], axis=-1) + # nf x (nloc + nloc) x (1 + nnei + nnei) + extended_nlist = np.concatenate( + [extended_nlist, -1 * np.ones_like(extended_nlist)], axis=-2 + ) + # update the index for switch + first_part_index = (nloc <= extended_nlist) & (extended_nlist < nall) + second_part_index = (nall <= extended_nlist) & (extended_nlist < (nall + nloc)) + extended_nlist[first_part_index] += nloc + extended_nlist[second_part_index] -= nall - nloc + return extended_nlist + + @staticmethod + def concat_switch_virtual(extended_tensor, extended_tensor_virtual, nloc: int): + nframes, nall = extended_tensor.shape[:2] + out_shape = list(extended_tensor.shape) + out_shape[1] *= 2 + extended_tensor_updated = np.zeros( + out_shape, + dtype=extended_tensor.dtype, + ) + extended_tensor_updated[:, :nloc] = extended_tensor[:, :nloc] + extended_tensor_updated[:, nloc : nloc + nloc] = extended_tensor_virtual[ + :, :nloc + ] + extended_tensor_updated[:, nloc + nloc : nloc + nall] = extended_tensor[ + :, nloc: + ] + extended_tensor_updated[:, nloc + nall :] = extended_tensor_virtual[:, nloc:] + return extended_tensor_updated.reshape(out_shape) + + def get_type_map(self) -> List[str]: + """Get the type map.""" + tmap = self.backbone_model.get_type_map() + ntypes = len(tmap) // 2 # ignore the virtual type + return tmap[:ntypes] + + def get_rcut(self): + """Get the cut-off radius.""" + return self.backbone_model.get_rcut() + + def get_dim_fparam(self): + """Get the number (dimension) of frame parameters of this atomic model.""" + return self.backbone_model.get_dim_fparam() + + def get_dim_aparam(self): + """Get the number (dimension) of atomic parameters of this atomic model.""" + return self.backbone_model.get_dim_aparam() + + def get_sel_type(self) -> List[int]: + """Get the selected atom types of this model. + Only atoms with selected atom types have atomic contribution + to the result of the model. + If returning an empty list, all atom types are selected. + """ + return self.backbone_model.get_sel_type() + + def is_aparam_nall(self) -> bool: + """Check whether the shape of atomic parameters is (nframes, nall, ndim). + If False, the shape is (nframes, nloc, ndim). + """ + return self.backbone_model.is_aparam_nall() + + def model_output_type(self) -> List[str]: + """Get the output type for the model.""" + return self.backbone_model.model_output_type() + + def get_model_def_script(self) -> str: + """Get the model definition script.""" + return self.backbone_model.get_model_def_script() + + def get_nnei(self) -> int: + """Returns the total number of selected neighboring atoms in the cut-off radius.""" + # for C++ interface + if not self.backbone_model.mixed_types(): + return self.backbone_model.get_nnei() // 2 # ignore the virtual selected + else: + return self.backbone_model.get_nnei() + + def get_nsel(self) -> int: + """Returns the total number of selected neighboring atoms in the cut-off radius.""" + if not self.backbone_model.mixed_types(): + return self.backbone_model.get_nsel() // 2 # ignore the virtual selected + else: + return self.backbone_model.get_nsel() + + @staticmethod + def has_spin() -> bool: + """Returns whether it has spin input and output.""" + return True + + def __getattr__(self, name): + """Get attribute from the wrapped model.""" + if name in self.__dict__: + return self.__dict__[name] + else: + return getattr(self.backbone_model, name) + + def serialize(self) -> dict: + return { + "backbone_model": self.backbone_model.serialize(), + "spin": self.spin.serialize(), + } + + @classmethod + def deserialize(cls, data) -> "SpinModel": + backbone_model_obj = DPModel.deserialize(data["backbone_model"]) + spin = Spin.deserialize(data["spin"]) + return cls( + backbone_model=backbone_model_obj, + spin=spin, + ) + + def call( + self, + coord, + atype, + spin, + box: Optional[np.ndarray] = None, + fparam: Optional[np.ndarray] = None, + aparam: Optional[np.ndarray] = None, + do_atomic_virial: bool = False, + ) -> Dict[str, np.ndarray]: + """Return model prediction. + + Parameters + ---------- + coord + The coordinates of the atoms. + shape: nf x (nloc x 3) + atype + The type of atoms. shape: nf x nloc + spin + The spins of the atoms. + shape: nf x (nloc x 3) + box + The simulation box. shape: nf x 9 + fparam + frame parameter. nf x ndf + aparam + atomic parameter. nf x nloc x nda + do_atomic_virial + If calculate the atomic virial. + + Returns + ------- + ret_dict + The result dict of type Dict[str,np.ndarray]. + The keys are defined by the `ModelOutputDef`. + + """ + nframes, nloc = coord.shape[:2] + coord_updated, atype_updated = self.process_spin_input(coord, atype, spin) + model_predict = self.backbone_model.call( + coord_updated, + atype_updated, + box, + fparam=fparam, + aparam=aparam, + do_atomic_virial=do_atomic_virial, + ) + model_output_type = self.backbone_model.model_output_type() + if "mask" in model_output_type: + model_output_type.pop(model_output_type.index("mask")) + var_name = model_output_type[0] + model_predict[f"{var_name}"] = np.split( + model_predict[f"{var_name}"], [nloc], axis=1 + )[0] + # for now omit the grad output + return model_predict + + def call_lower( + self, + extended_coord: np.ndarray, + extended_atype: np.ndarray, + extended_spin: np.ndarray, + nlist: np.ndarray, + mapping: Optional[np.ndarray] = None, + fparam: Optional[np.ndarray] = None, + aparam: Optional[np.ndarray] = None, + do_atomic_virial: bool = False, + ): + """Return model prediction. Lower interface that takes + extended atomic coordinates, types and spins, nlist, and mapping + as input, and returns the predictions on the extended region. + The predictions are not reduced. + + Parameters + ---------- + extended_coord + coordinates in extended region. nf x (nall x 3). + extended_atype + atomic type in extended region. nf x nall. + extended_spin + spins in extended region. nf x (nall x 3). + nlist + neighbor list. nf x nloc x nsel. + mapping + maps the extended indices to local indices. nf x nall. + fparam + frame parameter. nf x ndf + aparam + atomic parameter. nf x nloc x nda + do_atomic_virial + whether calculate atomic virial + + Returns + ------- + result_dict + the result dict, defined by the `FittingOutputDef`. + + """ + nframes, nloc = nlist.shape[:2] + ( + extended_coord_updated, + extended_atype_updated, + nlist_updated, + mapping_updated, + ) = self.process_spin_input_lower( + extended_coord, extended_atype, extended_spin, nlist, mapping=mapping + ) + model_predict = self.backbone_model.call_lower( + extended_coord_updated, + extended_atype_updated, + nlist_updated, + mapping=mapping_updated, + fparam=fparam, + aparam=aparam, + do_atomic_virial=do_atomic_virial, + ) + model_output_type = self.backbone_model.model_output_type() + if "mask" in model_output_type: + model_output_type.pop(model_output_type.index("mask")) + var_name = model_output_type[0] + model_predict[f"{var_name}"] = np.split( + model_predict[f"{var_name}"], [nloc], axis=1 + )[0] + # for now omit the grad output + return model_predict diff --git a/deepmd/dpmodel/output_def.py b/deepmd/dpmodel/output_def.py index ac41513246..cbebb4908a 100644 --- a/deepmd/dpmodel/output_def.py +++ b/deepmd/dpmodel/output_def.py @@ -125,6 +125,8 @@ class OutputVariableOperation(IntEnum): """Derivative w.r.t. cell.""" _SEC_DERV_R = 8 """Second derivative w.r.t. coordinates.""" + MAG = 16 + """Magnetic output.""" class OutputVariableCategory(IntEnum): @@ -142,6 +144,10 @@ class OutputVariableCategory(IntEnum): """Virial, the transposed negative gradient with cell tensor times cell tensor, see eq 40 JCP 159, 054801 (2023). """ DERV_R_DERV_R = OutputVariableOperation.DERV_R | OutputVariableOperation._SEC_DERV_R """Hession matrix, the second derivative w.r.t. coordinates.""" + DERV_R_MAG = OutputVariableOperation.DERV_R | OutputVariableOperation.MAG + """Magnetic part of negative derivative w.r.t. coordinates. (e.g. magnetic force)""" + DERV_C_MAG = OutputVariableOperation.DERV_C | OutputVariableOperation.MAG + """Magnetic part of atomic component of the virial.""" class OutputVariableDef: @@ -176,8 +182,10 @@ class OutputVariableDef: If the variable is defined for each atom. category : int The category of the output variable. - hessian : bool + r_hessian : bool If hessian is requred + magnetic : bool + If the derivatives of variable have magnetic parts. """ def __init__( @@ -190,6 +198,7 @@ def __init__( atomic: bool = True, category: int = OutputVariableCategory.OUT.value, r_hessian: bool = False, + magnetic: bool = False, ): self.name = name self.shape = list(shape) @@ -208,6 +217,7 @@ def __init__( raise ValueError("a reduciable variable should be atomic") self.category = category self.r_hessian = r_hessian + self.magnetic = magnetic if self.r_hessian: if not self.reduciable: raise ValueError("only reduciable variable can calculate hessian") @@ -271,6 +281,7 @@ def __init__( self.def_derv_r, self.def_derv_c = do_derivative(self.def_outp.get_data()) self.def_hess_r, _ = do_derivative(self.def_derv_r) self.def_derv_c_redu = do_reduce(self.def_derv_c) + self.def_mask = do_mask(self.def_outp.get_data()) self.var_defs: Dict[str, OutputVariableDef] = {} for ii in [ self.def_outp.get_data(), @@ -279,6 +290,7 @@ def __init__( self.def_derv_r, self.def_derv_c_redu, self.def_hess_r, + self.def_mask, ]: self.var_defs.update(ii) @@ -324,12 +336,16 @@ def get_deriv_name(name: str) -> Tuple[str, str]: return name + "_derv_r", name + "_derv_c" +def get_deriv_name_mag(name: str) -> Tuple[str, str]: + return name + "_derv_r_mag", name + "_derv_c_mag" + + def get_hessian_name(name: str) -> str: return name + "_derv_r_derv_r" def apply_operation(var_def: OutputVariableDef, op: OutputVariableOperation) -> int: - """Apply a operation to the category of a variable definition. + """Apply an operation to the category of a variable definition. Parameters ---------- @@ -401,6 +417,31 @@ def do_reduce( return def_redu +def do_mask( + def_outp_data: Dict[str, OutputVariableDef], +) -> Dict[str, OutputVariableDef]: + def_mask: Dict[str, OutputVariableDef] = {} + # for deep eval when has atomic mask + def_mask["mask"] = OutputVariableDef( + name="mask", + shape=[1], + reduciable=False, + r_differentiable=False, + c_differentiable=False, + ) + for kk, vv in def_outp_data.items(): + if vv.magnetic: + # for deep eval when has atomic mask for magnetic atoms + def_mask["mask_mag"] = OutputVariableDef( + name="mask_mag", + shape=[1], + reduciable=False, + r_differentiable=False, + c_differentiable=False, + ) + return def_mask + + def do_derivative( def_outp_data: Dict[str, OutputVariableDef], ) -> Tuple[Dict[str, OutputVariableDef], Dict[str, OutputVariableDef]]: @@ -408,6 +449,7 @@ def do_derivative( def_derv_c: Dict[str, OutputVariableDef] = {} for kk, vv in def_outp_data.items(): rkr, rkc = get_deriv_name(kk) + rkrm, rkcm = get_deriv_name_mag(kk) if vv.r_differentiable: def_derv_r[rkr] = OutputVariableDef( rkr, @@ -420,9 +462,22 @@ def do_derivative( atomic=True, category=apply_operation(vv, OutputVariableOperation.DERV_R), ) + if vv.magnetic: + def_derv_r[rkrm] = OutputVariableDef( + rkrm, + vv.shape + [3], # noqa: RUF005 + reduciable=False, + r_differentiable=( + vv.r_hessian and vv.category == OutputVariableCategory.OUT.value + ), + c_differentiable=False, + atomic=True, + category=apply_operation(vv, OutputVariableOperation.DERV_R), + magnetic=True, + ) + if vv.c_differentiable: assert vv.r_differentiable - rkr, rkc = get_deriv_name(kk) def_derv_c[rkc] = OutputVariableDef( rkc, vv.shape + [9], # noqa: RUF005 @@ -432,4 +487,15 @@ def do_derivative( atomic=True, category=apply_operation(vv, OutputVariableOperation.DERV_C), ) + if vv.magnetic: + def_derv_r[rkcm] = OutputVariableDef( + rkcm, + vv.shape + [9], # noqa: RUF005 + reduciable=True, + r_differentiable=False, + c_differentiable=False, + atomic=True, + category=apply_operation(vv, OutputVariableOperation.DERV_C), + magnetic=True, + ) return def_derv_r, def_derv_c diff --git a/deepmd/dpmodel/utils/env_mat.py b/deepmd/dpmodel/utils/env_mat.py index 5fb4ac4107..0c2ca43c40 100644 --- a/deepmd/dpmodel/utils/env_mat.py +++ b/deepmd/dpmodel/utils/env_mat.py @@ -33,6 +33,7 @@ def _make_env_mat( rcut: float, ruct_smth: float, radial_only: bool = False, + protection: float = 0.0, ): """Make smooth environment matrix.""" nf, nloc, nnei = nlist.shape @@ -53,8 +54,8 @@ def _make_env_mat( length = np.linalg.norm(diff, axis=-1, keepdims=True) # for index 0 nloc atom length = length + ~np.expand_dims(mask, -1) - t0 = 1 / length - t1 = diff / length**2 + t0 = 1 / (length + protection) + t1 = diff / (length + protection) ** 2 weight = compute_smooth_weight(length, ruct_smth, rcut) weight = weight * np.expand_dims(mask, -1) if radial_only: @@ -69,9 +70,11 @@ def __init__( self, rcut, rcut_smth, + protection: float = 0.0, ): self.rcut = rcut self.rcut_smth = rcut_smth + self.protection = protection def call( self, @@ -120,7 +123,12 @@ def call( def _call(self, nlist, coord_ext, radial_only): em, diff, ww = _make_env_mat( - nlist, coord_ext, self.rcut, self.rcut_smth, radial_only + nlist, + coord_ext, + self.rcut, + self.rcut_smth, + radial_only=radial_only, + protection=self.protection, ) return em, ww diff --git a/deepmd/dpmodel/utils/exclude_mask.py b/deepmd/dpmodel/utils/exclude_mask.py index 360f190e13..ff668b8153 100644 --- a/deepmd/dpmodel/utils/exclude_mask.py +++ b/deepmd/dpmodel/utils/exclude_mask.py @@ -24,6 +24,12 @@ def __init__( # (ntypes) self.type_mask = self.type_mask.reshape([-1]) + def get_exclude_types(self): + return self.exclude_types + + def get_type_mask(self): + return self.type_mask + def build_type_exclude_mask( self, atype: np.ndarray, @@ -75,6 +81,9 @@ def __init__( # (ntypes+1 x ntypes+1) self.type_mask = self.type_mask.reshape([-1]) + def get_exclude_types(self): + return self.exclude_types + def build_type_exclude_mask( self, nlist: np.ndarray, diff --git a/deepmd/dpmodel/utils/nlist.py b/deepmd/dpmodel/utils/nlist.py index 657d6ecee2..1aa1820495 100644 --- a/deepmd/dpmodel/utils/nlist.py +++ b/deepmd/dpmodel/utils/nlist.py @@ -69,6 +69,8 @@ def build_neighbor_list( ) assert list(diff.shape) == [batch_size, nloc, nall, 3] rr = np.linalg.norm(diff, axis=-1) + # if central atom has two zero distances, sorting sometimes can not exclude itself + rr -= np.eye(nloc, nall, dtype=diff.dtype)[np.newaxis, :, :] nlist = np.argsort(rr, axis=-1) rr = np.sort(rr, axis=-1) rr = rr[:, :, 1:] diff --git a/deepmd/entrypoints/test.py b/deepmd/entrypoints/test.py index efc75e31a7..ccf8b1da1e 100644 --- a/deepmd/entrypoints/test.py +++ b/deepmd/entrypoints/test.py @@ -298,6 +298,9 @@ def test_ener( ) if dp.get_dim_aparam() > 0: data.add("aparam", dp.get_dim_aparam(), atomic=True, must=True, high_prec=False) + if dp.has_spin: + data.add("spin", 3, atomic=True, must=True, high_prec=False) + data.add("force_mag", 3, atomic=True, must=False, high_prec=False) test_data = data.get_test() mixed_type = data.mixed_type @@ -311,6 +314,10 @@ def test_ener( efield = test_data["efield"][:numb_test].reshape([numb_test, -1]) else: efield = None + if dp.has_spin: + spin = test_data["spin"][:numb_test].reshape([numb_test, -1]) + else: + spin = None if not data.pbc: box = None if mixed_type: @@ -335,6 +342,7 @@ def test_ener( atomic=has_atom_ener, efield=efield, mixed_type=mixed_type, + spin=spin, ) energy = ret[0] force = ret[1] @@ -347,26 +355,50 @@ def test_ener( av = ret[4] ae = ae.reshape([numb_test, -1]) av = av.reshape([numb_test, -1]) - if dp.get_ntypes_spin() != 0: - ntypes_real = dp.get_ntypes() - dp.get_ntypes_spin() - nloc = natoms - nloc_real = sum([np.count_nonzero(atype == ii) for ii in range(ntypes_real)]) - force_r = np.split( - force, indices_or_sections=[nloc_real * 3, nloc * 3], axis=1 - )[0] - force_m = np.split( - force, indices_or_sections=[nloc_real * 3, nloc * 3], axis=1 - )[1] - test_force_r = np.split( - test_data["force"][:numb_test], - indices_or_sections=[nloc_real * 3, nloc * 3], - axis=1, - )[0] - test_force_m = np.split( - test_data["force"][:numb_test], - indices_or_sections=[nloc_real * 3, nloc * 3], - axis=1, - )[1] + if dp.has_spin: + force_m = ret[5] + force_m = force_m.reshape([numb_test, -1]) + mask_mag = ret[6] + mask_mag = mask_mag.reshape([numb_test, -1]) + else: + if dp.has_spin: + force_m = ret[3] + force_m = force_m.reshape([numb_test, -1]) + mask_mag = ret[4] + mask_mag = mask_mag.reshape([numb_test, -1]) + out_put_spin = dp.get_ntypes_spin() != 0 or dp.has_spin + if out_put_spin: + if dp.get_ntypes_spin() != 0: # old tf support for spin + ntypes_real = dp.get_ntypes() - dp.get_ntypes_spin() + nloc = natoms + nloc_real = sum( + [np.count_nonzero(atype == ii) for ii in range(ntypes_real)] + ) + force_r = np.split( + force, indices_or_sections=[nloc_real * 3, nloc * 3], axis=1 + )[0] + force_m = np.split( + force, indices_or_sections=[nloc_real * 3, nloc * 3], axis=1 + )[1] + test_force_r = np.split( + test_data["force"][:numb_test], + indices_or_sections=[nloc_real * 3, nloc * 3], + axis=1, + )[0] + test_force_m = np.split( + test_data["force"][:numb_test], + indices_or_sections=[nloc_real * 3, nloc * 3], + axis=1, + )[1] + else: # pt support for spin + force_r = force + test_force_r = test_data["force"][:numb_test] + # The shape of force_m and test_force_m are [-1, 3], + # which is designed for mixed_type cases + force_m = force_m.reshape(-1, 3)[mask_mag.reshape(-1)] + test_force_m = test_data["force_mag"][:numb_test].reshape(-1, 3)[ + mask_mag.reshape(-1) + ] diff_e = energy - test_data["energy"][:numb_test].reshape([-1, 1]) mae_e = mae(diff_e) @@ -385,7 +417,7 @@ def test_ener( diff_ae = test_data["atom_ener"][:numb_test].reshape([-1]) - ae.reshape([-1]) mae_ae = mae(diff_ae) rmse_ae = rmse(diff_ae) - if dp.get_ntypes_spin() != 0: + if out_put_spin: mae_fr = mae(force_r - test_force_r) mae_fm = mae(force_m - test_force_m) rmse_fr = rmse(force_r - test_force_r) @@ -396,16 +428,16 @@ def test_ener( log.info(f"Energy RMSE : {rmse_e:e} eV") log.info(f"Energy MAE/Natoms : {mae_ea:e} eV") log.info(f"Energy RMSE/Natoms : {rmse_ea:e} eV") - if dp.get_ntypes_spin() == 0: + if not out_put_spin: log.info(f"Force MAE : {mae_f:e} eV/A") log.info(f"Force RMSE : {rmse_f:e} eV/A") else: log.info(f"Force atom MAE : {mae_fr:e} eV/A") - log.info(f"Force spin MAE : {mae_fm:e} eV/uB") log.info(f"Force atom RMSE : {rmse_fr:e} eV/A") + log.info(f"Force spin MAE : {mae_fm:e} eV/uB") log.info(f"Force spin RMSE : {rmse_fm:e} eV/uB") - if data.pbc: + if data.pbc and not out_put_spin: log.info(f"Virial MAE : {mae_v:e} eV") log.info(f"Virial RMSE : {rmse_v:e} eV") log.info(f"Virial MAE/Natoms : {mae_va:e} eV") @@ -437,7 +469,7 @@ def test_ener( header="%s: data_e pred_e" % system, append=append_detail, ) - if dp.get_ntypes_spin() == 0: + if not out_put_spin: pf = np.concatenate( ( np.reshape(test_data["force"][:numb_test], [-1, 3]), @@ -497,7 +529,7 @@ def test_ener( "pred_vyy pred_vyz pred_vzx pred_vzy pred_vzz", append=append_detail, ) - if dp.get_ntypes_spin() == 0: + if not out_put_spin: return { "mae_e": (mae_e, energy.size), "mae_ea": (mae_ea, energy.size), diff --git a/deepmd/infer/deep_eval.py b/deepmd/infer/deep_eval.py index de964b88b9..065982a870 100644 --- a/deepmd/infer/deep_eval.py +++ b/deepmd/infer/deep_eval.py @@ -57,7 +57,9 @@ class DeepEvalBackend(ABC): "energy": "atom_energy", "energy_redu": "energy", "energy_derv_r": "force", + "energy_derv_r_mag": "force_mag", "energy_derv_c": "atom_virial", + "energy_derv_c_mag": "atom_virial_mag", "energy_derv_c_redu": "virial", "polar": "polar", "polar_redu": "global_polar", @@ -71,6 +73,8 @@ class DeepEvalBackend(ABC): "dipole_derv_c_redu": "virial", "dos": "atom_dos", "dos_redu": "dos", + "mask_mag": "mask_mag", + "mask": "mask", } @abstractmethod @@ -262,9 +266,13 @@ def get_has_efield(self): """Check if the model has efield.""" return False + def get_has_spin(self): + """Check if the model has spin atom types.""" + return False + @abstractmethod def get_ntypes_spin(self) -> int: - """Get the number of spin atom types of this model.""" + """Get the number of spin atom types of this model. Only used in old implement.""" class DeepEval(ABC): @@ -317,6 +325,8 @@ def __init__( neighbor_list=neighbor_list, **kwargs, ) + if self.deep_eval.get_has_spin() and hasattr(self, "output_def_mag"): + self.deep_eval.output_def = self.output_def_mag @property @abstractmethod @@ -518,6 +528,11 @@ def has_efield(self) -> bool: """Check if the model has efield.""" return self.deep_eval.get_has_efield() + @property + def has_spin(self) -> bool: + """Check if the model has spin.""" + return self.deep_eval.get_has_spin() + def get_ntypes_spin(self) -> int: - """Get the number of spin atom types of this model.""" + """Get the number of spin atom types of this model. Only used in old implement.""" return self.deep_eval.get_ntypes_spin() diff --git a/deepmd/infer/deep_pot.py b/deepmd/infer/deep_pot.py index e955a3ed65..bc0bfc9599 100644 --- a/deepmd/infer/deep_pot.py +++ b/deepmd/infer/deep_pot.py @@ -70,6 +70,25 @@ def output_def(self) -> ModelOutputDef: ) ) + @property + def output_def_mag(self) -> ModelOutputDef: + """Get the output definition of this model with magnetic parts.""" + return ModelOutputDef( + FittingOutputDef( + [ + OutputVariableDef( + "energy", + shape=[1], + reduciable=True, + r_differentiable=True, + c_differentiable=True, + atomic=True, + magnetic=True, + ), + ] + ) + ) + def eval( self, coords: np.ndarray, @@ -162,7 +181,7 @@ def eval( natoms_real = natoms atomic_energy = results["energy"].reshape(nframes, natoms_real, 1) atomic_virial = results["energy_derv_c"].reshape(nframes, natoms, 9) - return ( + result = ( energy, force, virial, @@ -170,11 +189,16 @@ def eval( atomic_virial, ) else: - return ( + result = ( energy, force, virial, ) + if self.deep_eval.get_has_spin(): + force_mag = results["energy_derv_r_mag"].reshape(nframes, natoms, 3) + mask_mag = results["mask_mag"].reshape(nframes, natoms, 1) + result = (*list(result), force_mag, mask_mag) + return result __all__ = ["DeepPot"] diff --git a/deepmd/pt/infer/deep_eval.py b/deepmd/pt/infer/deep_eval.py index bf6a5b0306..b8031993c0 100644 --- a/deepmd/pt/infer/deep_eval.py +++ b/deepmd/pt/infer/deep_eval.py @@ -124,6 +124,9 @@ def __init__( self.auto_batch_size = auto_batch_size else: raise TypeError("auto_batch_size should be bool, int, or AutoBatchSize") + self._has_spin = getattr(self.dp.model["Default"], "has_spin", False) + if callable(self._has_spin): + self._has_spin = self._has_spin() def get_rcut(self) -> float: """Get the cutoff radius of this model.""" @@ -182,9 +185,13 @@ def get_has_efield(self): return False def get_ntypes_spin(self): - """Get the number of spin atom types of this model.""" + """Get the number of spin atom types of this model. Only used in old implement.""" return 0 + def get_has_spin(self): + """Check if the model has spin atom types.""" + return self._has_spin + def eval( self, coords: np.ndarray, @@ -240,14 +247,20 @@ def eval( coords, atom_types, len(atom_types.shape) > 1 ) request_defs = self._get_request_defs(atomic) - out = self._eval_func(self._eval_model, numb_test, natoms)( - coords, - cells, - atom_types, - fparam, - aparam, - request_defs, - ) + if "spin" not in kwargs or kwargs["spin"] is None: + out = self._eval_func(self._eval_model, numb_test, natoms)( + coords, cells, atom_types, fparam, aparam, request_defs + ) + else: + out = self._eval_func(self._eval_model_spin, numb_test, natoms)( + coords, + cells, + atom_types, + np.array(kwargs["spin"]), + fparam, + aparam, + request_defs, + ) return dict( zip( [x.name for x in request_defs], @@ -280,6 +293,7 @@ def _get_request_defs(self, atomic: bool) -> List[OutputVariableDef]: for x in self.output_def.var_defs.values() if x.category in ( + OutputVariableCategory.OUT, OutputVariableCategory.REDU, OutputVariableCategory.DERV_R, OutputVariableCategory.DERV_C_REDU, @@ -399,6 +413,82 @@ def _eval_model( results.append(np.full(np.abs(shape), np.nan)) # this is kinda hacky return tuple(results) + def _eval_model_spin( + self, + coords: np.ndarray, + cells: Optional[np.ndarray], + atom_types: np.ndarray, + spins: np.ndarray, + fparam: Optional[np.ndarray], + aparam: Optional[np.ndarray], + request_defs: List[OutputVariableDef], + ): + model = self.dp.to(DEVICE) + + nframes = coords.shape[0] + if len(atom_types.shape) == 1: + natoms = len(atom_types) + atom_types = np.tile(atom_types, nframes).reshape(nframes, -1) + else: + natoms = len(atom_types[0]) + + coord_input = torch.tensor( + coords.reshape([-1, natoms, 3]), + dtype=GLOBAL_PT_FLOAT_PRECISION, + device=DEVICE, + ) + type_input = torch.tensor(atom_types, dtype=torch.long, device=DEVICE) + spin_input = torch.tensor( + spins.reshape([-1, natoms, 3]), + dtype=GLOBAL_PT_FLOAT_PRECISION, + device=DEVICE, + ) + if cells is not None: + box_input = torch.tensor( + cells.reshape([-1, 3, 3]), + dtype=GLOBAL_PT_FLOAT_PRECISION, + device=DEVICE, + ) + else: + box_input = None + if fparam is not None: + fparam_input = to_torch_tensor(fparam.reshape(-1, self.get_dim_fparam())) + else: + fparam_input = None + if aparam is not None: + aparam_input = to_torch_tensor( + aparam.reshape(-1, natoms, self.get_dim_aparam()) + ) + else: + aparam_input = None + + do_atomic_virial = any( + x.category == OutputVariableCategory.DERV_C_REDU for x in request_defs + ) + batch_output = model( + coord_input, + type_input, + spin=spin_input, + box=box_input, + do_atomic_virial=do_atomic_virial, + fparam=fparam_input, + aparam=aparam_input, + ) + if isinstance(batch_output, tuple): + batch_output = batch_output[0] + + results = [] + for odef in request_defs: + pt_name = self._OUTDEF_DP2BACKEND[odef.name] + if pt_name in batch_output: + shape = self._get_output_shape(odef, nframes, natoms) + out = batch_output[pt_name].reshape(shape).detach().cpu().numpy() + results.append(out) + else: + shape = self._get_output_shape(odef, nframes, natoms) + results.append(np.full(np.abs(shape), np.nan)) # this is kinda hacky + return tuple(results) + def _get_output_shape(self, odef, nframes, natoms): if odef.category == OutputVariableCategory.DERV_C_REDU: # virial @@ -427,6 +517,7 @@ def eval_model( coords: Union[np.ndarray, torch.Tensor], cells: Optional[Union[np.ndarray, torch.Tensor]], atom_types: Union[np.ndarray, torch.Tensor, List[int]], + spins: Optional[Union[np.ndarray, torch.Tensor]] = None, atomic: bool = False, infer_batch_size: int = 2, denoise: bool = False, @@ -435,6 +526,7 @@ def eval_model( energy_out = [] atomic_energy_out = [] force_out = [] + force_mag_out = [] virial_out = [] atomic_virial_out = [] updated_coord_out = [] @@ -447,11 +539,15 @@ def eval_model( if isinstance(coords, torch.Tensor): if cells is not None: assert isinstance(cells, torch.Tensor), err_msg + if spins is not None: + assert isinstance(spins, torch.Tensor), err_msg assert isinstance(atom_types, torch.Tensor) or isinstance(atom_types, list) atom_types = torch.tensor(atom_types, dtype=torch.long, device=DEVICE) elif isinstance(coords, np.ndarray): if cells is not None: assert isinstance(cells, np.ndarray), err_msg + if spins is not None: + assert isinstance(spins, np.ndarray), err_msg assert isinstance(atom_types, np.ndarray) or isinstance(atom_types, list) atom_types = np.array(atom_types, dtype=np.int32) return_tensor = False @@ -471,6 +567,16 @@ def eval_model( coord_input = torch.tensor( coords.reshape([-1, natoms, 3]), dtype=GLOBAL_PT_FLOAT_PRECISION, device=DEVICE ) + spin_input = None + if spins is not None: + spin_input = torch.tensor( + spins.reshape([-1, natoms, 3]), + dtype=GLOBAL_PT_FLOAT_PRECISION, + device=DEVICE, + ) + has_spin = getattr(model, "has_spin", False) + if callable(has_spin): + has_spin = has_spin() type_input = torch.tensor(atom_types, dtype=torch.long, device=DEVICE) box_input = None if cells is None: @@ -486,9 +592,20 @@ def eval_model( batch_coord = coord_input[ii * infer_batch_size : (ii + 1) * infer_batch_size] batch_atype = type_input[ii * infer_batch_size : (ii + 1) * infer_batch_size] batch_box = None + batch_spin = None + if spin_input is not None: + batch_spin = spin_input[ii * infer_batch_size : (ii + 1) * infer_batch_size] if pbc: batch_box = box_input[ii * infer_batch_size : (ii + 1) * infer_batch_size] - batch_output = model(batch_coord, batch_atype, box=batch_box) + input_dict = { + "coord": batch_coord, + "atype": batch_atype, + "box": batch_box, + "do_atomic_virial": atomic, + } + if has_spin: + input_dict["spin"] = batch_spin + batch_output = model(**input_dict) if isinstance(batch_output, tuple): batch_output = batch_output[0] if not return_tensor: @@ -500,6 +617,8 @@ def eval_model( ) if "force" in batch_output: force_out.append(batch_output["force"].detach().cpu().numpy()) + if "force_mag" in batch_output: + force_mag_out.append(batch_output["force_mag"].detach().cpu().numpy()) if "virial" in batch_output: virial_out.append(batch_output["virial"].detach().cpu().numpy()) if "atom_virial" in batch_output: @@ -519,6 +638,8 @@ def eval_model( atomic_energy_out.append(batch_output["atom_energy"]) if "force" in batch_output: force_out.append(batch_output["force"]) + if "force_mag" in batch_output: + force_mag_out.append(batch_output["force_mag"]) if "virial" in batch_output: virial_out.append(batch_output["virial"]) if "atom_virial" in batch_output: @@ -539,6 +660,11 @@ def eval_model( force_out = ( np.concatenate(force_out) if force_out else np.zeros([nframes, natoms, 3]) ) + force_mag_out = ( + np.concatenate(force_mag_out) + if force_mag_out + else np.zeros([nframes, natoms, 3]) + ) virial_out = ( np.concatenate(virial_out) if virial_out else np.zeros([nframes, 3, 3]) ) @@ -573,6 +699,13 @@ def eval_model( [nframes, natoms, 3], dtype=GLOBAL_PT_FLOAT_PRECISION, device=DEVICE ) ) + force_mag_out = ( + torch.cat(force_mag_out) + if force_mag_out + else torch.zeros( + [nframes, natoms, 3], dtype=GLOBAL_PT_FLOAT_PRECISION, device=DEVICE + ) + ) virial_out = ( torch.cat(virial_out) if virial_out @@ -592,13 +725,14 @@ def eval_model( if denoise: return updated_coord_out, logits_out else: - if not atomic: - return energy_out, force_out, virial_out - else: - return ( - energy_out, - force_out, - virial_out, - atomic_energy_out, - atomic_virial_out, - ) + results_dict = { + "energy": energy_out, + "force": force_out, + "virial": virial_out, + } + if has_spin: + results_dict["force_mag"] = force_mag_out + if atomic: + results_dict["atom_energy"] = atomic_energy_out + results_dict["atom_virial"] = atomic_virial_out + return results_dict diff --git a/deepmd/pt/loss/__init__.py b/deepmd/pt/loss/__init__.py index d2f6ab9e52..9c8bbc9a2a 100644 --- a/deepmd/pt/loss/__init__.py +++ b/deepmd/pt/loss/__init__.py @@ -5,6 +5,9 @@ from .ener import ( EnergyStdLoss, ) +from .ener_spin import ( + EnergySpinLoss, +) from .loss import ( TaskLoss, ) @@ -15,6 +18,7 @@ __all__ = [ "DenoiseLoss", "EnergyStdLoss", + "EnergySpinLoss", "TensorLoss", "TaskLoss", ] diff --git a/deepmd/pt/loss/ener_spin.py b/deepmd/pt/loss/ener_spin.py new file mode 100644 index 0000000000..b94acf26ea --- /dev/null +++ b/deepmd/pt/loss/ener_spin.py @@ -0,0 +1,245 @@ +# SPDX-License-Identifier: LGPL-3.0-or-later +from typing import ( + List, +) + +import torch +import torch.nn.functional as F + +from deepmd.pt.loss.loss import ( + TaskLoss, +) +from deepmd.pt.utils import ( + env, +) +from deepmd.pt.utils.env import ( + GLOBAL_PT_FLOAT_PRECISION, +) +from deepmd.utils.data import ( + DataRequirementItem, +) + + +class EnergySpinLoss(TaskLoss): + def __init__( + self, + starter_learning_rate=1.0, + start_pref_e=0.0, + limit_pref_e=0.0, + start_pref_fr=0.0, + limit_pref_fr=0.0, + start_pref_fm=0.0, + limit_pref_fm=0.0, + start_pref_v=0.0, + limit_pref_v=0.0, + start_pref_ae: float = 0.0, + limit_pref_ae: float = 0.0, + start_pref_pf: float = 0.0, + limit_pref_pf: float = 0.0, + use_l1_all: bool = False, + inference=False, + **kwargs, + ): + """Construct a layer to compute loss on energy, real force, magnetic force and virial.""" + super().__init__() + self.starter_learning_rate = starter_learning_rate + self.has_e = (start_pref_e != 0.0 and limit_pref_e != 0.0) or inference + self.has_fr = (start_pref_fr != 0.0 and limit_pref_fr != 0.0) or inference + self.has_fm = (start_pref_fm != 0.0 and limit_pref_fm != 0.0) or inference + + # TODO need support for virial, atomic energy and atomic pref + self.has_v = (start_pref_v != 0.0 and limit_pref_v != 0.0) or inference + self.has_ae = (start_pref_ae != 0.0 and limit_pref_ae != 0.0) or inference + self.has_pf = (start_pref_pf != 0.0 and limit_pref_pf != 0.0) or inference + + self.start_pref_e = start_pref_e + self.limit_pref_e = limit_pref_e + self.start_pref_fr = start_pref_fr + self.limit_pref_fr = limit_pref_fr + self.start_pref_fm = start_pref_fm + self.limit_pref_fm = limit_pref_fm + self.start_pref_v = start_pref_v + self.limit_pref_v = limit_pref_v + self.use_l1_all = use_l1_all + self.inference = inference + + def forward(self, model_pred, label, natoms, learning_rate, mae=False): + """Return energy loss with magnetic labels. + + Parameters + ---------- + model_pred : dict[str, torch.Tensor] + Model predictions. + label : dict[str, torch.Tensor] + Labels. + natoms : int + The local atom number. + + Returns + ------- + loss: torch.Tensor + Loss for model to minimize. + more_loss: dict[str, torch.Tensor] + Other losses for display. + """ + coef = learning_rate / self.starter_learning_rate + pref_e = self.limit_pref_e + (self.start_pref_e - self.limit_pref_e) * coef + pref_fr = self.limit_pref_fr + (self.start_pref_fr - self.limit_pref_fr) * coef + pref_fm = self.limit_pref_fm + (self.start_pref_fm - self.limit_pref_fm) * coef + pref_v = self.limit_pref_v + (self.start_pref_v - self.limit_pref_v) * coef + loss = torch.tensor(0.0, dtype=env.GLOBAL_PT_FLOAT_PRECISION, device=env.DEVICE) + more_loss = {} + # more_loss['log_keys'] = [] # showed when validation on the fly + # more_loss['test_keys'] = [] # showed when doing dp test + atom_norm = 1.0 / natoms + if self.has_e and "energy" in model_pred and "energy" in label: + if not self.use_l1_all: + l2_ener_loss = torch.mean( + torch.square(model_pred["energy"] - label["energy"]) + ) + if not self.inference: + more_loss["l2_ener_loss"] = l2_ener_loss.detach() + loss += atom_norm * (pref_e * l2_ener_loss) + rmse_e = l2_ener_loss.sqrt() * atom_norm + more_loss["rmse_e"] = rmse_e.detach() + # more_loss['log_keys'].append('rmse_e') + else: # use l1 and for all atoms + l1_ener_loss = F.l1_loss( + model_pred["energy"].reshape(-1), + label["energy"].reshape(-1), + reduction="sum", + ) + loss += pref_e * l1_ener_loss + more_loss["mae_e"] = F.l1_loss( + model_pred["energy"].reshape(-1), + label["energy"].reshape(-1), + reduction="mean", + ).detach() + # more_loss['log_keys'].append('rmse_e') + if mae: + mae_e = ( + torch.mean(torch.abs(model_pred["energy"] - label["energy"])) + * atom_norm + ) + more_loss["mae_e"] = mae_e.detach() + mae_e_all = torch.mean( + torch.abs(model_pred["energy"] - label["energy"]) + ) + more_loss["mae_e_all"] = mae_e_all.detach() + + if self.has_fr and "force" in model_pred and "force" in label: + if not self.use_l1_all: + diff_fr = label["force"] - model_pred["force"] + l2_force_real_loss = torch.mean(torch.square(diff_fr)) + if not self.inference: + more_loss["l2_force_r_loss"] = l2_force_real_loss.detach() + loss += (pref_fr * l2_force_real_loss).to(GLOBAL_PT_FLOAT_PRECISION) + rmse_fr = l2_force_real_loss.sqrt() + more_loss["rmse_fr"] = rmse_fr.detach() + if mae: + mae_fr = torch.mean(torch.abs(diff_fr)) + more_loss["mae_fr"] = mae_fr.detach() + else: + l1_force_real_loss = F.l1_loss( + label["force"], model_pred["force"], reduction="none" + ) + more_loss["mae_fr"] = l1_force_real_loss.mean().detach() + l1_force_real_loss = l1_force_real_loss.sum(-1).mean(-1).sum() + loss += (pref_fr * l1_force_real_loss).to(GLOBAL_PT_FLOAT_PRECISION) + + if self.has_fm and "force_mag" in model_pred and "force_mag" in label: + nframes = model_pred["force_mag"].shape[0] + atomic_mask = model_pred["mask_mag"].expand([-1, -1, 3]) + label_force_mag = label["force_mag"][atomic_mask].view(nframes, -1, 3) + model_pred_force_mag = model_pred["force_mag"][atomic_mask].view( + nframes, -1, 3 + ) + if not self.use_l1_all: + diff_fm = label_force_mag - model_pred_force_mag + l2_force_mag_loss = torch.mean(torch.square(diff_fm)) + if not self.inference: + more_loss["l2_force_m_loss"] = l2_force_mag_loss.detach() + loss += (pref_fm * l2_force_mag_loss).to(GLOBAL_PT_FLOAT_PRECISION) + rmse_fm = l2_force_mag_loss.sqrt() + more_loss["rmse_fm"] = rmse_fm.detach() + if mae: + mae_fm = torch.mean(torch.abs(diff_fm)) + more_loss["mae_fm"] = mae_fm.detach() + else: + l1_force_mag_loss = F.l1_loss( + label_force_mag, model_pred_force_mag, reduction="none" + ) + more_loss["mae_fm"] = l1_force_mag_loss.mean().detach() + l1_force_mag_loss = l1_force_mag_loss.sum(-1).mean(-1).sum() + loss += (pref_fm * l1_force_mag_loss).to(GLOBAL_PT_FLOAT_PRECISION) + + if not self.inference: + more_loss["rmse"] = torch.sqrt(loss.detach()) + return loss, more_loss + + @property + def label_requirement(self) -> List[DataRequirementItem]: + """Return data label requirements needed for this loss calculation.""" + label_requirement = [] + if self.has_e: + label_requirement.append( + DataRequirementItem( + "energy", + ndof=1, + atomic=False, + must=False, + high_prec=True, + ) + ) + if self.has_fr: + label_requirement.append( + DataRequirementItem( + "force", + ndof=3, + atomic=True, + must=False, + high_prec=False, + ) + ) + if self.has_fm: + label_requirement.append( + DataRequirementItem( + "force_mag", + ndof=3, + atomic=True, + must=False, + high_prec=False, + ) + ) + if self.has_v: + label_requirement.append( + DataRequirementItem( + "virial", + ndof=9, + atomic=False, + must=False, + high_prec=False, + ) + ) + if self.has_ae: + label_requirement.append( + DataRequirementItem( + "atom_ener", + ndof=1, + atomic=True, + must=False, + high_prec=False, + ) + ) + if self.has_pf: + label_requirement.append( + DataRequirementItem( + "atom_pref", + ndof=1, + atomic=True, + must=False, + high_prec=False, + repeat=3, + ) + ) + return label_requirement diff --git a/deepmd/pt/model/atomic_model/dp_atomic_model.py b/deepmd/pt/model/atomic_model/dp_atomic_model.py index 807f8433e5..cad1e1cc88 100644 --- a/deepmd/pt/model/atomic_model/dp_atomic_model.py +++ b/deepmd/pt/model/atomic_model/dp_atomic_model.py @@ -1,5 +1,6 @@ # SPDX-License-Identifier: LGPL-3.0-or-later import copy +import functools import logging from typing import ( Dict, @@ -204,9 +205,23 @@ def compute_or_load_stat( # descriptors and fitting net with different type_map # should not share the same parameters stat_file_path /= " ".join(self.type_map) - self.descriptor.compute_input_stats(sampled_func, stat_file_path) + + @functools.lru_cache + def wrapped_sampler(): + sampled = sampled_func() + if self.pair_excl is not None: + pair_exclude_types = self.pair_excl.get_exclude_types() + for sample in sampled: + sample["pair_exclude_types"] = list(pair_exclude_types) + if self.atom_excl is not None: + atom_exclude_types = self.atom_excl.get_exclude_types() + for sample in sampled: + sample["atom_exclude_types"] = list(atom_exclude_types) + return sampled + + self.descriptor.compute_input_stats(wrapped_sampler, stat_file_path) if self.fitting_net is not None: - self.fitting_net.compute_output_stats(sampled_func, stat_file_path) + self.fitting_net.compute_output_stats(wrapped_sampler, stat_file_path) @torch.jit.export def get_dim_fparam(self) -> int: diff --git a/deepmd/pt/model/atomic_model/pairtab_atomic_model.py b/deepmd/pt/model/atomic_model/pairtab_atomic_model.py index 215bb25de5..19a67fc8ff 100644 --- a/deepmd/pt/model/atomic_model/pairtab_atomic_model.py +++ b/deepmd/pt/model/atomic_model/pairtab_atomic_model.py @@ -230,7 +230,7 @@ def compute_or_load_stat( """ bias_atom_e = compute_output_stats( - merged, stat_file_path, self.rcond, self.atom_ener + merged, self.ntypes, stat_file_path, self.rcond, self.atom_ener ) self.bias_atom_e.copy_( torch.tensor(bias_atom_e, device=env.DEVICE).view([self.ntypes, 1]) diff --git a/deepmd/pt/model/descriptor/dpa1.py b/deepmd/pt/model/descriptor/dpa1.py index 1b32467540..21275317dc 100644 --- a/deepmd/pt/model/descriptor/dpa1.py +++ b/deepmd/pt/model/descriptor/dpa1.py @@ -3,6 +3,7 @@ Callable, List, Optional, + Tuple, Union, ) @@ -55,13 +56,14 @@ def __init__( temperature=None, return_rot=False, concat_output_tebd: bool = True, + env_protection: float = 0.0, type: Optional[str] = None, # not implemented resnet_dt: bool = False, type_one_side: bool = True, precision: str = "default", trainable: bool = True, - exclude_types: Optional[List[List[int]]] = None, + exclude_types: List[Tuple[int, int]] = [], stripped_type_embedding: bool = False, smooth_type_embdding: bool = False, ): @@ -72,8 +74,6 @@ def __init__( raise NotImplementedError("type_one_side is not supported.") if precision != "default" and precision != "float64": raise NotImplementedError("precison is not supported.") - if exclude_types is not None and exclude_types != []: - raise NotImplementedError("exclude_types is not supported.") if stripped_type_embedding: raise NotImplementedError("stripped_type_embedding is not supported.") if smooth_type_embdding: @@ -102,6 +102,8 @@ def __init__( normalize=normalize, temperature=temperature, return_rot=return_rot, + exclude_types=exclude_types, + env_protection=env_protection, ) self.type_embedding = TypeEmbedNet(ntypes, tebd_dim) self.tebd_dim = tebd_dim diff --git a/deepmd/pt/model/descriptor/dpa2.py b/deepmd/pt/model/descriptor/dpa2.py index a80cc4a445..fb792a51e2 100644 --- a/deepmd/pt/model/descriptor/dpa2.py +++ b/deepmd/pt/model/descriptor/dpa2.py @@ -3,6 +3,7 @@ Callable, List, Optional, + Tuple, Union, ) @@ -77,7 +78,9 @@ def __init__( repformer_update_style: str = "res_avg", repformer_set_davg_zero: bool = True, # TODO repformer_add_type_ebd_to_seq: bool = False, + env_protection: float = 0.0, trainable: bool = True, + exclude_types: List[Tuple[int, int]] = [], type: Optional[ str ] = None, # work around the bad design in get_trainer and DpLoaderSet! @@ -175,6 +178,9 @@ def __init__( repformers block: concatenate the type embedding at the output. trainable : bool If the parameters in the descriptor are trainable. + exclude_types : List[Tuple[int, int]] = [], + The excluded pairs of types which have no interaction with each other. + For example, `[[0, 1]]` means no interaction between type 0 and type 1. Returns ------- @@ -205,6 +211,8 @@ def __init__( tebd_input_mode="concat", # tebd_input_mode='dot_residual_s', set_davg_zero=repinit_set_davg_zero, + exclude_types=exclude_types, + env_protection=env_protection, activation_function=repinit_activation, ) self.repformers = DescrptBlockRepformers( @@ -236,6 +244,8 @@ def __init__( set_davg_zero=repformer_set_davg_zero, smooth=True, add_type_ebd_to_seq=repformer_add_type_ebd_to_seq, + exclude_types=exclude_types, + env_protection=env_protection, ) self.type_embedding = TypeEmbedNet(ntypes, tebd_dim) if self.repinit.dim_out == self.repformers.dim_in: diff --git a/deepmd/pt/model/descriptor/env_mat.py b/deepmd/pt/model/descriptor/env_mat.py index 4e6ffb7785..e89e7467d3 100644 --- a/deepmd/pt/model/descriptor/env_mat.py +++ b/deepmd/pt/model/descriptor/env_mat.py @@ -8,7 +8,12 @@ def _make_env_mat( - nlist, coord, rcut: float, ruct_smth: float, radial_only: bool = False + nlist, + coord, + rcut: float, + ruct_smth: float, + radial_only: bool = False, + protection: float = 0.0, ): """Make smooth environment matrix.""" bsz, natoms, nnei = nlist.shape @@ -25,8 +30,8 @@ def _make_env_mat( length = torch.linalg.norm(diff, dim=-1, keepdim=True) # for index 0 nloc atom length = length + ~mask.unsqueeze(-1) - t0 = 1 / length - t1 = diff / length**2 + t0 = 1 / (length + protection) + t1 = diff / (length + protection) ** 2 weight = compute_smooth_weight(length, ruct_smth, rcut) weight = weight * mask.unsqueeze(-1) if radial_only: @@ -45,6 +50,7 @@ def prod_env_mat( rcut: float, rcut_smth: float, radial_only: bool = False, + protection: float = 0.0, ): """Generate smooth environment matrix from atom coordinates and other context. @@ -56,13 +62,19 @@ def prod_env_mat( - rcut: Cut-off radius. - rcut_smth: Smooth hyper-parameter for pair force & energy. - radial_only: Whether to return a full description or a radial-only descriptor. + - protection: Protection parameter to prevent division by zero errors during calculations. Returns ------- - env_mat: Shape is [nframes, natoms[1]*nnei*4]. """ _env_mat_se_a, diff, switch = _make_env_mat( - nlist, extended_coord, rcut, rcut_smth, radial_only + nlist, + extended_coord, + rcut, + rcut_smth, + radial_only, + protection=protection, ) # shape [n_atom, dim, 4 or 1] t_avg = mean[atype] # [n_atom, dim, 4 or 1] t_std = stddev[atype] # [n_atom, dim, 4 or 1] diff --git a/deepmd/pt/model/descriptor/repformers.py b/deepmd/pt/model/descriptor/repformers.py index 3e8bf72f77..a908d2e057 100644 --- a/deepmd/pt/model/descriptor/repformers.py +++ b/deepmd/pt/model/descriptor/repformers.py @@ -4,6 +4,7 @@ Dict, List, Optional, + Tuple, Union, ) @@ -24,6 +25,9 @@ from deepmd.pt.utils.env_mat_stat import ( EnvMatStatSe, ) +from deepmd.pt.utils.exclude_mask import ( + PairExcludeMask, +) from deepmd.pt.utils.utils import ( get_activation_fn, ) @@ -83,6 +87,8 @@ def __init__( set_davg_zero: bool = True, # TODO smooth: bool = True, add_type_ebd_to_seq: bool = False, + exclude_types: List[Tuple[int, int]] = [], + env_protection: float = 0.0, type: Optional[str] = None, ): """ @@ -114,6 +120,9 @@ def __init__( self.act = get_activation_fn(activation_function) self.direct_dist = direct_dist self.add_type_ebd_to_seq = add_type_ebd_to_seq + # order matters, placed after the assignment of self.ntypes + self.reinit_exclude(exclude_types) + self.env_protection = env_protection self.g2_embd = mylinear(1, self.g2_dim) layers = [] @@ -211,6 +220,13 @@ def dim_emb(self): """Returns the embedding dimension g2.""" return self.get_dim_emb() + def reinit_exclude( + self, + exclude_types: List[Tuple[int, int]] = [], + ): + self.exclude_types = exclude_types + self.emask = PairExcludeMask(self.ntypes, exclude_types=exclude_types) + def forward( self, nlist: torch.Tensor, @@ -233,6 +249,7 @@ def forward( self.stddev, self.rcut, self.rcut_smth, + protection=self.env_protection, ) nlist_mask = nlist != -1 sw = torch.squeeze(sw, -1) diff --git a/deepmd/pt/model/descriptor/se_a.py b/deepmd/pt/model/descriptor/se_a.py index c4b2c772f8..e17b7c5d54 100644 --- a/deepmd/pt/model/descriptor/se_a.py +++ b/deepmd/pt/model/descriptor/se_a.py @@ -79,6 +79,7 @@ def __init__( precision: str = "float64", resnet_dt: bool = False, exclude_types: List[Tuple[int, int]] = [], + env_protection: float = 0.0, old_impl: bool = False, type_one_side: bool = True, **kwargs, @@ -95,6 +96,7 @@ def __init__( precision=precision, resnet_dt=resnet_dt, exclude_types=exclude_types, + env_protection=env_protection, old_impl=old_impl, type_one_side=type_one_side, **kwargs, @@ -249,6 +251,7 @@ def serialize(self) -> dict: "embeddings": obj.filter_layers.serialize(), "env_mat": DPEnvMat(obj.rcut, obj.rcut_smth).serialize(), "exclude_types": obj.exclude_types, + "env_protection": obj.env_protection, "@variables": { "davg": obj["davg"].detach().cpu().numpy(), "dstd": obj["dstd"].detach().cpu().numpy(), @@ -310,6 +313,7 @@ def __init__( precision: str = "float64", resnet_dt: bool = False, exclude_types: List[Tuple[int, int]] = [], + env_protection: float = 0.0, old_impl: bool = False, type_one_side: bool = True, trainable: bool = True, @@ -336,6 +340,7 @@ def __init__( self.prec = PRECISION_DICT[self.precision] self.resnet_dt = resnet_dt self.old_impl = old_impl + self.env_protection = env_protection self.ntypes = len(sel) self.type_one_side = type_one_side # order matters, placed after the assignment of self.ntypes @@ -539,6 +544,7 @@ def forward( self.stddev, self.rcut, self.rcut_smth, + protection=self.env_protection, ) if self.old_impl: diff --git a/deepmd/pt/model/descriptor/se_atten.py b/deepmd/pt/model/descriptor/se_atten.py index db9202c7fc..051c66385c 100644 --- a/deepmd/pt/model/descriptor/se_atten.py +++ b/deepmd/pt/model/descriptor/se_atten.py @@ -4,6 +4,7 @@ Dict, List, Optional, + Tuple, Union, ) @@ -26,6 +27,9 @@ from deepmd.pt.utils.env_mat_stat import ( EnvMatStatSe, ) +from deepmd.pt.utils.exclude_mask import ( + PairExcludeMask, +) from deepmd.utils.env_mat_stat import ( StatItem, ) @@ -61,6 +65,8 @@ def __init__( normalize=True, temperature=None, return_rot=False, + exclude_types: List[Tuple[int, int]] = [], + env_protection: float = 0.0, type: Optional[str] = None, ): """Construct an embedding net of type `se_atten`. @@ -96,6 +102,7 @@ def __init__( self.normalize = normalize self.temperature = temperature self.return_rot = return_rot + self.env_protection = env_protection if isinstance(sel, int): sel = [sel] @@ -106,6 +113,8 @@ def __init__( self.split_sel = self.sel self.nnei = sum(sel) self.ndescrpt = self.nnei * 4 + # order matters, placed after the assignment of self.ntypes + self.reinit_exclude(exclude_types) self.dpa1_attention = NeighborWiseAttention( self.attn_layer, self.nnei, @@ -249,6 +258,13 @@ def get_stats(self) -> Dict[str, StatItem]: ) return self.stats + def reinit_exclude( + self, + exclude_types: List[Tuple[int, int]] = [], + ): + self.exclude_types = exclude_types + self.emask = PairExcludeMask(self.ntypes, exclude_types=exclude_types) + def forward( self, nlist: torch.Tensor, @@ -284,6 +300,7 @@ def forward( self.stddev, self.rcut, self.rcut_smth, + protection=self.env_protection, ) # [nfxnlocxnnei, self.ndescrpt] dmatrix = dmatrix.view(-1, self.ndescrpt) diff --git a/deepmd/pt/model/descriptor/se_r.py b/deepmd/pt/model/descriptor/se_r.py index 5a4920b0e6..ff922e0649 100644 --- a/deepmd/pt/model/descriptor/se_r.py +++ b/deepmd/pt/model/descriptor/se_r.py @@ -64,6 +64,7 @@ def __init__( precision: str = "float64", resnet_dt: bool = False, exclude_types: List[Tuple[int, int]] = [], + env_protection: float = 0.0, old_impl: bool = False, trainable: bool = True, **kwargs, @@ -81,7 +82,9 @@ def __init__( self.old_impl = False # this does not support old implementation. self.exclude_types = exclude_types self.ntypes = len(sel) - self.emask = PairExcludeMask(len(sel), exclude_types=exclude_types) + # order matters, placed after the assignment of self.ntypes + self.reinit_exclude(exclude_types) + self.env_protection = env_protection self.sel = sel self.sec = torch.tensor( @@ -253,6 +256,13 @@ def __getitem__(self, key): else: raise KeyError(key) + def reinit_exclude( + self, + exclude_types: List[Tuple[int, int]] = [], + ): + self.exclude_types = exclude_types + self.emask = PairExcludeMask(self.ntypes, exclude_types=exclude_types) + def forward( self, coord_ext: torch.Tensor, @@ -302,6 +312,7 @@ def forward( self.rcut, self.rcut_smth, True, + protection=self.env_protection, ) assert self.filter_layers is not None @@ -362,6 +373,7 @@ def serialize(self) -> dict: "embeddings": self.filter_layers.serialize(), "env_mat": DPEnvMat(self.rcut, self.rcut_smth).serialize(), "exclude_types": self.exclude_types, + "env_protection": self.env_protection, "@variables": { "davg": self["davg"].detach().cpu().numpy(), "dstd": self["dstd"].detach().cpu().numpy(), diff --git a/deepmd/pt/model/model/__init__.py b/deepmd/pt/model/model/__init__.py index cd53f0a6b3..8e4352e60c 100644 --- a/deepmd/pt/model/model/__init__.py +++ b/deepmd/pt/model/model/__init__.py @@ -22,6 +22,9 @@ from deepmd.pt.model.task import ( BaseFitting, ) +from deepmd.utils.spin import ( + Spin, +) from .dp_model import ( DPModel, @@ -41,6 +44,40 @@ from .model import ( BaseModel, ) +from .spin_model import ( + SpinEnergyModel, + SpinModel, +) + + +def get_spin_model(model_params): + model_params = copy.deepcopy(model_params) + # include virtual spin and placeholder types + model_params["type_map"] += [item + "_spin" for item in model_params["type_map"]] + spin = Spin( + use_spin=model_params["spin"]["use_spin"], + virtual_scale=model_params["spin"]["virtual_scale"], + ) + pair_exclude_types = spin.get_pair_exclude_types( + exclude_types=model_params.get("pair_exclude_types", None) + ) + model_params["pair_exclude_types"] = pair_exclude_types + # for descriptor data stat + model_params["descriptor"]["exclude_types"] = pair_exclude_types + atom_exclude_types = spin.get_atom_exclude_types( + exclude_types=model_params.get("atom_exclude_types", None) + ) + model_params["atom_exclude_types"] = atom_exclude_types + if ( + "env_protection" not in model_params["descriptor"] + or model_params["descriptor"]["env_protection"] == 0.0 + ): + model_params["descriptor"]["env_protection"] = 1e-6 + if model_params["descriptor"]["type"] in ["se_e2_a"]: + # only expand sel for se_e2_a + model_params["descriptor"]["sel"] += model_params["descriptor"]["sel"] + backbone_model = get_standard_model(model_params) + return SpinEnergyModel(backbone_model=backbone_model, spin=spin) def get_zbl_model(model_params): @@ -87,7 +124,7 @@ def get_zbl_model(model_params): ) -def get_model(model_params): +def get_standard_model(model_params): model_params = copy.deepcopy(model_params) ntypes = len(model_params["type_map"]) # descriptor @@ -120,12 +157,22 @@ def get_model(model_params): return model +def get_model(model_params): + if "spin" in model_params: + return get_spin_model(model_params) + elif "use_srtab" in model_params: + return get_zbl_model(model_params) + else: + return get_standard_model(model_params) + + __all__ = [ "BaseModel", "get_model", - "get_zbl_model", "DPModel", "EnergyModel", + "SpinModel", + "SpinEnergyModel", "DPZBLModel", "make_model", "make_hessian_model", diff --git a/deepmd/pt/model/model/dp_zbl_model.py b/deepmd/pt/model/model/dp_zbl_model.py index cacf59c16c..fdf9334119 100644 --- a/deepmd/pt/model/model/dp_zbl_model.py +++ b/deepmd/pt/model/model/dp_zbl_model.py @@ -92,15 +92,16 @@ def forward_lower( model_predict["atom_energy"] = model_ret["energy"] model_predict["energy"] = model_ret["energy_redu"] if self.do_grad_r("energy"): - model_predict["force"] = model_ret["energy_derv_r"].squeeze(-2) + model_predict["extended_force"] = model_ret["energy_derv_r"].squeeze(-2) if self.do_grad_c("energy"): model_predict["virial"] = model_ret["energy_derv_c_redu"].squeeze(-2) if do_atomic_virial: - model_predict["atom_virial"] = model_ret["energy_derv_c"].squeeze(-3) + model_predict["extended_virial"] = model_ret["energy_derv_c"].squeeze( + -3 + ) else: assert model_ret["dforce"] is not None model_predict["dforce"] = model_ret["dforce"] - model_predict = model_ret return model_predict @classmethod diff --git a/deepmd/pt/model/model/spin_model.py b/deepmd/pt/model/model/spin_model.py new file mode 100644 index 0000000000..df2f48e2e4 --- /dev/null +++ b/deepmd/pt/model/model/spin_model.py @@ -0,0 +1,560 @@ +# SPDX-License-Identifier: LGPL-3.0-or-later +import functools +from typing import ( + Dict, + List, + Optional, +) + +import torch + +from deepmd.pt.utils.utils import ( + to_torch_tensor, +) +from deepmd.utils.path import ( + DPPath, +) +from deepmd.utils.spin import ( + Spin, +) + +from .dp_model import ( + DPModel, +) + + +class SpinModel(torch.nn.Module): + """A spin model wrapper, with spin input preprocess and output split.""" + + def __init__( + self, + backbone_model, + spin: Spin, + ): + super().__init__() + self.backbone_model = backbone_model + self.spin = spin + self.ntypes_real = self.spin.ntypes_real + self.virtual_scale_mask = to_torch_tensor(self.spin.get_virtual_scale_mask()) + self.spin_mask = to_torch_tensor(self.spin.get_spin_mask()) + + def process_spin_input(self, coord, atype, spin): + """Generate virtual coordinates and types, concat into the input.""" + nframes, nloc = coord.shape[:-1] + atype_spin = torch.concat([atype, atype + self.ntypes_real], dim=-1) + virtual_coord = coord + spin * self.virtual_scale_mask[atype].reshape( + [nframes, nloc, 1] + ) + coord_spin = torch.concat([coord, virtual_coord], dim=-2) + return coord_spin, atype_spin + + def process_spin_input_lower( + self, + extended_coord, + extended_atype, + extended_spin, + nlist, + mapping: Optional[torch.Tensor] = None, + ): + """ + Add `extended_spin` into `extended_coord` to generate virtual atoms, and extend `nlist` and `mapping`. + Note that the final `extended_coord_updated` with shape [nframes, nall + nall, 3] has the following order: + - [:, :nloc]: original nloc real atoms. + - [:, nloc: nloc + nloc]: virtual atoms corresponding to nloc real atoms. + - [:, nloc + nloc: nloc + nall]: ghost real atoms. + - [:, nloc + nall: nall + nall]: virtual atoms corresponding to ghost real atoms. + """ + nframes, nall = extended_coord.shape[:2] + nloc = nlist.shape[1] + virtual_extended_coord = ( + extended_coord + + extended_spin + * self.virtual_scale_mask[extended_atype].reshape([nframes, nall, 1]) + ) + virtual_extended_atype = extended_atype + self.ntypes_real + extended_coord_updated = self.concat_switch_virtual( + extended_coord, virtual_extended_coord, nloc + ) + extended_atype_updated = self.concat_switch_virtual( + extended_atype, virtual_extended_atype, nloc + ) + if mapping is not None: + virtual_mapping = mapping + nloc + mapping_updated = self.concat_switch_virtual(mapping, virtual_mapping, nloc) + else: + mapping_updated = None + # extend the nlist + nlist_updated = self.extend_nlist(extended_atype, nlist) + return ( + extended_coord_updated, + extended_atype_updated, + nlist_updated, + mapping_updated, + ) + + def process_spin_output( + self, atype, out_tensor, add_mag: bool = True, virtual_scale: bool = True + ): + """ + Split the output both real and virtual atoms, and scale the latter. + add_mag: whether to add magnetic tensor onto the real tensor. + Default: True. e.g. Ture for forces and False for atomic virials on real atoms. + virtual_scale: whether to scale the magnetic tensor with virtual scale factor. + Default: True. e.g. Ture for forces and False for atomic virials on virtual atoms. + """ + nframes, nloc_double = out_tensor.shape[:2] + nloc = nloc_double // 2 + if virtual_scale: + virtual_scale_mask = self.virtual_scale_mask + else: + virtual_scale_mask = self.spin_mask + atomic_mask = virtual_scale_mask[atype].reshape([nframes, nloc, 1]) + out_real, out_mag = torch.split(out_tensor, [nloc, nloc], dim=1) + if add_mag: + out_real = out_real + out_mag + out_mag = (out_mag.view([nframes, nloc, -1]) * atomic_mask).view(out_mag.shape) + return out_real, out_mag, atomic_mask > 0.0 + + def process_spin_output_lower( + self, + extended_atype, + extended_out_tensor, + nloc: int, + add_mag: bool = True, + virtual_scale: bool = True, + ): + """ + Split the extended output of both real and virtual atoms with switch, and scale the latter. + add_mag: whether to add magnetic tensor onto the real tensor. + Default: True. e.g. Ture for forces and False for atomic virials on real atoms. + virtual_scale: whether to scale the magnetic tensor with virtual scale factor. + Default: True. e.g. Ture for forces and False for atomic virials on virtual atoms. + """ + nframes, nall_double = extended_out_tensor.shape[:2] + nall = nall_double // 2 + if virtual_scale: + virtual_scale_mask = self.virtual_scale_mask + else: + virtual_scale_mask = self.spin_mask + atomic_mask = virtual_scale_mask[extended_atype].reshape([nframes, nall, 1]) + extended_out_real = torch.cat( + [ + extended_out_tensor[:, :nloc], + extended_out_tensor[:, nloc + nloc : nloc + nall], + ], + dim=1, + ) + extended_out_mag = torch.cat( + [ + extended_out_tensor[:, nloc : nloc + nloc], + extended_out_tensor[:, nloc + nall :], + ], + dim=1, + ) + if add_mag: + extended_out_real = extended_out_real + extended_out_mag + extended_out_mag = ( + extended_out_mag.view([nframes, nall, -1]) * atomic_mask + ).view(extended_out_mag.shape) + return extended_out_real, extended_out_mag, atomic_mask > 0.0 + + @staticmethod + def extend_nlist(extended_atype, nlist): + nframes, nloc, nnei = nlist.shape + nall = extended_atype.shape[1] + nlist_mask = nlist != -1 + nlist[nlist == -1] = 0 + nlist_shift = nlist + nall + nlist[~nlist_mask] = -1 + nlist_shift[~nlist_mask] = -1 + self_spin = torch.arange(0, nloc, dtype=nlist.dtype, device=nlist.device) + nall + self_spin = self_spin.view(1, -1, 1).expand(nframes, -1, -1) + # self spin + real neighbor + virtual neighbor + # nf x nloc x (1 + nnei + nnei) + extended_nlist = torch.cat([self_spin, nlist, nlist_shift], dim=-1) + # nf x (nloc + nloc) x (1 + nnei + nnei) + extended_nlist = torch.cat( + [extended_nlist, -1 * torch.ones_like(extended_nlist)], dim=-2 + ) + # update the index for switch + first_part_index = (nloc <= extended_nlist) & (extended_nlist < nall) + second_part_index = (nall <= extended_nlist) & (extended_nlist < (nall + nloc)) + extended_nlist[first_part_index] += nloc + extended_nlist[second_part_index] -= nall - nloc + return extended_nlist + + @staticmethod + def concat_switch_virtual(extended_tensor, extended_tensor_virtual, nloc: int): + """ + Concat real and virtual extended tensors, and switch all the local ones to the first nloc * 2 atoms. + - [:, :nloc]: original nloc real atoms. + - [:, nloc: nloc + nloc]: virtual atoms corresponding to nloc real atoms. + - [:, nloc + nloc: nloc + nall]: ghost real atoms. + - [:, nloc + nall: nall + nall]: virtual atoms corresponding to ghost real atoms. + """ + nframes, nall = extended_tensor.shape[:2] + out_shape = list(extended_tensor.shape) + out_shape[1] *= 2 + extended_tensor_updated = torch.zeros( + out_shape, + dtype=extended_tensor.dtype, + device=extended_tensor.device, + ) + extended_tensor_updated[:, :nloc] = extended_tensor[:, :nloc] + extended_tensor_updated[:, nloc : nloc + nloc] = extended_tensor_virtual[ + :, :nloc + ] + extended_tensor_updated[:, nloc + nloc : nloc + nall] = extended_tensor[ + :, nloc: + ] + extended_tensor_updated[:, nloc + nall :] = extended_tensor_virtual[:, nloc:] + return extended_tensor_updated.view(out_shape) + + @staticmethod + def expand_aparam(aparam, nloc: int): + """Expand the atom parameters for virtual atoms if necessary.""" + nframes, natom, numb_aparam = aparam.shape[1:] + if natom == nloc: # good + pass + elif natom < nloc: # for spin with virtual atoms + aparam = torch.concat( + [ + aparam, + torch.zeros( + [nframes, nloc - natom, numb_aparam], + device=aparam.device, + dtype=aparam.dtype, + ), + ], + dim=1, + ) + else: + raise ValueError( + f"get an input aparam with {aparam.shape[1]} inputs, ", + f"which is larger than {nloc} atoms.", + ) + return aparam + + @torch.jit.export + def get_type_map(self) -> List[str]: + """Get the type map.""" + tmap = self.backbone_model.get_type_map() + ntypes = len(tmap) // 2 # ignore the virtual type + return tmap[:ntypes] + + @torch.jit.export + def get_rcut(self): + """Get the cut-off radius.""" + return self.backbone_model.get_rcut() + + @torch.jit.export + def get_dim_fparam(self): + """Get the number (dimension) of frame parameters of this atomic model.""" + return self.backbone_model.get_dim_fparam() + + @torch.jit.export + def get_dim_aparam(self): + """Get the number (dimension) of atomic parameters of this atomic model.""" + return self.backbone_model.get_dim_aparam() + + @torch.jit.export + def get_sel_type(self) -> List[int]: + """Get the selected atom types of this model. + Only atoms with selected atom types have atomic contribution + to the result of the model. + If returning an empty list, all atom types are selected. + """ + return self.backbone_model.get_sel_type() + + @torch.jit.export + def is_aparam_nall(self) -> bool: + """Check whether the shape of atomic parameters is (nframes, nall, ndim). + If False, the shape is (nframes, nloc, ndim). + """ + return self.backbone_model.is_aparam_nall() + + @torch.jit.export + def model_output_type(self) -> List[str]: + """Get the output type for the model.""" + return self.backbone_model.model_output_type() + + @torch.jit.export + def get_model_def_script(self) -> str: + """Get the model definition script.""" + return self.backbone_model.get_model_def_script() + + @torch.jit.export + def get_nnei(self) -> int: + """Returns the total number of selected neighboring atoms in the cut-off radius.""" + # for C++ interface + if not self.backbone_model.mixed_types(): + return self.backbone_model.get_nnei() // 2 # ignore the virtual selected + else: + return self.backbone_model.get_nnei() + + @torch.jit.export + def get_nsel(self) -> int: + """Returns the total number of selected neighboring atoms in the cut-off radius.""" + if not self.backbone_model.mixed_types(): + return self.backbone_model.get_nsel() // 2 # ignore the virtual selected + else: + return self.backbone_model.get_nsel() + + @torch.jit.export + def has_spin(self) -> bool: + """Returns whether it has spin input and output.""" + return True + + def __getattr__(self, name): + """Get attribute from the wrapped model.""" + if ( + name == "backbone_model" + ): # torch.nn.Module will exclude modules to self.__dict__["_modules"] + return self.__dict__["_modules"]["backbone_model"] + elif name in self.__dict__: + return self.__dict__[name] + else: + return getattr(self.backbone_model, name) + + def compute_or_load_stat( + self, + sampled_func, + stat_file_path: Optional[DPPath] = None, + ): + """ + Compute or load the statistics parameters of the model, + such as mean and standard deviation of descriptors or the energy bias of the fitting net. + When `sampled` is provided, all the statistics parameters will be calculated (or re-calculated for update), + and saved in the `stat_file_path`(s). + When `sampled` is not provided, it will check the existence of `stat_file_path`(s) + and load the calculated statistics parameters. + + Parameters + ---------- + sampled_func + The lazy sampled function to get data frames from different data systems. + stat_file_path + The dictionary of paths to the statistics files. + """ + + @functools.lru_cache + def spin_sampled_func(): + sampled = sampled_func() + spin_sampled = [] + for sys in sampled: + coord_updated, atype_updated = self.process_spin_input( + sys["coord"], sys["atype"], sys["spin"] + ) + tmp_dict = { + "coord": coord_updated, + "atype": atype_updated, + } + if "natoms" in sys: + natoms = sys["natoms"] + tmp_dict["natoms"] = torch.cat( + [2 * natoms[:, :2], natoms[:, 2:], natoms[:, 2:]], dim=-1 + ) + for item_key in sys.keys(): + if item_key not in ["coord", "atype", "spin", "natoms"]: + tmp_dict[item_key] = sys[item_key] + spin_sampled.append(tmp_dict) + return spin_sampled + + self.backbone_model.compute_or_load_stat(spin_sampled_func, stat_file_path) + + def forward_common( + self, + coord, + atype, + spin, + box: Optional[torch.Tensor] = None, + fparam: Optional[torch.Tensor] = None, + aparam: Optional[torch.Tensor] = None, + do_atomic_virial: bool = False, + ) -> Dict[str, torch.Tensor]: + nframes, nloc = coord.shape[:2] + coord_updated, atype_updated = self.process_spin_input(coord, atype, spin) + model_ret = self.backbone_model.forward_common( + coord_updated, + atype_updated, + box, + fparam=fparam, + aparam=aparam, + do_atomic_virial=do_atomic_virial, + ) + model_output_type = self.backbone_model.model_output_type() + if "mask" in model_output_type: + model_output_type.pop(model_output_type.index("mask")) + var_name = model_output_type[0] + model_ret[f"{var_name}"] = torch.split( + model_ret[f"{var_name}"], [nloc, nloc], dim=1 + )[0] + if self.backbone_model.do_grad_r(var_name): + ( + model_ret[f"{var_name}_derv_r"], + model_ret[f"{var_name}_derv_r_mag"], + model_ret["mask_mag"], + ) = self.process_spin_output(atype, model_ret[f"{var_name}_derv_r"]) + if self.backbone_model.do_grad_c(var_name) and do_atomic_virial: + ( + model_ret[f"{var_name}_derv_c"], + model_ret[f"{var_name}_derv_c_mag"], + model_ret["mask_mag"], + ) = self.process_spin_output( + atype, + model_ret[f"{var_name}_derv_c"], + add_mag=False, + virtual_scale=False, + ) + return model_ret + + def forward_common_lower( + self, + extended_coord, + extended_atype, + extended_spin, + nlist, + mapping: Optional[torch.Tensor] = None, + fparam: Optional[torch.Tensor] = None, + aparam: Optional[torch.Tensor] = None, + do_atomic_virial: bool = False, + ): + nframes, nloc = nlist.shape[:2] + ( + extended_coord_updated, + extended_atype_updated, + nlist_updated, + mapping_updated, + ) = self.process_spin_input_lower( + extended_coord, extended_atype, extended_spin, nlist, mapping=mapping + ) + model_ret = self.backbone_model.forward_common_lower( + extended_coord_updated, + extended_atype_updated, + nlist_updated, + mapping=mapping_updated, + fparam=fparam, + aparam=aparam, + do_atomic_virial=do_atomic_virial, + ) + model_output_type = self.backbone_model.model_output_type() + if "mask" in model_output_type: + model_output_type.pop(model_output_type.index("mask")) + var_name = model_output_type[0] + model_ret[f"{var_name}"] = torch.split( + model_ret[f"{var_name}"], [nloc, nloc], dim=1 + )[0] + if self.backbone_model.do_grad_r(var_name): + ( + model_ret[f"{var_name}_derv_r"], + model_ret[f"{var_name}_derv_r_mag"], + model_ret["mask_mag"], + ) = self.process_spin_output_lower( + extended_atype, model_ret[f"{var_name}_derv_r"], nloc + ) + if self.backbone_model.do_grad_c(var_name) and do_atomic_virial: + ( + model_ret[f"{var_name}_derv_c"], + model_ret[f"{var_name}_derv_c_mag"], + model_ret["mask_mag"], + ) = self.process_spin_output_lower( + extended_atype, + model_ret[f"{var_name}_derv_c"], + nloc, + add_mag=False, + virtual_scale=False, + ) + return model_ret + + def serialize(self) -> dict: + return { + "backbone_model": self.backbone_model.serialize(), + "spin": self.spin.serialize(), + } + + @classmethod + def deserialize(cls, data) -> "SpinModel": + backbone_model_obj = DPModel.deserialize(data["backbone_model"]) + spin = Spin.deserialize(data["spin"]) + return cls( + backbone_model=backbone_model_obj, + spin=spin, + ) + + +class SpinEnergyModel(SpinModel): + """A spin model for energy.""" + + model_type = "ener" + + def __init__( + self, + backbone_model, + spin: Spin, + ): + super().__init__(backbone_model, spin) + + def forward( + self, + coord, + atype, + spin, + box: Optional[torch.Tensor] = None, + fparam: Optional[torch.Tensor] = None, + aparam: Optional[torch.Tensor] = None, + do_atomic_virial: bool = False, + ) -> Dict[str, torch.Tensor]: + if aparam is not None: + aparam = self.expand_aparam(aparam, coord.shape[1]) + model_ret = self.forward_common( + coord, + atype, + spin, + box, + fparam=fparam, + aparam=aparam, + do_atomic_virial=do_atomic_virial, + ) + model_predict = {} + model_predict["atom_energy"] = model_ret["energy"] + model_predict["energy"] = model_ret["energy_redu"] + model_predict["mask_mag"] = model_ret["mask_mag"] + if self.backbone_model.do_grad_r("energy"): + model_predict["force"] = model_ret["energy_derv_r"].squeeze(-2) + model_predict["force_mag"] = model_ret["energy_derv_r_mag"].squeeze(-2) + # not support virial by far + return model_predict + + @torch.jit.export + def forward_lower( + self, + extended_coord, + extended_atype, + extended_spin, + nlist, + mapping: Optional[torch.Tensor] = None, + fparam: Optional[torch.Tensor] = None, + aparam: Optional[torch.Tensor] = None, + do_atomic_virial: bool = False, + ): + model_ret = self.forward_common_lower( + extended_coord, + extended_atype, + extended_spin, + nlist, + mapping=mapping, + fparam=fparam, + aparam=aparam, + do_atomic_virial=do_atomic_virial, + ) + model_predict = {} + model_predict["atom_energy"] = model_ret["energy"] + model_predict["energy"] = model_ret["energy_redu"] + model_predict["mask_mag"] = model_ret["mask_mag"] + if self.backbone_model.do_grad_r("energy"): + model_predict["extended_force"] = model_ret["energy_derv_r"].squeeze(-2) + model_predict["extended_force_mag"] = model_ret[ + "energy_derv_r_mag" + ].squeeze(-2) + # not support virial by far + return model_predict diff --git a/deepmd/pt/model/task/ener.py b/deepmd/pt/model/task/ener.py index a11f6410a4..b593ddc3cc 100644 --- a/deepmd/pt/model/task/ener.py +++ b/deepmd/pt/model/task/ener.py @@ -162,7 +162,7 @@ def compute_output_stats( """ bias_atom_e = compute_output_stats( - merged, stat_file_path, self.rcond, self.atom_ener + merged, self.ntypes, stat_file_path, self.rcond, self.atom_ener ) self.bias_atom_e.copy_( torch.tensor(bias_atom_e, device=env.DEVICE).view( diff --git a/deepmd/pt/model/task/fitting.py b/deepmd/pt/model/task/fitting.py index 22fb409cad..09f8563bfb 100644 --- a/deepmd/pt/model/task/fitting.py +++ b/deepmd/pt/model/task/fitting.py @@ -508,10 +508,10 @@ def _forward_common( assert self.aparam_inv_std is not None if aparam.shape[-1] != self.numb_aparam: raise ValueError( - "get an input aparam of dim {aparam.shape[-1]}, ", - "which is not consistent with {self.numb_aparam}.", + f"get an input aparam of dim {aparam.shape[-1]}, ", + f"which is not consistent with {self.numb_aparam}.", ) - aparam = aparam.view([nf, nloc, self.numb_aparam]) + aparam = aparam.view([nf, -1, self.numb_aparam]) nb, nloc, _ = aparam.shape t_aparam_avg = self._extend_a_avg_std(self.aparam_avg, nb, nloc) t_aparam_inv_std = self._extend_a_avg_std(self.aparam_inv_std, nb, nloc) diff --git a/deepmd/pt/train/training.py b/deepmd/pt/train/training.py index 2a80956b9d..6938db9b3c 100644 --- a/deepmd/pt/train/training.py +++ b/deepmd/pt/train/training.py @@ -25,6 +25,7 @@ ) from deepmd.pt.loss import ( DenoiseLoss, + EnergySpinLoss, EnergyStdLoss, TensorLoss, ) @@ -207,27 +208,31 @@ def single_model_stat( _stat_file_path, _data_requirement, ): - _training_data.add_data_requirement(_data_requirement) - if _validation_data is not None: - _validation_data.add_data_requirement(_data_requirement) if _model.get_dim_fparam() > 0: fparam_requirement_items = [ DataRequirementItem( "fparam", _model.get_dim_fparam(), atomic=False, must=True ) ] - _training_data.add_data_requirement(fparam_requirement_items) - if _validation_data is not None: - _validation_data.add_data_requirement(fparam_requirement_items) + _data_requirement += fparam_requirement_items if _model.get_dim_aparam() > 0: aparam_requirement_items = [ DataRequirementItem( "aparam", _model.get_dim_aparam(), atomic=True, must=True ) ] - _training_data.add_data_requirement(aparam_requirement_items) - if _validation_data is not None: - _validation_data.add_data_requirement(aparam_requirement_items) + _data_requirement += aparam_requirement_items + has_spin = getattr(_model, "has_spin", False) + if callable(has_spin): + has_spin = has_spin() + if has_spin: + spin_requirement_items = [ + DataRequirementItem("spin", ndof=3, atomic=True, must=True) + ] + _data_requirement += spin_requirement_items + _training_data.add_data_requirement(_data_requirement) + if _validation_data is not None: + _validation_data.add_data_requirement(_data_requirement) if not resuming and self.rank == 0: @functools.lru_cache @@ -268,6 +273,9 @@ def get_loss(loss_params, start_lr, _ntypes, _model): if loss_type == "ener": loss_params["starter_learning_rate"] = start_lr return EnergyStdLoss(**loss_params) + elif loss_type == "ener_spin": + loss_params["starter_learning_rate"] = start_lr + return EnergySpinLoss(**loss_params) elif loss_type == "denoise": loss_params["ntypes"] = _ntypes return DenoiseLoss(**loss_params) @@ -961,7 +969,7 @@ def get_data(self, is_train=True, task_key="Default"): batch_data = next(iter(self.validation_data[task_key])) for key in batch_data.keys(): - if key == "sid" or key == "fid": + if key == "sid" or key == "fid" or key == "box": continue elif not isinstance(batch_data[key], list): if batch_data[key] is not None: @@ -973,8 +981,8 @@ def get_data(self, is_train=True, task_key="Default"): input_keys = [ "coord", "atype", - "box", "spin", + "box", "fparam", "aparam", ] diff --git a/deepmd/pt/train/wrapper.py b/deepmd/pt/train/wrapper.py index a455041526..c1040fb9e3 100644 --- a/deepmd/pt/train/wrapper.py +++ b/deepmd/pt/train/wrapper.py @@ -141,8 +141,8 @@ def forward( self, coord, atype, - box: Optional[torch.Tensor] = None, spin: Optional[torch.Tensor] = None, + box: Optional[torch.Tensor] = None, cur_lr: Optional[torch.Tensor] = None, label: Optional[torch.Tensor] = None, task_key: Optional[torch.Tensor] = None, @@ -157,14 +157,20 @@ def forward( assert ( task_key is not None ), f"Multitask model must specify the inference task! Supported tasks are {list(self.model.keys())}." - model_pred = self.model[task_key]( - coord, - atype, - box=box, - do_atomic_virial=do_atomic_virial, - fparam=fparam, - aparam=aparam, - ) + input_dict = { + "coord": coord, + "atype": atype, + "box": box, + "do_atomic_virial": do_atomic_virial, + "fparam": fparam, + "aparam": aparam, + } + has_spin = getattr(self.model[task_key], "has_spin", False) + if callable(has_spin): + has_spin = has_spin() + if has_spin: + input_dict["spin"] = spin + model_pred = self.model[task_key](**input_dict) natoms = atype.shape[-1] if not self.inference_only and not inference_only: loss, more_loss = self.loss[task_key]( diff --git a/deepmd/pt/utils/env_mat_stat.py b/deepmd/pt/utils/env_mat_stat.py index cd2943e6a8..47e17e9eaa 100644 --- a/deepmd/pt/utils/env_mat_stat.py +++ b/deepmd/pt/utils/env_mat_stat.py @@ -4,6 +4,8 @@ Dict, Iterator, List, + Tuple, + Union, ) import numpy as np @@ -18,6 +20,9 @@ from deepmd.pt.utils import ( env, ) +from deepmd.pt.utils.exclude_mask import ( + PairExcludeMask, +) from deepmd.pt.utils.nlist import ( extend_input_and_build_neighbor_list, ) @@ -73,13 +78,13 @@ def __init__(self, descriptor: "DescriptorBlock"): ) # se_r=1, se_a=4 def iter( - self, data: List[Dict[str, torch.Tensor]] + self, data: List[Dict[str, Union[torch.Tensor, List[Tuple[int, int]]]]] ) -> Iterator[Dict[str, StatItem]]: """Get the iterator of the environment matrix. Parameters ---------- - data : List[Dict[str, torch.Tensor]] + data : List[Dict[str, Union[torch.Tensor, List[Tuple[int, int]]]]] The data. Yields @@ -139,6 +144,7 @@ def iter( # TODO: export rcut_smth from DescriptorBlock self.descriptor.rcut_smth, radial_only, + protection=self.descriptor.env_protection, ) # reshape to nframes * nloc at the atom level, # so nframes/mixed_type do not matter @@ -156,9 +162,16 @@ def iter( self.descriptor.get_ntypes(), device=env.DEVICE, dtype=torch.int32 ).view(-1, 1), ) + if "pair_exclude_types" in system: + # shape: (1, nloc, nnei) + exclude_mask = PairExcludeMask( + self.descriptor.get_ntypes(), system["pair_exclude_types"] + )(nlist, extended_atype).view(1, coord.shape[0] * coord.shape[1], -1) + # shape: (ntypes, nloc, nnei) + type_idx = torch.logical_and(type_idx.unsqueeze(-1), exclude_mask) for type_i in range(self.descriptor.get_ntypes()): dd = env_mat[type_idx[type_i]] - dd = dd.reshape([-1, self.last_dim]) # typen_atoms * nnei, 4 + dd = dd.reshape([-1, self.last_dim]) # typen_atoms * unmasked_nnei, 4 env_mats = {} env_mats[f"r_{type_i}"] = dd[:, :1] if self.last_dim == 4: diff --git a/deepmd/pt/utils/exclude_mask.py b/deepmd/pt/utils/exclude_mask.py index 6df8df8dd0..9ddae3a416 100644 --- a/deepmd/pt/utils/exclude_mask.py +++ b/deepmd/pt/utils/exclude_mask.py @@ -37,6 +37,12 @@ def reinit( ) self.type_mask = to_torch_tensor(self.type_mask).view([-1]) + def get_exclude_types(self): + return self.exclude_types + + def get_type_mask(self): + return self.type_mask + def forward( self, atype: torch.Tensor, @@ -46,7 +52,7 @@ def forward( Parameters ---------- atype - The extended aotm types. shape: nf x natom + The extended atom types. shape: nf x natom Returns ------- @@ -97,6 +103,9 @@ def reinit( self.type_mask = to_torch_tensor(self.type_mask).view([-1]) self.no_exclusion = len(self._exclude_types) == 0 + def get_exclude_types(self): + return self._exclude_types + # may have a better place for this method... def forward( self, diff --git a/deepmd/pt/utils/multi_task.py b/deepmd/pt/utils/multi_task.py index ae3933a101..5f06d93208 100644 --- a/deepmd/pt/utils/multi_task.py +++ b/deepmd/pt/utils/multi_task.py @@ -143,8 +143,12 @@ def replace_one_item(params_dict, key_type, key_in_dict, suffix="", index=None): ) for shared_key in shared_links: shared_links[shared_key]["links"] = sorted( - shared_links[shared_key]["links"], key=lambda x: x["shared_level"] + shared_links[shared_key]["links"], + key=lambda x: x["shared_level"] + - ("spin" in model_config["model_dict"][x["model_key"]]) * 100, ) + # little trick to make spin models in the front to be the base models, + # because its type embeddings are more general. assert len(type_map_keys) == 1, "Multitask model must have only one type_map!" return model_config, shared_links diff --git a/deepmd/pt/utils/nlist.py b/deepmd/pt/utils/nlist.py index cfc75d9438..d37931b65a 100644 --- a/deepmd/pt/utils/nlist.py +++ b/deepmd/pt/utils/nlist.py @@ -27,14 +27,16 @@ def extend_input_and_build_neighbor_list( ): nframes, nloc = atype.shape[:2] if box is not None: + box_gpu = box.to(coord.device, non_blocking=True) coord_normalized = normalize_coord( coord.view(nframes, nloc, 3), - box.reshape(nframes, 3, 3), + box_gpu.reshape(nframes, 3, 3), ) else: + box_gpu = None coord_normalized = coord.clone() extended_coord, extended_atype, mapping = extend_coord_with_ghosts( - coord_normalized, atype, box, rcut + coord_normalized, atype, box_gpu, rcut, box ) nlist = build_neighbor_list( extended_coord, @@ -105,6 +107,8 @@ def build_neighbor_list( assert list(diff.shape) == [batch_size, nloc, nall, 3] # nloc x nall rr = torch.linalg.norm(diff, dim=-1) + # if central atom has two zero distances, sorting sometimes can not exclude itself + rr -= torch.eye(nloc, nall, dtype=rr.dtype, device=rr.device).unsqueeze(0) rr, nlist = torch.sort(rr, dim=-1) # nloc x (nall-1) rr = rr[:, :, 1:] @@ -262,6 +266,7 @@ def extend_coord_with_ghosts( atype: torch.Tensor, cell: Optional[torch.Tensor], rcut: float, + cell_cpu: Optional[torch.Tensor] = None, ): """Extend the coordinates of the atoms by appending peridoc images. The number of images is large enough to ensure all the neighbors @@ -277,6 +282,8 @@ def extend_coord_with_ghosts( simulation cell tensor of shape [-1, 9]. rcut : float the cutoff radius + cell_cpu : torch.Tensor + cell on cpu for performance Returns ------- @@ -299,27 +306,25 @@ def extend_coord_with_ghosts( else: coord = coord.view([nf, nloc, 3]) cell = cell.view([nf, 3, 3]) + cell_cpu = cell_cpu.view([nf, 3, 3]) if cell_cpu is not None else cell # nf x 3 - to_face = to_face_distance(cell) + to_face = to_face_distance(cell_cpu) # nf x 3 # *2: ghost copies on + and - directions # +1: central cell nbuff = torch.ceil(rcut / to_face).to(torch.long) # 3 nbuff = torch.max(nbuff, dim=0, keepdim=False).values - xi = torch.arange(-nbuff[0], nbuff[0] + 1, 1, device=device) - yi = torch.arange(-nbuff[1], nbuff[1] + 1, 1, device=device) - zi = torch.arange(-nbuff[2], nbuff[2] + 1, 1, device=device) - xyz = xi.view(-1, 1, 1, 1) * torch.tensor( - [1, 0, 0], dtype=env.GLOBAL_PT_FLOAT_PRECISION, device=device - ) - xyz = xyz + yi.view(1, -1, 1, 1) * torch.tensor( - [0, 1, 0], dtype=env.GLOBAL_PT_FLOAT_PRECISION, device=device - ) - xyz = xyz + zi.view(1, 1, -1, 1) * torch.tensor( - [0, 0, 1], dtype=env.GLOBAL_PT_FLOAT_PRECISION, device=device - ) + nbuff_cpu = nbuff.cpu() + xi = torch.arange(-nbuff_cpu[0], nbuff_cpu[0] + 1, 1, device="cpu") + yi = torch.arange(-nbuff_cpu[1], nbuff_cpu[1] + 1, 1, device="cpu") + zi = torch.arange(-nbuff_cpu[2], nbuff_cpu[2] + 1, 1, device="cpu") + eye_3 = torch.eye(3, dtype=env.GLOBAL_PT_FLOAT_PRECISION, device="cpu") + xyz = xi.view(-1, 1, 1, 1) * eye_3[0] + xyz = xyz + yi.view(1, -1, 1, 1) * eye_3[1] + xyz = xyz + zi.view(1, 1, -1, 1) * eye_3[2] xyz = xyz.view(-1, 3) + xyz = xyz.to(device=device, non_blocking=True) # ns x 3 shift_idx = xyz[torch.argsort(torch.norm(xyz, dim=1))] ns, _ = shift_idx.shape @@ -332,7 +337,6 @@ def extend_coord_with_ghosts( extend_atype = torch.tile(atype.unsqueeze(-2), [1, ns, 1]) # nf x ns x nloc extend_aidx = torch.tile(aidx.unsqueeze(-2), [1, ns, 1]) - return ( extend_coord.reshape([nf, nall * 3]).to(device), extend_atype.view([nf, nall]).to(device), diff --git a/deepmd/pt/utils/region.py b/deepmd/pt/utils/region.py index b07d2f73bf..9d811acb9b 100644 --- a/deepmd/pt/utils/region.py +++ b/deepmd/pt/utils/region.py @@ -21,7 +21,7 @@ def phys2inter( the internal coordinates """ - rec_cell = torch.linalg.inv(cell) + rec_cell, _ = torch.linalg.inv_ex(cell) return torch.matmul(coord, rec_cell) diff --git a/deepmd/pt/utils/stat.py b/deepmd/pt/utils/stat.py index 63abccc75d..5e631d9412 100644 --- a/deepmd/pt/utils/stat.py +++ b/deepmd/pt/utils/stat.py @@ -11,6 +11,7 @@ import torch from deepmd.pt.utils import ( + AtomExcludeMask, env, ) from deepmd.pt.utils.utils import ( @@ -71,6 +72,7 @@ def make_stat_input(datasets, dataloaders, nbatches): def compute_output_stats( merged: Union[Callable[[], List[dict]], List[dict]], + ntypes: int, stat_file_path: Optional[DPPath] = None, rcond: Optional[float] = None, atom_ener: Optional[List[float]] = None, @@ -87,6 +89,8 @@ def compute_output_stats( - Callable[[], List[dict]]: A lazy function that returns data samples in the above format only when needed. Since the sampling process can be slow and memory-intensive, the lazy function helps by only sampling once. + ntypes : int + The number of atom types. stat_file_path : DPPath, optional The path to the stat file. rcond : float, optional @@ -107,10 +111,14 @@ def compute_output_stats( sampled = merged energy = [item["energy"] for item in sampled] data_mixed_type = "real_natoms_vec" in sampled[0] - if data_mixed_type: - input_natoms = [item["real_natoms_vec"] for item in sampled] - else: - input_natoms = [item["natoms"] for item in sampled] + natoms_key = "natoms" if not data_mixed_type else "real_natoms_vec" + for system in sampled: + if "atom_exclude_types" in system: + type_mask = AtomExcludeMask( + ntypes, system["atom_exclude_types"] + ).get_type_mask() + system[natoms_key][:, 2:] *= type_mask.unsqueeze(0) + input_natoms = [item[natoms_key] for item in sampled] # shape: (nframes, ndim) merged_energy = to_numpy_array(torch.cat(energy)) # shape: (nframes, ntypes) diff --git a/deepmd/tf/descriptor/se_a.py b/deepmd/tf/descriptor/se_a.py index 0e15ba13a8..4635554610 100644 --- a/deepmd/tf/descriptor/se_a.py +++ b/deepmd/tf/descriptor/se_a.py @@ -154,6 +154,8 @@ class DescrptSeA(DescrptSe): Only for the purpose of backward compatibility, retrieves the old behavior of using the random seed multi_task If the model has multi fitting nets to train. + env_protection: float + Protection parameter to prevent division by zero errors during environment matrix calculations. References ---------- @@ -182,6 +184,7 @@ def __init__( multi_task: bool = False, spin: Optional[Spin] = None, stripped_type_embedding: bool = False, + env_protection: float = 0.0, # not implement!! **kwargs, ) -> None: """Constructor.""" @@ -189,6 +192,8 @@ def __init__( raise RuntimeError( f"rcut_smth ({rcut_smth:f}) should be no more than rcut ({rcut:f})!" ) + if env_protection != 0.0: + raise NotImplementedError("env_protection != 0.0 is not supported.") self.sel_a = sel self.rcut_r = rcut self.rcut_r_smth = rcut_smth @@ -206,6 +211,7 @@ def __init__( self.filter_np_precision = get_np_precision(precision) self.orig_exclude_types = exclude_types self.exclude_types = set() + self.env_protection = env_protection for tt in exclude_types: assert len(tt) == 2 self.exclude_types.add((tt[0], tt[1])) @@ -1436,6 +1442,7 @@ def serialize(self, suffix: str = "") -> dict: "trainable": self.trainable, "type_one_side": self.type_one_side, "exclude_types": list(self.orig_exclude_types), + "env_protection": self.env_protection, "set_davg_zero": self.set_davg_zero, "activation_function": self.activation_function_name, "precision": self.filter_precision.name, diff --git a/deepmd/tf/descriptor/se_r.py b/deepmd/tf/descriptor/se_r.py index ba1a261390..9f88ebe37d 100644 --- a/deepmd/tf/descriptor/se_r.py +++ b/deepmd/tf/descriptor/se_r.py @@ -104,6 +104,7 @@ def __init__( uniform_seed: bool = False, multi_task: bool = False, spin: Optional[Spin] = None, + env_protection: float = 0.0, # not implement!! **kwargs, ) -> None: """Constructor.""" @@ -111,6 +112,8 @@ def __init__( raise RuntimeError( f"rcut_smth ({rcut_smth:f}) should be no more than rcut ({rcut:f})!" ) + if env_protection != 0.0: + raise NotImplementedError("env_protection != 0.0 is not supported.") self.sel_r = sel self.rcut = rcut self.rcut_smth = rcut_smth @@ -125,6 +128,7 @@ def __init__( self.filter_precision = get_precision(precision) self.orig_exclude_types = exclude_types self.exclude_types = set() + self.env_protection = env_protection for tt in exclude_types: assert len(tt) == 2 self.exclude_types.add((tt[0], tt[1])) @@ -776,6 +780,7 @@ def serialize(self, suffix: str = "") -> dict: "trainable": self.trainable, "type_one_side": self.type_one_side, "exclude_types": list(self.orig_exclude_types), + "env_protection": self.env_protection, "set_davg_zero": self.set_davg_zero, "activation_function": self.activation_function_name, "precision": self.filter_precision.name, diff --git a/deepmd/tf/infer/deep_eval.py b/deepmd/tf/infer/deep_eval.py index 45eda3392f..b9db0863b5 100644 --- a/deepmd/tf/infer/deep_eval.py +++ b/deepmd/tf/infer/deep_eval.py @@ -4,6 +4,7 @@ ) from typing import ( TYPE_CHECKING, + Any, Callable, Dict, List, @@ -693,6 +694,7 @@ def eval( fparam: Optional[np.ndarray] = None, aparam: Optional[np.ndarray] = None, efield: Optional[np.ndarray] = None, + **kwargs: Dict[str, Any], ) -> Dict[str, np.ndarray]: """Evaluate the energy, force and virial by using this DP. @@ -724,6 +726,8 @@ def eval( efield The external field on atoms. The array should be of size nframes x natoms x 3 + **kwargs + Other parameters Returns ------- diff --git a/deepmd/utils/argcheck.py b/deepmd/utils/argcheck.py index 8bc9104b16..5e8db431f8 100644 --- a/deepmd/utils/argcheck.py +++ b/deepmd/utils/argcheck.py @@ -96,11 +96,35 @@ def spin_args(): doc_use_spin = "Whether to use atomic spin model for each atom type" doc_spin_norm = "The magnitude of atomic spin for each atom type with spin" doc_virtual_len = "The distance between virtual atom representing spin and its corresponding real atom for each atom type with spin" + doc_virtual_scale = ( + "The scaling factor to determine the virtual distance between a virtual atom " + "representing spin and its corresponding real atom for each atom type with spin. " + "This factor is defined as the virtual distance divided by the magnitude of atomic spin " + "for each atom type with spin. The virtual coordinate is defined as the real coordinate " + "plus spin * virtual_scale. List of float values with shape of [ntypes] or [ntypes_spin] " + "or one single float value for all types, only used when use_spin is True for each atom type." + ) return [ Argument("use_spin", List[bool], doc=doc_use_spin), - Argument("spin_norm", List[float], doc=doc_spin_norm), - Argument("virtual_len", List[float], doc=doc_virtual_len), + Argument( + "spin_norm", + List[float], + optional=True, + doc=doc_only_tf_supported + doc_spin_norm, + ), + Argument( + "virtual_len", + List[float], + optional=True, + doc=doc_only_tf_supported + doc_virtual_len, + ), + Argument( + "virtual_scale", + List[float], + optional=True, + doc=doc_only_pt_supported + doc_virtual_scale, + ), ] @@ -203,6 +227,7 @@ def descrpt_se_a_args(): doc_trainable = "If the parameters in the embedding net is trainable" doc_seed = "Random seed for parameter initialization" doc_exclude_types = "The excluded pairs of types which have no interaction with each other. For example, `[[0, 1]]` means no interaction between type 0 and type 1." + doc_env_protection = "Protection parameter to prevent division by zero errors during environment matrix calculations. For example, when using paddings, there may be zero distances of neighbors, which may make division by zero error during environment matrix calculations without protection." doc_set_davg_zero = "Set the normalization average to zero. This option should be set when `atom_ener` in the energy fitting is used" return [ @@ -241,6 +266,13 @@ def descrpt_se_a_args(): default=[], doc=doc_exclude_types, ), + Argument( + "env_protection", + float, + optional=True, + default=0.0, + doc=doc_only_tf_supported + doc_env_protection, + ), Argument( "set_davg_zero", bool, optional=True, default=False, doc=doc_set_davg_zero ), diff --git a/deepmd/utils/data.py b/deepmd/utils/data.py index 194c6b1e24..1e1d7c2251 100644 --- a/deepmd/utils/data.py +++ b/deepmd/utils/data.py @@ -491,7 +491,7 @@ def reformat_data_torch(self, data): if "find_" in kk: pass else: - if self.data_dict[kk]["atomic"]: + if kk in data and self.data_dict[kk]["atomic"]: data[kk] = data[kk].reshape(-1, self.data_dict[kk]["ndof"]) data["atype"] = data["type"] if not self.pbc: diff --git a/deepmd/utils/spin.py b/deepmd/utils/spin.py new file mode 100644 index 0000000000..38e8da48da --- /dev/null +++ b/deepmd/utils/spin.py @@ -0,0 +1,199 @@ +# SPDX-License-Identifier: LGPL-3.0-or-later +import copy +from typing import ( + List, + Tuple, + Union, +) + +import numpy as np + + +class Spin: + """Class for spin, mainly processes the spin type-related information. + Atom types can be split into three kinds: + 1. Real types: real atom species, "Fe", "H", "O", etc. + 2. Spin types: atom species with spin, as virtual atoms in input, "Fe_spin", etc. + 3. Placeholder types: atom species without spin, as placeholders in input without contribution, + also name "H_spin", "O_spin", etc. + For any types in 2. or 3., the type index is `ntypes` plus index of its corresponding real type. + + Parameters + ---------- + use_spin: List[bool] + A list of boolean values indicating whether to use atomic spin for each atom type. + True for spin and False for not. List of bool values with shape of [ntypes]. + virtual_scale: List[float], float + The scaling factor to determine the virtual distance + between a virtual atom representing spin and its corresponding real atom + for each atom type with spin. This factor is defined as the virtual distance + divided by the magnitude of atomic spin for each atom type with spin. + The virtual coordinate is defined as the real coordinate plus spin * virtual_scale. + List of float values with shape of [ntypes] or [ntypes_spin] or one single float value for all types, + only used when use_spin is True for each atom type. + """ + + def __init__( + self, + use_spin: List[bool], + virtual_scale: Union[List[float], float], + ) -> None: + self.ntypes_real = len(use_spin) + self.ntypes_spin = use_spin.count(True) + self.use_spin = np.array(use_spin) + self.spin_mask = self.use_spin.astype(np.int64) + self.ntypes_real_and_spin = self.ntypes_real + self.ntypes_spin + self.ntypes_placeholder = self.ntypes_real - self.ntypes_spin + self.ntypes_input = 2 * self.ntypes_real # with placeholder for input types + self.real_type = np.arange(self.ntypes_real) + self.spin_type = np.arange(self.ntypes_real)[self.use_spin] + self.ntypes_real + self.real_and_spin_type = np.concatenate([self.real_type, self.spin_type]) + self.placeholder_type = ( + np.arange(self.ntypes_real)[~self.use_spin] + self.ntypes_real + ) + self.spin_placeholder_type = np.arange(self.ntypes_real) + self.ntypes_real + self.input_type = np.arange(self.ntypes_real * 2) + if isinstance(virtual_scale, list): + if len(virtual_scale) == self.ntypes_real: + self.virtual_scale = virtual_scale + elif len(virtual_scale) == self.ntypes_spin: + self.virtual_scale = np.zeros(self.ntypes_real) + self.virtual_scale[self.use_spin] = virtual_scale + else: + raise ValueError( + f"Invalid length of virtual_scale for spin atoms" + f": Expected {self.ntypes_real} or { self.ntypes_spin} but got {len(virtual_scale)}!" + ) + elif isinstance(virtual_scale, float): + self.virtual_scale = [virtual_scale for _ in range(self.ntypes_real)] + else: + raise ValueError(f"Invalid virtual scale type: {type(virtual_scale)}") + self.virtual_scale = np.array(self.virtual_scale) + self.virtual_scale_mask = (self.virtual_scale * self.use_spin).reshape([-1]) + self.pair_exclude_types = [] + self.init_pair_exclude_types_placeholder() + self.atom_exclude_types_ps = [] + self.init_atom_exclude_types_placeholder_spin() + self.atom_exclude_types_p = [] + self.init_atom_exclude_types_placeholder() + + def get_ntypes_real(self) -> int: + """Returns the number of real atom types.""" + return self.ntypes_real + + def get_ntypes_spin(self) -> int: + """Returns the number of atom types which contain spin.""" + return self.ntypes_spin + + def get_ntypes_real_and_spin(self) -> int: + """Returns the number of real atom types and types which contain spin.""" + return self.ntypes_real_and_spin + + def get_ntypes_input(self) -> int: + """Returns the number of double real atom types for input placeholder.""" + return self.ntypes_input + + def get_use_spin(self) -> List[bool]: + """Returns the list of whether to use spin for each atom type.""" + return self.use_spin + + def get_virtual_scale(self) -> np.ndarray: + """Returns the list of magnitude of atomic spin for each atom type.""" + return self.virtual_scale + + def init_pair_exclude_types_placeholder(self) -> None: + """ + Initialize the pair-wise exclusion types for descriptor. + The placeholder types for those without spin are excluded. + """ + ti_grid, tj_grid = np.meshgrid( + self.placeholder_type, self.input_type, indexing="ij" + ) + self.pair_exclude_types = ( + np.stack((ti_grid, tj_grid), axis=-1).reshape(-1, 2).tolist() + ) + + def init_atom_exclude_types_placeholder_spin(self) -> None: + """ + Initialize the atom-wise exclusion types for fitting. + Both the placeholder types and spin types are excluded. + """ + self.atom_exclude_types_ps = self.spin_placeholder_type.tolist() + + def init_atom_exclude_types_placeholder(self) -> None: + """ + Initialize the atom-wise exclusion types for fitting. + The placeholder types for those without spin are excluded. + """ + self.atom_exclude_types_p = self.placeholder_type.tolist() + + def get_pair_exclude_types(self, exclude_types=None) -> List[Tuple[int, int]]: + """ + Return the pair-wise exclusion types for descriptor. + The placeholder types for those without spin are excluded. + """ + if exclude_types is None: + return self.pair_exclude_types + else: + _exclude_types: List[Tuple[int, int]] = copy.deepcopy( + self.pair_exclude_types + ) + for tt in exclude_types: + assert len(tt) == 2 + _exclude_types.append((tt[0], tt[1])) + return _exclude_types + + def get_atom_exclude_types(self, exclude_types=None) -> List[int]: + """ + Return the atom-wise exclusion types for fitting before out_def. + Both the placeholder types and spin types are excluded. + """ + if exclude_types is None: + return self.atom_exclude_types_ps + else: + _exclude_types: List[int] = copy.deepcopy(self.atom_exclude_types_ps) + _exclude_types += exclude_types + _exclude_types = list(set(_exclude_types)) + return _exclude_types + + def get_atom_exclude_types_placeholder(self, exclude_types=None) -> List[int]: + """ + Return the atom-wise exclusion types for fitting after out_def. + The placeholder types for those without spin are excluded. + """ + if exclude_types is None: + return self.atom_exclude_types_p + else: + _exclude_types: List[int] = copy.deepcopy(self.atom_exclude_types_p) + _exclude_types += exclude_types + _exclude_types = list(set(_exclude_types)) + return _exclude_types + + def get_spin_mask(self): + """ + Return the spin mask of shape [ntypes], + with spin types being 1, and non-spin types being 0. + """ + return self.spin_mask + + def get_virtual_scale_mask(self): + """ + Return the virtual scale mask of shape [ntypes], + with spin types being its virtual scale, and non-spin types being 0. + """ + return self.virtual_scale_mask + + def serialize( + self, + ) -> dict: + return { + "use_spin": self.use_spin.tolist(), + "virtual_scale": self.virtual_scale.tolist(), + } + + @classmethod + def deserialize( + cls, + data: dict, + ) -> "Spin": + return cls(**data) diff --git a/examples/spin/data_reformat/data_0/set.000/box.npy b/examples/spin/data_reformat/data_0/set.000/box.npy new file mode 100644 index 0000000000..1f72eb7185 Binary files /dev/null and b/examples/spin/data_reformat/data_0/set.000/box.npy differ diff --git a/examples/spin/data_reformat/data_0/set.000/coord.npy b/examples/spin/data_reformat/data_0/set.000/coord.npy new file mode 100644 index 0000000000..4b60ae0e0b Binary files /dev/null and b/examples/spin/data_reformat/data_0/set.000/coord.npy differ diff --git a/examples/spin/data_reformat/data_0/set.000/energy.npy b/examples/spin/data_reformat/data_0/set.000/energy.npy new file mode 100644 index 0000000000..8754b6dad2 Binary files /dev/null and b/examples/spin/data_reformat/data_0/set.000/energy.npy differ diff --git a/examples/spin/data_reformat/data_0/set.000/force.npy b/examples/spin/data_reformat/data_0/set.000/force.npy new file mode 100644 index 0000000000..e95173d561 Binary files /dev/null and b/examples/spin/data_reformat/data_0/set.000/force.npy differ diff --git a/examples/spin/data_reformat/data_0/set.000/force_mag.npy b/examples/spin/data_reformat/data_0/set.000/force_mag.npy new file mode 100644 index 0000000000..65bc1ef837 Binary files /dev/null and b/examples/spin/data_reformat/data_0/set.000/force_mag.npy differ diff --git a/examples/spin/data_reformat/data_0/set.000/spin.npy b/examples/spin/data_reformat/data_0/set.000/spin.npy new file mode 100644 index 0000000000..c426f1c7f6 Binary files /dev/null and b/examples/spin/data_reformat/data_0/set.000/spin.npy differ diff --git a/examples/spin/data_reformat/data_0/type.raw b/examples/spin/data_reformat/data_0/type.raw new file mode 100644 index 0000000000..d9664c7a22 --- /dev/null +++ b/examples/spin/data_reformat/data_0/type.raw @@ -0,0 +1,32 @@ +0 +0 +0 +0 +0 +0 +0 +0 +0 +0 +0 +0 +0 +0 +0 +0 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 diff --git a/examples/spin/data_reformat/data_0/type_map.raw b/examples/spin/data_reformat/data_0/type_map.raw new file mode 100644 index 0000000000..7eca995c31 --- /dev/null +++ b/examples/spin/data_reformat/data_0/type_map.raw @@ -0,0 +1,2 @@ +Ni +O diff --git a/examples/spin/data_reformat/data_1/set.000/box.npy b/examples/spin/data_reformat/data_1/set.000/box.npy new file mode 100644 index 0000000000..1f72eb7185 Binary files /dev/null and b/examples/spin/data_reformat/data_1/set.000/box.npy differ diff --git a/examples/spin/data_reformat/data_1/set.000/coord.npy b/examples/spin/data_reformat/data_1/set.000/coord.npy new file mode 100644 index 0000000000..fc51107998 Binary files /dev/null and b/examples/spin/data_reformat/data_1/set.000/coord.npy differ diff --git a/examples/spin/data_reformat/data_1/set.000/energy.npy b/examples/spin/data_reformat/data_1/set.000/energy.npy new file mode 100644 index 0000000000..a0eecad8d8 Binary files /dev/null and b/examples/spin/data_reformat/data_1/set.000/energy.npy differ diff --git a/examples/spin/data_reformat/data_1/set.000/force.npy b/examples/spin/data_reformat/data_1/set.000/force.npy new file mode 100644 index 0000000000..ec4a05f8f2 Binary files /dev/null and b/examples/spin/data_reformat/data_1/set.000/force.npy differ diff --git a/examples/spin/data_reformat/data_1/set.000/force_mag.npy b/examples/spin/data_reformat/data_1/set.000/force_mag.npy new file mode 100644 index 0000000000..844df39b76 Binary files /dev/null and b/examples/spin/data_reformat/data_1/set.000/force_mag.npy differ diff --git a/examples/spin/data_reformat/data_1/set.000/spin.npy b/examples/spin/data_reformat/data_1/set.000/spin.npy new file mode 100644 index 0000000000..1444e35c5f Binary files /dev/null and b/examples/spin/data_reformat/data_1/set.000/spin.npy differ diff --git a/examples/spin/data_reformat/data_1/type.raw b/examples/spin/data_reformat/data_1/type.raw new file mode 100644 index 0000000000..d9664c7a22 --- /dev/null +++ b/examples/spin/data_reformat/data_1/type.raw @@ -0,0 +1,32 @@ +0 +0 +0 +0 +0 +0 +0 +0 +0 +0 +0 +0 +0 +0 +0 +0 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 diff --git a/examples/spin/data_reformat/data_1/type_map.raw b/examples/spin/data_reformat/data_1/type_map.raw new file mode 100644 index 0000000000..7eca995c31 --- /dev/null +++ b/examples/spin/data_reformat/data_1/type_map.raw @@ -0,0 +1,2 @@ +Ni +O diff --git a/examples/spin/data_reformat/data_2/set.000/box.npy b/examples/spin/data_reformat/data_2/set.000/box.npy new file mode 100644 index 0000000000..4e817ccff5 Binary files /dev/null and b/examples/spin/data_reformat/data_2/set.000/box.npy differ diff --git a/examples/spin/data_reformat/data_2/set.000/coord.npy b/examples/spin/data_reformat/data_2/set.000/coord.npy new file mode 100644 index 0000000000..aa515d0b6e Binary files /dev/null and b/examples/spin/data_reformat/data_2/set.000/coord.npy differ diff --git a/examples/spin/data_reformat/data_2/set.000/energy.npy b/examples/spin/data_reformat/data_2/set.000/energy.npy new file mode 100644 index 0000000000..cd4efe3b55 Binary files /dev/null and b/examples/spin/data_reformat/data_2/set.000/energy.npy differ diff --git a/examples/spin/data_reformat/data_2/set.000/force.npy b/examples/spin/data_reformat/data_2/set.000/force.npy new file mode 100644 index 0000000000..5cf07333e0 Binary files /dev/null and b/examples/spin/data_reformat/data_2/set.000/force.npy differ diff --git a/examples/spin/data_reformat/data_2/set.000/force_mag.npy b/examples/spin/data_reformat/data_2/set.000/force_mag.npy new file mode 100644 index 0000000000..14b73ffb54 Binary files /dev/null and b/examples/spin/data_reformat/data_2/set.000/force_mag.npy differ diff --git a/examples/spin/data_reformat/data_2/set.000/spin.npy b/examples/spin/data_reformat/data_2/set.000/spin.npy new file mode 100644 index 0000000000..4bd1396c7d Binary files /dev/null and b/examples/spin/data_reformat/data_2/set.000/spin.npy differ diff --git a/examples/spin/data_reformat/data_2/type.raw b/examples/spin/data_reformat/data_2/type.raw new file mode 100644 index 0000000000..d9664c7a22 --- /dev/null +++ b/examples/spin/data_reformat/data_2/type.raw @@ -0,0 +1,32 @@ +0 +0 +0 +0 +0 +0 +0 +0 +0 +0 +0 +0 +0 +0 +0 +0 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 diff --git a/examples/spin/data_reformat/data_2/type_map.raw b/examples/spin/data_reformat/data_2/type_map.raw new file mode 100644 index 0000000000..7eca995c31 --- /dev/null +++ b/examples/spin/data_reformat/data_2/type_map.raw @@ -0,0 +1,2 @@ +Ni +O diff --git a/examples/spin/se_e2_a/input.json b/examples/spin/se_e2_a/input_tf.json similarity index 100% rename from examples/spin/se_e2_a/input.json rename to examples/spin/se_e2_a/input_tf.json diff --git a/examples/spin/se_e2_a/input_torch.json b/examples/spin/se_e2_a/input_torch.json new file mode 100644 index 0000000000..37859b8402 --- /dev/null +++ b/examples/spin/se_e2_a/input_torch.json @@ -0,0 +1,90 @@ +{ + "model": { + "type_map": [ + "Ni", + "O" + ], + "descriptor": { + "type": "se_e2_a", + "sel": [ + 60, + 60 + ], + "rcut_smth": 5.4, + "rcut": 5.6, + "neuron": [ + 25, + 50, + 100 + ], + "resnet_dt": false, + "axis_neuron": 16, + "seed": 1, + "_comment": " that's all" + }, + "fitting_net": { + "neuron": [ + 240, + 240, + 240 + ], + "resnet_dt": true, + "seed": 1, + "_comment": " that's all" + }, + "data_stat_nbatch": 10, + "spin": { + "use_spin": [ + true, + false + ], + "virtual_scale": [ + 0.3140 + ], + "_comment": " that's all" + }, + "_comment": " that's all" + }, + "learning_rate": { + "type": "exp", + "decay_steps": 5000, + "start_lr": 0.001, + "stop_lr": 3.51e-8, + "_comment": "that's all" + }, + "loss": { + "type": "ener_spin", + "start_pref_e": 0.02, + "limit_pref_e": 1, + "start_pref_fr": 1000, + "limit_pref_fr": 1, + "start_pref_fm": 1000, + "limit_pref_fm": 1, + "_comment": " that's all" + }, + "training": { + "training_data": { + "systems": [ + "../data_reformat/data_0", + "../data_reformat/data_1" + ], + "batch_size": 3, + "_comment": "that's all" + }, + "validation_data": { + "systems": [ + "../data_reformat/data_2" + ], + "batch_size": 1, + "numb_btch": 3, + "_comment": "that's all" + }, + "numb_steps": 100000, + "seed": 10, + "disp_file": "lcurve.out", + "disp_freq": 100, + "save_freq": 10000, + "_comment": "that's all" + }, + "_comment": "that's all" +} diff --git a/source/tests/common/dpmodel/test_output_def.py b/source/tests/common/dpmodel/test_output_def.py index 272082513c..27fa54ea4c 100644 --- a/source/tests/common/dpmodel/test_output_def.py +++ b/source/tests/common/dpmodel/test_output_def.py @@ -55,6 +55,15 @@ def test_model_output_def(self): atomic=True, r_hessian=True, ), + OutputVariableDef( + "energy3", + [1], + reduciable=True, + r_differentiable=True, + c_differentiable=True, + atomic=True, + magnetic=True, + ), OutputVariableDef( "dos", [10], @@ -74,7 +83,7 @@ def test_model_output_def(self): ] # fitting definition fd = FittingOutputDef(defs) - expected_keys = ["energy", "energy2", "dos", "foo"] + expected_keys = ["energy", "energy2", "energy3", "dos", "foo"] self.assertEqual( set(expected_keys), set(fd.keys()), @@ -82,16 +91,19 @@ def test_model_output_def(self): # shape self.assertEqual(fd["energy"].shape, [1]) self.assertEqual(fd["energy2"].shape, [1]) + self.assertEqual(fd["energy3"].shape, [1]) self.assertEqual(fd["dos"].shape, [10]) self.assertEqual(fd["foo"].shape, [3]) # atomic self.assertEqual(fd["energy"].atomic, True) self.assertEqual(fd["energy2"].atomic, True) + self.assertEqual(fd["energy3"].atomic, True) self.assertEqual(fd["dos"].atomic, True) self.assertEqual(fd["foo"].atomic, True) # reduce self.assertEqual(fd["energy"].reduciable, True) self.assertEqual(fd["energy2"].reduciable, True) + self.assertEqual(fd["energy3"].reduciable, True) self.assertEqual(fd["dos"].reduciable, True) self.assertEqual(fd["foo"].reduciable, False) # derivative @@ -101,15 +113,25 @@ def test_model_output_def(self): self.assertEqual(fd["energy2"].r_differentiable, True) self.assertEqual(fd["energy2"].c_differentiable, True) self.assertEqual(fd["energy2"].r_hessian, True) + self.assertEqual(fd["energy3"].r_differentiable, True) + self.assertEqual(fd["energy3"].c_differentiable, True) + self.assertEqual(fd["energy3"].r_hessian, False) self.assertEqual(fd["dos"].r_differentiable, False) self.assertEqual(fd["foo"].r_differentiable, False) self.assertEqual(fd["dos"].c_differentiable, False) self.assertEqual(fd["foo"].c_differentiable, False) + # magnetic + self.assertEqual(fd["energy"].magnetic, False) + self.assertEqual(fd["energy2"].magnetic, False) + self.assertEqual(fd["energy3"].magnetic, True) + self.assertEqual(fd["dos"].magnetic, False) + self.assertEqual(fd["foo"].magnetic, False) # model definition md = ModelOutputDef(fd) expected_keys = [ "energy", "energy2", + "energy3", "dos", "foo", "energy_redu", @@ -121,7 +143,15 @@ def test_model_output_def(self): "energy2_derv_r_derv_r", "energy2_derv_c", "energy2_derv_c_redu", + "energy3_redu", + "energy3_derv_r", + "energy3_derv_c", + "energy3_derv_c_redu", + "energy3_derv_r_mag", + "energy3_derv_c_mag", "dos_redu", + "mask", + "mask_mag", ] self.assertEqual( set(expected_keys), @@ -132,6 +162,7 @@ def test_model_output_def(self): # reduce self.assertEqual(md["energy"].reduciable, True) self.assertEqual(md["energy2"].reduciable, True) + self.assertEqual(md["energy3"].reduciable, True) self.assertEqual(md["dos"].reduciable, True) self.assertEqual(md["foo"].reduciable, False) # derivative @@ -141,13 +172,19 @@ def test_model_output_def(self): self.assertEqual(md["energy2"].r_differentiable, True) self.assertEqual(md["energy2"].c_differentiable, True) self.assertEqual(md["energy2"].r_hessian, True) + self.assertEqual(md["energy3"].r_differentiable, True) + self.assertEqual(md["energy3"].c_differentiable, True) + self.assertEqual(md["energy3"].r_hessian, False) self.assertEqual(md["dos"].r_differentiable, False) self.assertEqual(md["foo"].r_differentiable, False) self.assertEqual(md["dos"].c_differentiable, False) self.assertEqual(md["foo"].c_differentiable, False) # shape + self.assertEqual(md["mask"].shape, [1]) + self.assertEqual(md["mask_mag"].shape, [1]) self.assertEqual(md["energy"].shape, [1]) self.assertEqual(md["energy2"].shape, [1]) + self.assertEqual(md["energy3"].shape, [1]) self.assertEqual(md["dos"].shape, [10]) self.assertEqual(md["foo"].shape, [3]) self.assertEqual(md["energy_redu"].shape, [1]) @@ -159,6 +196,11 @@ def test_model_output_def(self): self.assertEqual(md["energy2_derv_c"].shape, [1, 9]) self.assertEqual(md["energy2_derv_c_redu"].shape, [1, 9]) self.assertEqual(md["energy2_derv_r_derv_r"].shape, [1, 3, 3]) + self.assertEqual(md["energy3_derv_r"].shape, [1, 3]) + self.assertEqual(md["energy3_derv_c"].shape, [1, 9]) + self.assertEqual(md["energy3_derv_c_redu"].shape, [1, 9]) + self.assertEqual(md["energy3_derv_r_mag"].shape, [1, 3]) + self.assertEqual(md["energy3_derv_c_mag"].shape, [1, 9]) # atomic self.assertEqual(md["energy"].atomic, True) self.assertEqual(md["energy2"].atomic, True) @@ -173,9 +215,18 @@ def test_model_output_def(self): self.assertEqual(md["energy2_derv_c"].atomic, True) self.assertEqual(md["energy2_derv_c_redu"].atomic, False) self.assertEqual(md["energy2_derv_r_derv_r"].atomic, True) + self.assertEqual(md["energy3_redu"].atomic, False) + self.assertEqual(md["energy3_derv_r"].atomic, True) + self.assertEqual(md["energy3_derv_c"].atomic, True) + self.assertEqual(md["energy3_derv_r_mag"].atomic, True) + self.assertEqual(md["energy3_derv_c_mag"].atomic, True) + self.assertEqual(md["energy3_derv_c_redu"].atomic, False) # category + self.assertEqual(md["mask"].category, OutputVariableCategory.OUT) + self.assertEqual(md["mask_mag"].category, OutputVariableCategory.OUT) self.assertEqual(md["energy"].category, OutputVariableCategory.OUT) self.assertEqual(md["energy2"].category, OutputVariableCategory.OUT) + self.assertEqual(md["energy3"].category, OutputVariableCategory.OUT) self.assertEqual(md["dos"].category, OutputVariableCategory.OUT) self.assertEqual(md["foo"].category, OutputVariableCategory.OUT) self.assertEqual(md["energy_redu"].category, OutputVariableCategory.REDU) @@ -193,6 +244,18 @@ def test_model_output_def(self): self.assertEqual( md["energy2_derv_r_derv_r"].category, OutputVariableCategory.DERV_R_DERV_R ) + self.assertEqual(md["energy3_redu"].category, OutputVariableCategory.REDU) + self.assertEqual(md["energy3_derv_r"].category, OutputVariableCategory.DERV_R) + self.assertEqual(md["energy3_derv_c"].category, OutputVariableCategory.DERV_C) + self.assertEqual( + md["energy3_derv_c_redu"].category, OutputVariableCategory.DERV_C_REDU + ) + self.assertEqual( + md["energy3_derv_r_mag"].category, OutputVariableCategory.DERV_R + ) + self.assertEqual( + md["energy3_derv_c_mag"].category, OutputVariableCategory.DERV_C + ) # flag OVO = OutputVariableOperation self.assertEqual(md["energy"].category & OVO.REDU, 0) @@ -201,6 +264,9 @@ def test_model_output_def(self): self.assertEqual(md["energy2"].category & OVO.REDU, 0) self.assertEqual(md["energy2"].category & OVO.DERV_R, 0) self.assertEqual(md["energy2"].category & OVO.DERV_C, 0) + self.assertEqual(md["energy3"].category & OVO.REDU, 0) + self.assertEqual(md["energy3"].category & OVO.DERV_R, 0) + self.assertEqual(md["energy3"].category & OVO.DERV_C, 0) self.assertEqual(md["dos"].category & OVO.REDU, 0) self.assertEqual(md["dos"].category & OVO.DERV_R, 0) self.assertEqual(md["dos"].category & OVO.DERV_C, 0) @@ -261,6 +327,46 @@ def test_model_output_def(self): self.assertEqual(md[kk].category & OVO.DERV_R, OVO.DERV_R) self.assertEqual(md[kk].category & OVO.DERV_C, 0) self.assertEqual(md[kk].category & OVO._SEC_DERV_R, OVO._SEC_DERV_R) + # flag: energy3 + self.assertEqual( + md["energy3_redu"].category & OVO.REDU, + OVO.REDU, + ) + self.assertEqual(md["energy3_redu"].category & OVO.DERV_R, 0) + self.assertEqual(md["energy3_redu"].category & OVO.DERV_C, 0) + self.assertEqual(md["energy3_derv_r"].category & OVO.REDU, 0) + self.assertEqual( + md["energy3_derv_r"].category & OVO.DERV_R, + OVO.DERV_R, + ) + self.assertEqual(md["energy3_derv_r"].category & OVO.DERV_C, 0) + self.assertEqual(md["energy3_derv_c"].category & OVO.REDU, 0) + self.assertEqual(md["energy3_derv_c"].category & OVO.DERV_R, 0) + self.assertEqual( + md["energy3_derv_c"].category & OVO.DERV_C, + OVO.DERV_C, + ) + self.assertEqual( + md["energy3_derv_c_redu"].category & OVO.REDU, + OVO.REDU, + ) + self.assertEqual(md["energy3_derv_c_redu"].category & OVO.DERV_R, 0) + self.assertEqual( + md["energy3_derv_c_redu"].category & OVO.DERV_C, + OVO.DERV_C, + ) + self.assertEqual(md["energy3_derv_r_mag"].category & OVO.REDU, 0) + self.assertEqual( + md["energy3_derv_r_mag"].category & OVO.DERV_R, + OVO.DERV_R, + ) + self.assertEqual(md["energy3_derv_r_mag"].category & OVO.DERV_C, 0) + self.assertEqual(md["energy3_derv_c_mag"].category & OVO.REDU, 0) + self.assertEqual(md["energy3_derv_c_mag"].category & OVO.DERV_R, 0) + self.assertEqual( + md["energy3_derv_c_mag"].category & OVO.DERV_C, + OVO.DERV_C, + ) # apply_operation: energy self.assertEqual( apply_operation(md["energy"], OVO.REDU), @@ -299,6 +405,31 @@ def test_model_output_def(self): apply_operation(md["energy2_derv_r"], OVO.DERV_R), md["energy2_derv_r_derv_r"].category, ) + # apply_operation: energy3 + self.assertEqual( + apply_operation(md["energy3"], OVO.REDU), + md["energy3_redu"].category, + ) + self.assertEqual( + apply_operation(md["energy3"], OVO.DERV_R), + md["energy3_derv_r"].category, + ) + self.assertEqual( + apply_operation(md["energy3"], OVO.DERV_C), + md["energy3_derv_c"].category, + ) + self.assertEqual( + apply_operation(md["energy3_derv_c"], OVO.REDU), + md["energy3_derv_c_redu"].category, + ) + self.assertEqual( + apply_operation(md["energy3"], OVO.DERV_R), + md["energy3_derv_r_mag"].category, + ) + self.assertEqual( + apply_operation(md["energy3"], OVO.DERV_C), + md["energy3_derv_c_mag"].category, + ) # raise ValueError with self.assertRaises(ValueError): apply_operation(md["energy_redu"], OVO.REDU) @@ -315,6 +446,15 @@ def test_model_output_def(self): apply_operation(md["energy2_derv_c_redu"], OVO.REDU) with self.assertRaises(ValueError): apply_operation(md["energy2_derv_r_derv_r"], OVO.DERV_R) + # raise ValueError + with self.assertRaises(ValueError): + apply_operation(md["energy3_redu"], OVO.REDU) + with self.assertRaises(ValueError): + apply_operation(md["energy3_derv_c"], OVO.DERV_C) + with self.assertRaises(ValueError): + apply_operation(md["energy3_derv_c_redu"], OVO.REDU) + with self.assertRaises(ValueError): + apply_operation(md["energy3_derv_c_mag"], OVO.DERV_C) # hession hession_cat = apply_operation(md["energy_derv_r"], OVO.DERV_R) self.assertEqual(hession_cat & OVO.DERV_R, OVO.DERV_R) @@ -378,6 +518,20 @@ def test_hessian_not_r_differentiable(self): ), ) + def test_energy_magnetic(self): + with self.assertRaises(ValueError) as context: + ( + OutputVariableDef( + "energy", + [1], + reduciable=False, + atomic=False, + r_differentiable=True, + r_hessian=True, + magnetic=True, + ), + ) + def test_model_decorator(self): nf = 2 nloc = 3 diff --git a/source/tests/common/test_examples.py b/source/tests/common/test_examples.py index 647bee2bbb..1ec4cef3a5 100644 --- a/source/tests/common/test_examples.py +++ b/source/tests/common/test_examples.py @@ -40,7 +40,8 @@ p_examples / "fparam" / "train" / "input_aparam.json", p_examples / "zinc_protein" / "zinc_se_a_mask.json", p_examples / "dos" / "train" / "input.json", - p_examples / "spin" / "se_e2_a" / "input.json", + p_examples / "spin" / "se_e2_a" / "input_tf.json", + p_examples / "spin" / "se_e2_a" / "input_torch.json", p_examples / "dprc" / "normal" / "input.json", p_examples / "dprc" / "pairwise" / "input.json", p_examples / "dprc" / "generalized_force" / "input.json", diff --git a/source/tests/common/test_spin.py b/source/tests/common/test_spin.py new file mode 100644 index 0000000000..c3bca50b09 --- /dev/null +++ b/source/tests/common/test_spin.py @@ -0,0 +1,172 @@ +# SPDX-License-Identifier: LGPL-3.0-or-later +import os +import unittest + +import numpy as np + +from deepmd.utils.spin import ( + Spin, +) + +CUR_DIR = os.path.dirname(__file__) + + +class SpinTest(unittest.TestCase): + def setUp(self): + type_map_1 = ["H", "O"] + self.use_spin_1 = [False, False] + self.virtual_scale_1 = [0.1, 0.1] + + type_map_2 = ["B", "Ni", "O"] + self.use_spin_2 = [False, True, False] + self.virtual_scale_2 = [0.1, 0.1, 0.1] + + type_map_3 = ["H", "O", "B", "Ni", "O"] + self.use_spin_3 = [False, False, False, True, False] + self.virtual_scale_3 = [0.1, 0.1, 0.1, 0.1, 0.1] + + self.virtual_scale_float = 0.1 + self.virtual_scale_nspin = [0.1] + + self.spin_1 = Spin(self.use_spin_1, self.virtual_scale_1) + self.spin_2 = Spin(self.use_spin_2, self.virtual_scale_2) + self.spin_3 = Spin(self.use_spin_3, self.virtual_scale_3) + self.spin_3_float = Spin(self.use_spin_3, self.virtual_scale_float) + self.spin_3_nspin = Spin(self.use_spin_3, self.virtual_scale_nspin) + + self.expect_virtual_scale_mask_1 = np.array([0.0, 0.0]) + self.expect_virtual_scale_mask_2 = np.array([0.0, 0.1, 0.0]) + self.expect_virtual_scale_mask_3 = np.array([0.0, 0.0, 0.0, 0.1, 0.0]) + + self.expect_pair_exclude_types_1 = [ + [2, 0], + [2, 1], + [2, 2], + [2, 3], + [3, 0], + [3, 1], + [3, 2], + [3, 3], + ] + self.expect_pair_exclude_types_2 = [ + [3, 0], + [3, 1], + [3, 2], + [3, 3], + [3, 4], + [3, 5], + [5, 0], + [5, 1], + [5, 2], + [5, 3], + [5, 4], + [5, 5], + ] + self.expect_pair_exclude_types_3 = [ + [5, 0], + [5, 1], + [5, 2], + [5, 3], + [5, 4], + [5, 5], + [5, 6], + [5, 7], + [5, 8], + [5, 9], + [6, 0], + [6, 1], + [6, 2], + [6, 3], + [6, 4], + [6, 5], + [6, 6], + [6, 7], + [6, 8], + [6, 9], + [7, 0], + [7, 1], + [7, 2], + [7, 3], + [7, 4], + [7, 5], + [7, 6], + [7, 7], + [7, 8], + [7, 9], + [9, 0], + [9, 1], + [9, 2], + [9, 3], + [9, 4], + [9, 5], + [9, 6], + [9, 7], + [9, 8], + [9, 9], + ] + + def test_ntypes(self): + self.assertEqual(self.spin_1.get_ntypes_real(), 2) + self.assertEqual(self.spin_1.get_ntypes_spin(), 0) + self.assertEqual(self.spin_1.get_ntypes_real_and_spin(), 2) + self.assertEqual(self.spin_1.get_ntypes_input(), 4) + + self.assertEqual(self.spin_2.get_ntypes_real(), 3) + self.assertEqual(self.spin_2.get_ntypes_spin(), 1) + self.assertEqual(self.spin_2.get_ntypes_real_and_spin(), 4) + self.assertEqual(self.spin_2.get_ntypes_input(), 6) + + self.assertEqual(self.spin_3.get_ntypes_real(), 5) + self.assertEqual(self.spin_3.get_ntypes_spin(), 1) + self.assertEqual(self.spin_3.get_ntypes_real_and_spin(), 6) + self.assertEqual(self.spin_3.get_ntypes_input(), 10) + + def test_use_spin(self): + np.testing.assert_allclose(self.spin_1.get_use_spin(), self.use_spin_1) + np.testing.assert_allclose(self.spin_2.get_use_spin(), self.use_spin_2) + np.testing.assert_allclose(self.spin_3.get_use_spin(), self.use_spin_3) + + def test_mask(self): + np.testing.assert_allclose( + self.spin_1.get_virtual_scale_mask(), self.expect_virtual_scale_mask_1 + ) + np.testing.assert_allclose( + self.spin_2.get_virtual_scale_mask(), self.expect_virtual_scale_mask_2 + ) + np.testing.assert_allclose( + self.spin_3.get_virtual_scale_mask(), self.expect_virtual_scale_mask_3 + ) + + def test_exclude_types(self): + self.assertEqual( + sorted(self.spin_1.get_pair_exclude_types()), + sorted(self.expect_pair_exclude_types_1), + ) + self.assertEqual( + sorted(self.spin_2.get_pair_exclude_types()), + sorted(self.expect_pair_exclude_types_2), + ) + self.assertEqual( + sorted(self.spin_3.get_pair_exclude_types()), + sorted(self.expect_pair_exclude_types_3), + ) + + def test_virtual_scale_consistence(self): + np.testing.assert_allclose( + self.spin_3.get_virtual_scale(), self.spin_3_float.get_virtual_scale() + ) + np.testing.assert_allclose( + self.spin_3.get_virtual_scale_mask(), self.spin_3_nspin.get_virtual_scale() + ) + np.testing.assert_allclose( + self.spin_3.get_virtual_scale_mask(), + self.spin_3_float.get_virtual_scale_mask(), + ) + np.testing.assert_allclose( + self.spin_3.get_virtual_scale_mask(), + self.spin_3_nspin.get_virtual_scale_mask(), + ) + self.assertEqual( + self.spin_3.get_pair_exclude_types(), + self.spin_3_float.get_pair_exclude_types(), + ) diff --git a/source/tests/consistent/descriptor/test_se_e2_a.py b/source/tests/consistent/descriptor/test_se_e2_a.py index b8f4205d09..1e3e5ae86d 100644 --- a/source/tests/consistent/descriptor/test_se_e2_a.py +++ b/source/tests/consistent/descriptor/test_se_e2_a.py @@ -40,6 +40,7 @@ (True, False), # type_one_side ([], [[0, 1]]), # excluded_types ("float32", "float64"), # precision + (0.0, 1e-8, 1e-2), # env_protection ) class TestSeA(CommonTest, DescriptorTest, unittest.TestCase): @property @@ -49,6 +50,7 @@ def data(self) -> dict: type_one_side, excluded_types, precision, + env_protection, ) = self.param return { "sel": [9, 10], @@ -59,6 +61,7 @@ def data(self) -> dict: "resnet_dt": resnet_dt, "type_one_side": type_one_side, "exclude_types": excluded_types, + "env_protection": env_protection, "precision": precision, "seed": 1145141919810, } @@ -70,6 +73,7 @@ def skip_pt(self) -> bool: type_one_side, excluded_types, precision, + env_protection, ) = self.param return CommonTest.skip_pt @@ -80,9 +84,21 @@ def skip_dp(self) -> bool: type_one_side, excluded_types, precision, + env_protection, ) = self.param return CommonTest.skip_dp + @property + def skip_tf(self) -> bool: + ( + resnet_dt, + type_one_side, + excluded_types, + precision, + env_protection, + ) = self.param + return env_protection != 0.0 + tf_class = DescrptSeATF dp_class = DescrptSeADP pt_class = DescrptSeAPT @@ -127,6 +143,7 @@ def setUp(self): type_one_side, excluded_types, precision, + env_protection, ) = self.param if not type_one_side: idx = np.argsort(self.atype) @@ -172,6 +189,7 @@ def rtol(self) -> float: type_one_side, excluded_types, precision, + env_protection, ) = self.param if precision == "float64": return 1e-10 @@ -188,6 +206,7 @@ def atol(self) -> float: type_one_side, excluded_types, precision, + env_protection, ) = self.param if precision == "float64": return 1e-10 diff --git a/source/tests/pt/NiO/data/data_0/set.000/box.npy b/source/tests/pt/NiO/data/data_0/set.000/box.npy new file mode 100644 index 0000000000..1f72eb7185 Binary files /dev/null and b/source/tests/pt/NiO/data/data_0/set.000/box.npy differ diff --git a/source/tests/pt/NiO/data/data_0/set.000/coord.npy b/source/tests/pt/NiO/data/data_0/set.000/coord.npy new file mode 100644 index 0000000000..4b60ae0e0b Binary files /dev/null and b/source/tests/pt/NiO/data/data_0/set.000/coord.npy differ diff --git a/source/tests/pt/NiO/data/data_0/set.000/energy.npy b/source/tests/pt/NiO/data/data_0/set.000/energy.npy new file mode 100644 index 0000000000..8754b6dad2 Binary files /dev/null and b/source/tests/pt/NiO/data/data_0/set.000/energy.npy differ diff --git a/source/tests/pt/NiO/data/data_0/set.000/force.npy b/source/tests/pt/NiO/data/data_0/set.000/force.npy new file mode 100644 index 0000000000..e95173d561 Binary files /dev/null and b/source/tests/pt/NiO/data/data_0/set.000/force.npy differ diff --git a/source/tests/pt/NiO/data/data_0/set.000/force_mag.npy b/source/tests/pt/NiO/data/data_0/set.000/force_mag.npy new file mode 100644 index 0000000000..65bc1ef837 Binary files /dev/null and b/source/tests/pt/NiO/data/data_0/set.000/force_mag.npy differ diff --git a/source/tests/pt/NiO/data/data_0/set.000/spin.npy b/source/tests/pt/NiO/data/data_0/set.000/spin.npy new file mode 100644 index 0000000000..c426f1c7f6 Binary files /dev/null and b/source/tests/pt/NiO/data/data_0/set.000/spin.npy differ diff --git a/source/tests/pt/NiO/data/data_0/type.raw b/source/tests/pt/NiO/data/data_0/type.raw new file mode 100644 index 0000000000..d9664c7a22 --- /dev/null +++ b/source/tests/pt/NiO/data/data_0/type.raw @@ -0,0 +1,32 @@ +0 +0 +0 +0 +0 +0 +0 +0 +0 +0 +0 +0 +0 +0 +0 +0 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 diff --git a/source/tests/pt/NiO/data/data_0/type_map.raw b/source/tests/pt/NiO/data/data_0/type_map.raw new file mode 100644 index 0000000000..7eca995c31 --- /dev/null +++ b/source/tests/pt/NiO/data/data_0/type_map.raw @@ -0,0 +1,2 @@ +Ni +O diff --git a/source/tests/pt/NiO/data/single/set.000/box.npy b/source/tests/pt/NiO/data/single/set.000/box.npy new file mode 100644 index 0000000000..d3ac265aa8 Binary files /dev/null and b/source/tests/pt/NiO/data/single/set.000/box.npy differ diff --git a/source/tests/pt/NiO/data/single/set.000/coord.npy b/source/tests/pt/NiO/data/single/set.000/coord.npy new file mode 100644 index 0000000000..4060f0fc53 Binary files /dev/null and b/source/tests/pt/NiO/data/single/set.000/coord.npy differ diff --git a/source/tests/pt/NiO/data/single/set.000/energy.npy b/source/tests/pt/NiO/data/single/set.000/energy.npy new file mode 100644 index 0000000000..fd7d1420ee Binary files /dev/null and b/source/tests/pt/NiO/data/single/set.000/energy.npy differ diff --git a/source/tests/pt/NiO/data/single/set.000/force.npy b/source/tests/pt/NiO/data/single/set.000/force.npy new file mode 100644 index 0000000000..c5c238d200 Binary files /dev/null and b/source/tests/pt/NiO/data/single/set.000/force.npy differ diff --git a/source/tests/pt/NiO/data/single/set.000/force_mag.npy b/source/tests/pt/NiO/data/single/set.000/force_mag.npy new file mode 100644 index 0000000000..3f0323ad8e Binary files /dev/null and b/source/tests/pt/NiO/data/single/set.000/force_mag.npy differ diff --git a/source/tests/pt/NiO/data/single/set.000/spin.npy b/source/tests/pt/NiO/data/single/set.000/spin.npy new file mode 100644 index 0000000000..88985f5d2c Binary files /dev/null and b/source/tests/pt/NiO/data/single/set.000/spin.npy differ diff --git a/source/tests/pt/NiO/data/single/type.raw b/source/tests/pt/NiO/data/single/type.raw new file mode 100644 index 0000000000..d9664c7a22 --- /dev/null +++ b/source/tests/pt/NiO/data/single/type.raw @@ -0,0 +1,32 @@ +0 +0 +0 +0 +0 +0 +0 +0 +0 +0 +0 +0 +0 +0 +0 +0 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 diff --git a/source/tests/pt/NiO/data/single/type_map.raw b/source/tests/pt/NiO/data/single/type_map.raw new file mode 100644 index 0000000000..7eca995c31 --- /dev/null +++ b/source/tests/pt/NiO/data/single/type_map.raw @@ -0,0 +1,2 @@ +Ni +O diff --git a/source/tests/pt/model/test_autodiff.py b/source/tests/pt/model/test_autodiff.py index c32f202625..91fc3cabf6 100644 --- a/source/tests/pt/model/test_autodiff.py +++ b/source/tests/pt/model/test_autodiff.py @@ -7,11 +7,13 @@ from deepmd.pt.model.model import ( get_model, - get_zbl_model, ) from deepmd.pt.utils import ( env, ) +from deepmd.pt.utils.utils import ( + to_numpy_array, +) dtype = torch.float64 @@ -21,6 +23,7 @@ model_dpa2, model_hybrid, model_se_e2_a, + model_spin, model_zbl, ) @@ -59,34 +62,64 @@ def test( cell = (cell + cell.T) + 5.0 * torch.eye(3, device="cpu") coord = torch.rand([natoms, 3], dtype=dtype, device="cpu") coord = torch.matmul(coord, cell) + spin = torch.rand([natoms, 3], dtype=dtype, device="cpu") atype = torch.IntTensor([0, 0, 0, 1, 1]) # assumes input to be numpy tensor coord = coord.numpy() - - def np_infer( + spin = spin.numpy() + test_spin = getattr(self, "test_spin", False) + if not test_spin: + test_keys = ["energy", "force", "virial"] + else: + test_keys = ["energy", "force", "force_mag", "virial"] + + def np_infer_coord( coord, ): - e0, f0, v0 = eval_model( + result = eval_model( self.model, torch.tensor(coord, device=env.DEVICE).unsqueeze(0), cell.unsqueeze(0), atype, + spins=torch.tensor(spin, device=env.DEVICE).unsqueeze(0), ) - ret = { - "energy": e0.squeeze(0), - "force": f0.squeeze(0), - "virial": v0.squeeze(0), - } # detach - ret = {kk: ret[kk].detach().cpu().numpy() for kk in ret} + ret = {key: to_numpy_array(result[key].squeeze(0)) for key in test_keys} return ret - def ff(_coord): - return np_infer(_coord)["energy"] + def np_infer_spin( + spin, + ): + result = eval_model( + self.model, + torch.tensor(coord, device=env.DEVICE).unsqueeze(0), + cell.unsqueeze(0), + atype, + spins=torch.tensor(spin, device=env.DEVICE).unsqueeze(0), + ) + # detach + ret = {key: to_numpy_array(result[key].squeeze(0)) for key in test_keys} + return ret - fdf = -finite_difference(ff, coord, delta=delta).squeeze() - rff = np_infer(coord)["force"] - np.testing.assert_almost_equal(fdf, rff, decimal=places) + def ff_coord(_coord): + return np_infer_coord(_coord)["energy"] + + def ff_spin(_spin): + return np_infer_spin(_spin)["energy"] + + if not test_spin: + fdf = -finite_difference(ff_coord, coord, delta=delta).squeeze() + rff = np_infer_coord(coord)["force"] + np.testing.assert_almost_equal(fdf, rff, decimal=places) + else: + # real force + fdf = -finite_difference(ff_coord, coord, delta=delta).squeeze() + rff = np_infer_coord(coord)["force"] + np.testing.assert_almost_equal(fdf, rff, decimal=places) + # magnetic force + fdf = -finite_difference(ff_spin, spin, delta=delta).squeeze() + rff = np_infer_spin(spin)["force_mag"] + np.testing.assert_almost_equal(fdf, rff, decimal=places) class VirialTest: @@ -104,11 +137,12 @@ def test( # assumes input to be numpy tensor coord = coord.numpy() cell = cell.numpy() + test_keys = ["energy", "force", "virial"] def np_infer( new_cell, ): - e0, f0, v0 = eval_model( + result = eval_model( self.model, torch.tensor( stretch_box(coord, cell, new_cell), device="cpu" @@ -116,13 +150,9 @@ def np_infer( torch.tensor(new_cell, device="cpu").unsqueeze(0), atype, ) - ret = { - "energy": e0.squeeze(0), - "force": f0.squeeze(0), - "virial": v0.squeeze(0), - } # detach - ret = {kk: ret[kk].detach().cpu().numpy() for kk in ret} + ret = {key: to_numpy_array(result[key].squeeze(0)) for key in test_keys} + # detach return ret def ff(bb): @@ -211,11 +241,19 @@ class TestEnergyModelZBLForce(unittest.TestCase, ForceTest): def setUp(self): model_params = copy.deepcopy(model_zbl) self.type_split = False - self.model = get_zbl_model(model_params).to(env.DEVICE) + self.model = get_model(model_params).to(env.DEVICE) class TestEnergyModelZBLVirial(unittest.TestCase, VirialTest): def setUp(self): model_params = copy.deepcopy(model_zbl) self.type_split = False - self.model = get_zbl_model(model_params).to(env.DEVICE) + self.model = get_model(model_params).to(env.DEVICE) + + +class TestEnergyModelSpinSeAForce(unittest.TestCase, ForceTest): + def setUp(self): + model_params = copy.deepcopy(model_spin) + self.type_split = False + self.test_spin = True + self.model = get_model(model_params).to(env.DEVICE) diff --git a/source/tests/pt/model/test_deeppot.py b/source/tests/pt/model/test_deeppot.py index 697ebb6411..68b1ff65d5 100644 --- a/source/tests/pt/model/test_deeppot.py +++ b/source/tests/pt/model/test_deeppot.py @@ -95,7 +95,8 @@ def test_dp_test(self): ).reshape(1, -1, 3) atype = np.array([0, 0, 0, 1, 1]).reshape(1, -1) - e, f, v, ae, av = dp.eval(coord, cell, atype, atomic=True) + ret = dp.eval(coord, cell, atype, atomic=True) + e, f, v, ae, av = ret[0], ret[1], ret[2], ret[3], ret[4] self.assertEqual(e.shape, (1, 1)) self.assertEqual(f.shape, (1, 5, 3)) self.assertEqual(v.shape, (1, 9)) diff --git a/source/tests/pt/model/test_embedding_net.py b/source/tests/pt/model/test_embedding_net.py index a1895718dd..63a3534c74 100644 --- a/source/tests/pt/model/test_embedding_net.py +++ b/source/tests/pt/model/test_embedding_net.py @@ -56,13 +56,22 @@ def get_single_batch(dataset, index=None): np_batch = dataset[index] pt_batch = {} - for key in ["coord", "box", "force", "energy", "virial", "atype", "natoms"]: + for key in [ + "coord", + "box", + "force", + "force_mag", + "energy", + "virial", + "atype", + "natoms", + ]: if key in np_batch.keys(): np_batch[key] = np.expand_dims(np_batch[key], axis=0) pt_batch[key] = torch.as_tensor(np_batch[key], device=env.DEVICE) - np_batch["coord"] = np_batch["coord"].reshape(1, -1) + if key in ["coord", "force", "force_mag"]: + np_batch[key] = np_batch[key].reshape(1, -1) np_batch["natoms"] = np_batch["natoms"][0] - np_batch["force"] = np_batch["force"].reshape(1, -1) return np_batch, pt_batch diff --git a/source/tests/pt/model/test_ener_spin_model.py b/source/tests/pt/model/test_ener_spin_model.py new file mode 100644 index 0000000000..2bd5c22aaf --- /dev/null +++ b/source/tests/pt/model/test_ener_spin_model.py @@ -0,0 +1,420 @@ +# SPDX-License-Identifier: LGPL-3.0-or-later +import copy +import unittest + +import numpy as np +import torch + +from deepmd.dpmodel.model import SpinModel as DPSpinModel +from deepmd.pt.model.model import ( + SpinEnergyModel, + get_model, +) +from deepmd.pt.utils import ( + env, +) +from deepmd.pt.utils.nlist import ( + extend_input_and_build_neighbor_list, +) +from deepmd.pt.utils.utils import ( + to_numpy_array, +) + +from .test_permutation import ( + model_dpa1, + model_dpa2, + model_se_e2_a, + model_spin, +) + +dtype = torch.float64 + + +def reduce_tensor(extended_tensor, mapping, nloc: int): + nframes, nall = extended_tensor.shape[:2] + ext_dims = extended_tensor.shape[2:] + reduced_tensor = torch.zeros( + [nframes, nloc, *ext_dims], + dtype=extended_tensor.dtype, + device=extended_tensor.device, + ) + mldims = list(mapping.shape) + mapping = mapping.view(mldims + [1] * len(ext_dims)).expand( + [-1] * len(mldims) + list(ext_dims) + ) + # nf x nloc x (*ext_dims) + reduced_tensor = torch.scatter_reduce( + reduced_tensor, + 1, + index=mapping, + src=extended_tensor, + reduce="sum", + ) + return reduced_tensor + + +class SpinTest: + def setUp(self): + self.prec = 1e-10 + natoms = 5 + self.ntypes = 3 # ["O", "H", "B"] for test + self.cell = 4.0 * torch.eye(3, dtype=dtype, device=env.DEVICE).unsqueeze(0) + self.coord = 3.0 * torch.rand( + [natoms, 3], dtype=dtype, device=env.DEVICE + ).unsqueeze(0) + self.spin = 0.5 * torch.rand( + [natoms, 3], dtype=dtype, device=env.DEVICE + ).unsqueeze(0) + self.atype = torch.tensor( + [0, 0, 0, 1, 1], dtype=torch.int64, device=env.DEVICE + ).unsqueeze(0) + + self.expected_mask = torch.tensor( + [ + [True], + [True], + [True], + [False], + [False], + ], + dtype=torch.bool, + device=env.DEVICE, + ).unsqueeze(0) + self.expected_atype_with_spin = torch.tensor( + [0, 0, 0, 1, 1, 3, 3, 3, 4, 4], dtype=torch.int64, device=env.DEVICE + ).unsqueeze(0) + self.expected_nloc_spin_index = ( + torch.arange(natoms, natoms * 2, dtype=torch.int64, device=env.DEVICE) + .unsqueeze(0) + .unsqueeze(-1) + ) + + def test_output_shape( + self, + ): + result = self.model( + self.coord, + self.atype, + self.spin, + self.cell, + ) + # check magnetic mask + torch.testing.assert_close(result["mask_mag"], self.expected_mask) + # check output shape to assure split + nframes, nloc = self.coord.shape[:2] + torch.testing.assert_close(result["energy"].shape, [nframes, 1]) + torch.testing.assert_close(result["atom_energy"].shape, [nframes, nloc, 1]) + torch.testing.assert_close(result["force"].shape, [nframes, nloc, 3]) + torch.testing.assert_close(result["force_mag"].shape, [nframes, nloc, 3]) + + def test_input_output_process(self): + nframes, nloc = self.coord.shape[:2] + self.real_ntypes = self.model.spin.get_ntypes_real() + # 1. test forward input process + coord_updated, atype_updated = self.model.process_spin_input( + self.coord, self.atype, self.spin + ) + # compare atypes of real and virtual atoms + torch.testing.assert_close(atype_updated, self.expected_atype_with_spin) + # compare coords of real and virtual atoms + torch.testing.assert_close(coord_updated.shape, [nframes, nloc * 2, 3]) + torch.testing.assert_close(coord_updated[:, :nloc], self.coord) + virtual_scale = torch.tensor( + self.model.spin.get_virtual_scale_mask()[self.atype.cpu()], + dtype=dtype, + device=env.DEVICE, + ) + virtual_coord = self.coord + self.spin * virtual_scale.unsqueeze(-1) + torch.testing.assert_close(coord_updated[:, nloc:], virtual_coord) + + # 2. test forward output process + model_ret = self.model.backbone_model.forward_common( + coord_updated, + atype_updated, + self.cell, + do_atomic_virial=True, + ) + if self.model.do_grad_r("energy"): + force_all = model_ret["energy_derv_r"].squeeze(-2) + force_real, force_mag, _ = self.model.process_spin_output( + self.atype, force_all + ) + torch.testing.assert_close( + force_real, force_all[:, :nloc] + force_all[:, nloc:] + ) + torch.testing.assert_close( + force_mag, force_all[:, nloc:] * virtual_scale.unsqueeze(-1) + ) + + # 3. test forward_lower input process + ( + extended_coord, + extended_atype, + mapping, + nlist, + ) = extend_input_and_build_neighbor_list( + self.coord, + self.atype, + self.model.get_rcut(), + self.model.get_sel(), + mixed_types=self.model.mixed_types(), + box=self.cell, + ) + nall = extended_coord.shape[1] + nnei = nlist.shape[-1] + extended_spin = torch.gather( + self.spin, index=mapping.unsqueeze(-1).tile((1, 1, 3)), dim=1 + ) + ( + extended_coord_updated, + extended_atype_updated, + nlist_updated, + mapping_updated, + ) = self.model.process_spin_input_lower( + extended_coord, extended_atype, extended_spin, nlist, mapping=mapping + ) + # compare atypes of real and virtual atoms + # Note that the real and virtual atoms corresponding to the local ones are switch to the first nloc * 2 atoms + torch.testing.assert_close(extended_atype_updated.shape, [nframes, nall * 2]) + torch.testing.assert_close( + extended_atype_updated[:, :nloc], extended_atype[:, :nloc] + ) + torch.testing.assert_close( + extended_atype_updated[:, nloc : nloc + nloc], + extended_atype[:, :nloc] + self.real_ntypes, + ) + torch.testing.assert_close( + extended_atype_updated[:, nloc + nloc : nloc + nall], + extended_atype[:, nloc:nall], + ) + torch.testing.assert_close( + extended_atype_updated[:, nloc + nall :], + extended_atype[:, nloc:nall] + self.real_ntypes, + ) + virtual_scale = torch.tensor( + self.model.spin.get_virtual_scale_mask()[extended_atype.cpu()], + dtype=dtype, + device=env.DEVICE, + ) + # compare coords of real and virtual atoms + virtual_coord = extended_coord + extended_spin * virtual_scale.unsqueeze(-1) + torch.testing.assert_close(extended_coord_updated.shape, [nframes, nall * 2, 3]) + torch.testing.assert_close( + extended_coord_updated[:, :nloc], extended_coord[:, :nloc] + ) + torch.testing.assert_close( + extended_coord_updated[:, nloc : nloc + nloc], virtual_coord[:, :nloc] + ) + torch.testing.assert_close( + extended_coord_updated[:, nloc + nloc : nloc + nall], + extended_coord[:, nloc:nall], + ) + torch.testing.assert_close( + extended_coord_updated[:, nloc + nall :], virtual_coord[:, nloc:nall] + ) + + # compare mapping + torch.testing.assert_close(mapping_updated.shape, [nframes, nall * 2]) + torch.testing.assert_close(mapping_updated[:, :nloc], mapping[:, :nloc]) + torch.testing.assert_close( + mapping_updated[:, nloc : nloc + nloc], mapping[:, :nloc] + nloc + ) + torch.testing.assert_close( + mapping_updated[:, nloc + nloc : nloc + nall], mapping[:, nloc:nall] + ) + torch.testing.assert_close( + mapping_updated[:, nloc + nall :], mapping[:, nloc:nall] + nloc + ) + + # compare nlist + torch.testing.assert_close( + nlist_updated.shape, [nframes, nloc * 2, nnei * 2 + 1] + ) + # self spin + torch.testing.assert_close( + nlist_updated[:, :nloc, :1], self.expected_nloc_spin_index + ) + # real and virtual neighbors + loc_atoms_mask = (nlist < nloc) & (nlist != -1) + ghost_atoms_mask = nlist >= nloc + real_neighbors = nlist.clone() + real_neighbors[ghost_atoms_mask] += nloc + torch.testing.assert_close( + nlist_updated[:, :nloc, 1 : 1 + nnei], real_neighbors + ) + virtual_neighbors = nlist.clone() + virtual_neighbors[loc_atoms_mask] += nloc + virtual_neighbors[ghost_atoms_mask] += nall + torch.testing.assert_close( + nlist_updated[:, :nloc, 1 + nnei :], virtual_neighbors + ) + + # 4. test forward_lower output process + model_ret = self.model.backbone_model.forward_common_lower( + extended_coord_updated, + extended_atype_updated, + nlist_updated, + mapping=mapping_updated, + do_atomic_virial=True, + ) + if self.model.do_grad_r("energy"): + force_all = model_ret["energy_derv_r"].squeeze(-2) + force_real, force_mag, _ = self.model.process_spin_output_lower( + extended_atype, force_all, nloc + ) + force_all_switched = torch.zeros_like(force_all) + force_all_switched[:, :nloc] = force_all[:, :nloc] + force_all_switched[:, nloc:nall] = force_all[:, nloc + nloc : nloc + nall] + force_all_switched[:, nall : nall + nloc] = force_all[:, nloc : nloc + nloc] + force_all_switched[:, nall + nloc :] = force_all[:, nloc + nall :] + torch.testing.assert_close( + force_real, force_all_switched[:, :nall] + force_all_switched[:, nall:] + ) + torch.testing.assert_close( + force_mag, force_all_switched[:, nall:] * virtual_scale.unsqueeze(-1) + ) + + def test_jit(self): + model = torch.jit.script(self.model) + self.assertEqual(model.get_rcut(), self.rcut) + self.assertEqual(model.get_nsel(), self.nsel) + self.assertEqual(model.get_type_map(), self.type_map) + + def test_self_consistency(self): + if hasattr(self, "serial_test") and not self.serial_test: + # not implement serialize and deserialize + return + model1 = SpinEnergyModel.deserialize(self.model.serialize()) + result = model1( + self.coord, + self.atype, + self.spin, + self.cell, + ) + expected_result = self.model( + self.coord, + self.atype, + self.spin, + self.cell, + ) + for key in result: + torch.testing.assert_close( + result[key], expected_result[key], rtol=self.prec, atol=self.prec + ) + model1 = torch.jit.script(model1) + + def test_dp_consistency(self): + if hasattr(self, "serial_test") and not self.serial_test: + # not implement serialize and deserialize + return + dp_model = DPSpinModel.deserialize(self.model.serialize()) + # test call + dp_ret = dp_model.call( + to_numpy_array(self.coord), + to_numpy_array(self.atype), + to_numpy_array(self.spin), + to_numpy_array(self.cell), + ) + result = self.model.forward_common( + self.coord, + self.atype, + self.spin, + self.cell, + ) + np.testing.assert_allclose( + to_numpy_array(result["energy"]), + dp_ret["energy"], + rtol=self.prec, + atol=self.prec, + ) + np.testing.assert_allclose( + to_numpy_array(result["energy_redu"]), + dp_ret["energy_redu"], + rtol=self.prec, + atol=self.prec, + ) + + # test call_lower + ( + extended_coord, + extended_atype, + mapping, + nlist, + ) = extend_input_and_build_neighbor_list( + self.coord, + self.atype, + self.model.get_rcut(), + self.model.get_sel(), + mixed_types=self.model.mixed_types(), + box=self.cell, + ) + extended_spin = torch.gather( + self.spin, index=mapping.unsqueeze(-1).tile((1, 1, 3)), dim=1 + ) + dp_ret_lower = dp_model.call_lower( + to_numpy_array(extended_coord), + to_numpy_array(extended_atype), + to_numpy_array(extended_spin), + to_numpy_array(nlist), + to_numpy_array(mapping), + ) + result_lower = self.model.forward_common_lower( + extended_coord, + extended_atype, + extended_spin, + nlist, + mapping, + ) + np.testing.assert_allclose( + to_numpy_array(result_lower["energy"]), + dp_ret_lower["energy"], + rtol=self.prec, + atol=self.prec, + ) + np.testing.assert_allclose( + to_numpy_array(result_lower["energy_redu"]), + dp_ret_lower["energy_redu"], + rtol=self.prec, + atol=self.prec, + ) + + +class TestEnergyModelSpinSeA(unittest.TestCase, SpinTest): + def setUp(self): + SpinTest.setUp(self) + model_params = copy.deepcopy(model_spin) + model_params["descriptor"] = copy.deepcopy(model_se_e2_a["descriptor"]) + self.rcut = model_params["descriptor"]["rcut"] + self.nsel = sum(model_params["descriptor"]["sel"]) + self.type_map = model_params["type_map"] + self.model = get_model(model_params).to(env.DEVICE) + + +class TestEnergyModelSpinDPA1(unittest.TestCase, SpinTest): + def setUp(self): + SpinTest.setUp(self) + model_params = copy.deepcopy(model_spin) + model_params["descriptor"] = copy.deepcopy(model_dpa1["descriptor"]) + self.rcut = model_params["descriptor"]["rcut"] + self.nsel = model_params["descriptor"]["sel"] + self.type_map = model_params["type_map"] + # not implement serialize and deserialize + self.serial_test = False + self.model = get_model(model_params).to(env.DEVICE) + + +class TestEnergyModelSpinDPA2(unittest.TestCase, SpinTest): + def setUp(self): + SpinTest.setUp(self) + model_params = copy.deepcopy(model_spin) + model_params["descriptor"] = copy.deepcopy(model_dpa2["descriptor"]) + self.rcut = model_params["descriptor"]["repinit_rcut"] + self.nsel = model_params["descriptor"]["repinit_nsel"] + self.type_map = model_params["type_map"] + # not implement serialize and deserialize + self.serial_test = False + self.model = get_model(model_params).to(env.DEVICE) + + +if __name__ == "__main__": + unittest.main() diff --git a/source/tests/pt/model/test_forward_lower.py b/source/tests/pt/model/test_forward_lower.py new file mode 100644 index 0000000000..32be3b62ad --- /dev/null +++ b/source/tests/pt/model/test_forward_lower.py @@ -0,0 +1,177 @@ +# SPDX-License-Identifier: LGPL-3.0-or-later +import copy +import unittest + +import torch + +from deepmd.pt.infer.deep_eval import ( + eval_model, +) +from deepmd.pt.model.model import ( + get_model, +) +from deepmd.pt.utils import ( + env, +) +from deepmd.pt.utils.nlist import ( + extend_input_and_build_neighbor_list, +) + +from .test_permutation import ( # model_dpau, + model_dpa1, + model_dpa2, + model_se_e2_a, + model_spin, + model_zbl, +) + +dtype = torch.float64 + + +def reduce_tensor(extended_tensor, mapping, nloc: int): + nframes, nall = extended_tensor.shape[:2] + ext_dims = extended_tensor.shape[2:] + reduced_tensor = torch.zeros( + [nframes, nloc, *ext_dims], + dtype=extended_tensor.dtype, + device=extended_tensor.device, + ) + mldims = list(mapping.shape) + mapping = mapping.view(mldims + [1] * len(ext_dims)).expand( + [-1] * len(mldims) + list(ext_dims) + ) + # nf x nloc x (*ext_dims) + reduced_tensor = torch.scatter_reduce( + reduced_tensor, + 1, + index=mapping, + src=extended_tensor, + reduce="sum", + ) + return reduced_tensor + + +class ForwardLowerTest: + def test( + self, + ): + prec = self.prec + natoms = 5 + cell = 4.0 * torch.eye(3, dtype=dtype, device=env.DEVICE) + coord = 3.0 * torch.rand([natoms, 3], dtype=dtype, device=env.DEVICE) + spin = 0.5 * torch.rand([natoms, 3], dtype=dtype, device=env.DEVICE) + atype = torch.tensor([0, 0, 0, 1, 1], dtype=torch.int64, device=env.DEVICE) + test_spin = getattr(self, "test_spin", False) + if not test_spin: + test_keys = ["energy", "force", "virial"] + else: + test_keys = ["energy", "force", "force_mag"] + + result_forward = eval_model( + self.model, + coord.unsqueeze(0), + cell.unsqueeze(0), + atype, + spins=spin.unsqueeze(0), + ) + ( + extended_coord, + extended_atype, + mapping, + nlist, + ) = extend_input_and_build_neighbor_list( + coord.unsqueeze(0), + atype.unsqueeze(0), + self.model.get_rcut(), + self.model.get_sel(), + mixed_types=self.model.mixed_types(), + box=cell.unsqueeze(0), + ) + extended_spin = torch.gather( + spin.unsqueeze(0), index=mapping.unsqueeze(-1).tile((1, 1, 3)), dim=1 + ) + input_dict = { + "extended_coord": extended_coord, + "extended_atype": extended_atype, + "nlist": nlist, + "mapping": mapping, + "do_atomic_virial": False, + } + if test_spin: + input_dict["extended_spin"] = extended_spin + result_forward_lower = self.model.forward_lower(**input_dict) + for key in test_keys: + if key in ["energy"]: + torch.testing.assert_close( + result_forward_lower[key], result_forward[key], rtol=prec, atol=prec + ) + elif key in ["force", "force_mag"]: + reduced_vv = reduce_tensor( + result_forward_lower[f"extended_{key}"], mapping, natoms + ) + torch.testing.assert_close( + reduced_vv, result_forward[key], rtol=prec, atol=prec + ) + elif key == "virial": + if not hasattr(self, "test_virial") or self.test_virial: + torch.testing.assert_close( + result_forward_lower[key], + result_forward[key], + rtol=prec, + atol=prec, + ) + else: + raise RuntimeError(f"Unexpected test key {key}") + + +class TestEnergyModelSeA(unittest.TestCase, ForwardLowerTest): + def setUp(self): + self.prec = 1e-10 + model_params = copy.deepcopy(model_se_e2_a) + self.type_split = False + self.model = get_model(model_params).to(env.DEVICE) + + +class TestEnergyModelDPA1(unittest.TestCase, ForwardLowerTest): + def setUp(self): + self.prec = 1e-10 + model_params = copy.deepcopy(model_dpa1) + self.type_split = True + self.model = get_model(model_params).to(env.DEVICE) + + +class TestEnergyModelDPA2(unittest.TestCase, ForwardLowerTest): + def setUp(self): + self.prec = 1e-10 + model_params_sample = copy.deepcopy(model_dpa2) + model_params_sample["descriptor"]["rcut"] = model_params_sample["descriptor"][ + "repinit_rcut" + ] + model_params_sample["descriptor"]["sel"] = model_params_sample["descriptor"][ + "repinit_nsel" + ] + model_params = copy.deepcopy(model_dpa2) + self.type_split = True + self.model = get_model(model_params).to(env.DEVICE) + + +class TestEnergyModelZBL(unittest.TestCase, ForwardLowerTest): + def setUp(self): + self.prec = 1e-10 + model_params = copy.deepcopy(model_zbl) + self.type_split = False + self.model = get_model(model_params).to(env.DEVICE) + + +class TestEnergyModelSpinSeA(unittest.TestCase, ForwardLowerTest): + def setUp(self): + # still need to figure out why only 1e-5 rtol and atol + self.prec = 1e-5 + model_params = copy.deepcopy(model_spin) + self.type_split = False + self.test_spin = True + self.model = get_model(model_params).to(env.DEVICE) + + +if __name__ == "__main__": + unittest.main() diff --git a/source/tests/pt/model/test_null_input.py b/source/tests/pt/model/test_null_input.py index eb8ff714e8..c8f4307d52 100644 --- a/source/tests/pt/model/test_null_input.py +++ b/source/tests/pt/model/test_null_input.py @@ -41,14 +41,9 @@ def test_nloc_1( cell = (cell + cell.T) + 100.0 * torch.eye(3, device=env.DEVICE) coord = torch.rand([natoms, 3], dtype=dtype, device=env.DEVICE) atype = torch.tensor([0], dtype=torch.int32, device=env.DEVICE) - e0, f0, v0 = eval_model( - self.model, coord.unsqueeze(0), cell.unsqueeze(0), atype - ) - ret0 = { - "energy": e0.squeeze(0), - "force": f0.squeeze(0), - "virial": v0.squeeze(0), - } + test_keys = ["energy", "force", "virial"] + result = eval_model(self.model, coord.unsqueeze(0), cell.unsqueeze(0), atype) + ret0 = {key: result[key].squeeze(0) for key in test_keys} prec = 1e-10 expect_e_shape = [1] expect_f = torch.zeros([natoms, 3], dtype=dtype, device=env.DEVICE) @@ -70,14 +65,9 @@ def test_nloc_2_far( # 2 far-away atoms coord = torch.cat([coord, coord + 100.0], dim=0) atype = torch.tensor([0, 2], dtype=torch.int32, device=env.DEVICE) - e0, f0, v0 = eval_model( - self.model, coord.unsqueeze(0), cell.unsqueeze(0), atype - ) - ret0 = { - "energy": e0.squeeze(0), - "force": f0.squeeze(0), - "virial": v0.squeeze(0), - } + test_keys = ["energy", "force", "virial"] + result = eval_model(self.model, coord.unsqueeze(0), cell.unsqueeze(0), atype) + ret0 = {key: result[key].squeeze(0) for key in test_keys} prec = 1e-10 expect_e_shape = [1] expect_f = torch.zeros([natoms, 3], dtype=dtype, device=env.DEVICE) diff --git a/source/tests/pt/model/test_permutation.py b/source/tests/pt/model/test_permutation.py index fa97281718..8ec5c375fd 100644 --- a/source/tests/pt/model/test_permutation.py +++ b/source/tests/pt/model/test_permutation.py @@ -9,7 +9,6 @@ ) from deepmd.pt.model.model import ( get_model, - get_zbl_model, ) from deepmd.pt.utils import ( env, @@ -23,7 +22,7 @@ "type": "se_e2_a", "sel": [46, 92, 4], "rcut_smth": 0.50, - "rcut": 6.00, + "rcut": 4.00, "neuron": [25, 50, 100], "resnet_dt": False, "axis_neuron": 16, @@ -61,6 +60,31 @@ "data_stat_nbatch": 20, } +model_spin = { + "type_map": ["O", "H", "B"], + "descriptor": { + "type": "se_e2_a", + "sel": [46, 92, 4], + "rcut_smth": 0.50, + "rcut": 4.00, + "neuron": [25, 50, 100], + "resnet_dt": False, + "axis_neuron": 16, + "seed": 1, + }, + "fitting_net": { + "neuron": [24, 24, 24], + "resnet_dt": True, + "seed": 1, + }, + "data_stat_nbatch": 20, + "spin": { + "use_spin": [True, False, False], + "virtual_scale": [0.3140], + "_comment": " that's all", + }, +} + model_dpa2 = { "type_map": ["O", "H", "B"], "descriptor": { @@ -205,34 +229,46 @@ def test( cell = torch.rand([3, 3], dtype=dtype, device=env.DEVICE) cell = (cell + cell.T) + 5.0 * torch.eye(3, device=env.DEVICE) coord = torch.rand([natoms, 3], dtype=dtype, device=env.DEVICE) + spin = torch.rand([natoms, 3], dtype=dtype, device=env.DEVICE) coord = torch.matmul(coord, cell) atype = torch.tensor([0, 0, 0, 1, 1], dtype=torch.int32, device=env.DEVICE) idx_perm = [1, 0, 4, 3, 2] - e0, f0, v0 = eval_model( - self.model, coord.unsqueeze(0), cell.unsqueeze(0), atype + test_spin = getattr(self, "test_spin", False) + if not test_spin: + test_keys = ["energy", "force", "virial"] + else: + test_keys = ["energy", "force", "force_mag", "virial"] + result_0 = eval_model( + self.model, + coord.unsqueeze(0), + cell.unsqueeze(0), + atype, + spins=spin.unsqueeze(0), ) - ret0 = { - "energy": e0.squeeze(0), - "force": f0.squeeze(0), - "virial": v0.squeeze(0), - } - e1, f1, v1 = eval_model( - self.model, coord[idx_perm].unsqueeze(0), cell.unsqueeze(0), atype[idx_perm] + ret0 = {key: result_0[key].squeeze(0) for key in test_keys} + result_1 = eval_model( + self.model, + coord[idx_perm].unsqueeze(0), + cell.unsqueeze(0), + atype[idx_perm], + spins=spin[idx_perm].unsqueeze(0), ) - ret1 = { - "energy": e1.squeeze(0), - "force": f1.squeeze(0), - "virial": v1.squeeze(0), - } + ret1 = {key: result_1[key].squeeze(0) for key in test_keys} prec = 1e-10 - torch.testing.assert_close(ret0["energy"], ret1["energy"], rtol=prec, atol=prec) - torch.testing.assert_close( - ret0["force"][idx_perm], ret1["force"], rtol=prec, atol=prec - ) - if not hasattr(self, "test_virial") or self.test_virial: - torch.testing.assert_close( - ret0["virial"], ret1["virial"], rtol=prec, atol=prec - ) + for key in test_keys: + if key in ["energy"]: + torch.testing.assert_close(ret0[key], ret1[key], rtol=prec, atol=prec) + elif key in ["force", "force_mag"]: + torch.testing.assert_close( + ret0[key][idx_perm], ret1[key], rtol=prec, atol=prec + ) + elif key == "virial": + if not hasattr(self, "test_virial") or self.test_virial: + torch.testing.assert_close( + ret0[key], ret1[key], rtol=prec, atol=prec + ) + else: + raise RuntimeError(f"Unexpected test key {key}") class TestEnergyModelSeA(unittest.TestCase, PermutationTest): @@ -299,7 +335,15 @@ class TestEnergyModelZBL(unittest.TestCase, PermutationTest): def setUp(self): model_params = copy.deepcopy(model_zbl) self.type_split = False - self.model = get_zbl_model(model_params).to(env.DEVICE) + self.model = get_model(model_params).to(env.DEVICE) + + +class TestEnergyModelSpinSeA(unittest.TestCase, PermutationTest): + def setUp(self): + model_params = copy.deepcopy(model_spin) + self.type_split = False + self.test_spin = True + self.model = get_model(model_params).to(env.DEVICE) # class TestEnergyFoo(unittest.TestCase): diff --git a/source/tests/pt/model/test_rot.py b/source/tests/pt/model/test_rot.py index 19f671e619..a12bd063b4 100644 --- a/source/tests/pt/model/test_rot.py +++ b/source/tests/pt/model/test_rot.py @@ -9,7 +9,6 @@ ) from deepmd.pt.model.model import ( get_model, - get_zbl_model, ) from deepmd.pt.utils import ( env, @@ -20,6 +19,7 @@ model_dpa2, model_hybrid, model_se_e2_a, + model_spin, model_zbl, ) @@ -34,80 +34,102 @@ def test( natoms = 5 cell = 10.0 * torch.eye(3, dtype=dtype, device=env.DEVICE) coord = 2 * torch.rand([natoms, 3], dtype=dtype, device=env.DEVICE) + spin = 2 * torch.rand([natoms, 3], dtype=dtype, device=env.DEVICE) shift = torch.tensor([4, 4, 4], dtype=dtype, device=env.DEVICE) atype = torch.tensor([0, 0, 0, 1, 1], dtype=torch.int32, device=env.DEVICE) from scipy.stats import ( special_ortho_group, ) + test_spin = getattr(self, "test_spin", False) + if not test_spin: + test_keys = ["energy", "force", "virial"] + else: + test_keys = ["energy", "force", "force_mag"] rmat = torch.tensor(special_ortho_group.rvs(3), dtype=dtype, device=env.DEVICE) # rotate only coord and shift to the center of cell coord_rot = torch.matmul(coord, rmat) - e0, f0, v0 = eval_model( - self.model, (coord + shift).unsqueeze(0), cell.unsqueeze(0), atype + spin_rot = torch.matmul(spin, rmat) + result_0 = eval_model( + self.model, + (coord + shift).unsqueeze(0), + cell.unsqueeze(0), + atype, + spins=spin.unsqueeze(0), ) - ret0 = { - "energy": e0.squeeze(0), - "force": f0.squeeze(0), - "virial": v0.squeeze(0), - } - e1, f1, v1 = eval_model( - self.model, (coord_rot + shift).unsqueeze(0), cell.unsqueeze(0), atype + ret0 = {key: result_0[key].squeeze(0) for key in test_keys} + result_1 = eval_model( + self.model, + (coord_rot + shift).unsqueeze(0), + cell.unsqueeze(0), + atype, + spins=spin_rot.unsqueeze(0), ) - ret1 = { - "energy": e1.squeeze(0), - "force": f1.squeeze(0), - "virial": v1.squeeze(0), - } - torch.testing.assert_close(ret0["energy"], ret1["energy"], rtol=prec, atol=prec) - torch.testing.assert_close( - torch.matmul(ret0["force"], rmat), ret1["force"], rtol=prec, atol=prec - ) - if not hasattr(self, "test_virial") or self.test_virial: - torch.testing.assert_close( - torch.matmul(rmat.T, torch.matmul(ret0["virial"].view([3, 3]), rmat)), - ret1["virial"].view([3, 3]), - rtol=prec, - atol=prec, - ) - + ret1 = {key: result_1[key].squeeze(0) for key in test_keys} + for key in test_keys: + if key in ["energy"]: + torch.testing.assert_close(ret0[key], ret1[key], rtol=prec, atol=prec) + elif key in ["force", "force_mag"]: + torch.testing.assert_close( + torch.matmul(ret0[key], rmat), ret1[key], rtol=prec, atol=prec + ) + elif key == "virial": + if not hasattr(self, "test_virial") or self.test_virial: + torch.testing.assert_close( + torch.matmul( + rmat.T, torch.matmul(ret0[key].view([3, 3]), rmat) + ), + ret1[key].view([3, 3]), + rtol=prec, + atol=prec, + ) + else: + raise RuntimeError(f"Unexpected test key {key}") # rotate coord and cell torch.manual_seed(0) cell = torch.rand([3, 3], dtype=dtype, device=env.DEVICE) cell = (cell + cell.T) + 5.0 * torch.eye(3, device=env.DEVICE) coord = torch.rand([natoms, 3], dtype=dtype, device=env.DEVICE) coord = torch.matmul(coord, cell) + spin = torch.rand([natoms, 3], dtype=dtype, device=env.DEVICE) atype = torch.tensor([0, 0, 0, 1, 1], dtype=torch.int32, device=env.DEVICE) coord_rot = torch.matmul(coord, rmat) + spin_rot = torch.matmul(spin, rmat) cell_rot = torch.matmul(cell, rmat) - e0, f0, v0 = eval_model( - self.model, coord.unsqueeze(0), cell.unsqueeze(0), atype - ) - ret0 = { - "energy": e0.squeeze(0), - "force": f0.squeeze(0), - "virial": v0.squeeze(0), - } - e1, f1, v1 = eval_model( - self.model, coord_rot.unsqueeze(0), cell_rot.unsqueeze(0), atype + result_0 = eval_model( + self.model, + coord.unsqueeze(0), + cell.unsqueeze(0), + atype, + spins=spin.unsqueeze(0), ) - ret1 = { - "energy": e1.squeeze(0), - "force": f1.squeeze(0), - "virial": v1.squeeze(0), - } - torch.testing.assert_close(ret0["energy"], ret1["energy"], rtol=prec, atol=prec) - torch.testing.assert_close( - torch.matmul(ret0["force"], rmat), ret1["force"], rtol=prec, atol=prec + ret0 = {key: result_0[key].squeeze(0) for key in test_keys} + result_1 = eval_model( + self.model, + coord_rot.unsqueeze(0), + cell_rot.unsqueeze(0), + atype, + spins=spin_rot.unsqueeze(0), ) - if not hasattr(self, "test_virial") or self.test_virial: - torch.testing.assert_close( - torch.matmul(rmat.T, torch.matmul(ret0["virial"].view([3, 3]), rmat)), - ret1["virial"].view([3, 3]), - rtol=prec, - atol=prec, - ) + ret1 = {key: result_1[key].squeeze(0) for key in test_keys} + for key in test_keys: + if key in ["energy"]: + torch.testing.assert_close(ret0[key], ret1[key], rtol=prec, atol=prec) + elif key in ["force", "force_mag"]: + torch.testing.assert_close( + torch.matmul(ret0[key], rmat), ret1[key], rtol=prec, atol=prec + ) + elif key == "virial": + if not hasattr(self, "test_virial") or self.test_virial: + torch.testing.assert_close( + torch.matmul( + rmat.T, torch.matmul(ret0[key].view([3, 3]), rmat) + ), + ret1[key].view([3, 3]), + rtol=prec, + atol=prec, + ) class TestEnergyModelSeA(unittest.TestCase, RotTest): @@ -174,7 +196,15 @@ class TestEnergyModelZBL(unittest.TestCase, RotTest): def setUp(self): model_params = copy.deepcopy(model_zbl) self.type_split = False - self.model = get_zbl_model(model_params).to(env.DEVICE) + self.model = get_model(model_params).to(env.DEVICE) + + +class TestEnergyModelSpinSeA(unittest.TestCase, RotTest): + def setUp(self): + model_params = copy.deepcopy(model_spin) + self.type_split = False + self.test_spin = True + self.model = get_model(model_params).to(env.DEVICE) if __name__ == "__main__": diff --git a/source/tests/pt/model/test_smooth.py b/source/tests/pt/model/test_smooth.py index bc1d26bffa..86e9ed94d7 100644 --- a/source/tests/pt/model/test_smooth.py +++ b/source/tests/pt/model/test_smooth.py @@ -9,7 +9,6 @@ ) from deepmd.pt.model.model import ( get_model, - get_zbl_model, ) from deepmd.pt.utils import ( env, @@ -20,6 +19,7 @@ model_dpa2, model_hybrid, model_se_e2_a, + model_spin, model_zbl, ) @@ -59,7 +59,7 @@ def test( ) coord1 = torch.matmul(coord1, cell) coord = torch.concat([coord0, coord1], dim=0) - + spin = torch.rand([natoms, 3], dtype=dtype, device=env.DEVICE) coord0 = torch.clone(coord) coord1 = torch.clone(coord) coord1[1][0] += epsilon @@ -68,52 +68,63 @@ def test( coord3 = torch.clone(coord) coord3[1][0] += epsilon coord3[2][1] += epsilon - - e0, f0, v0 = eval_model( - self.model, coord0.unsqueeze(0), cell.unsqueeze(0), atype + test_spin = getattr(self, "test_spin", False) + if not test_spin: + test_keys = ["energy", "force", "virial"] + else: + test_keys = ["energy", "force", "force_mag", "virial"] + + result_0 = eval_model( + self.model, + coord0.unsqueeze(0), + cell.unsqueeze(0), + atype, + spins=spin.unsqueeze(0), ) - ret0 = { - "energy": e0.squeeze(0), - "force": f0.squeeze(0), - "virial": v0.squeeze(0), - } - e1, f1, v1 = eval_model( - self.model, coord1.unsqueeze(0), cell.unsqueeze(0), atype + ret0 = {key: result_0[key].squeeze(0) for key in test_keys} + result_1 = eval_model( + self.model, + coord1.unsqueeze(0), + cell.unsqueeze(0), + atype, + spins=spin.unsqueeze(0), ) - ret1 = { - "energy": e1.squeeze(0), - "force": f1.squeeze(0), - "virial": v1.squeeze(0), - } - e2, f2, v2 = eval_model( - self.model, coord2.unsqueeze(0), cell.unsqueeze(0), atype + ret1 = {key: result_1[key].squeeze(0) for key in test_keys} + result_2 = eval_model( + self.model, + coord2.unsqueeze(0), + cell.unsqueeze(0), + atype, + spins=spin.unsqueeze(0), ) - ret2 = { - "energy": e2.squeeze(0), - "force": f2.squeeze(0), - "virial": v2.squeeze(0), - } - e3, f3, v3 = eval_model( - self.model, coord3.unsqueeze(0), cell.unsqueeze(0), atype + ret2 = {key: result_2[key].squeeze(0) for key in test_keys} + result_3 = eval_model( + self.model, + coord3.unsqueeze(0), + cell.unsqueeze(0), + atype, + spins=spin.unsqueeze(0), ) - ret3 = { - "energy": e3.squeeze(0), - "force": f3.squeeze(0), - "virial": v3.squeeze(0), - } + ret3 = {key: result_3[key].squeeze(0) for key in test_keys} def compare(ret0, ret1): - torch.testing.assert_close( - ret0["energy"], ret1["energy"], rtol=rprec, atol=aprec - ) - # plus 1. to avoid the divided-by-zero issue - torch.testing.assert_close( - 1.0 + ret0["force"], 1.0 + ret1["force"], rtol=rprec, atol=aprec - ) - if not hasattr(self, "test_virial") or self.test_virial: - torch.testing.assert_close( - 1.0 + ret0["virial"], 1.0 + ret1["virial"], rtol=rprec, atol=aprec - ) + for key in test_keys: + if key in ["energy"]: + torch.testing.assert_close( + ret0[key], ret1[key], rtol=rprec, atol=aprec + ) + elif key in ["force", "force_mag"]: + # plus 1. to avoid the divided-by-zero issue + torch.testing.assert_close( + 1.0 + ret0[key], 1.0 + ret1[key], rtol=rprec, atol=aprec + ) + elif key == "virial": + if not hasattr(self, "test_virial") or self.test_virial: + torch.testing.assert_close( + 1.0 + ret0[key], 1.0 + ret1[key], rtol=rprec, atol=aprec + ) + else: + raise RuntimeError(f"Unexpected test key {key}") compare(ret0, ret1) compare(ret1, ret2) @@ -207,7 +218,16 @@ class TestEnergyModelZBL(unittest.TestCase, SmoothTest): def setUp(self): model_params = copy.deepcopy(model_zbl) self.type_split = False - self.model = get_zbl_model(model_params).to(env.DEVICE) + self.model = get_model(model_params).to(env.DEVICE) + self.epsilon, self.aprec = 1e-10, None + + +class TestEnergyModelSpinSeA(unittest.TestCase, SmoothTest): + def setUp(self): + model_params = copy.deepcopy(model_spin) + self.type_split = False + self.test_spin = True + self.model = get_model(model_params).to(env.DEVICE) self.epsilon, self.aprec = None, None diff --git a/source/tests/pt/model/test_trans.py b/source/tests/pt/model/test_trans.py index b9affac3aa..359e91d8c8 100644 --- a/source/tests/pt/model/test_trans.py +++ b/source/tests/pt/model/test_trans.py @@ -9,7 +9,6 @@ ) from deepmd.pt.model.model import ( get_model, - get_zbl_model, ) from deepmd.pt.utils import ( env, @@ -20,6 +19,7 @@ model_dpa2, model_hybrid, model_se_e2_a, + model_spin, model_zbl, ) @@ -35,35 +35,45 @@ def test( cell = (cell + cell.T) + 5.0 * torch.eye(3, device=env.DEVICE) coord = torch.rand([natoms, 3], dtype=dtype, device=env.DEVICE) coord = torch.matmul(coord, cell) + spin = torch.rand([natoms, 3], dtype=dtype, device=env.DEVICE) atype = torch.tensor([0, 0, 0, 1, 1], dtype=torch.int32, device=env.DEVICE) shift = (torch.rand([3], dtype=dtype, device=env.DEVICE) - 0.5) * 2.0 coord_s = torch.matmul( torch.remainder(torch.matmul(coord + shift, torch.linalg.inv(cell)), 1.0), cell, ) - e0, f0, v0 = eval_model( - self.model, coord.unsqueeze(0), cell.unsqueeze(0), atype + test_spin = getattr(self, "test_spin", False) + if not test_spin: + test_keys = ["energy", "force", "virial"] + else: + test_keys = ["energy", "force", "force_mag", "virial"] + result_0 = eval_model( + self.model, + coord.unsqueeze(0), + cell.unsqueeze(0), + atype, + spins=spin.unsqueeze(0), ) - ret0 = { - "energy": e0.squeeze(0), - "force": f0.squeeze(0), - "virial": v0.squeeze(0), - } - e1, f1, v1 = eval_model( - self.model, coord_s.unsqueeze(0), cell.unsqueeze(0), atype + ret0 = {key: result_0[key].squeeze(0) for key in test_keys} + result_1 = eval_model( + self.model, + coord_s.unsqueeze(0), + cell.unsqueeze(0), + atype, + spins=spin.unsqueeze(0), ) - ret1 = { - "energy": e1.squeeze(0), - "force": f1.squeeze(0), - "virial": v1.squeeze(0), - } + ret1 = {key: result_1[key].squeeze(0) for key in test_keys} prec = 1e-10 - torch.testing.assert_close(ret0["energy"], ret1["energy"], rtol=prec, atol=prec) - torch.testing.assert_close(ret0["force"], ret1["force"], rtol=prec, atol=prec) - if not hasattr(self, "test_virial") or self.test_virial: - torch.testing.assert_close( - ret0["virial"], ret1["virial"], rtol=prec, atol=prec - ) + for key in test_keys: + if key in ["energy", "force", "force_mag"]: + torch.testing.assert_close(ret0[key], ret1[key], rtol=prec, atol=prec) + elif key == "virial": + if not hasattr(self, "test_virial") or self.test_virial: + torch.testing.assert_close( + ret0[key], ret1[key], rtol=prec, atol=prec + ) + else: + raise RuntimeError(f"Unexpected test key {key}") class TestEnergyModelSeA(unittest.TestCase, TransTest): @@ -130,7 +140,15 @@ class TestEnergyModelZBL(unittest.TestCase, TransTest): def setUp(self): model_params = copy.deepcopy(model_zbl) self.type_split = False - self.model = get_zbl_model(model_params).to(env.DEVICE) + self.model = get_model(model_params).to(env.DEVICE) + + +class TestEnergyModelSpinSeA(unittest.TestCase, TransTest): + def setUp(self): + model_params = copy.deepcopy(model_spin) + self.type_split = False + self.test_spin = True + self.model = get_model(model_params).to(env.DEVICE) if __name__ == "__main__": diff --git a/source/tests/pt/model/test_unused_params.py b/source/tests/pt/model/test_unused_params.py index 36080c2bbd..a3c93cbe68 100644 --- a/source/tests/pt/model/test_unused_params.py +++ b/source/tests/pt/model/test_unused_params.py @@ -64,14 +64,9 @@ def _test_unused(self, model_params): coord = torch.matmul(coord, cell) atype = torch.IntTensor([0, 0, 0, 1, 1]).to(env.DEVICE) idx_perm = [1, 0, 4, 3, 2] - e0, f0, v0 = eval_model( - self.model, coord.unsqueeze(0), cell.unsqueeze(0), atype - ) - ret0 = { - "energy": e0.squeeze(0), - "force": f0.squeeze(0), - "virial": v0.squeeze(0), - } + result_0 = eval_model(self.model, coord.unsqueeze(0), cell.unsqueeze(0), atype) + test_keys = ["energy", "force", "virial"] + ret0 = {key: result_0[key].squeeze(0) for key in test_keys} # use computation graph to find all contributing tensors def get_contributing_params(y, top_level=True): diff --git a/source/tests/pt/test_dp_test.py b/source/tests/pt/test_dp_test.py index 095994f8ec..271b8f1082 100644 --- a/source/tests/pt/test_dp_test.py +++ b/source/tests/pt/test_dp_test.py @@ -2,6 +2,7 @@ import json import os import shutil +import tempfile import unittest from copy import ( deepcopy, @@ -13,59 +14,130 @@ import numpy as np import torch +from deepmd.entrypoints.test import test as dp_test from deepmd.pt.entrypoints.main import ( get_trainer, ) -from deepmd.pt.infer import ( - inference, +from deepmd.pt.utils.utils import ( + to_numpy_array, ) +from .model.test_permutation import ( + model_se_e2_a, + model_spin, +) -class TestDPTest(unittest.TestCase): - def setUp(self): - input_json = str(Path(__file__).parent / "water/se_atten.json") - with open(input_json) as f: - self.config = json.load(f) - self.config["training"]["numb_steps"] = 1 - self.config["training"]["save_freq"] = 1 - data_file = [str(Path(__file__).parent / "water/data/data_0")] - self.config["training"]["training_data"]["systems"] = data_file - self.config["training"]["validation_data"]["systems"] = [ - str(Path(__file__).parent / "water/data/single") - ] - self.input_json = "test_dp_test.json" - with open(self.input_json, "w") as fp: - json.dump(self.config, fp, indent=4) - def test_dp_test(self): +class DPTest: + def test_dp_test_1_frame(self): trainer = get_trainer(deepcopy(self.config)) - trainer.run() - with torch.device("cpu"): input_dict, label_dict, _ = trainer.get_data(is_train=False) - _, _, more_loss = trainer.wrapper(**input_dict, label=label_dict, cur_lr=1.0) - - tester = inference.Tester("model.pt", input_script=self.input_json) - try: - res = tester.run() - except StopIteration: - raise StopIteration("Unexpected stop iteration.(test step < total batch)") - for k, v in res.items(): - if k == "rmse" or "mae" in k or k not in more_loss: - continue - np.testing.assert_allclose( - v, more_loss[k].cpu().detach().numpy(), rtol=1e-04, atol=1e-07 + has_spin = getattr(trainer.model, "has_spin", False) + if callable(has_spin): + has_spin = has_spin() + if not has_spin: + input_dict.pop("spin", None) + input_dict["do_atomic_virial"] = True + result = trainer.model(**input_dict) + model = torch.jit.script(trainer.model) + tmp_model = tempfile.NamedTemporaryFile(delete=False, suffix=".pth") + torch.jit.save(model, tmp_model.name) + dp_test( + model=tmp_model.name, + system=self.config["training"]["validation_data"]["systems"][0], + datafile=None, + set_prefix="set", + numb_test=0, + rand_seed=None, + shuffle_test=False, + detail_file=self.detail_file, + atomic=False, + ) + os.unlink(tmp_model.name) + natom = input_dict["atype"].shape[1] + pred_e = np.loadtxt(self.detail_file + ".e.out", ndmin=2)[0, 1] + np.testing.assert_almost_equal( + pred_e, + to_numpy_array(result["energy"])[0][0], + ) + pred_e_peratom = np.loadtxt(self.detail_file + ".e_peratom.out", ndmin=2)[0, 1] + np.testing.assert_almost_equal(pred_e_peratom, pred_e / natom) + if not has_spin: + pred_f = np.loadtxt(self.detail_file + ".f.out", ndmin=2)[:, 3:6] + np.testing.assert_almost_equal( + pred_f, + to_numpy_array(result["force"]).reshape(-1, 3), + ) + pred_v = np.loadtxt(self.detail_file + ".v.out", ndmin=2)[:, 9:18] + np.testing.assert_almost_equal( + pred_v, + to_numpy_array(result["virial"]), + ) + pred_v_peratom = np.loadtxt(self.detail_file + ".v_peratom.out", ndmin=2)[ + :, 9:18 + ] + np.testing.assert_almost_equal(pred_v_peratom, pred_v / natom) + else: + pred_fr = np.loadtxt(self.detail_file + ".fr.out", ndmin=2)[:, 3:6] + np.testing.assert_almost_equal( + pred_fr, + to_numpy_array(result["force"]).reshape(-1, 3), + ) + pred_fm = np.loadtxt(self.detail_file + ".fm.out", ndmin=2)[:, 3:6] + np.testing.assert_almost_equal( + pred_fm, + to_numpy_array( + result["force_mag"][result["mask_mag"].bool().squeeze(-1)] + ).reshape(-1, 3), ) def tearDown(self): for f in os.listdir("."): if f.startswith("model") and f.endswith(".pt"): os.remove(f) + if f.startswith(self.detail_file): + os.remove(f) if f in ["lcurve.out", self.input_json]: os.remove(f) if f in ["stat_files"]: shutil.rmtree(f) +class TestDPTestSeA(DPTest, unittest.TestCase): + def setUp(self): + self.detail_file = "test_dp_test_ener_detail" + input_json = str(Path(__file__).parent / "water/se_atten.json") + with open(input_json) as f: + self.config = json.load(f) + self.config["training"]["numb_steps"] = 1 + self.config["training"]["save_freq"] = 1 + data_file = [str(Path(__file__).parent / "water/data/single")] + self.config["training"]["training_data"]["systems"] = data_file + self.config["training"]["validation_data"]["systems"] = data_file + self.config["model"] = deepcopy(model_se_e2_a) + self.input_json = "test_dp_test.json" + with open(self.input_json, "w") as fp: + json.dump(self.config, fp, indent=4) + + +class TestDPTestSeASpin(DPTest, unittest.TestCase): + def setUp(self): + self.detail_file = "test_dp_test_ener_spin_detail" + input_json = str(Path(__file__).parent / "water/se_atten.json") + with open(input_json) as f: + self.config = json.load(f) + self.config["training"]["numb_steps"] = 1 + self.config["training"]["save_freq"] = 1 + data_file = [str(Path(__file__).parent / "NiO/data/single")] + self.config["training"]["training_data"]["systems"] = data_file + self.config["training"]["validation_data"]["systems"] = data_file + self.config["model"] = deepcopy(model_spin) + self.config["model"]["type_map"] = ["Ni", "O", "B"] + self.input_json = "test_dp_test.json" + with open(self.input_json, "w") as fp: + json.dump(self.config, fp, indent=4) + + if __name__ == "__main__": unittest.main() diff --git a/source/tests/pt/test_init_frz_model.py b/source/tests/pt/test_init_frz_model.py index d156eddc41..223b28515d 100644 --- a/source/tests/pt/test_init_frz_model.py +++ b/source/tests/pt/test_init_frz_model.py @@ -92,8 +92,10 @@ def test_dp_test(self): ).reshape(1, -1, 3) atype = np.array([0, 0, 0, 1, 1]).reshape(1, -1) - e1, f1, v1, ae1, av1 = dp1.eval(coord, cell, atype, atomic=True) - e2, f2, v2, ae2, av2 = dp2.eval(coord, cell, atype, atomic=True) + ret1 = dp1.eval(coord, cell, atype, atomic=True) + e1, f1, v1, ae1, av1 = ret1[0], ret1[1], ret1[2], ret1[3], ret1[4] + ret2 = dp2.eval(coord, cell, atype, atomic=True) + e2, f2, v2, ae2, av2 = ret2[0], ret2[1], ret2[2], ret2[3], ret2[4] np.testing.assert_allclose(e1, e2, rtol=1e-10, atol=1e-10) np.testing.assert_allclose(f1, f2, rtol=1e-10, atol=1e-10) np.testing.assert_allclose(v1, v2, rtol=1e-10, atol=1e-10) diff --git a/source/tests/pt/test_loss.py b/source/tests/pt/test_loss.py index 484d62a3ad..dddc9af219 100644 --- a/source/tests/pt/test_loss.py +++ b/source/tests/pt/test_loss.py @@ -1,5 +1,4 @@ # SPDX-License-Identifier: LGPL-3.0-or-later -import json import os import unittest @@ -8,22 +7,27 @@ import torch tf.disable_eager_execution() +from copy import ( + deepcopy, +) from pathlib import ( Path, ) from deepmd.pt.loss import ( + EnergySpinLoss, EnergyStdLoss, ) from deepmd.pt.utils.dataset import ( DeepmdDataSetForLoader, ) -from deepmd.tf.common import ( - expand_sys_str, -) from deepmd.tf.loss.ener import ( + EnerSpinLoss, EnerStdLoss, ) +from deepmd.utils.data import ( + DataRequirementItem, +) from .model.test_embedding_net import ( get_single_batch, @@ -35,28 +39,17 @@ CUR_DIR = os.path.dirname(__file__) -def get_batch(): - with open(str(Path(__file__).parent / "water/se_e2_a.json")) as fin: - content = fin.read() - config = json.loads(content) - data_file = [str(Path(__file__).parent / "water/data/data_0")] - config["training"]["training_data"]["systems"] = data_file - config["training"]["validation_data"]["systems"] = data_file - model_config = config["model"] - rcut = model_config["descriptor"]["rcut"] - # self.rcut_smth = model_config['descriptor']['rcut_smth'] - sel = model_config["descriptor"]["sel"] - systems = config["training"]["validation_data"]["systems"] - if isinstance(systems, str): - systems = expand_sys_str(systems) - dataset = DeepmdDataSetForLoader(systems[0], model_config["type_map"]) - dataset.add_data_requirement(energy_data_requirement) +def get_batch(system, type_map, data_requirement): + dataset = DeepmdDataSetForLoader(system, type_map) + dataset.add_data_requirement(data_requirement) np_batch, pt_batch = get_single_batch(dataset) return np_batch, pt_batch -class TestLearningRate(unittest.TestCase): +class TestEnerStdLoss(unittest.TestCase): def setUp(self): + self.system = str(Path(__file__).parent / "water/data/data_0") + self.type_map = ["H", "O"] self.start_lr = 1.1 self.start_pref_e = 0.02 self.limit_pref_e = 1.0 @@ -66,7 +59,9 @@ def setUp(self): self.limit_pref_v = 1.0 self.cur_lr = 1.2 # data - np_batch, pt_batch = get_batch() + np_batch, pt_batch = get_batch( + self.system, self.type_map, energy_data_requirement + ) natoms = np_batch["natoms"] self.nloc = natoms[0] l_energy, l_force, l_virial = ( @@ -177,8 +172,8 @@ def test_consistency(self): self.limit_pref_v, ) my_loss, my_more_loss = mine( - self.label, self.model_pred, + self.label, self.nloc, self.cur_lr, ) @@ -192,5 +187,179 @@ def test_consistency(self): ) +class TestEnerSpinLoss(unittest.TestCase): + def setUp(self): + self.system = str(Path(__file__).parent / "NiO/data/data_0") + self.type_map = ["Ni", "O"] + self.start_lr = 1.1 + self.start_pref_e = 0.02 + self.limit_pref_e = 1.0 + self.start_pref_fr = 1000.0 + self.limit_pref_fr = 1.0 + self.start_pref_fm = 1000.0 + self.limit_pref_fm = 1.0 + self.cur_lr = 1.2 + self.use_spin = [1, 0] + # data + spin_data_requirement = deepcopy(energy_data_requirement) + spin_data_requirement.append( + DataRequirementItem( + "force_mag", + ndof=3, + atomic=True, + must=False, + high_prec=False, + ) + ) + np_batch, pt_batch = get_batch( + self.system, self.type_map, spin_data_requirement + ) + natoms = np_batch["natoms"] + self.nloc = natoms[0] + nframes = np_batch["energy"].shape[0] + l_energy, l_force_real, l_force_mag, l_virial = ( + np_batch["energy"], + np_batch["force"], + np_batch["force_mag"], + np_batch["virial"], + ) + # merged force for tf old implement + l_force_merge_tf = np.concatenate( + [ + l_force_real.reshape(nframes, self.nloc, 3), + l_force_mag.reshape(nframes, self.nloc, 3)[ + np_batch["atype"] == 0 + ].reshape(nframes, -1, 3), + ], + axis=1, + ).reshape(nframes, -1) + p_energy, p_force_real, p_force_mag, p_force_merge_tf, p_virial = ( + np.ones_like(l_energy), + np.ones_like(l_force_real), + np.ones_like(l_force_mag), + np.ones_like(l_force_merge_tf), + np.ones_like(l_virial), + ) + virt_nloc = (np_batch["atype"] == 0).sum(-1) + natoms_tf = np.concatenate([natoms, virt_nloc], axis=0) + natoms_tf[:2] += virt_nloc + nloc = natoms_tf[0] + batch_size = pt_batch["coord"].shape[0] + atom_energy = np.zeros(shape=[batch_size, nloc]) + atom_pref = np.zeros(shape=[batch_size, nloc * 3]) + self.nloc_tf = nloc + # tf + base = EnerSpinLoss( + self.start_lr, + self.start_pref_e, + self.limit_pref_e, + self.start_pref_fr, + self.limit_pref_fr, + self.start_pref_fm, + self.limit_pref_fm, + use_spin=self.use_spin, + ) + self.g = tf.Graph() + with self.g.as_default(): + t_cur_lr = tf.placeholder(shape=[], dtype=tf.float64) + t_natoms = tf.placeholder(shape=[None], dtype=tf.int32) + t_penergy = tf.placeholder(shape=[None, 1], dtype=tf.float64) + t_pforce = tf.placeholder(shape=[None, None], dtype=tf.float64) + t_pvirial = tf.placeholder(shape=[None, 9], dtype=tf.float64) + t_patom_energy = tf.placeholder(shape=[None, None], dtype=tf.float64) + t_lenergy = tf.placeholder(shape=[None, 1], dtype=tf.float64) + t_lforce = tf.placeholder(shape=[None, None], dtype=tf.float64) + t_lvirial = tf.placeholder(shape=[None, 9], dtype=tf.float64) + t_latom_energy = tf.placeholder(shape=[None, None], dtype=tf.float64) + t_atom_pref = tf.placeholder(shape=[None, None], dtype=tf.float64) + find_energy = tf.constant(1.0, dtype=tf.float64) + find_force = tf.constant(1.0, dtype=tf.float64) + find_virial = tf.constant(0.0, dtype=tf.float64) + find_atom_energy = tf.constant(0.0, dtype=tf.float64) + find_atom_pref = tf.constant(0.0, dtype=tf.float64) + model_dict = { + "energy": t_penergy, + "force": t_pforce, + "virial": t_pvirial, + "atom_ener": t_patom_energy, + } + label_dict = { + "energy": t_lenergy, + "force": t_lforce, + "virial": t_lvirial, + "atom_ener": t_latom_energy, + "atom_pref": t_atom_pref, + "find_energy": find_energy, + "find_force": find_force, + "find_virial": find_virial, + "find_atom_ener": find_atom_energy, + "find_atom_pref": find_atom_pref, + } + self.base_loss_sess = base.build( + t_cur_lr, t_natoms, model_dict, label_dict, "" + ) + # torch + self.feed_dict = { + t_cur_lr: self.cur_lr, + t_natoms: natoms_tf, + t_penergy: p_energy, + t_pforce: p_force_merge_tf, + t_pvirial: p_virial.reshape(-1, 9), + t_patom_energy: atom_energy, + t_lenergy: l_energy, + t_lforce: l_force_merge_tf, + t_lvirial: l_virial.reshape(-1, 9), + t_latom_energy: atom_energy, + t_atom_pref: atom_pref, + } + self.model_pred = { + "energy": torch.from_numpy(p_energy), + "force": torch.from_numpy(p_force_real).reshape(nframes, self.nloc, 3), + "force_mag": torch.from_numpy(p_force_mag).reshape(nframes, self.nloc, 3), + "mask_mag": torch.from_numpy(np_batch["atype"] == 0).reshape( + nframes, self.nloc, 1 + ), + } + self.label = { + "energy": torch.from_numpy(l_energy), + "force": torch.from_numpy(l_force_real).reshape(nframes, self.nloc, 3), + "force_mag": torch.from_numpy(l_force_mag).reshape(nframes, self.nloc, 3), + } + self.natoms = pt_batch["natoms"] + + def tearDown(self) -> None: + tf.reset_default_graph() + return super().tearDown() + + def test_consistency(self): + with tf.Session(graph=self.g) as sess: + base_loss, base_more_loss = sess.run( + self.base_loss_sess, feed_dict=self.feed_dict + ) + mine = EnergySpinLoss( + self.start_lr, + self.start_pref_e, + self.limit_pref_e, + self.start_pref_fr, + self.limit_pref_fr, + self.start_pref_fm, + self.limit_pref_fm, + ) + my_loss, my_more_loss = mine( + self.model_pred, + self.label, + self.nloc_tf, # use tf natoms pref + self.cur_lr, + ) + my_loss = my_loss.detach().cpu() + self.assertTrue(np.allclose(base_loss, my_loss.numpy())) + for key in ["ener", "force_r", "force_m"]: + self.assertTrue( + np.allclose( + base_more_loss["l2_%s_loss" % key], my_more_loss["l2_%s_loss" % key] + ) + ) + + if __name__ == "__main__": unittest.main() diff --git a/source/tests/pt/test_stat.py b/source/tests/pt/test_stat.py index 3a09f82baf..e69caad502 100644 --- a/source/tests/pt/test_stat.py +++ b/source/tests/pt/test_stat.py @@ -180,7 +180,9 @@ def my_merge(energy, natoms): .unsqueeze(0) .expand(energy[i][j].shape[0], -1) ) - return energy_lst, natoms_lst + energy_merge = torch.cat(energy_lst) + natoms_merge = torch.cat(natoms_lst) + return energy_merge, natoms_merge energy = self.dp_sampled["energy"] natoms = self.dp_sampled["natoms_vec"] diff --git a/source/tests/tf/test_deeppot_a.py b/source/tests/tf/test_deeppot_a.py index 9b4d64282f..f40b57c213 100644 --- a/source/tests/tf/test_deeppot_a.py +++ b/source/tests/tf/test_deeppot_a.py @@ -804,7 +804,7 @@ def test_convert_012(self): convert_pbtxt_to_pb(str(infer_path / "sea_012.pbtxt"), old_model) run_dp(f"dp convert-from 0.12 -i {old_model} -o {new_model}") dp = DeepPot(new_model) - _, _, _, _, _ = dp.eval(self.coords, self.box, self.atype, atomic=True) + _ = dp.eval(self.coords, self.box, self.atype, atomic=True) os.remove(old_model) os.remove(new_model) @@ -814,7 +814,7 @@ def test_convert(self): convert_pbtxt_to_pb(str(infer_path / "sea_012.pbtxt"), old_model) run_dp(f"dp convert-from -i {old_model} -o {new_model}") dp = DeepPot(new_model) - _, _, _, _, _ = dp.eval(self.coords, self.box, self.atype, atomic=True) + _ = dp.eval(self.coords, self.box, self.atype, atomic=True) os.remove(old_model) os.remove(new_model)