diff --git a/deepmd/dpmodel/utils/__init__.py b/deepmd/dpmodel/utils/__init__.py index 50f3529c8e..60a4486d52 100644 --- a/deepmd/dpmodel/utils/__init__.py +++ b/deepmd/dpmodel/utils/__init__.py @@ -32,10 +32,6 @@ phys2inter, to_face_distance, ) -from .spin import ( - BaseSpin, - Spin, -) __all__ = [ "EnvMat", @@ -63,6 +59,4 @@ "to_face_distance", "AtomExcludeMask", "PairExcludeMask", - "BaseSpin", - "Spin", ] diff --git a/deepmd/pt/infer/deep_eval.py b/deepmd/pt/infer/deep_eval.py index f642d34d61..96989974f2 100644 --- a/deepmd/pt/infer/deep_eval.py +++ b/deepmd/pt/infer/deep_eval.py @@ -406,6 +406,7 @@ def eval_model( coords: Union[np.ndarray, torch.Tensor], cells: Optional[Union[np.ndarray, torch.Tensor]], atom_types: Union[np.ndarray, torch.Tensor, List[int]], + spins: Optional[Union[np.ndarray, torch.Tensor]] = None, atomic: bool = False, infer_batch_size: int = 2, denoise: bool = False, @@ -414,6 +415,8 @@ def eval_model( energy_out = [] atomic_energy_out = [] force_out = [] + force_real_out = [] + force_mag_out = [] virial_out = [] atomic_virial_out = [] updated_coord_out = [] @@ -426,11 +429,15 @@ def eval_model( if isinstance(coords, torch.Tensor): if cells is not None: assert isinstance(cells, torch.Tensor), err_msg + if spins is not None: + assert isinstance(spins, torch.Tensor), err_msg assert isinstance(atom_types, torch.Tensor) or isinstance(atom_types, list) atom_types = torch.tensor(atom_types, dtype=torch.long, device=DEVICE) elif isinstance(coords, np.ndarray): if cells is not None: assert isinstance(cells, np.ndarray), err_msg + if spins is not None: + assert isinstance(spins, np.ndarray), err_msg assert isinstance(atom_types, np.ndarray) or isinstance(atom_types, list) atom_types = np.array(atom_types, dtype=np.int32) return_tensor = False @@ -450,6 +457,13 @@ def eval_model( coord_input = torch.tensor( coords.reshape([-1, natoms, 3]), dtype=GLOBAL_PT_FLOAT_PRECISION, device=DEVICE ) + spin_input = None + if spins is not None: + spin_input = torch.tensor( + spins.reshape([-1, natoms, 3]), + dtype=GLOBAL_PT_FLOAT_PRECISION, + device=DEVICE, + ) type_input = torch.tensor(atom_types, dtype=torch.long, device=DEVICE) box_input = None if cells is None: @@ -465,9 +479,19 @@ def eval_model( batch_coord = coord_input[ii * infer_batch_size : (ii + 1) * infer_batch_size] batch_atype = type_input[ii * infer_batch_size : (ii + 1) * infer_batch_size] batch_box = None + batch_spin = None + if spin_input is not None: + batch_spin = spin_input[ii * infer_batch_size : (ii + 1) * infer_batch_size] if pbc: batch_box = box_input[ii * infer_batch_size : (ii + 1) * infer_batch_size] - batch_output = model(batch_coord, batch_atype, box=batch_box) + input_dict = { + "coord": batch_coord, + "atype": batch_atype, + "box": batch_box, + } + if getattr(model, "__USE_SPIN_INPUT__", False): + input_dict["spin"] = batch_spin + batch_output = model(**input_dict) if isinstance(batch_output, tuple): batch_output = batch_output[0] if not return_tensor: @@ -479,6 +503,10 @@ def eval_model( ) if "force" in batch_output: force_out.append(batch_output["force"].detach().cpu().numpy()) + if "force_real" in batch_output: + force_real_out.append(batch_output["force_real"].detach().cpu().numpy()) + if "force_mag" in batch_output: + force_mag_out.append(batch_output["force_mag"].detach().cpu().numpy()) if "virial" in batch_output: virial_out.append(batch_output["virial"].detach().cpu().numpy()) if "atom_virial" in batch_output: @@ -498,6 +526,10 @@ def eval_model( atomic_energy_out.append(batch_output["atom_energy"]) if "force" in batch_output: force_out.append(batch_output["force"]) + if "force_real" in batch_output: + force_real_out.append(batch_output["force_real"]) + if "force_mag" in batch_output: + force_mag_out.append(batch_output["force_mag"]) if "virial" in batch_output: virial_out.append(batch_output["virial"]) if "atom_virial" in batch_output: @@ -518,6 +550,16 @@ def eval_model( force_out = ( np.concatenate(force_out) if force_out else np.zeros([nframes, natoms, 3]) ) + force_real_out = ( + np.concatenate(force_real_out) + if force_real_out + else np.zeros([nframes, natoms, 3]) + ) + force_mag_out = ( + np.concatenate(force_mag_out) + if force_mag_out + else np.zeros([nframes, natoms, 3]) + ) virial_out = ( np.concatenate(virial_out) if virial_out else np.zeros([nframes, 3, 3]) ) @@ -552,6 +594,20 @@ def eval_model( [nframes, natoms, 3], dtype=GLOBAL_PT_FLOAT_PRECISION, device=DEVICE ) ) + force_real_out = ( + torch.cat(force_real_out) + if force_real_out + else torch.zeros( + [nframes, natoms, 3], dtype=GLOBAL_PT_FLOAT_PRECISION, device=DEVICE + ) + ) + force_mag_out = ( + torch.cat(force_mag_out) + if force_mag_out + else torch.zeros( + [nframes, natoms, 3], dtype=GLOBAL_PT_FLOAT_PRECISION, device=DEVICE + ) + ) virial_out = ( torch.cat(virial_out) if virial_out @@ -571,13 +627,16 @@ def eval_model( if denoise: return updated_coord_out, logits_out else: - if not atomic: - return energy_out, force_out, virial_out + results_dict = { + "energy": energy_out, + "virial": virial_out, + } + if not getattr(model, "__USE_SPIN_INPUT__", False): + results_dict["force"] = force_out else: - return ( - energy_out, - force_out, - virial_out, - atomic_energy_out, - atomic_virial_out, - ) + results_dict["force_real"] = force_real_out + results_dict["force_mag"] = force_mag_out + if atomic: + results_dict["atom_energy"] = atomic_energy_out + results_dict["atom_virial"] = atomic_virial_out + return results_dict diff --git a/deepmd/pt/loss/ener_spin.py b/deepmd/pt/loss/ener_spin.py index a368a53af3..ef0a4d9c1e 100644 --- a/deepmd/pt/loss/ener_spin.py +++ b/deepmd/pt/loss/ener_spin.py @@ -128,9 +128,9 @@ def forward(self, model_pred, label, natoms, learning_rate, mae=False): if self.has_fm and "force_mag" in model_pred and "force_mag" in label: nframes = model_pred["force_mag"].shape[0] - atmoic_mask = model_pred["atmoic_mask"].expand([-1, -1, 3]) - label_force_mag = label["force_mag"][atmoic_mask].view(nframes, -1, 3) - model_pred_force_mag = model_pred["force_mag"][atmoic_mask].view( + atomic_mask = model_pred["mask_mag"].expand([-1, -1, 3]) + label_force_mag = label["force_mag"][atomic_mask].view(nframes, -1, 3) + model_pred_force_mag = model_pred["force_mag"][atomic_mask].view( nframes, -1, 3 ) if not self.use_l1_all: diff --git a/deepmd/pt/model/descriptor/descriptor.py b/deepmd/pt/model/descriptor/descriptor.py index 16659e444d..edbe4a980c 100644 --- a/deepmd/pt/model/descriptor/descriptor.py +++ b/deepmd/pt/model/descriptor/descriptor.py @@ -189,6 +189,20 @@ def get_stats(self) -> Dict[str, StatItem]: """Get the statistics of the descriptor.""" raise NotImplementedError + def get_emask(self, nlist: torch.Tensor, atype: torch.Tensor) -> torch.Tensor: + """ + Compute the pair-wise type mask for given nlist and atype, + with shape same as nlist. + 1 for include and 0 for exclude. + """ + if hasattr(self, "emask"): + exclude_mask = self.emask(nlist, atype) + else: + exclude_mask = torch.ones_like( + nlist, dtype=torch.int32, device=nlist.device + ) + return exclude_mask + def share_params(self, base_class, shared_level, resume=False): assert ( self.__class__ == base_class.__class__ diff --git a/deepmd/pt/model/descriptor/se_a.py b/deepmd/pt/model/descriptor/se_a.py index 0ad1f86b8a..a3bad786ac 100644 --- a/deepmd/pt/model/descriptor/se_a.py +++ b/deepmd/pt/model/descriptor/se_a.py @@ -208,6 +208,7 @@ def serialize(self) -> dict: "embeddings": obj.filter_layers.serialize(), "env_mat": DPEnvMat(obj.rcut, obj.rcut_smth).serialize(), "exclude_types": obj.exclude_types, + "env_protection": obj.env_protection, "@variables": { "davg": obj["davg"].detach().cpu().numpy(), "dstd": obj["dstd"].detach().cpu().numpy(), diff --git a/deepmd/pt/model/model/__init__.py b/deepmd/pt/model/model/__init__.py index 677190471a..2da2302e27 100644 --- a/deepmd/pt/model/model/__init__.py +++ b/deepmd/pt/model/model/__init__.py @@ -21,17 +21,13 @@ from deepmd.pt.model.task import ( Fitting, ) -from deepmd.pt.utils.spin import ( +from deepmd.utils.spin import ( Spin, ) from .dp_model import ( DPModel, ) -from .dp_spin_model import ( - SpinEnergyModel, - SpinModel, -) from .dp_zbl_model import ( DPZBLModel, ) @@ -47,6 +43,10 @@ from .model import ( BaseModel, ) +from .spin_model import ( + SpinEnergyModel, + SpinModel, +) def get_zbl_model(model_params): diff --git a/deepmd/pt/model/model/dp_zbl_model.py b/deepmd/pt/model/model/dp_zbl_model.py index 4683f62466..47169d8b6e 100644 --- a/deepmd/pt/model/model/dp_zbl_model.py +++ b/deepmd/pt/model/model/dp_zbl_model.py @@ -82,13 +82,14 @@ def forward_lower( model_predict["atom_energy"] = model_ret["energy"] model_predict["energy"] = model_ret["energy_redu"] if self.do_grad_r("energy"): - model_predict["force"] = model_ret["energy_derv_r"].squeeze(-2) + model_predict["extended_force"] = model_ret["energy_derv_r"].squeeze(-2) if self.do_grad_c("energy"): model_predict["virial"] = model_ret["energy_derv_c_redu"].squeeze(-2) if do_atomic_virial: - model_predict["atom_virial"] = model_ret["energy_derv_c"].squeeze(-3) + model_predict["extended_virial"] = model_ret["energy_derv_c"].squeeze( + -3 + ) else: assert model_ret["dforce"] is not None model_predict["dforce"] = model_ret["dforce"] - model_predict = model_ret return model_predict diff --git a/deepmd/pt/model/model/dp_spin_model.py b/deepmd/pt/model/model/spin_model.py similarity index 58% rename from deepmd/pt/model/model/dp_spin_model.py rename to deepmd/pt/model/model/spin_model.py index 7b3584aa4a..9757c21692 100644 --- a/deepmd/pt/model/model/dp_spin_model.py +++ b/deepmd/pt/model/model/spin_model.py @@ -8,10 +8,18 @@ from deepmd.pt.utils.utils import ( dict_to_device, + to_torch_tensor, ) from deepmd.utils.path import ( DPPath, ) +from deepmd.utils.spin import ( + Spin, +) + +from .dp_model import ( + DPModel, +) class SpinModel(torch.nn.Module): @@ -22,34 +30,110 @@ class SpinModel(torch.nn.Module): def __init__( self, backbone_model, - spin, + spin: Spin, ): super().__init__() self.backbone_model = backbone_model self.spin = spin + self.ntypes_real = self.spin.ntypes_real + self.virtual_scale_mask = to_torch_tensor(self.spin.get_virtual_scale_mask()) + self.spin_mask = to_torch_tensor(self.spin.get_spin_mask()) def process_spin_input(self, coord, atype, spin): """Generate virtual coordinates and types, concat into the input.""" - nframes, natom = coord.shape[:-1] - atype_spin = torch.concat([atype, atype + self.spin.ntypes_real], dim=-1) - virtual_scale_mask = self.spin.get_virtual_scale_mask() - virtual_coord = coord + spin * virtual_scale_mask[atype].reshape( - [nframes, natom, 1] + nframes, nloc = coord.shape[:-1] + atype_spin = torch.concat([atype, atype + self.ntypes_real], dim=-1) + virtual_coord = coord + spin * self.virtual_scale_mask[atype].reshape( + [nframes, nloc, 1] ) coord_spin = torch.concat([coord, virtual_coord], dim=-2) return coord_spin, atype_spin - def process_spin_output(self, atype, force): - """Split the output gradient of both real and virtual atoms, and scale the latter.""" - nframes, natom_double = force.shape[:2] - natom = natom_double // 2 - virtual_scale_mask = self.spin.get_virtual_scale_mask() - atmoic_mask = virtual_scale_mask[atype].reshape([nframes, natom, 1]) - force_real, force_mag = torch.split(force, [natom, natom], dim=1) - force_mag = (force_mag.view([nframes, natom, -1]) * atmoic_mask).view( - force_mag.shape + def process_spin_input_lower( + self, + extended_coord, + extended_atype, + extended_spin, + nlist, + mapping: Optional[torch.Tensor] = None, + ): + """ + Add `extended_spin` into `extended_coord` to generate virtual atoms, and extend `nlist` and `mapping`. + Note that the final `extended_coord_updated` with shape [nframes, nall + nall, 3] has the following order: + - [:, :nloc]: original nloc real atoms. + - [:, nloc: nloc + nloc]: virtual atoms corresponding to nloc real atoms. + - [:, nloc + nloc: nloc + nall]: ghost real atoms. + - [:, nloc + nall: nall + nall]: virtual atoms corresponding to ghost real atoms. + """ + nframes, nall = extended_coord.shape[:2] + nloc = nlist.shape[1] + virtual_extended_coord = ( + extended_coord + + extended_spin + * self.virtual_scale_mask[extended_atype].reshape([nframes, nall, 1]) + ) + virtual_extended_atype = extended_atype + self.ntypes_real + extended_coord_updated = self.concat_switch_virtual( + extended_coord, virtual_extended_coord, nloc + ) + extended_atype_updated = self.concat_switch_virtual( + extended_atype, virtual_extended_atype, nloc + ) + mapping_updated = None + if mapping is not None: + virtual_mapping = mapping + nloc + mapping_updated = self.concat_switch_virtual(mapping, virtual_mapping, nloc) + # extend the nlist + nlist_updated = self.extend_nlist(extended_atype, nlist) + return ( + extended_coord_updated, + extended_atype_updated, + nlist_updated, + mapping_updated, + ) + + def process_spin_output(self, atype, out_tensor, virtual_scale: bool = True): + """Split the output both real and virtual atoms, and scale the latter.""" + nframes, nloc_double = out_tensor.shape[:2] + nloc = nloc_double // 2 + if virtual_scale: + virtual_scale_mask = self.virtual_scale_mask + else: + virtual_scale_mask = self.spin_mask + atomic_mask = virtual_scale_mask[atype].reshape([nframes, nloc, 1]) + out_real, out_mag = torch.split(out_tensor, [nloc, nloc], dim=1) + out_mag = (out_mag.view([nframes, nloc, -1]) * atomic_mask).view(out_mag.shape) + return out_real, out_mag, atomic_mask > 0.0 + + def process_spin_output_lower( + self, extended_atype, extended_out_tensor, nloc: int, virtual_scale: bool = True + ): + """Split the extended output of both real and virtual atoms with switch, and scale the latter.""" + nframes, nall_double = extended_out_tensor.shape[:2] + nall = nall_double // 2 + if virtual_scale: + virtual_scale_mask = self.virtual_scale_mask + else: + virtual_scale_mask = self.spin_mask + atomic_mask = virtual_scale_mask[extended_atype].reshape([nframes, nall, 1]) + extended_out_real = torch.cat( + [ + extended_out_tensor[:, :nloc], + extended_out_tensor[:, nloc + nloc : nloc + nall], + ], + dim=1, + ) + extended_out_mag = torch.cat( + [ + extended_out_tensor[:, nloc : nloc + nloc], + extended_out_tensor[:, nloc + nall :], + ], + dim=1, ) - return force_real, force_mag, atmoic_mask > 0.0 + extended_out_mag = ( + extended_out_mag.view([nframes, nall, -1]) * atomic_mask + ).view(extended_out_mag.shape) + return extended_out_real, extended_out_mag, atomic_mask > 0.0 @staticmethod def extend_nlist(extended_atype, nlist): @@ -69,75 +153,35 @@ def extend_nlist(extended_atype, nlist): extended_nlist = torch.cat( [extended_nlist, -1 * torch.ones_like(extended_nlist)], dim=-2 ) + # update the index for switch + first_part_index = (nloc <= extended_nlist) & (extended_nlist < nall) + second_part_index = (nall <= extended_nlist) & (extended_nlist < (nall + nloc)) + extended_nlist[first_part_index] += nloc + extended_nlist[second_part_index] -= nall - nloc return extended_nlist @staticmethod - def extend_mapping(mapping, nloc: int): - return torch.cat([mapping, mapping + nloc], dim=-1) - - @staticmethod - def switch_virtual_loc(extended_tensor, nloc: int): - """ - Switch the virtual atoms of nloc ones from [nall: nall+nloc] to [nloc: nloc+nloc], - to assure the atom types of first nloc * 2 atoms in nall * 2 to be right. - """ - nframes, nall_double = extended_tensor.shape[:2] - nall = nall_double // 2 - swithed_tensor = torch.zeros_like(extended_tensor) - swithed_tensor[:, :nloc] = extended_tensor[:, :nloc] - swithed_tensor[:, nloc : nloc + nloc] = extended_tensor[:, nall : nall + nloc] - swithed_tensor[:, nloc + nloc : nloc + nall] = extended_tensor[:, nloc:nall] - swithed_tensor[:, nloc + nall :] = extended_tensor[:, nloc + nall :] - return swithed_tensor - - @staticmethod - def switch_nlist(nlist_updated, nall: int): - nframes, nloc_double = nlist_updated.shape[:2] - nloc = nloc_double // 2 - first_part_index = (nloc <= nlist_updated) & (nlist_updated < nall) - second_part_index = (nall <= nlist_updated) & (nlist_updated < (nall + nloc)) - nlist_updated[first_part_index] += nloc - nlist_updated[second_part_index] -= nall - nloc - return nlist_updated - - def extend_switch_input( - self, - extended_coord, - extended_atype, - extended_spin, - nlist, - mapping: Optional[torch.Tensor] = None, - ): + def concat_switch_virtual(extended_tensor, extended_tensor_virtual, nloc: int): """ - Add `extended_spin` into `extended_coord` to generate virtual atoms, and extend `nlist` and `mapping`. - Note that the final `extended_coord_updated` with shape [nframes, nall + nall, 3] has the following order: + Concat real and virtual extended tensors, and switch all the local ones to the first nloc * 2 atoms. - [:, :nloc]: original nloc real atoms. - [:, nloc: nloc + nloc]: virtual atoms corresponding to nloc real atoms. - [:, nloc + nloc: nloc + nall]: ghost real atoms. - [:, nloc + nall: nall + nall]: virtual atoms corresponding to ghost real atoms. """ - nframes, nall = extended_coord.shape[:2] - nloc = nlist.shape[1] - # add spin but ignore the index switch - extended_coord_updated, extended_atype_updated = self.process_spin_input( - extended_coord, extended_atype, extended_spin - ) - # extend the nlist and mapping but ignore the index switch - nlist_updated = self.extend_nlist(extended_atype, nlist) - mapping_updated = None - if mapping is not None: - mapping_updated = self.extend_mapping(mapping, nloc) - # process the index switch - extended_coord_updated = self.switch_virtual_loc(extended_coord_updated, nloc) - extended_atype_updated = self.switch_virtual_loc(extended_atype_updated, nloc) - mapping_updated = self.switch_virtual_loc(mapping_updated, nloc) - nlist_updated = self.switch_nlist(nlist_updated, nall) - return ( - extended_coord_updated, - extended_atype_updated, - nlist_updated, - mapping_updated, + nframes, nall = extended_tensor.shape[:2] + extended_atype_updated = torch.zeros( + [nframes, nall * 2, *extended_tensor.shape[2:]], + dtype=extended_tensor.dtype, + device=extended_tensor.device, ) + extended_atype_updated[:, :nloc] = extended_tensor[:, :nloc] + extended_atype_updated[:, nloc : nloc + nloc] = extended_tensor_virtual[ + :, :nloc + ] + extended_atype_updated[:, nloc + nloc : nloc + nall] = extended_tensor[:, nloc:] + extended_atype_updated[:, nloc + nall :] = extended_tensor_virtual[:, nloc:] + return extended_atype_updated def __getattr__(self, name): """Get attribute from the wrapped model.""" @@ -201,6 +245,7 @@ def forward_common( aparam: Optional[torch.Tensor] = None, do_atomic_virial: bool = False, ) -> Dict[str, torch.Tensor]: + nframes, nloc = coord.shape[:2] coord_updated, atype_updated = self.process_spin_input(coord, atype, spin) model_ret = self.backbone_model.forward_common( coord_updated, @@ -210,22 +255,33 @@ def forward_common( aparam=aparam, do_atomic_virial=do_atomic_virial, ) - if self.fitting_net is not None: - var_name = self.fitting_net.var_name - if self.do_grad_r(var_name): + if self.backbone_model.fitting_net is not None: + var_name = self.backbone_model.fitting_net.var_name + model_ret[f"{var_name}"] = torch.split( + model_ret[f"{var_name}"], [nloc, nloc], dim=1 + )[0] + if self.backbone_model.do_grad_r(var_name): force_all = model_ret[f"{var_name}_derv_r"] ( - model_ret[f"{var_name}_derv_r_real"], + model_ret[f"{var_name}_derv_r"], model_ret[f"{var_name}_derv_r_mag"], - model_ret["atmoic_mask"], + model_ret["mask_mag"], ) = self.process_spin_output(atype, force_all) else: force_all = model_ret["dforce"] ( model_ret["dforce_real"], model_ret["dforce_mag"], - model_ret["atmoic_mask"], + model_ret["mask_mag"], ) = self.process_spin_output(atype, force_all) + if self.backbone_model.do_grad_c(var_name) and do_atomic_virial: + ( + model_ret[f"{var_name}_derv_c"], + model_ret[f"{var_name}_derv_c_mag"], + model_ret["mask_mag"], + ) = self.process_spin_output( + atype, model_ret[f"{var_name}_derv_c"], virtual_scale=False + ) return model_ret def forward_common_lower( @@ -239,12 +295,13 @@ def forward_common_lower( aparam: Optional[torch.Tensor] = None, do_atomic_virial: bool = False, ): + nframes, nloc = nlist.shape[:2] ( extended_coord_updated, extended_atype_updated, nlist_updated, mapping_updated, - ) = self.extend_switch_input( + ) = self.process_spin_input_lower( extended_coord, extended_atype, extended_spin, nlist, mapping=mapping ) model_ret = self.backbone_model.forward_common_lower( @@ -256,24 +313,53 @@ def forward_common_lower( aparam=aparam, do_atomic_virial=do_atomic_virial, ) - if self.fitting_net is not None: - var_name = self.fitting_net.var_name - if self.do_grad_r(var_name): + if self.backbone_model.fitting_net is not None: + var_name = self.backbone_model.fitting_net.var_name + model_ret[f"{var_name}"] = torch.split( + model_ret[f"{var_name}"], [nloc, nloc], dim=1 + )[0] + if self.backbone_model.do_grad_r(var_name): force_all = model_ret[f"{var_name}_derv_r"] ( - model_ret[f"{var_name}_derv_r_real"], + model_ret[f"{var_name}_derv_r"], model_ret[f"{var_name}_derv_r_mag"], - model_ret["atmoic_mask"], - ) = self.process_spin_output(extended_atype, force_all) + model_ret["mask_mag"], + ) = self.process_spin_output_lower(extended_atype, force_all, nloc) else: force_all = model_ret["dforce"] ( model_ret["dforce_real"], model_ret["dforce_mag"], - model_ret["atmoic_mask"], - ) = self.process_spin_output(extended_atype, force_all) + model_ret["mask_mag"], + ) = self.process_spin_output_lower(extended_atype, force_all, nloc) + if self.backbone_model.do_grad_c(var_name) and do_atomic_virial: + ( + model_ret[f"{var_name}_derv_c"], + model_ret[f"{var_name}_derv_c_mag"], + model_ret["mask_mag"], + ) = self.process_spin_output_lower( + extended_atype, + model_ret[f"{var_name}_derv_c"], + nloc, + virtual_scale=False, + ) return model_ret + def serialize(self) -> dict: + return { + "backbone_model": self.backbone_model.serialize(), + "spin": self.spin.serialize(), + } + + @classmethod + def deserialize(cls, data) -> "SpinModel": + backbone_model_obj = DPModel.deserialize(data["backbone_model"]) + spin = Spin.deserialize(data["spin"]) + return cls( + backbone_model=backbone_model_obj, + spin=spin, + ) + class SpinEnergyModel(SpinModel): """A spin model for energy.""" @@ -283,7 +369,7 @@ class SpinEnergyModel(SpinModel): def __init__( self, backbone_model, - spin, + spin: Spin, ): super().__init__(backbone_model, spin) @@ -309,11 +395,11 @@ def forward( model_predict = {} model_predict["atom_energy"] = model_ret["energy"] model_predict["energy"] = model_ret["energy_redu"] - model_predict["atmoic_mask"] = model_ret["atmoic_mask"] - if self.do_grad_r("energy"): - model_predict["force_real"] = model_ret["energy_derv_r_real"].squeeze(-2) + model_predict["mask_mag"] = model_ret["mask_mag"] + if self.backbone_model.do_grad_r("energy"): + model_predict["force_real"] = model_ret["energy_derv_r"].squeeze(-2) model_predict["force_mag"] = model_ret["energy_derv_r_mag"].squeeze(-2) - if self.do_grad_c("energy"): + if self.backbone_model.do_grad_c("energy"): model_predict["virial"] = model_ret["energy_derv_c_redu"].squeeze(-2) if do_atomic_virial: model_predict["atom_virial"] = model_ret["energy_derv_c"].squeeze(-3) @@ -345,18 +431,19 @@ def forward_lower( aparam=aparam, do_atomic_virial=do_atomic_virial, ) - if self.fitting_net is not None: + if self.backbone_model.fitting_net is not None: model_predict = {} model_predict["atom_energy"] = model_ret["energy"] model_predict["energy"] = model_ret["energy_redu"] - if self.do_grad_r("energy"): + model_predict["mask_mag"] = model_ret["mask_mag"] + if self.backbone_model.do_grad_r("energy"): model_predict["extended_force_real"] = model_ret[ - "energy_derv_r_real" + "energy_derv_r" ].squeeze(-2) model_predict["extended_force_mag"] = model_ret[ "energy_derv_r_mag" ].squeeze(-2) - if self.do_grad_c("energy"): + if self.backbone_model.do_grad_c("energy"): model_predict["virial"] = model_ret["energy_derv_c_redu"].squeeze(-2) if do_atomic_virial: model_predict["extended_virial"] = model_ret[ diff --git a/deepmd/pt/model/task/ener.py b/deepmd/pt/model/task/ener.py index 1be9018a56..309da88fde 100644 --- a/deepmd/pt/model/task/ener.py +++ b/deepmd/pt/model/task/ener.py @@ -148,12 +148,7 @@ def compute_output_stats(self, merged, stat_file_path: Optional[DPPath] = None): if stat_file_path is not None and stat_file_path.is_file(): bias_atom_e = stat_file_path.load_numpy() else: - if hasattr(self, "emask"): - type_mask = self.emask( - torch.arange(0, self.ntypes, device=env.DEVICE).unsqueeze(0) - ) - else: - type_mask = None + type_mask = self.get_emask bias_atom_e = compute_output_bias( energy, input_natoms, rcond=self.rcond, type_mask=type_mask ) diff --git a/deepmd/pt/model/task/fitting.py b/deepmd/pt/model/task/fitting.py index cd2ffbc5c0..261503db97 100644 --- a/deepmd/pt/model/task/fitting.py +++ b/deepmd/pt/model/task/fitting.py @@ -321,6 +321,9 @@ def __init__( self.exclude_types = exclude_types self.emask = AtomExcludeMask(self.ntypes, self.exclude_types) + self.type_mask = self.emask( + torch.arange(0, self.ntypes, device=device).unsqueeze(0) + ) net_dim_out = self._net_out_dim() # init constants @@ -495,6 +498,14 @@ def __getitem__(self, key): else: raise KeyError(key) + @property + def get_emask(self): + """ + Compute the atom-wise type mask for each type with shape [ntypes]. + 1 for include and 0 for exclude. + """ + return self.type_mask + @abstractmethod def _net_out_dim(self): """Set the FittingNet output dim.""" diff --git a/deepmd/pt/utils/env_mat_stat.py b/deepmd/pt/utils/env_mat_stat.py index c69cdc11bc..08365ad7c9 100644 --- a/deepmd/pt/utils/env_mat_stat.py +++ b/deepmd/pt/utils/env_mat_stat.py @@ -133,14 +133,9 @@ def iter( env_mat = env_mat.view( coord.shape[0] * coord.shape[1], self.descriptor.get_nsel(), 4 ) - if hasattr(self.descriptor, "emask"): - exclude_mask = self.descriptor.emask(nlist, extended_atype).view( - coord.shape[0] * coord.shape[1], -1 - ) - else: - exclude_mask = torch.ones_like( - nlist, dtype=torch.int32, device=nlist.device - ).view(coord.shape[0] * coord.shape[1], -1) + exclude_mask = self.descriptor.get_emask(nlist, extended_atype).view( + coord.shape[0] * coord.shape[1], -1 + ) atype = atype.view(coord.shape[0] * coord.shape[1]) # (1, nloc) eq (ntypes, 1), so broadcast is possible # shape: (ntypes, nloc) diff --git a/deepmd/pt/utils/spin.py b/deepmd/pt/utils/spin.py deleted file mode 100644 index 895bc3b2bc..0000000000 --- a/deepmd/pt/utils/spin.py +++ /dev/null @@ -1,35 +0,0 @@ -# SPDX-License-Identifier: LGPL-3.0-or-later -from deepmd.dpmodel.utils import BaseSpin as DPBaseSpin -from deepmd.pt.utils.utils import ( - to_torch_tensor, -) - - -class Spin(DPBaseSpin): - def __init__(self, *args, **kwargs): - super().__init__(*args, **kwargs) - self.virtual_scale_mask = to_torch_tensor( - self.virtual_scale * self.use_spin - ).view([-1]) - self.spin_mask = to_torch_tensor(self.spin_mask) - - def get_virtual_scale_mask(self): - return self.virtual_scale_mask - - def get_spin_mask(self): - return self.spin_mask - - def serialize( - self, - ) -> dict: - return { - "use_spin": self.use_spin, - "virtual_scale": self.virtual_scale, - } - - @classmethod - def deserialize( - cls, - data: dict, - ) -> "Spin": - return cls(**data) diff --git a/deepmd/dpmodel/utils/spin.py b/deepmd/utils/spin.py similarity index 91% rename from deepmd/dpmodel/utils/spin.py rename to deepmd/utils/spin.py index 69da9a51d8..e1f40bc3b9 100644 --- a/deepmd/dpmodel/utils/spin.py +++ b/deepmd/utils/spin.py @@ -1,9 +1,5 @@ # SPDX-License-Identifier: LGPL-3.0-or-later import copy -from abc import ( - ABC, - abstractmethod, -) from typing import ( List, Tuple, @@ -13,8 +9,8 @@ import numpy as np -class BaseSpin(ABC): - """Abstract class for spin, mainly processes the spin type-related information. +class Spin: + """Class for spin, mainly processes the spin type-related information. Atom types can be split into three kinds: 1. Real types: real atom species, "Fe", "H", "O", etc. 2. Spin types: atom species with spin, as virtual atoms in input, "Fe_spin", etc. @@ -64,11 +60,17 @@ def __init__( self.virtual_scale = virtual_scale + [ 0.0 for _ in range(self.ntypes_real - self.ntypes_spin) ] + else: + raise ValueError( + f"Invalid length of virtual_scale for spin atoms" + f": Expected {self.ntypes_real} or { self.ntypes_spin} but got {len(virtual_scale)}!" + ) elif isinstance(virtual_scale, float): self.virtual_scale = [virtual_scale for _ in range(self.ntypes_real)] else: raise ValueError(f"Invalid virtual scale type: {type(virtual_scale)}") self.virtual_scale = np.array(self.virtual_scale) + self.virtual_scale_mask = (self.virtual_scale * self.use_spin).reshape([-1]) self.pair_exclude_types = [] self.init_pair_exclude_types_placeholder() self.atom_exclude_types_ps = [] @@ -168,46 +170,26 @@ def get_atom_exclude_types_placeholder(self, exclude_types=None) -> List[int]: _exclude_types = list(set(_exclude_types)) return _exclude_types - @abstractmethod - def serialize( - self, - ) -> dict: - pass - - @classmethod - @abstractmethod - def deserialize( - cls, - data: dict, - ) -> "BaseSpin": - pass - - @abstractmethod - def get_virtual_scale_mask(self): - pass - - @abstractmethod def get_spin_mask(self): - pass - - -class Spin(BaseSpin): - def __init__(self, *args, **kwargs): - super().__init__(*args, **kwargs) - self.virtual_scale_mask = (self.virtual_scale * self.use_spin).reshape([-1]) + """ + Return the spin mask of shape [ntypes], + with spin types being 1, and non-spin types being 0. + """ + return self.spin_mask def get_virtual_scale_mask(self): + """ + Return the virtual scale mask of shape [ntypes], + with spin types being its virtual scale, and non-spin types being 0. + """ return self.virtual_scale_mask - def get_spin_mask(self): - return self.spin_mask - def serialize( self, ) -> dict: return { - "use_spin": self.use_spin, - "virtual_scale": self.virtual_scale, + "use_spin": self.use_spin.tolist(), + "virtual_scale": self.virtual_scale.tolist(), } @classmethod diff --git a/source/tests/pt/model/test_autodiff.py b/source/tests/pt/model/test_autodiff.py index c83c6246f5..ac7d478b57 100644 --- a/source/tests/pt/model/test_autodiff.py +++ b/source/tests/pt/model/test_autodiff.py @@ -19,6 +19,7 @@ model_dpa1, model_dpa2, model_se_e2_a, + model_spin, model_zbl, ) @@ -46,6 +47,23 @@ def stretch_box(old_coord, old_box, new_box): return ncoord.reshape(old_coord.shape) +def fix_virtual(spin, coord, new_coord, atype, model, protection=1e-8): + """ + Fix the virtual atom when doing perturbations on the real atom. + The corresponding spin will be updated to assure this. + """ + if not getattr(model, "__USE_SPIN_INPUT__", False): + return spin + else: + spin_compensation = (coord - new_coord) / ( + model.spin.get_virtual_scale_mask()[atype] + protection + ).reshape(-1, 1) + spin_compensation = spin_compensation * model.spin.get_spin_mask()[ + atype + ].reshape(-1, 1) + return spin + spin_compensation + + class ForceTest: def test( self, @@ -57,34 +75,71 @@ def test( cell = (cell + cell.T) + 5.0 * torch.eye(3, device="cpu") coord = torch.rand([natoms, 3], dtype=dtype, device="cpu") coord = torch.matmul(coord, cell) + spin = torch.rand([natoms, 3], dtype=dtype, device="cpu") atype = torch.IntTensor([0, 0, 0, 1, 1]) # assumes input to be numpy tensor coord = coord.numpy() + spin = spin.numpy() + test_spin = getattr(self, "test_spin", False) + if not test_spin: + test_keys = ["energy", "force", "virial"] + else: + test_keys = ["energy", "force_real", "force_mag", "virial"] + + def np_infer_coord( + new_coord, + ): + result = eval_model( + self.model, + torch.tensor(new_coord, device=env.DEVICE).unsqueeze(0), + cell.unsqueeze(0), + atype, + spins=torch.tensor( + fix_virtual(spin, coord, new_coord, atype, self.model), + device=env.DEVICE, + ).unsqueeze(0), + ) + # detach + ret = { + key: result[key].squeeze(0).detach().cpu().numpy() for key in test_keys + } + return ret - def np_infer( - coord, + def np_infer_spin( + spin, ): - e0, f0, v0 = eval_model( + result = eval_model( self.model, torch.tensor(coord, device=env.DEVICE).unsqueeze(0), cell.unsqueeze(0), atype, + spins=torch.tensor(spin, device=env.DEVICE).unsqueeze(0), ) + # detach ret = { - "energy": e0.squeeze(0), - "force": f0.squeeze(0), - "virial": v0.squeeze(0), + key: result[key].squeeze(0).detach().cpu().numpy() for key in test_keys } - # detach - ret = {kk: ret[kk].detach().cpu().numpy() for kk in ret} return ret - def ff(_coord): - return np_infer(_coord)["energy"] + def ff_coord(_coord): + return np_infer_coord(_coord)["energy"] + + def ff_spin(_spin): + return np_infer_spin(_spin)["energy"] - fdf = -finite_difference(ff, coord, delta=delta).squeeze() - rff = np_infer(coord)["force"] - np.testing.assert_almost_equal(fdf, rff, decimal=places) + if not test_spin: + fdf = -finite_difference(ff_coord, coord, delta=delta).squeeze() + rff = np_infer_coord(coord)["force"] + np.testing.assert_almost_equal(fdf, rff, decimal=places) + else: + # real force + fdf = -finite_difference(ff_coord, coord, delta=delta).squeeze() + rff = np_infer_coord(coord)["force_real"] + np.testing.assert_almost_equal(fdf, rff, decimal=places) + # magnetic force + fdf = -finite_difference(ff_spin, spin, delta=delta).squeeze() + rff = np_infer_spin(spin)["force_mag"] + np.testing.assert_almost_equal(fdf, rff, decimal=places) class VirialTest: @@ -98,29 +153,37 @@ def test( cell = (cell) + 5.0 * torch.eye(3, device="cpu") coord = torch.rand([natoms, 3], dtype=dtype, device="cpu") coord = torch.matmul(coord, cell) + spin = torch.rand([natoms, 3], dtype=dtype, device="cpu") atype = torch.IntTensor([0, 0, 0, 1, 1]) # assumes input to be numpy tensor coord = coord.numpy() cell = cell.numpy() + spin = spin.numpy() + test_spin = getattr(self, "test_spin", False) + if not test_spin: + test_keys = ["energy", "force", "virial"] + else: + test_keys = ["energy", "force_real", "force_mag", "virial"] def np_infer( new_cell, ): - e0, f0, v0 = eval_model( + result = eval_model( self.model, torch.tensor( stretch_box(coord, cell, new_cell), device="cpu" ).unsqueeze(0), torch.tensor(new_cell, device="cpu").unsqueeze(0), atype, + spins=torch.tensor( + stretch_box(spin, cell, new_cell), device="cpu" + ).unsqueeze(0), ) + # detach ret = { - "energy": e0.squeeze(0), - "force": f0.squeeze(0), - "virial": v0.squeeze(0), + key: result[key].squeeze(0).detach().cpu().numpy() for key in test_keys } # detach - ret = {kk: ret[kk].detach().cpu().numpy() for kk in ret} return ret def ff(bb): @@ -203,3 +266,19 @@ def setUp(self): model_params = copy.deepcopy(model_zbl) self.type_split = False self.model = get_model(model_params).to(env.DEVICE) + + +class TestEnergyModelSpinSeAForce(unittest.TestCase, ForceTest): + def setUp(self): + model_params = copy.deepcopy(model_spin) + self.type_split = False + self.test_spin = True + self.model = get_model(model_params).to(env.DEVICE) + + +# class TestEnergyModelSpinSeAVirial(unittest.TestCase, VirialTest): +# def setUp(self): +# model_params = copy.deepcopy(model_spin) +# self.type_split = False +# self.test_spin = True +# self.model = get_model(model_params).to(env.DEVICE) diff --git a/source/tests/pt/model/test_ener_spin_model.py b/source/tests/pt/model/test_ener_spin_model.py new file mode 100644 index 0000000000..a2e45f39bb --- /dev/null +++ b/source/tests/pt/model/test_ener_spin_model.py @@ -0,0 +1,329 @@ +# SPDX-License-Identifier: LGPL-3.0-or-later +import copy +import unittest + +import torch + +from deepmd.pt.model.model import ( + SpinEnergyModel, + get_model, +) +from deepmd.pt.utils import ( + env, +) +from deepmd.pt.utils.nlist import ( + extend_input_and_build_neighbor_list, +) + +from .test_permutation import ( + model_spin, +) + +dtype = torch.float64 + + +def reduce_tensor(extended_tensor, mapping, nloc: int): + nframes, nall = extended_tensor.shape[:2] + ext_dims = extended_tensor.shape[2:] + reduced_tensor = torch.zeros( + [nframes, nloc, *ext_dims], + dtype=extended_tensor.dtype, + device=extended_tensor.device, + ) + mldims = list(mapping.shape) + mapping = mapping.view(mldims + [1] * len(ext_dims)).expand( + [-1] * len(mldims) + list(ext_dims) + ) + # nf x nloc x (*ext_dims) + reduced_tensor = torch.scatter_reduce( + reduced_tensor, + 1, + index=mapping, + src=extended_tensor, + reduce="sum", + ) + return reduced_tensor + + +class SpinTest: + def setUp(self): + self.prec = 1e-10 + natoms = 5 + self.cell = 4.0 * torch.eye(3, dtype=dtype, device=env.DEVICE).unsqueeze(0) + self.coord = 3.0 * torch.rand( + [natoms, 3], dtype=dtype, device=env.DEVICE + ).unsqueeze(0) + self.spin = 0.5 * torch.rand( + [natoms, 3], dtype=dtype, device=env.DEVICE + ).unsqueeze(0) + self.atype = torch.tensor( + [0, 0, 0, 1, 1], dtype=torch.int64, device=env.DEVICE + ).unsqueeze(0) + + self.expected_mask = torch.tensor( + [ + [True], + [True], + [True], + [False], + [False], + ], + dtype=torch.bool, + device=env.DEVICE, + ).unsqueeze(0) + self.expected_atype_with_spin = torch.tensor( + [0, 0, 0, 1, 1, 3, 3, 3, 4, 4], dtype=torch.int64, device=env.DEVICE + ).unsqueeze(0) + self.expected_nloc_spin_index = ( + torch.arange(natoms, natoms * 2, dtype=torch.int64, device=env.DEVICE) + .unsqueeze(0) + .unsqueeze(-1) + ) + + def test_output_shape( + self, + ): + result = self.model( + self.coord, + self.atype, + self.spin, + self.cell, + ) + # check magnetic mask + torch.testing.assert_close(result["mask_mag"], self.expected_mask) + # check output shape to assure split + nframes, nloc = self.coord.shape[:2] + torch.testing.assert_close(result["energy"].shape, [nframes, 1]) + torch.testing.assert_close(result["atom_energy"].shape, [nframes, nloc, 1]) + torch.testing.assert_close(result["force_real"].shape, [nframes, nloc, 3]) + torch.testing.assert_close(result["force_mag"].shape, [nframes, nloc, 3]) + torch.testing.assert_close(result["virial"].shape, [nframes, 9]) + + def test_input_output_process(self): + nframes, nloc = self.coord.shape[:2] + self.real_ntypes = self.model.spin.get_ntypes_real() + # 1. test forward input process + coord_updated, atype_updated = self.model.process_spin_input( + self.coord, self.atype, self.spin + ) + # compare atypes of real and virtual atoms + torch.testing.assert_close(atype_updated, self.expected_atype_with_spin) + # compare coords of real and virtual atoms + torch.testing.assert_close(coord_updated.shape, [nframes, nloc * 2, 3]) + torch.testing.assert_close(coord_updated[:, :nloc], self.coord) + virtual_scale = torch.tensor( + self.model.spin.get_virtual_scale_mask()[self.atype.cpu()], + dtype=dtype, + device=env.DEVICE, + ) + virtual_coord = self.coord + self.spin * virtual_scale.unsqueeze(-1) + torch.testing.assert_close(coord_updated[:, nloc:], virtual_coord) + + # 2. test forward output process + model_ret = self.model.backbone_model.forward_common( + coord_updated, + atype_updated, + self.cell, + do_atomic_virial=True, + ) + if self.model.do_grad_r("energy"): + force_all = model_ret["energy_derv_r"].squeeze(-2) + force_real, force_mag, _ = self.model.process_spin_output( + self.atype, force_all + ) + torch.testing.assert_close(force_real, force_all[:, :nloc]) + torch.testing.assert_close( + force_mag, force_all[:, nloc:] * virtual_scale.unsqueeze(-1) + ) + if self.model.do_grad_c("energy"): + atom_virial_all = model_ret["energy_derv_c"].squeeze(-2) + atom_virial_real, atom_virial_mag, _ = self.model.process_spin_output( + self.atype, atom_virial_all, virtual_scale=False + ) + torch.testing.assert_close(atom_virial_real, atom_virial_all[:, :nloc]) + torch.testing.assert_close(atom_virial_mag, atom_virial_all[:, nloc:]) + + # 3. test forward_lower input process + ( + extended_coord, + extended_atype, + mapping, + nlist, + ) = extend_input_and_build_neighbor_list( + self.coord, + self.atype, + self.model.get_rcut(), + self.model.get_sel(), + mixed_types=self.model.mixed_types(), + box=self.cell, + ) + nall = extended_coord.shape[1] + nnei = nlist.shape[-1] + extended_spin = torch.gather( + self.spin, index=mapping.unsqueeze(-1).tile((1, 1, 3)), dim=1 + ) + ( + extended_coord_updated, + extended_atype_updated, + nlist_updated, + mapping_updated, + ) = self.model.process_spin_input_lower( + extended_coord, extended_atype, extended_spin, nlist, mapping=mapping + ) + # compare atypes of real and virtual atoms + # Note that the real and virtual atoms corresponding to the local ones are switch to the first nloc * 2 atoms + torch.testing.assert_close(extended_atype_updated.shape, [nframes, nall * 2]) + torch.testing.assert_close( + extended_atype_updated[:, :nloc], extended_atype[:, :nloc] + ) + torch.testing.assert_close( + extended_atype_updated[:, nloc : nloc + nloc], + extended_atype[:, :nloc] + self.real_ntypes, + ) + torch.testing.assert_close( + extended_atype_updated[:, nloc + nloc : nloc + nall], + extended_atype[:, nloc:nall], + ) + torch.testing.assert_close( + extended_atype_updated[:, nloc + nall :], + extended_atype[:, nloc:nall] + self.real_ntypes, + ) + virtual_scale = torch.tensor( + self.model.spin.get_virtual_scale_mask()[extended_atype.cpu()], + dtype=dtype, + device=env.DEVICE, + ) + # compare coords of real and virtual atoms + virtual_coord = extended_coord + extended_spin * virtual_scale.unsqueeze(-1) + torch.testing.assert_close(extended_coord_updated.shape, [nframes, nall * 2, 3]) + torch.testing.assert_close( + extended_coord_updated[:, :nloc], extended_coord[:, :nloc] + ) + torch.testing.assert_close( + extended_coord_updated[:, nloc : nloc + nloc], virtual_coord[:, :nloc] + ) + torch.testing.assert_close( + extended_coord_updated[:, nloc + nloc : nloc + nall], + extended_coord[:, nloc:nall], + ) + torch.testing.assert_close( + extended_coord_updated[:, nloc + nall :], virtual_coord[:, nloc:nall] + ) + + # compare mapping + torch.testing.assert_close(mapping_updated.shape, [nframes, nall * 2]) + torch.testing.assert_close(mapping_updated[:, :nloc], mapping[:, :nloc]) + torch.testing.assert_close( + mapping_updated[:, nloc : nloc + nloc], mapping[:, :nloc] + nloc + ) + torch.testing.assert_close( + mapping_updated[:, nloc + nloc : nloc + nall], mapping[:, nloc:nall] + ) + torch.testing.assert_close( + mapping_updated[:, nloc + nall :], mapping[:, nloc:nall] + nloc + ) + + # compare nlist + torch.testing.assert_close( + nlist_updated.shape, [nframes, nloc * 2, nnei * 2 + 1] + ) + # self spin + torch.testing.assert_close( + nlist_updated[:, :nloc, :1], self.expected_nloc_spin_index + ) + # real and virtual neighbors + loc_atoms_mask = (nlist < nloc) & (nlist != -1) + ghost_atoms_mask = nlist >= nloc + real_neighbors = nlist.clone() + real_neighbors[ghost_atoms_mask] += nloc + torch.testing.assert_close( + nlist_updated[:, :nloc, 1 : 1 + nnei], real_neighbors + ) + virtual_neighbors = nlist.clone() + virtual_neighbors[loc_atoms_mask] += nloc + virtual_neighbors[ghost_atoms_mask] += nall + torch.testing.assert_close( + nlist_updated[:, :nloc, 1 + nnei :], virtual_neighbors + ) + + # 4. test forward_lower output process + model_ret = self.model.backbone_model.forward_common_lower( + extended_coord_updated, + extended_atype_updated, + nlist_updated, + mapping=mapping_updated, + do_atomic_virial=True, + ) + if self.model.do_grad_r("energy"): + force_all = model_ret["energy_derv_r"].squeeze(-2) + force_real, force_mag, _ = self.model.process_spin_output_lower( + extended_atype, force_all, nloc + ) + force_all_switched = torch.zeros_like(force_all) + force_all_switched[:, :nloc] = force_all[:, :nloc] + force_all_switched[:, nloc:nall] = force_all[:, nloc + nloc : nloc + nall] + force_all_switched[:, nall : nall + nloc] = force_all[:, nloc : nloc + nloc] + force_all_switched[:, nall + nloc :] = force_all[:, nloc + nall :] + torch.testing.assert_close(force_real, force_all_switched[:, :nall]) + torch.testing.assert_close( + force_mag, force_all_switched[:, nall:] * virtual_scale.unsqueeze(-1) + ) + if self.model.do_grad_c("energy"): + atom_virial_all = model_ret["energy_derv_c"].squeeze(-2) + atom_virial_real, atom_virial_mag, _ = self.model.process_spin_output_lower( + extended_atype, atom_virial_all, nloc, virtual_scale=False + ) + atom_virial_all_switched = torch.zeros_like(atom_virial_all) + atom_virial_all_switched[:, :nloc] = atom_virial_all[:, :nloc] + atom_virial_all_switched[:, nloc:nall] = atom_virial_all[ + :, nloc + nloc : nloc + nall + ] + atom_virial_all_switched[:, nall : nall + nloc] = atom_virial_all[ + :, nloc : nloc + nloc + ] + atom_virial_all_switched[:, nall + nloc :] = atom_virial_all[ + :, nloc + nall : + ] + torch.testing.assert_close( + atom_virial_real, atom_virial_all_switched[:, :nall] + ) + torch.testing.assert_close( + atom_virial_mag, atom_virial_all_switched[:, nall:] + ) + + def test_jit(self): + model = torch.jit.script(self.model) + + def test_self_consistency(self): + a = self.model.serialize() + model1 = SpinEnergyModel.deserialize(a) + result = model1( + self.coord, + self.atype, + self.spin, + self.cell, + ) + expected_result = self.model( + self.coord, + self.atype, + self.spin, + self.cell, + ) + for key in result: + torch.testing.assert_close( + result[key], expected_result[key], rtol=self.prec, atol=self.prec + ) + model1 = torch.jit.script(model1) + + # def test_dp_consistency(self): + + +class TestEnergyModelSpinSeA(unittest.TestCase, SpinTest): + def setUp(self): + SpinTest.setUp(self) + model_params = copy.deepcopy(model_spin) + self.model = get_model(model_params).to(env.DEVICE) + + +if __name__ == "__main__": + unittest.main() diff --git a/source/tests/pt/model/test_forward_lower.py b/source/tests/pt/model/test_forward_lower.py new file mode 100644 index 0000000000..a5c9101e34 --- /dev/null +++ b/source/tests/pt/model/test_forward_lower.py @@ -0,0 +1,177 @@ +# SPDX-License-Identifier: LGPL-3.0-or-later +import copy +import unittest + +import torch + +from deepmd.pt.infer.deep_eval import ( + eval_model, +) +from deepmd.pt.model.model import ( + get_model, +) +from deepmd.pt.utils import ( + env, +) +from deepmd.pt.utils.nlist import ( + extend_input_and_build_neighbor_list, +) + +from .test_permutation import ( # model_dpau, + model_dpa1, + model_dpa2, + model_se_e2_a, + model_spin, + model_zbl, +) + +dtype = torch.float64 + + +def reduce_tensor(extended_tensor, mapping, nloc: int): + nframes, nall = extended_tensor.shape[:2] + ext_dims = extended_tensor.shape[2:] + reduced_tensor = torch.zeros( + [nframes, nloc, *ext_dims], + dtype=extended_tensor.dtype, + device=extended_tensor.device, + ) + mldims = list(mapping.shape) + mapping = mapping.view(mldims + [1] * len(ext_dims)).expand( + [-1] * len(mldims) + list(ext_dims) + ) + # nf x nloc x (*ext_dims) + reduced_tensor = torch.scatter_reduce( + reduced_tensor, + 1, + index=mapping, + src=extended_tensor, + reduce="sum", + ) + return reduced_tensor + + +class ForwardLowerTest: + def test( + self, + ): + prec = self.prec + natoms = 5 + cell = 4.0 * torch.eye(3, dtype=dtype, device=env.DEVICE) + coord = 3.0 * torch.rand([natoms, 3], dtype=dtype, device=env.DEVICE) + spin = 0.5 * torch.rand([natoms, 3], dtype=dtype, device=env.DEVICE) + atype = torch.tensor([0, 0, 0, 1, 1], dtype=torch.int64, device=env.DEVICE) + test_spin = getattr(self, "test_spin", False) + if not test_spin: + test_keys = ["energy", "force", "virial"] + else: + test_keys = ["energy", "force_real", "force_mag", "virial"] + + result_forward = eval_model( + self.model, + coord.unsqueeze(0), + cell.unsqueeze(0), + atype, + spins=spin.unsqueeze(0), + ) + ( + extended_coord, + extended_atype, + mapping, + nlist, + ) = extend_input_and_build_neighbor_list( + coord.unsqueeze(0), + atype.unsqueeze(0), + self.model.get_rcut(), + self.model.get_sel(), + mixed_types=self.model.mixed_types(), + box=cell.unsqueeze(0), + ) + extended_spin = torch.gather( + spin.unsqueeze(0), index=mapping.unsqueeze(-1).tile((1, 1, 3)), dim=1 + ) + input_dict = { + "extended_coord": extended_coord, + "extended_atype": extended_atype, + "nlist": nlist, + "mapping": mapping, + "do_atomic_virial": False, + } + if test_spin: + input_dict["extended_spin"] = extended_spin + result_forward_lower = self.model.forward_lower(**input_dict) + for key in test_keys: + if key in ["energy"]: + torch.testing.assert_close( + result_forward_lower[key], result_forward[key], rtol=prec, atol=prec + ) + elif key in ["force", "force_real", "force_mag"]: + reduced_vv = reduce_tensor( + result_forward_lower[f"extended_{key}"], mapping, natoms + ) + torch.testing.assert_close( + reduced_vv, result_forward[key], rtol=prec, atol=prec + ) + elif key == "virial": + if not hasattr(self, "test_virial") or self.test_virial: + torch.testing.assert_close( + result_forward_lower[key], + result_forward[key], + rtol=prec, + atol=prec, + ) + else: + raise RuntimeError(f"Unexpected test key {key}") + + +class TestEnergyModelSeA(unittest.TestCase, ForwardLowerTest): + def setUp(self): + self.prec = 1e-10 + model_params = copy.deepcopy(model_se_e2_a) + self.type_split = False + self.model = get_model(model_params).to(env.DEVICE) + + +class TestEnergyModelDPA1(unittest.TestCase, ForwardLowerTest): + def setUp(self): + self.prec = 1e-10 + model_params = copy.deepcopy(model_dpa1) + self.type_split = True + self.model = get_model(model_params).to(env.DEVICE) + + +class TestEnergyModelDPA2(unittest.TestCase, ForwardLowerTest): + def setUp(self): + self.prec = 1e-10 + model_params_sample = copy.deepcopy(model_dpa2) + model_params_sample["descriptor"]["rcut"] = model_params_sample["descriptor"][ + "repinit_rcut" + ] + model_params_sample["descriptor"]["sel"] = model_params_sample["descriptor"][ + "repinit_nsel" + ] + model_params = copy.deepcopy(model_dpa2) + self.type_split = True + self.model = get_model(model_params).to(env.DEVICE) + + +class TestEnergyModelZBL(unittest.TestCase, ForwardLowerTest): + def setUp(self): + self.prec = 1e-10 + model_params = copy.deepcopy(model_zbl) + self.type_split = False + self.model = get_model(model_params).to(env.DEVICE) + + +class TestEnergyModelSpinSeA(unittest.TestCase, ForwardLowerTest): + def setUp(self): + # still need to figure out why only 1e-6 rtol and atol + self.prec = 1e-6 + model_params = copy.deepcopy(model_spin) + self.type_split = False + self.test_spin = True + self.model = get_model(model_params).to(env.DEVICE) + + +if __name__ == "__main__": + unittest.main() diff --git a/source/tests/pt/model/test_permutation.py b/source/tests/pt/model/test_permutation.py index ea8f1a5c7a..b843411191 100644 --- a/source/tests/pt/model/test_permutation.py +++ b/source/tests/pt/model/test_permutation.py @@ -22,7 +22,7 @@ "type": "se_e2_a", "sel": [46, 92, 4], "rcut_smth": 0.50, - "rcut": 6.00, + "rcut": 4.00, "neuron": [25, 50, 100], "resnet_dt": False, "axis_neuron": 16, @@ -47,7 +47,7 @@ "type": "se_e2_a", "sel": [46, 92, 4], "rcut_smth": 0.50, - "rcut": 6.00, + "rcut": 4.00, "neuron": [25, 50, 100], "resnet_dt": False, "axis_neuron": 16, @@ -61,6 +61,32 @@ "data_stat_nbatch": 20, } +model_spin = { + "type": "spin", + "type_map": ["O", "H", "B"], + "descriptor": { + "type": "se_e2_a", + "sel": [46, 92, 4], + "rcut_smth": 0.50, + "rcut": 4.00, + "neuron": [25, 50, 100], + "resnet_dt": False, + "axis_neuron": 16, + "seed": 1, + }, + "fitting_net": { + "neuron": [24, 24, 24], + "resnet_dt": True, + "seed": 1, + }, + "data_stat_nbatch": 20, + "spin": { + "use_spin": [True, False, False], + "virtual_scale": [0.3140], + "_comment": " that's all", + }, +} + model_dpa2 = { "type_map": ["O", "H", "B"], "descriptor": { @@ -204,34 +230,46 @@ def test( cell = torch.rand([3, 3], dtype=dtype, device=env.DEVICE) cell = (cell + cell.T) + 5.0 * torch.eye(3, device=env.DEVICE) coord = torch.rand([natoms, 3], dtype=dtype, device=env.DEVICE) + spin = torch.rand([natoms, 3], dtype=dtype, device=env.DEVICE) coord = torch.matmul(coord, cell) atype = torch.tensor([0, 0, 0, 1, 1], dtype=torch.int32, device=env.DEVICE) idx_perm = [1, 0, 4, 3, 2] - e0, f0, v0 = eval_model( - self.model, coord.unsqueeze(0), cell.unsqueeze(0), atype + test_spin = getattr(self, "test_spin", False) + if not test_spin: + test_keys = ["energy", "force", "virial"] + else: + test_keys = ["energy", "force_real", "force_mag", "virial"] + result_0 = eval_model( + self.model, + coord.unsqueeze(0), + cell.unsqueeze(0), + atype, + spins=spin.unsqueeze(0), ) - ret0 = { - "energy": e0.squeeze(0), - "force": f0.squeeze(0), - "virial": v0.squeeze(0), - } - e1, f1, v1 = eval_model( - self.model, coord[idx_perm].unsqueeze(0), cell.unsqueeze(0), atype[idx_perm] + ret0 = {key: result_0[key].squeeze(0) for key in test_keys} + result_1 = eval_model( + self.model, + coord[idx_perm].unsqueeze(0), + cell.unsqueeze(0), + atype[idx_perm], + spins=spin[idx_perm].unsqueeze(0), ) - ret1 = { - "energy": e1.squeeze(0), - "force": f1.squeeze(0), - "virial": v1.squeeze(0), - } + ret1 = {key: result_1[key].squeeze(0) for key in test_keys} prec = 1e-10 - torch.testing.assert_close(ret0["energy"], ret1["energy"], rtol=prec, atol=prec) - torch.testing.assert_close( - ret0["force"][idx_perm], ret1["force"], rtol=prec, atol=prec - ) - if not hasattr(self, "test_virial") or self.test_virial: - torch.testing.assert_close( - ret0["virial"], ret1["virial"], rtol=prec, atol=prec - ) + for key in test_keys: + if key in ["energy"]: + torch.testing.assert_close(ret0[key], ret1[key], rtol=prec, atol=prec) + elif key in ["force", "force_real", "force_mag"]: + torch.testing.assert_close( + ret0[key][idx_perm], ret1[key], rtol=prec, atol=prec + ) + elif key == "virial": + if not hasattr(self, "test_virial") or self.test_virial: + torch.testing.assert_close( + ret0[key], ret1[key], rtol=prec, atol=prec + ) + else: + raise RuntimeError(f"Unexpected test key {key}") class TestEnergyModelSeA(unittest.TestCase, PermutationTest): @@ -303,6 +341,14 @@ def setUp(self): self.model = get_model(model_params).to(env.DEVICE) +class TestEnergyModelSpinSeA(unittest.TestCase, PermutationTest): + def setUp(self): + model_params = copy.deepcopy(model_spin) + self.type_split = False + self.test_spin = True + self.model = get_model(model_params).to(env.DEVICE) + + # class TestEnergyFoo(unittest.TestCase): # def test(self): # model_params = model_dpau diff --git a/source/tests/pt/model/test_rot.py b/source/tests/pt/model/test_rot.py index f769d87546..53ea88bb7b 100644 --- a/source/tests/pt/model/test_rot.py +++ b/source/tests/pt/model/test_rot.py @@ -17,9 +17,8 @@ from .test_permutation import ( # model_dpau, model_dpa1, model_dpa2, - model_hybrid, model_se_e2_a, - model_zbl, + model_spin, ) dtype = torch.float64 @@ -33,80 +32,102 @@ def test( natoms = 5 cell = 10.0 * torch.eye(3, dtype=dtype, device=env.DEVICE) coord = 2 * torch.rand([natoms, 3], dtype=dtype, device=env.DEVICE) + spin = 2 * torch.rand([natoms, 3], dtype=dtype, device=env.DEVICE) shift = torch.tensor([4, 4, 4], dtype=dtype, device=env.DEVICE) atype = torch.tensor([0, 0, 0, 1, 1], dtype=torch.int32, device=env.DEVICE) from scipy.stats import ( special_ortho_group, ) + test_spin = getattr(self, "test_spin", False) + if not test_spin: + test_keys = ["energy", "force", "virial"] + else: + test_keys = ["energy", "force_real", "force_mag", "virial"] rmat = torch.tensor(special_ortho_group.rvs(3), dtype=dtype, device=env.DEVICE) # rotate only coord and shift to the center of cell coord_rot = torch.matmul(coord, rmat) - e0, f0, v0 = eval_model( - self.model, (coord + shift).unsqueeze(0), cell.unsqueeze(0), atype + spin_rot = torch.matmul(spin, rmat) + result_0 = eval_model( + self.model, + (coord + shift).unsqueeze(0), + cell.unsqueeze(0), + atype, + spins=spin.unsqueeze(0), ) - ret0 = { - "energy": e0.squeeze(0), - "force": f0.squeeze(0), - "virial": v0.squeeze(0), - } - e1, f1, v1 = eval_model( - self.model, (coord_rot + shift).unsqueeze(0), cell.unsqueeze(0), atype + ret0 = {key: result_0[key].squeeze(0) for key in test_keys} + result_1 = eval_model( + self.model, + (coord_rot + shift).unsqueeze(0), + cell.unsqueeze(0), + atype, + spins=spin_rot.unsqueeze(0), ) - ret1 = { - "energy": e1.squeeze(0), - "force": f1.squeeze(0), - "virial": v1.squeeze(0), - } - torch.testing.assert_close(ret0["energy"], ret1["energy"], rtol=prec, atol=prec) - torch.testing.assert_close( - torch.matmul(ret0["force"], rmat), ret1["force"], rtol=prec, atol=prec - ) - if not hasattr(self, "test_virial") or self.test_virial: - torch.testing.assert_close( - torch.matmul(rmat.T, torch.matmul(ret0["virial"].view([3, 3]), rmat)), - ret1["virial"].view([3, 3]), - rtol=prec, - atol=prec, - ) - + ret1 = {key: result_1[key].squeeze(0) for key in test_keys} + for key in test_keys: + if key in ["energy"]: + torch.testing.assert_close(ret0[key], ret1[key], rtol=prec, atol=prec) + elif key in ["force", "force_real", "force_mag"]: + torch.testing.assert_close( + torch.matmul(ret0[key], rmat), ret1[key], rtol=prec, atol=prec + ) + elif key == "virial": + if not hasattr(self, "test_virial") or self.test_virial: + torch.testing.assert_close( + torch.matmul( + rmat.T, torch.matmul(ret0[key].view([3, 3]), rmat) + ), + ret1[key].view([3, 3]), + rtol=prec, + atol=prec, + ) + else: + raise RuntimeError(f"Unexpected test key {key}") # rotate coord and cell torch.manual_seed(0) cell = torch.rand([3, 3], dtype=dtype, device=env.DEVICE) cell = (cell + cell.T) + 5.0 * torch.eye(3, device=env.DEVICE) coord = torch.rand([natoms, 3], dtype=dtype, device=env.DEVICE) coord = torch.matmul(coord, cell) + spin = torch.rand([natoms, 3], dtype=dtype, device=env.DEVICE) atype = torch.tensor([0, 0, 0, 1, 1], dtype=torch.int32, device=env.DEVICE) coord_rot = torch.matmul(coord, rmat) + spin_rot = torch.matmul(spin, rmat) cell_rot = torch.matmul(cell, rmat) - e0, f0, v0 = eval_model( - self.model, coord.unsqueeze(0), cell.unsqueeze(0), atype + result_0 = eval_model( + self.model, + coord.unsqueeze(0), + cell.unsqueeze(0), + atype, + spins=spin.unsqueeze(0), ) - ret0 = { - "energy": e0.squeeze(0), - "force": f0.squeeze(0), - "virial": v0.squeeze(0), - } - e1, f1, v1 = eval_model( - self.model, coord_rot.unsqueeze(0), cell_rot.unsqueeze(0), atype + ret0 = {key: result_0[key].squeeze(0) for key in test_keys} + result_1 = eval_model( + self.model, + coord_rot.unsqueeze(0), + cell_rot.unsqueeze(0), + atype, + spins=spin_rot.unsqueeze(0), ) - ret1 = { - "energy": e1.squeeze(0), - "force": f1.squeeze(0), - "virial": v1.squeeze(0), - } - torch.testing.assert_close(ret0["energy"], ret1["energy"], rtol=prec, atol=prec) - torch.testing.assert_close( - torch.matmul(ret0["force"], rmat), ret1["force"], rtol=prec, atol=prec - ) - if not hasattr(self, "test_virial") or self.test_virial: - torch.testing.assert_close( - torch.matmul(rmat.T, torch.matmul(ret0["virial"].view([3, 3]), rmat)), - ret1["virial"].view([3, 3]), - rtol=prec, - atol=prec, - ) + ret1 = {key: result_1[key].squeeze(0) for key in test_keys} + for key in test_keys: + if key in ["energy"]: + torch.testing.assert_close(ret0[key], ret1[key], rtol=prec, atol=prec) + elif key in ["force", "force_real", "force_mag"]: + torch.testing.assert_close( + torch.matmul(ret0[key], rmat), ret1[key], rtol=prec, atol=prec + ) + elif key == "virial": + if not hasattr(self, "test_virial") or self.test_virial: + torch.testing.assert_close( + torch.matmul( + rmat.T, torch.matmul(ret0[key].view([3, 3]), rmat) + ), + ret1[key].view([3, 3]), + rtol=prec, + atol=prec, + ) class TestEnergyModelSeA(unittest.TestCase, RotTest): @@ -153,28 +174,11 @@ def setUp(self): self.model = get_model(model_params).to(env.DEVICE) -@unittest.skip("hybrid not supported at the moment") -class TestEnergyModelHybrid(unittest.TestCase, RotTest): - def setUp(self): - model_params = copy.deepcopy(model_hybrid) - self.type_split = True - self.model = get_model(model_params).to(env.DEVICE) - - -@unittest.skip("hybrid not supported at the moment") -class TestForceModelHybrid(unittest.TestCase, RotTest): - def setUp(self): - model_params = copy.deepcopy(model_hybrid) - model_params["fitting_net"]["type"] = "direct_force_ener" - self.type_split = True - self.test_virial = False - self.model = get_model(model_params).to(env.DEVICE) - - -class TestEnergyModelZBL(unittest.TestCase, RotTest): +class TestEnergyModelSpinSeA(unittest.TestCase, RotTest): def setUp(self): - model_params = copy.deepcopy(model_zbl) + model_params = copy.deepcopy(model_spin) self.type_split = False + self.test_spin = True self.model = get_model(model_params).to(env.DEVICE) diff --git a/source/tests/pt/model/test_smooth.py b/source/tests/pt/model/test_smooth.py index d4d203bf51..be6e3368be 100644 --- a/source/tests/pt/model/test_smooth.py +++ b/source/tests/pt/model/test_smooth.py @@ -19,6 +19,7 @@ model_dpa2, model_hybrid, model_se_e2_a, + model_spin, model_zbl, ) @@ -58,6 +59,7 @@ def test( ) coord1 = torch.matmul(coord1, cell) coord = torch.concat([coord0, coord1], dim=0) + spin = torch.rand([natoms, 3], dtype=dtype, device=env.DEVICE) coord0 = torch.clone(coord) coord1 = torch.clone(coord) @@ -67,52 +69,63 @@ def test( coord3 = torch.clone(coord) coord3[1][0] += epsilon coord3[2][1] += epsilon - - e0, f0, v0 = eval_model( - self.model, coord0.unsqueeze(0), cell.unsqueeze(0), atype + test_spin = getattr(self, "test_spin", False) + if not test_spin: + test_keys = ["energy", "force", "virial"] + else: + test_keys = ["energy", "force_real", "force_mag", "virial"] + + result_0 = eval_model( + self.model, + coord0.unsqueeze(0), + cell.unsqueeze(0), + atype, + spins=spin.unsqueeze(0), ) - ret0 = { - "energy": e0.squeeze(0), - "force": f0.squeeze(0), - "virial": v0.squeeze(0), - } - e1, f1, v1 = eval_model( - self.model, coord1.unsqueeze(0), cell.unsqueeze(0), atype + ret0 = {key: result_0[key].squeeze(0) for key in test_keys} + result_1 = eval_model( + self.model, + coord1.unsqueeze(0), + cell.unsqueeze(0), + atype, + spins=spin.unsqueeze(0), ) - ret1 = { - "energy": e1.squeeze(0), - "force": f1.squeeze(0), - "virial": v1.squeeze(0), - } - e2, f2, v2 = eval_model( - self.model, coord2.unsqueeze(0), cell.unsqueeze(0), atype + ret1 = {key: result_1[key].squeeze(0) for key in test_keys} + result_2 = eval_model( + self.model, + coord2.unsqueeze(0), + cell.unsqueeze(0), + atype, + spins=spin.unsqueeze(0), ) - ret2 = { - "energy": e2.squeeze(0), - "force": f2.squeeze(0), - "virial": v2.squeeze(0), - } - e3, f3, v3 = eval_model( - self.model, coord3.unsqueeze(0), cell.unsqueeze(0), atype + ret2 = {key: result_2[key].squeeze(0) for key in test_keys} + result_3 = eval_model( + self.model, + coord3.unsqueeze(0), + cell.unsqueeze(0), + atype, + spins=spin.unsqueeze(0), ) - ret3 = { - "energy": e3.squeeze(0), - "force": f3.squeeze(0), - "virial": v3.squeeze(0), - } + ret3 = {key: result_3[key].squeeze(0) for key in test_keys} def compare(ret0, ret1): - torch.testing.assert_close( - ret0["energy"], ret1["energy"], rtol=rprec, atol=aprec - ) - # plus 1. to avoid the divided-by-zero issue - torch.testing.assert_close( - 1.0 + ret0["force"], 1.0 + ret1["force"], rtol=rprec, atol=aprec - ) - if not hasattr(self, "test_virial") or self.test_virial: - torch.testing.assert_close( - 1.0 + ret0["virial"], 1.0 + ret1["virial"], rtol=rprec, atol=aprec - ) + for key in test_keys: + if key in ["energy"]: + torch.testing.assert_close( + ret0[key], ret1[key], rtol=rprec, atol=aprec + ) + elif key in ["force", "force_real", "force_mag"]: + # plus 1. to avoid the divided-by-zero issue + torch.testing.assert_close( + 1.0 + ret0[key], 1.0 + ret1[key], rtol=rprec, atol=aprec + ) + elif key == "virial": + if not hasattr(self, "test_virial") or self.test_virial: + torch.testing.assert_close( + 1.0 + ret0[key], 1.0 + ret1[key], rtol=rprec, atol=aprec + ) + else: + raise RuntimeError(f"Unexpected test key {key}") compare(ret0, ret1) compare(ret1, ret2) @@ -211,6 +224,15 @@ def setUp(self): self.epsilon, self.aprec = None, None +class TestEnergyModelSpinSeA(unittest.TestCase, SmoothTest): + def setUp(self): + model_params = copy.deepcopy(model_spin) + self.type_split = False + self.test_spin = True + self.model = get_model(model_params).to(env.DEVICE) + self.epsilon, self.aprec = None, None + + # class TestEnergyFoo(unittest.TestCase): # def test(self): # model_params = model_dpau diff --git a/source/tests/pt/model/test_trans.py b/source/tests/pt/model/test_trans.py index c630112854..656f9f7380 100644 --- a/source/tests/pt/model/test_trans.py +++ b/source/tests/pt/model/test_trans.py @@ -19,6 +19,7 @@ model_dpa2, model_hybrid, model_se_e2_a, + model_spin, model_zbl, ) @@ -34,35 +35,45 @@ def test( cell = (cell + cell.T) + 5.0 * torch.eye(3, device=env.DEVICE) coord = torch.rand([natoms, 3], dtype=dtype, device=env.DEVICE) coord = torch.matmul(coord, cell) + spin = torch.rand([natoms, 3], dtype=dtype, device=env.DEVICE) atype = torch.tensor([0, 0, 0, 1, 1], dtype=torch.int32, device=env.DEVICE) shift = (torch.rand([3], dtype=dtype, device=env.DEVICE) - 0.5) * 2.0 coord_s = torch.matmul( torch.remainder(torch.matmul(coord + shift, torch.linalg.inv(cell)), 1.0), cell, ) - e0, f0, v0 = eval_model( - self.model, coord.unsqueeze(0), cell.unsqueeze(0), atype + test_spin = getattr(self, "test_spin", False) + if not test_spin: + test_keys = ["energy", "force", "virial"] + else: + test_keys = ["energy", "force_real", "force_mag", "virial"] + result_0 = eval_model( + self.model, + coord.unsqueeze(0), + cell.unsqueeze(0), + atype, + spins=spin.unsqueeze(0), ) - ret0 = { - "energy": e0.squeeze(0), - "force": f0.squeeze(0), - "virial": v0.squeeze(0), - } - e1, f1, v1 = eval_model( - self.model, coord_s.unsqueeze(0), cell.unsqueeze(0), atype + ret0 = {key: result_0[key].squeeze(0) for key in test_keys} + result_1 = eval_model( + self.model, + coord_s.unsqueeze(0), + cell.unsqueeze(0), + atype, + spins=spin.unsqueeze(0), ) - ret1 = { - "energy": e1.squeeze(0), - "force": f1.squeeze(0), - "virial": v1.squeeze(0), - } + ret1 = {key: result_1[key].squeeze(0) for key in test_keys} prec = 1e-10 - torch.testing.assert_close(ret0["energy"], ret1["energy"], rtol=prec, atol=prec) - torch.testing.assert_close(ret0["force"], ret1["force"], rtol=prec, atol=prec) - if not hasattr(self, "test_virial") or self.test_virial: - torch.testing.assert_close( - ret0["virial"], ret1["virial"], rtol=prec, atol=prec - ) + for key in test_keys: + if key in ["energy", "force", "force_real", "force_mag"]: + torch.testing.assert_close(ret0[key], ret1[key], rtol=prec, atol=prec) + elif key == "virial": + if not hasattr(self, "test_virial") or self.test_virial: + torch.testing.assert_close( + ret0[key], ret1[key], rtol=prec, atol=prec + ) + else: + raise RuntimeError(f"Unexpected test key {key}") class TestEnergyModelSeA(unittest.TestCase, TransTest): @@ -134,5 +145,13 @@ def setUp(self): self.model = get_model(model_params).to(env.DEVICE) +class TestEnergyModelSpinSeA(unittest.TestCase, TransTest): + def setUp(self): + model_params = copy.deepcopy(model_spin) + self.type_split = False + self.test_spin = True + self.model = get_model(model_params).to(env.DEVICE) + + if __name__ == "__main__": unittest.main()