diff --git a/deepmd/calculator.py b/deepmd/calculator.py index c5f742bbec..6ac676dcf6 100644 --- a/deepmd/calculator.py +++ b/deepmd/calculator.py @@ -130,7 +130,7 @@ def calculate( cell = None symbols = self.atoms.get_chemical_symbols() atype = [self.type_dict[k] for k in symbols] - e, f, v = self.dp.eval(coords=coord, cells=cell, atom_types=atype) + e, f, v = self.dp.eval(coords=coord, cells=cell, atom_types=atype)[:3] self.results["energy"] = e[0][0] # see https://gitlab.com/ase/ase/-/merge_requests/2485 self.results["free_energy"] = e[0][0] diff --git a/deepmd/dpmodel/infer/deep_eval.py b/deepmd/dpmodel/infer/deep_eval.py index ce176f5f45..91fa0ac2ac 100644 --- a/deepmd/dpmodel/infer/deep_eval.py +++ b/deepmd/dpmodel/infer/deep_eval.py @@ -383,6 +383,9 @@ def _get_output_shape(self, odef, nframes, natoms): # Something wrong here? # return [nframes, *shape, natoms, 1] return [nframes, natoms, *odef.shape, 1] + elif odef.category == OutputVariableCategory.DERV_R_DERV_R: + # hessian + return [nframes, 3 * natoms, 3 * natoms] else: raise RuntimeError("unknown category") diff --git a/deepmd/driver.py b/deepmd/driver.py index 30916259aa..8d17968376 100644 --- a/deepmd/driver.py +++ b/deepmd/driver.py @@ -67,7 +67,7 @@ def label(self, data: dict) -> dict: cell = data["cells"].reshape((nframes, 9)) else: cell = None - e, f, v = self.dp.eval(coord, cell, atype) + e, f, v = self.dp.eval(coords=coord, cells=cell, atom_types=atype)[:3] data = data.copy() data["energies"] = e.reshape((nframes,)) data["forces"] = f.reshape((nframes, natoms, 3)) diff --git a/deepmd/entrypoints/test.py b/deepmd/entrypoints/test.py index d9744246d7..2c5d68be0e 100644 --- a/deepmd/entrypoints/test.py +++ b/deepmd/entrypoints/test.py @@ -303,6 +303,8 @@ def test_ener( if dp.has_spin: data.add("spin", 3, atomic=True, must=True, high_prec=False) data.add("force_mag", 3, atomic=True, must=False, high_prec=False) + if dp.has_hessian: + data.add("hessian", 1, atomic=True, must=True, high_prec=False) test_data = data.get_test() mixed_type = data.mixed_type @@ -352,6 +354,9 @@ def test_ener( energy = energy.reshape([numb_test, 1]) force = force.reshape([numb_test, -1]) virial = virial.reshape([numb_test, 9]) + if dp.has_hessian: + hessian = ret[3] + hessian = hessian.reshape([numb_test, -1]) if has_atom_ener: ae = ret[3] av = ret[4] @@ -415,6 +420,10 @@ def test_ener( rmse_ea = rmse_e / natoms mae_va = mae_v / natoms rmse_va = rmse_v / natoms + if dp.has_hessian: + diff_h = hessian - test_data["hessian"][:numb_test] + mae_h = mae(diff_h) + rmse_h = rmse(diff_h) if has_atom_ener: diff_ae = test_data["atom_ener"][:numb_test].reshape([-1]) - ae.reshape([-1]) mae_ae = mae(diff_ae) @@ -447,6 +456,9 @@ def test_ener( if has_atom_ener: log.info(f"Atomic ener MAE : {mae_ae:e} eV") log.info(f"Atomic ener RMSE : {rmse_ae:e} eV") + if dp.has_hessian: + log.info(f"Hessian MAE : {mae_h:e} eV/A^2") + log.info(f"Hessian RMSE : {rmse_h:e} eV/A^2") if detail_file is not None: detail_path = Path(detail_file) @@ -530,8 +542,24 @@ def test_ener( "pred_vyy pred_vyz pred_vzx pred_vzy pred_vzz", append=append_detail, ) + if dp.has_hessian: + data_h = test_data["hessian"][:numb_test].reshape(-1, 1) + pred_h = hessian.reshape(-1, 1) + h = np.concatenate( + ( + data_h, + pred_h, + ), + axis=1, + ) + save_txt_file( + detail_path.with_suffix(".h.out"), + h, + header=f"{system}: data_h pred_h (3Na*3Na matrix in row-major order)", + append=append_detail, + ) if not out_put_spin: - return { + dict_to_return = { "mae_e": (mae_e, energy.size), "mae_ea": (mae_ea, energy.size), "mae_f": (mae_f, force.size), @@ -544,7 +572,7 @@ def test_ener( "rmse_va": (rmse_va, virial.size), } else: - return { + dict_to_return = { "mae_e": (mae_e, energy.size), "mae_ea": (mae_ea, energy.size), "mae_fr": (mae_fr, force_r.size), @@ -558,6 +586,10 @@ def test_ener( "rmse_v": (rmse_v, virial.size), "rmse_va": (rmse_va, virial.size), } + if dp.has_hessian: + dict_to_return["mae_h"] = (mae_h, hessian.size) + dict_to_return["rmse_h"] = (rmse_h, hessian.size) + return dict_to_return def print_ener_sys_avg(avg: dict[str, float]) -> None: @@ -584,6 +616,9 @@ def print_ener_sys_avg(avg: dict[str, float]) -> None: log.info(f"Virial RMSE : {avg['rmse_v']:e} eV") log.info(f"Virial MAE/Natoms : {avg['mae_va']:e} eV") log.info(f"Virial RMSE/Natoms : {avg['rmse_va']:e} eV") + if "rmse_h" in avg.keys(): + log.info(f"Hessian MAE : {avg['mae_h']:e} eV/A^2") + log.info(f"Hessian RMSE : {avg['rmse_h']:e} eV/A^2") def test_dos( diff --git a/deepmd/infer/deep_eval.py b/deepmd/infer/deep_eval.py index 159f9bdf60..8cc05c6daa 100644 --- a/deepmd/infer/deep_eval.py +++ b/deepmd/infer/deep_eval.py @@ -77,6 +77,7 @@ class DeepEvalBackend(ABC): # old models in v1 "global_polar": "global_polar", "wfc": "wfc", + "energy_derv_r_derv_r": "hessian", } @abstractmethod @@ -276,6 +277,10 @@ def get_has_spin(self) -> bool: """Check if the model has spin atom types.""" return False + def get_has_hessian(self): + """Check if the model has hessian.""" + return False + @abstractmethod def get_ntypes_spin(self) -> int: """Get the number of spin atom types of this model. Only used in old implement.""" @@ -541,6 +546,11 @@ def has_spin(self) -> bool: """Check if the model has spin.""" return self.deep_eval.get_has_spin() + @property + def has_hessian(self) -> bool: + """Check if the model has hessian.""" + return self.deep_eval.get_has_hessian() + def get_ntypes_spin(self) -> int: """Get the number of spin atom types of this model. Only used in old implement.""" return self.deep_eval.get_ntypes_spin() diff --git a/deepmd/infer/deep_pot.py b/deepmd/infer/deep_pot.py index 4755bc276a..6e00a30f91 100644 --- a/deepmd/infer/deep_pot.py +++ b/deepmd/infer/deep_pot.py @@ -64,6 +64,7 @@ def output_def(self) -> ModelOutputDef: r_differentiable=True, c_differentiable=True, atomic=True, + r_hessian=True, ), ] ) @@ -99,7 +100,10 @@ def eval( aparam: Optional[np.ndarray], mixed_type: bool, **kwargs: Any, - ) -> tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray, np.ndarray]: + ) -> Union[ + tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray, np.ndarray], + tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray, np.ndarray, np.ndarray], + ]: pass @overload @@ -113,7 +117,10 @@ def eval( aparam: Optional[np.ndarray], mixed_type: bool, **kwargs: Any, - ) -> tuple[np.ndarray, np.ndarray, np.ndarray]: + ) -> Union[ + tuple[np.ndarray, np.ndarray, np.ndarray], + tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray], + ]: pass @overload @@ -179,6 +186,8 @@ def eval( atomic_virial The atomic virial of the system, in shape (nframes, natoms, 9). Only returned when atomic is True. + hessian + The Hessian matrix of the system, in shape (nframes, 3 * natoms, 3 * natoms). Returned when available. """ # This method has been used by: # documentation python.md @@ -239,6 +248,11 @@ def eval( force_mag = results["energy_derv_r_mag"].reshape(nframes, natoms, 3) mask_mag = results["mask_mag"].reshape(nframes, natoms, 1) result = (*list(result), force_mag, mask_mag) + if self.deep_eval.get_has_hessian(): + hessian = results["energy_derv_r_derv_r"].reshape( + nframes, 3 * natoms, 3 * natoms + ) + result = (*list(result), hessian) return result diff --git a/deepmd/jax/infer/deep_eval.py b/deepmd/jax/infer/deep_eval.py index a83964329e..acfd42b66a 100644 --- a/deepmd/jax/infer/deep_eval.py +++ b/deepmd/jax/infer/deep_eval.py @@ -411,6 +411,9 @@ def _get_output_shape(self, odef, nframes, natoms): elif odef.category == OutputVariableCategory.OUT: # atom_energy, atom_tensor return [nframes, natoms, *odef.shape, 1] + elif odef.category == OutputVariableCategory.DERV_R_DERV_R: + # hessian + return [nframes, 3 * natoms, 3 * natoms] else: raise RuntimeError("unknown category") diff --git a/deepmd/pt/infer/deep_eval.py b/deepmd/pt/infer/deep_eval.py index 59b833d34c..dc3992dd90 100644 --- a/deepmd/pt/infer/deep_eval.py +++ b/deepmd/pt/infer/deep_eval.py @@ -130,7 +130,8 @@ def __init__( ] = state_dict[item].clone() state_dict = state_dict_head model = get_model(self.input_param).to(DEVICE) - model = torch.jit.script(model) + if not self.input_param.get("hessian_mode"): + model = torch.jit.script(model) self.dp = ModelWrapper(model) self.dp.load_state_dict(state_dict) elif str(self.model_path).endswith(".pth"): @@ -160,6 +161,7 @@ def __init__( self._has_spin = getattr(self.dp.model["Default"], "has_spin", False) if callable(self._has_spin): self._has_spin = self._has_spin() + self._has_hessian = self.model_def_script.get("hessian_mode", False) def get_rcut(self) -> float: """Get the cutoff radius of this model.""" @@ -234,6 +236,10 @@ def get_has_spin(self): """Check if the model has spin atom types.""" return self._has_spin + def get_has_hessian(self): + """Check if the model has hessian.""" + return self._has_hessian + def eval( self, coords: np.ndarray, @@ -339,6 +345,7 @@ def _get_request_defs(self, atomic: bool) -> list[OutputVariableDef]: OutputVariableCategory.REDU, OutputVariableCategory.DERV_R, OutputVariableCategory.DERV_C_REDU, + OutputVariableCategory.DERV_R_DERV_R, ) ] @@ -568,6 +575,9 @@ def _get_output_shape(self, odef, nframes, natoms): # Something wrong here? # return [nframes, *shape, natoms, 1] return [nframes, natoms, *odef.shape, 1] + elif odef.category == OutputVariableCategory.DERV_R_DERV_R: + return [nframes, 3 * natoms, 3 * natoms] + # return [nframes, *odef.shape, 3 * natoms, 3 * natoms] else: raise RuntimeError("unknown category") diff --git a/deepmd/pt/infer/inference.py b/deepmd/pt/infer/inference.py index 0e3bc31057..dd0e7eaccb 100644 --- a/deepmd/pt/infer/inference.py +++ b/deepmd/pt/infer/inference.py @@ -55,6 +55,9 @@ def __init__( ] = state_dict[item].clone() state_dict = state_dict_head + model_params.pop( + "hessian_mode", None + ) # wrapper Hessian to Energy model due to JIT limit self.model_params = deepcopy(model_params) self.model = get_model(model_params).to(DEVICE) diff --git a/deepmd/pt/loss/__init__.py b/deepmd/pt/loss/__init__.py index cae561a8a2..1d25c1e52f 100644 --- a/deepmd/pt/loss/__init__.py +++ b/deepmd/pt/loss/__init__.py @@ -6,6 +6,7 @@ DOSLoss, ) from .ener import ( + EnergyHessianStdLoss, EnergyStdLoss, ) from .ener_spin import ( @@ -24,6 +25,7 @@ __all__ = [ "DOSLoss", "DenoiseLoss", + "EnergyHessianStdLoss", "EnergySpinLoss", "EnergyStdLoss", "PropertyLoss", diff --git a/deepmd/pt/loss/ener.py b/deepmd/pt/loss/ener.py index 327d75c2cd..b564aa57ec 100644 --- a/deepmd/pt/loss/ener.py +++ b/deepmd/pt/loss/ener.py @@ -411,3 +411,75 @@ def label_requirement(self) -> list[DataRequirementItem]: ) ) return label_requirement + + +class EnergyHessianStdLoss(EnergyStdLoss): + def __init__( + self, + start_pref_h=0.0, + limit_pref_h=0.0, + **kwargs, + ): + r"""Enable the layer to compute loss on hessian. + + Parameters + ---------- + start_pref_h : float + The prefactor of hessian loss at the start of the training. + limit_pref_h : float + The prefactor of hessian loss at the end of the training. + **kwargs + Other keyword arguments. + """ + super().__init__(**kwargs) + self.has_h = (start_pref_h != 0.0 and limit_pref_h != 0.0) or self.inference + + self.start_pref_h = start_pref_h + self.limit_pref_h = limit_pref_h + + def forward(self, input_dict, model, label, natoms, learning_rate, mae=False): + model_pred, loss, more_loss = super().forward( + input_dict, model, label, natoms, learning_rate, mae=mae + ) + coef = learning_rate / self.starter_learning_rate + pref_h = self.limit_pref_h + (self.start_pref_h - self.limit_pref_h) * coef + + if self.has_h and "hessian" in model_pred and "hessian" in label: + find_hessian = label.get("find_hessian", 0.0) + pref_h = pref_h * find_hessian + diff_h = label["hessian"].reshape( + -1, + ) - model_pred["hessian"].reshape( + -1, + ) + l2_hessian_loss = torch.mean(torch.square(diff_h)) + if not self.inference: + more_loss["l2_hessian_loss"] = self.display_if_exist( + l2_hessian_loss.detach(), find_hessian + ) + loss += pref_h * l2_hessian_loss + rmse_h = l2_hessian_loss.sqrt() + more_loss["rmse_h"] = self.display_if_exist(rmse_h.detach(), find_hessian) + if mae: + mae_h = torch.mean(torch.abs(diff_h)) + more_loss["mae_h"] = self.display_if_exist(mae_h.detach(), find_hessian) + + if not self.inference: + more_loss["rmse"] = torch.sqrt(loss.detach()) + return model_pred, loss, more_loss + + @property + def label_requirement(self) -> list[DataRequirementItem]: + """Add hessian label requirement needed for this loss calculation.""" + label_requirement = super().label_requirement + if self.has_h: + label_requirement.append( + DataRequirementItem( + "hessian", + ndof=1, # 9=3*3 --> 3N*3N=ndof*natoms*natoms + atomic=True, + must=False, + high_prec=False, + ) + ) + return label_requirement diff --git a/deepmd/pt/model/descriptor/env_mat.py b/deepmd/pt/model/descriptor/env_mat.py index e89e7467d3..dc7142249a 100644 --- a/deepmd/pt/model/descriptor/env_mat.py +++ b/deepmd/pt/model/descriptor/env_mat.py @@ -21,10 +21,11 @@ def _make_env_mat( nall = coord.shape[1] mask = nlist >= 0 # nlist = nlist * mask ## this impl will contribute nans in Hessian calculation. - nlist = torch.where(mask, nlist, nall - 1) + nlist = torch.where(mask, nlist, nall) coord_l = coord[:, :natoms].view(bsz, -1, 1, 3) index = nlist.view(bsz, -1).unsqueeze(-1).expand(-1, -1, 3) - coord_r = torch.gather(coord, 1, index) + coord_pad = torch.concat([coord, coord[:, -1:, :] + rcut], dim=1) + coord_r = torch.gather(coord_pad, 1, index) coord_r = coord_r.view(bsz, natoms, nnei, 3) diff = coord_r - coord_l length = torch.linalg.norm(diff, dim=-1, keepdim=True) diff --git a/deepmd/pt/model/model/__init__.py b/deepmd/pt/model/model/__init__.py index 491a524da8..37e664e82a 100644 --- a/deepmd/pt/model/model/__init__.py +++ b/deepmd/pt/model/model/__init__.py @@ -276,6 +276,8 @@ def get_standard_model(model_params): pair_exclude_types=pair_exclude_types, preset_out_bias=preset_out_bias, ) + if model_params.get("hessian_mode"): + model.enable_hessian() model.model_def_script = json.dumps(model_params_old) return model diff --git a/deepmd/pt/model/model/ener_model.py b/deepmd/pt/model/model/ener_model.py index 9487bcc5bb..8064d3eac7 100644 --- a/deepmd/pt/model/model/ener_model.py +++ b/deepmd/pt/model/model/ener_model.py @@ -15,6 +15,9 @@ from .dp_model import ( DPModelCommon, ) +from .make_hessian_model import ( + make_hessian_model, +) from .make_model import ( make_model, ) @@ -33,6 +36,13 @@ def __init__( ) -> None: DPModelCommon.__init__(self) DPEnergyModel_.__init__(self, *args, **kwargs) + self._hessian_enabled = False + + def enable_hessian(self): + self.__class__ = make_hessian_model(type(self)) + self.hess_fitting_def = super(type(self), self).atomic_output_def() + self.requires_hessian("energy") + self._hessian_enabled = True def translated_output_def(self): out_def_data = self.model_output_def().get_data() @@ -50,6 +60,8 @@ def translated_output_def(self): output_def["atom_virial"].squeeze(-3) if "mask" in out_def_data: output_def["mask"] = out_def_data["mask"] + if self._hessian_enabled: + output_def["hessian"] = out_def_data["energy_derv_r_derv_r"] return output_def def forward( @@ -85,6 +97,8 @@ def forward( model_predict["force"] = model_ret["dforce"] if "mask" in model_ret: model_predict["mask"] = model_ret["mask"] + if self._hessian_enabled: + model_predict["hessian"] = model_ret["energy_derv_r_derv_r"].squeeze(-2) else: model_predict = model_ret model_predict["updated_coord"] += coord diff --git a/deepmd/pt/model/model/make_hessian_model.py b/deepmd/pt/model/model/make_hessian_model.py index 4104314225..000b9abea4 100644 --- a/deepmd/pt/model/model/make_hessian_model.py +++ b/deepmd/pt/model/model/make_hessian_model.py @@ -172,11 +172,10 @@ def _cal_hessian_one_component( # fparam: Optional[torch.Tensor] = None, # nfp # aparam: Optional[torch.Tensor] = None, # (nloc x nap) wc = wrapper_class_forward_energy(self, ci, atype, box, fparam, aparam) - hess = torch.autograd.functional.hessian( wc, coord, - create_graph=False, + create_graph=self.training, ) return hess diff --git a/deepmd/pt/train/training.py b/deepmd/pt/train/training.py index 72e84d577a..57aae3366b 100644 --- a/deepmd/pt/train/training.py +++ b/deepmd/pt/train/training.py @@ -25,6 +25,7 @@ from deepmd.pt.loss import ( DenoiseLoss, DOSLoss, + EnergyHessianStdLoss, EnergySpinLoss, EnergyStdLoss, PropertyLoss, @@ -264,8 +265,22 @@ def get_lr(lr_params): else: self.opt_type, self.opt_param = get_opt_param(training_params) + # loss_param_tmp for Hessian activation + loss_param_tmp = None + if not self.multi_task: + loss_param_tmp = config["loss"] + else: + loss_param_tmp = { + model_key: config["loss_dict"][model_key] + for model_key in self.model_keys + } + # Model - self.model = get_model_for_wrapper(model_params, resuming=resuming) + self.model = get_model_for_wrapper( + model_params, + resuming=resuming, + _loss_params=loss_param_tmp, + ) # Loss if not self.multi_task: @@ -1210,9 +1225,17 @@ def get_additional_data_requirement(_model): return additional_data_requirement +def whether_hessian(loss_params): + loss_type = loss_params.get("type", "ener") + return loss_type == "ener" and loss_params.get("start_pref_h", 0.0) > 0.0 + + def get_loss(loss_params, start_lr, _ntypes, _model): loss_type = loss_params.get("type", "ener") - if loss_type == "ener": + if whether_hessian(loss_params): + loss_params["starter_learning_rate"] = start_lr + return EnergyHessianStdLoss(**loss_params) + elif loss_type == "ener": loss_params["starter_learning_rate"] = start_lr return EnergyStdLoss(**loss_params) elif loss_type == "dos": @@ -1257,8 +1280,14 @@ def get_single_model( return model -def get_model_for_wrapper(_model_params, resuming=False): +def get_model_for_wrapper( + _model_params, + resuming=False, + _loss_params=None, +): if "model_dict" not in _model_params: + if _loss_params is not None and whether_hessian(_loss_params): + _model_params["hessian_mode"] = True _model = get_single_model( _model_params, ) @@ -1267,6 +1296,8 @@ def get_model_for_wrapper(_model_params, resuming=False): model_keys = list(_model_params["model_dict"]) do_case_embd, case_embd_index = get_case_embd_config(_model_params) for _model_key in model_keys: + if _loss_params is not None and whether_hessian(_loss_params[_model_key]): + _model_params["model_dict"][_model_key]["hessian_mode"] = True _model[_model_key] = get_single_model( _model_params["model_dict"][_model_key], ) diff --git a/deepmd/utils/argcheck.py b/deepmd/utils/argcheck.py index 5b57f15979..7d7fd18602 100644 --- a/deepmd/utils/argcheck.py +++ b/deepmd/utils/argcheck.py @@ -2176,6 +2176,8 @@ def loss_ener(): doc_limit_pref_f = limit_pref("force") doc_start_pref_v = start_pref("virial", abbr="v") doc_limit_pref_v = limit_pref("virial") + doc_start_pref_h = start_pref("hessian", abbr="h") # prefactor of hessian + doc_limit_pref_h = limit_pref("hessian") doc_start_pref_ae = start_pref("atomic energy", label="atom_ener", abbr="ae") doc_limit_pref_ae = limit_pref("atomic energy") doc_start_pref_pf = start_pref( @@ -2230,6 +2232,20 @@ def loss_ener(): default=0.00, doc=doc_limit_pref_v, ), + Argument( + "start_pref_h", + [float, int], + optional=True, + default=0.00, + doc=doc_start_pref_h, + ), + Argument( + "limit_pref_h", + [float, int], + optional=True, + default=0.00, + doc=doc_limit_pref_h, + ), Argument( "start_pref_ae", [float, int], diff --git a/deepmd/utils/data.py b/deepmd/utils/data.py index 3d74c72bda..d572efd321 100644 --- a/deepmd/utils/data.py +++ b/deepmd/utils/data.py @@ -660,9 +660,24 @@ def _load_data( f"({nframes}, {natoms_sel}, {ndof_}) or" f"({nframes}, {natoms}, {ndof_})" ) - data = data.reshape([nframes, natoms, -1]) - data = data[:, idx_map, :] - data = data.reshape([nframes, -1]) + if key == "hessian": + data = data.reshape(nframes, 3 * natoms, 3 * natoms) + # get idx_map for hessian + num_chunks, chunk_size = len(idx_map), 3 + idx_map_hess = np.arange(num_chunks * chunk_size) # pylint: disable=no-explicit-dtype + idx_map_hess = idx_map_hess.reshape(num_chunks, chunk_size) + idx_map_hess = idx_map_hess[idx_map] + idx_map_hess = idx_map_hess.flatten() + data = data[:, idx_map_hess, :] + data = data[:, :, idx_map_hess] + data = data.reshape([nframes, -1]) + ndof = ( + 3 * ndof * 3 * ndof + ) # size of hessian is 3Natoms * 3Natoms + else: + data = data.reshape([nframes, natoms, -1]) + data = data[:, idx_map, :] + data = data.reshape([nframes, -1]) data = np.reshape(data, [nframes, ndof]) except ValueError as err_message: explanation = "This error may occur when your label mismatch it's name, i.e. you might store global tensor in `atomic_tensor.npy` or atomic tensor in `tensor.npy`." diff --git a/doc/data/system.md b/doc/data/system.md index b50c6fa256..f6da7b534b 100644 --- a/doc/data/system.md +++ b/doc/data/system.md @@ -22,29 +22,31 @@ The input frame properties contain the following property, the first axis of whi The labeled frame properties are listed as follows, all of which will be used for training if and only if the loss function contains such property: -| ID | Property | Raw file | Unit | Shape | Description | -| --------------------- | -------------------------------------------------------------------------------- | ------------------------- | ---- | ------------------------------------- | ----------------------------------------- | -| energy | Frame energies | energy.raw | eV | Nframes | -| force | Atomic forces | force.raw | eV/Å | Nframes \* Natoms \* 3 | -| virial | Frame virial | virial.raw | eV | Nframes \* 9 | in the order `XX XY XZ YX YY YZ ZX ZY ZZ` | -| atom_ener | Atomic energies | atom_ener.raw | eV | Nframes \* Natoms | -| atom_pref | Weights of atomic forces | atom_pref.raw | 1 | Nframes \* Natoms | -| dipole | Frame dipole | dipole.raw | Any | Nframes \* 3 | -| atomic_dipole | Atomic dipole | atomic_dipole.raw | Any | Nframes \* Natoms \* 3 | -| polarizability | Frame polarizability | polarizability.raw | Any | Nframes \* 9 | in the order `XX XY XZ YX YY YZ ZX ZY ZZ` | -| atomic_polarizability | Atomic polarizability | atomic_polarizability.raw | Any | Nframes \* Natoms \* 9 | in the order `XX XY XZ YX YY YZ ZX ZY ZZ` | -| drdq | Partial derivative of atomic coordinates with respect to generalized coordinates | drdq.raw | 1 | Nframes \* Natoms \* 3 \* Ngen_coords | +| ID | Property | Raw file | Unit | Shape | Description | +| --------------------- | -------------------------------------------------------------------------------- | ------------------------- | ------ | ------------------------------------- | ----------------------------------------- | +| energy | Frame energies | energy.raw | eV | Nframes | +| force | Atomic forces | force.raw | eV/Å | Nframes \* Natoms \* 3 | +| virial | Frame virial | virial.raw | eV | Nframes \* 9 | in the order `XX XY XZ YX YY YZ ZX ZY ZZ` | +| hessian | Frame energy Hessian matrices | hessian.raw | eV/Å^2 | Nframes \* Natoms \* 3 \* Natoms \* 3 | full Hessian matrices | +| atom_ener | Atomic energies | atom_ener.raw | eV | Nframes \* Natoms | +| atom_pref | Weights of atomic forces | atom_pref.raw | 1 | Nframes \* Natoms | +| dipole | Frame dipole | dipole.raw | Any | Nframes \* 3 | +| atomic_dipole | Atomic dipole | atomic_dipole.raw | Any | Nframes \* Natoms \* 3 | +| polarizability | Frame polarizability | polarizability.raw | Any | Nframes \* 9 | in the order `XX XY XZ YX YY YZ ZX ZY ZZ` | +| atomic_polarizability | Atomic polarizability | atomic_polarizability.raw | Any | Nframes \* Natoms \* 9 | in the order `XX XY XZ YX YY YZ ZX ZY ZZ` | +| drdq | Partial derivative of atomic coordinates with respect to generalized coordinates | drdq.raw | 1 | Nframes \* Natoms \* 3 \* Ngen_coords | In general, we always use the following convention of units: -| Property | Unit | -| -------- | ---- | -| Time | ps | -| Length | Å | -| Energy | eV | -| Force | eV/Å | -| Virial | eV | -| Pressure | Bar | +| Property | Unit | +| -------- | ------ | +| Time | ps | +| Length | Å | +| Energy | eV | +| Force | eV/Å | +| Virial | eV | +| Hessian | eV/Å^2 | +| Pressure | Bar | ## Mixed type diff --git a/doc/model/index.rst b/doc/model/index.rst index c067ea4207..8c13d72a60 100644 --- a/doc/model/index.rst +++ b/doc/model/index.rst @@ -14,6 +14,7 @@ Model sel train-energy train-energy-spin + train-energy-hessian train-fitting-tensor train-fitting-dos train-se-e2-a-tebd diff --git a/doc/model/overall.md b/doc/model/overall.md index 7f67c6545d..cc72aa3887 100644 --- a/doc/model/overall.md +++ b/doc/model/overall.md @@ -57,6 +57,11 @@ DeePMD-kit implements the following descriptors: The fitting of the following physical properties is supported -1. [`ener`](train-energy.md): Fit the energy of the system. The force (derivative with atom positions) and the virial (derivative with the box tensor) can also be trained. +1. [`ener`](train-energy.md): Fit the energy of the system. The force (derivative with atom positions), the virial (derivative with the box tensor) and the hessian (second-order derivative with atom positions) can also be trained. + +:::{warning} +Due to the restrictions of torch jit script, the models trained with hessian are not jitable so that the frozen models cannot output hessians. +::: + 2. [`dipole`](train-fitting-tensor.md): The dipole moment. 3. [`polar`](train-fitting-tensor.md): The polarizability. diff --git a/doc/model/train-energy-hessian.md b/doc/model/train-energy-hessian.md new file mode 100644 index 0000000000..02b3c52052 --- /dev/null +++ b/doc/model/train-energy-hessian.md @@ -0,0 +1,54 @@ +# Fit energy Hessian {{ pytorch_icon }} + +:::{note} +**Supported backends**: PyTorch {{ pytorch_icon }} +::: + +:::{warning} +The model trained with Hessian cannot be frozen. If freezing is forced, the model will be treated as a standard energy model, and the frozen model will no longer be able to output Hessian predictions. +::: + +To train a model that takes Hessian matrices, i.e., the second order derivatives of energies w.r.t coordinates as input, you only need to prepare full Hessian matrices and modify the `loss` section to define the Hessian-specific settings, keeping other sections the same as the normal energy model's input script. + +## Energy Hessian Loss + +If you want to train with Hessians, you are expected to add the start and limit prefactors of Hessians, i.e., {ref}`start_pref_h ` and {ref}`limit_pref_h ` to the {ref}`loss ` section in the `input.json`: + +```json + "loss": { + "type": "ener", + "start_pref_e": 0.02, + "limit_pref_e": 1, + "start_pref_f": 1000, + "limit_pref_f": 1, + "start_pref_v": 0, + "limit_pref_v": 0, + "start_pref_h": 10, + "limit_pref_h": 1 + }, +``` + +The options {ref}`start_pref_e `, {ref}`limit_pref_e `, {ref}`start_pref_f `, {ref}`limit_pref_f `, {ref}`start_pref_v ` and {ref}`limit_pref_v ` determine the start and limit prefactors of energy, force, and virial, respectively. The calculation and definition of Hessian loss are the same as for the other terms. + +If one does not want to train with virial, then he/she may set the virial prefactors {ref}`start_pref_v ` and {ref}`limit_pref_v ` to 0. + +## Hessian format in PyTorch + +In the PyTorch backend, Hessian matrices are listed in `hessian.npy` files, and the data format may contain the following files: + +``` +type.raw +set.*/box.npy +set.*/coord.npy +set.*/energy.npy +set.*/force.npy +set.*/hessian.npy +``` + +This system contains `Nframes` frames with the same atom number `Natoms`, the total number of elements contained in all frames is `Ntypes`. Most files are the same as those in [standard formats](../data/system.md), here we only list the distinct ones: + +| ID | Property | Raw file | Unit | Shape | Description | +| ------- | ---------------- | ----------- | ------ | --------------------------------------- | ------------------------------------------------------- | +| hessian | Hessian matrices | hessian.npy | eV/Å^2 | Nframes \* (Natoms \* 3 \* Natoms \* 3) | Second-order derivatives of energies w.r.t coordinates. | + +Note that the `hessian.npy` should contain the **full** Hessian matrices with shape of `(3Natoms * 3Natoms)` for each frame, rather than the upper or lower triangular matrices with shape of `(3Natoms * (3Natoms + 1) / 2)` for each frame. diff --git a/examples/hessian/data/H10C5N2O/set.000/box.npy b/examples/hessian/data/H10C5N2O/set.000/box.npy new file mode 100644 index 0000000000..663e5b7b76 Binary files /dev/null and b/examples/hessian/data/H10C5N2O/set.000/box.npy differ diff --git a/examples/hessian/data/H10C5N2O/set.000/coord.npy b/examples/hessian/data/H10C5N2O/set.000/coord.npy new file mode 100644 index 0000000000..9fa44dbd21 Binary files /dev/null and b/examples/hessian/data/H10C5N2O/set.000/coord.npy differ diff --git a/examples/hessian/data/H10C5N2O/set.000/energy.npy b/examples/hessian/data/H10C5N2O/set.000/energy.npy new file mode 100644 index 0000000000..94e73c8303 Binary files /dev/null and b/examples/hessian/data/H10C5N2O/set.000/energy.npy differ diff --git a/examples/hessian/data/H10C5N2O/set.000/force.npy b/examples/hessian/data/H10C5N2O/set.000/force.npy new file mode 100644 index 0000000000..b7b378fac6 Binary files /dev/null and b/examples/hessian/data/H10C5N2O/set.000/force.npy differ diff --git a/examples/hessian/data/H10C5N2O/set.000/hessian.npy b/examples/hessian/data/H10C5N2O/set.000/hessian.npy new file mode 100644 index 0000000000..8875de32bd Binary files /dev/null and b/examples/hessian/data/H10C5N2O/set.000/hessian.npy differ diff --git a/examples/hessian/data/H10C5N2O/type.raw b/examples/hessian/data/H10C5N2O/type.raw new file mode 100644 index 0000000000..034f24e9c3 --- /dev/null +++ b/examples/hessian/data/H10C5N2O/type.raw @@ -0,0 +1,18 @@ +0 +0 +0 +2 +0 +0 +2 +3 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 diff --git a/examples/hessian/data/H10C5N2O/type_map.raw b/examples/hessian/data/H10C5N2O/type_map.raw new file mode 100644 index 0000000000..5d0a0b4b31 --- /dev/null +++ b/examples/hessian/data/H10C5N2O/type_map.raw @@ -0,0 +1,4 @@ +C +H +N +O diff --git a/examples/hessian/data/H8C4N2O/set.000/box.npy b/examples/hessian/data/H8C4N2O/set.000/box.npy new file mode 100644 index 0000000000..9fbff5b5c4 Binary files /dev/null and b/examples/hessian/data/H8C4N2O/set.000/box.npy differ diff --git a/examples/hessian/data/H8C4N2O/set.000/coord.npy b/examples/hessian/data/H8C4N2O/set.000/coord.npy new file mode 100644 index 0000000000..1e02afc923 Binary files /dev/null and b/examples/hessian/data/H8C4N2O/set.000/coord.npy differ diff --git a/examples/hessian/data/H8C4N2O/set.000/energy.npy b/examples/hessian/data/H8C4N2O/set.000/energy.npy new file mode 100644 index 0000000000..4761a45a41 Binary files /dev/null and b/examples/hessian/data/H8C4N2O/set.000/energy.npy differ diff --git a/examples/hessian/data/H8C4N2O/set.000/force.npy b/examples/hessian/data/H8C4N2O/set.000/force.npy new file mode 100644 index 0000000000..c69ac3552d Binary files /dev/null and b/examples/hessian/data/H8C4N2O/set.000/force.npy differ diff --git a/examples/hessian/data/H8C4N2O/set.000/hessian.npy b/examples/hessian/data/H8C4N2O/set.000/hessian.npy new file mode 100644 index 0000000000..68d6e78c55 Binary files /dev/null and b/examples/hessian/data/H8C4N2O/set.000/hessian.npy differ diff --git a/examples/hessian/data/H8C4N2O/type.raw b/examples/hessian/data/H8C4N2O/type.raw new file mode 100644 index 0000000000..a6510b1c81 --- /dev/null +++ b/examples/hessian/data/H8C4N2O/type.raw @@ -0,0 +1,15 @@ +0 +0 +0 +2 +2 +0 +3 +1 +1 +1 +1 +1 +1 +1 +1 diff --git a/examples/hessian/data/H8C4N2O/type_map.raw b/examples/hessian/data/H8C4N2O/type_map.raw new file mode 100644 index 0000000000..5d0a0b4b31 --- /dev/null +++ b/examples/hessian/data/H8C4N2O/type_map.raw @@ -0,0 +1,4 @@ +C +H +N +O diff --git a/examples/hessian/multi_task/input.json b/examples/hessian/multi_task/input.json new file mode 100644 index 0000000000..b9a347581b --- /dev/null +++ b/examples/hessian/multi_task/input.json @@ -0,0 +1,129 @@ +{ + "_comment": "that's all", + "model": { + "shared_dict": { + "type_map_all": [ + "C", + "H", + "N", + "O" + ], + "dpa1_descriptor": { + "type": "dpa1", + "sel": 120, + "rcut_smth": 0.5, + "rcut": 6.0, + "neuron": [ + 25, + 50, + 100 + ], + "tebd_dim": 256, + "axis_neuron": 16, + "type_one_side": true, + "attn": 128, + "attn_layer": 0, + "attn_dotr": true, + "attn_mask": false, + "activation_function": "tanh", + "scaling_factor": 1.0, + "normalize": true, + "temperature": 1.0 + }, + "_comment": "that's all" + }, + "model_dict": { + "H10C5N2O": { + "type_map": "type_map_all", + "descriptor": "dpa1_descriptor", + "fitting_net": { + "neuron": [ + 240, + 240, + 240 + ], + "resnet_dt": true, + "seed": 1, + "_comment": " that's all" + } + }, + "H8C4N2O": { + "type_map": "type_map_all", + "descriptor": "dpa1_descriptor", + "fitting_net": { + "neuron": [ + 240, + 240, + 240 + ], + "resnet_dt": true, + "seed": 1, + "_comment": " that's all" + } + } + } + }, + "learning_rate": { + "type": "exp", + "decay_steps": 20000, + "start_lr": 0.0002, + "stop_lr": 3.51e-08, + "_comment": "that's all" + }, + "loss_dict": { + "H10C5N2O": { + "type": "ener", + "start_pref_e": 0.02, + "limit_pref_e": 1, + "start_pref_f": 1000, + "limit_pref_f": 1, + "start_pref_v": 0, + "limit_pref_v": 0 + }, + "H8C4N2O": { + "type": "ener", + "start_pref_e": 0.02, + "limit_pref_e": 1, + "start_pref_f": 1000, + "limit_pref_f": 1, + "start_pref_v": 0, + "limit_pref_v": 0, + "start_pref_h": 10, + "limit_pref_h": 1 + } + }, + "training": { + "model_prob": { + "H10C5N2O": 2.0, + "H8C4N2O": 3.0 + }, + "data_dict": { + "H10C5N2O": { + "training_data": { + "systems": [ + "../data/H10C5N2O/" + ], + "batch_size": 1, + "_comment": "that's all" + } + }, + "H8C4N2O": { + "training_data": { + "systems": [ + "../data/H8C4N2O/" + ], + "batch_size": 1, + "_comment": "that's all" + } + } + }, + "numb_steps": 1, + "warmup_steps": 0, + "gradient_max_norm": 5.0, + "seed": 10, + "disp_file": "lcurve.out", + "disp_freq": 100, + "save_freq": 2000, + "_comment": "that's all" + } +} diff --git a/examples/hessian/single_task/input.json b/examples/hessian/single_task/input.json new file mode 100644 index 0000000000..e227cf342c --- /dev/null +++ b/examples/hessian/single_task/input.json @@ -0,0 +1,103 @@ +{ + "_comment": "that's all", + "model": { + "type_map": [ + "C", + "H", + "N", + "O" + ], + "descriptor": { + "type": "dpa2", + "repinit": { + "tebd_dim": 8, + "rcut": 9.0, + "rcut_smth": 8.0, + "nsel": 120, + "neuron": [ + 25, + 50, + 100 + ], + "axis_neuron": 12, + "activation_function": "tanh" + }, + "repformer": { + "rcut": 4.0, + "rcut_smth": 3.5, + "nsel": 40, + "nlayers": 12, + "g1_dim": 128, + "g2_dim": 32, + "attn2_hidden": 32, + "attn2_nhead": 4, + "attn1_hidden": 128, + "attn1_nhead": 4, + "axis_neuron": 4, + "update_h2": false, + "update_g1_has_conv": true, + "update_g1_has_grrg": true, + "update_g1_has_drrd": true, + "update_g1_has_attn": true, + "update_g2_has_g1g1": true, + "update_g2_has_attn": true, + "attn2_has_gate": true + }, + "add_tebd_to_repinit_out": false + }, + "fitting_net": { + "neuron": [ + 240, + 240, + 240 + ], + "resnet_dt": true, + "seed": 1, + "_comment": " that's all" + }, + "_comment": " that's all" + }, + "learning_rate": { + "type": "exp", + "decay_steps": 5000, + "start_lr": 0.0002, + "stop_lr": 3.51e-08, + "_comment": "that's all" + }, + "loss": { + "type": "ener", + "start_pref_e": 0.02, + "limit_pref_e": 1, + "start_pref_f": 1000, + "limit_pref_f": 1, + "start_pref_v": 0, + "limit_pref_v": 0, + "start_pref_h": 10, + "limit_pref_h": 1, + "_comment": " that's all" + }, + "training": { + "training_data": { + "systems": [ + "../data/H8C4N2O" + ], + "batch_size": 1, + "_comment": "that's all" + }, + "validation_data": { + "systems": [ + "../data/H10C5N2O" + ], + "batch_size": 1, + "_comment": "that's all" + }, + "numb_steps": 1000000, + "warmup_steps": 0, + "gradient_max_norm": 5.0, + "seed": 10, + "disp_file": "lcurve.out", + "disp_freq": 100, + "save_freq": 2000, + "_comment": "that's all" + } +} diff --git a/source/tests/common/test_examples.py b/source/tests/common/test_examples.py index 1ddbb50db9..92ecf3a09f 100644 --- a/source/tests/common/test_examples.py +++ b/source/tests/common/test_examples.py @@ -60,11 +60,13 @@ p_examples / "water" / "dpa2" / "input_torch_compressible.json", p_examples / "property" / "train" / "input_torch.json", p_examples / "water" / "se_e3_tebd" / "input_torch.json", + p_examples / "hessian" / "single_task" / "input.json", ) input_files_multi = ( p_examples / "water_multi_task" / "pytorch_example" / "input_torch.json", p_examples / "water_multi_task" / "pytorch_example" / "input_torch_sharefit.json", + p_examples / "hessian" / "multi_task" / "input.json", ) diff --git a/source/tests/infer/test_models.py b/source/tests/infer/test_models.py index a6cde3206c..a79ce0ab21 100644 --- a/source/tests/infer/test_models.py +++ b/source/tests/infer/test_models.py @@ -70,7 +70,7 @@ def test_1frame(self) -> None: atomic=False, fparam=result.fparam, aparam=result.aparam, - ) + )[:3] # check shape of the returns nframes = 1 natoms = len(result.atype) @@ -108,7 +108,7 @@ def test_1frame_atm(self) -> None: atomic=True, fparam=result.fparam, aparam=result.aparam, - ) + )[:5] # check shape of the returns nframes = 1 natoms = len(result.atype) @@ -174,7 +174,7 @@ def test_2frame_atm(self) -> None: atomic=True, fparam=result.fparam, aparam=result.aparam, - ) + )[:5] # check shape of the returns nframes = 2 natoms = len(result.atype) @@ -232,7 +232,7 @@ def test_zero_input(self) -> None: aparam=np.zeros([0, self.case.dim_aparam], dtype=np.float64) if self.case.dim_aparam else None, - ) + )[:3] # check shape of the returns natoms = 0 self.assertEqual(ee.shape, (nframes, 1)) diff --git a/source/tests/pt/hessian/data/H8C4N2O/set.000/box.npy b/source/tests/pt/hessian/data/H8C4N2O/set.000/box.npy new file mode 100644 index 0000000000..9fbff5b5c4 Binary files /dev/null and b/source/tests/pt/hessian/data/H8C4N2O/set.000/box.npy differ diff --git a/source/tests/pt/hessian/data/H8C4N2O/set.000/coord.npy b/source/tests/pt/hessian/data/H8C4N2O/set.000/coord.npy new file mode 100644 index 0000000000..1e02afc923 Binary files /dev/null and b/source/tests/pt/hessian/data/H8C4N2O/set.000/coord.npy differ diff --git a/source/tests/pt/hessian/data/H8C4N2O/set.000/energy.npy b/source/tests/pt/hessian/data/H8C4N2O/set.000/energy.npy new file mode 100644 index 0000000000..4761a45a41 Binary files /dev/null and b/source/tests/pt/hessian/data/H8C4N2O/set.000/energy.npy differ diff --git a/source/tests/pt/hessian/data/H8C4N2O/set.000/force.npy b/source/tests/pt/hessian/data/H8C4N2O/set.000/force.npy new file mode 100644 index 0000000000..c69ac3552d Binary files /dev/null and b/source/tests/pt/hessian/data/H8C4N2O/set.000/force.npy differ diff --git a/source/tests/pt/hessian/data/H8C4N2O/set.000/hessian.npy b/source/tests/pt/hessian/data/H8C4N2O/set.000/hessian.npy new file mode 100644 index 0000000000..68d6e78c55 Binary files /dev/null and b/source/tests/pt/hessian/data/H8C4N2O/set.000/hessian.npy differ diff --git a/source/tests/pt/hessian/data/H8C4N2O/type.raw b/source/tests/pt/hessian/data/H8C4N2O/type.raw new file mode 100644 index 0000000000..a6510b1c81 --- /dev/null +++ b/source/tests/pt/hessian/data/H8C4N2O/type.raw @@ -0,0 +1,15 @@ +0 +0 +0 +2 +2 +0 +3 +1 +1 +1 +1 +1 +1 +1 +1 diff --git a/source/tests/pt/hessian/data/H8C4N2O/type_map.raw b/source/tests/pt/hessian/data/H8C4N2O/type_map.raw new file mode 100644 index 0000000000..5d0a0b4b31 --- /dev/null +++ b/source/tests/pt/hessian/data/H8C4N2O/type_map.raw @@ -0,0 +1,4 @@ +C +H +N +O diff --git a/source/tests/pt/model/test_dp_hessian_model.py b/source/tests/pt/model/test_dp_hessian_model.py new file mode 100644 index 0000000000..55631f67c6 --- /dev/null +++ b/source/tests/pt/model/test_dp_hessian_model.py @@ -0,0 +1,183 @@ +# SPDX-License-Identifier: LGPL-3.0-or-later +import unittest + +import numpy as np + +from deepmd.pt.model.descriptor.se_a import ( + DescrptSeA, +) +from deepmd.pt.model.model import ( + EnergyModel, +) +from deepmd.pt.model.task.ener import ( + EnergyFittingNet, +) +from deepmd.pt.utils import ( + env, +) +from deepmd.pt.utils.utils import ( + to_numpy_array, + to_torch_tensor, +) + +from .test_env_mat import ( + TestCaseSingleFrameWithoutNlist, +) + +dtype = env.GLOBAL_PT_FLOAT_PRECISION + + +class TestEnergyHessianModel(unittest.TestCase, TestCaseSingleFrameWithoutNlist): + def setUp(self): + TestCaseSingleFrameWithoutNlist.setUp(self) + + def test_self_consistency(self): + ds = DescrptSeA( + self.rcut, + self.rcut_smth, + self.sel, + ).to(env.DEVICE) + ft = EnergyFittingNet( + self.nt, + ds.get_dim_out(), + mixed_types=ds.mixed_types(), + ).to(env.DEVICE) + type_map = ["foo", "bar"] + md0 = EnergyModel(ds, ft, type_map=type_map).to(env.DEVICE) + md1 = EnergyModel.deserialize(md0.serialize()).to(env.DEVICE) + md0.enable_hessian() + md1.enable_hessian() + args = [to_torch_tensor(ii) for ii in [self.coord, self.atype, self.cell]] + ret0 = md0.forward(*args) + ret1 = md1.forward(*args) + np.testing.assert_allclose( + to_numpy_array(ret0["atom_energy"]), + to_numpy_array(ret1["atom_energy"]), + atol=self.atol, + ) + np.testing.assert_allclose( + to_numpy_array(ret0["energy"]), + to_numpy_array(ret1["energy"]), + atol=self.atol, + ) + np.testing.assert_allclose( + to_numpy_array(ret0["force"]), + to_numpy_array(ret1["force"]), + atol=self.atol, + ) + np.testing.assert_allclose( + to_numpy_array(ret0["virial"]), + to_numpy_array(ret1["virial"]), + atol=self.atol, + ) + np.testing.assert_allclose( + to_numpy_array(ret0["hessian"]), + to_numpy_array(ret1["hessian"]), + atol=self.atol, + ) + ret0 = md0.forward(*args, do_atomic_virial=True) + ret1 = md1.forward(*args, do_atomic_virial=True) + np.testing.assert_allclose( + to_numpy_array(ret0["atom_virial"]), + to_numpy_array(ret1["atom_virial"]), + atol=self.atol, + ) + + def test_energy_consistency(self): + ds = DescrptSeA( + self.rcut, + self.rcut_smth, + self.sel, + ).to(env.DEVICE) + ft = EnergyFittingNet( + self.nt, + ds.get_dim_out(), + mixed_types=ds.mixed_types(), + ).to(env.DEVICE) + type_map = ["foo", "bar"] + md0 = EnergyModel(ds, ft, type_map=type_map).to(env.DEVICE) + md1 = EnergyModel.deserialize(md0.serialize()).to(env.DEVICE) + md1.enable_hessian() + args = [to_torch_tensor(ii) for ii in [self.coord, self.atype, self.cell]] + ret0 = md0.forward(*args) + ret1 = md1.forward(*args) + np.testing.assert_allclose( + to_numpy_array(ret0["atom_energy"]), + to_numpy_array(ret1["atom_energy"]), + atol=self.atol, + ) + np.testing.assert_allclose( + to_numpy_array(ret0["energy"]), + to_numpy_array(ret1["energy"]), + atol=self.atol, + ) + np.testing.assert_allclose( + to_numpy_array(ret0["force"]), + to_numpy_array(ret1["force"]), + atol=self.atol, + ) + np.testing.assert_allclose( + to_numpy_array(ret0["virial"]), + to_numpy_array(ret1["virial"]), + atol=self.atol, + ) + ret0 = md0.forward(*args, do_atomic_virial=True) + ret1 = md1.forward(*args, do_atomic_virial=True) + np.testing.assert_allclose( + to_numpy_array(ret0["atom_virial"]), + to_numpy_array(ret1["atom_virial"]), + atol=self.atol, + ) + + def test_forward_consistency(self): + ds = DescrptSeA( + self.rcut, + self.rcut_smth, + self.sel, + ).to(env.DEVICE) + ft = EnergyFittingNet( + self.nt, + ds.get_dim_out(), + mixed_types=ds.mixed_types(), + ).to(env.DEVICE) + type_map = ["foo", "bar"] + md0 = EnergyModel(ds, ft, type_map=type_map).to(env.DEVICE) + md1 = EnergyModel.deserialize(md0.serialize()).to(env.DEVICE) + md0.enable_hessian() + md1.enable_hessian() + md0.requires_hessian("energy") + args = [to_torch_tensor(ii) for ii in [self.coord, self.atype, self.cell]] + ret0 = md0.forward_common(*args) + ret1 = md1.forward(*args) + np.testing.assert_allclose( + to_numpy_array(ret0["energy"].squeeze()), + to_numpy_array(ret1["atom_energy"].squeeze()), + atol=self.atol, + ) + np.testing.assert_allclose( + to_numpy_array(ret0["energy_redu"].squeeze()), + to_numpy_array(ret1["energy"].squeeze()), + atol=self.atol, + ) + np.testing.assert_allclose( + to_numpy_array(ret0["energy_derv_r"].squeeze()), + to_numpy_array(ret1["force"].squeeze()), + atol=self.atol, + ) + np.testing.assert_allclose( + to_numpy_array(ret0["energy_derv_c_redu"].squeeze()), + to_numpy_array(ret1["virial"].squeeze()), + atol=self.atol, + ) + np.testing.assert_allclose( + to_numpy_array(ret0["energy_derv_r_derv_r"].squeeze()), + to_numpy_array(ret1["hessian"].squeeze()), + atol=self.atol, + ) + ret0 = md0.forward_common(*args, do_atomic_virial=True) + ret1 = md1.forward(*args, do_atomic_virial=True) + np.testing.assert_allclose( + to_numpy_array(ret0["energy_derv_c"].squeeze()), + to_numpy_array(ret1["atom_virial"].squeeze()), + atol=self.atol, + ) diff --git a/source/tests/pt/test_change_bias.py b/source/tests/pt/test_change_bias.py index a3cf3edbbc..58fd953656 100644 --- a/source/tests/pt/test_change_bias.py +++ b/source/tests/pt/test_change_bias.py @@ -87,6 +87,7 @@ def setUp(self) -> None: self.model_path_user_bias = Path(current_path) / ( model_name + "user_bias" + ".pt" ) + self.loss_params = self.config["loss"] def test_change_bias_with_data(self) -> None: run_dp( @@ -96,7 +97,10 @@ def test_change_bias_with_data(self) -> None: str(self.model_path_data_bias), map_location=DEVICE, weights_only=True ) model_params = state_dict["model"]["_extra_state"]["model_params"] - model_for_wrapper = get_model_for_wrapper(model_params) + model_for_wrapper = get_model_for_wrapper( + model_params, + _loss_params=self.loss_params, + ) wrapper = ModelWrapper(model_for_wrapper) wrapper.load_state_dict(state_dict["model"]) updated_bias = wrapper.model["Default"].get_out_bias() @@ -119,7 +123,10 @@ def test_change_bias_with_data_sys_file(self) -> None: str(self.model_path_data_file_bias), map_location=DEVICE, weights_only=True ) model_params = state_dict["model"]["_extra_state"]["model_params"] - model_for_wrapper = get_model_for_wrapper(model_params) + model_for_wrapper = get_model_for_wrapper( + model_params, + _loss_params=self.loss_params, + ) wrapper = ModelWrapper(model_for_wrapper) wrapper.load_state_dict(state_dict["model"]) updated_bias = wrapper.model["Default"].get_out_bias() @@ -140,7 +147,10 @@ def test_change_bias_with_user_defined(self) -> None: str(self.model_path_user_bias), map_location=DEVICE, weights_only=True ) model_params = state_dict["model"]["_extra_state"]["model_params"] - model_for_wrapper = get_model_for_wrapper(model_params) + model_for_wrapper = get_model_for_wrapper( + model_params, + _loss_params=self.loss_params, + ) wrapper = ModelWrapper(model_for_wrapper) wrapper.load_state_dict(state_dict["model"]) updated_bias = wrapper.model["Default"].get_out_bias() diff --git a/source/tests/pt/test_loss.py b/source/tests/pt/test_loss.py index d0746c1368..2519111357 100644 --- a/source/tests/pt/test_loss.py +++ b/source/tests/pt/test_loss.py @@ -12,6 +12,7 @@ ) from deepmd.pt.loss import ( + EnergyHessianStdLoss, EnergySpinLoss, EnergyStdLoss, ) @@ -52,6 +53,18 @@ def setUp(self) -> None: if not self.spin: self.system = str(Path(__file__).parent / "water/data/data_0") self.type_map = ["H", "O"] + if self.hess: + self.system = str(Path(__file__).parent / "hessian/data/H8C4N2O") + self.type_map = ["C", "H", "N", "O"] + energy_data_requirement.append( + DataRequirementItem( + "hessian", + ndof=1, + atomic=True, + must=False, + high_prec=False, + ) + ) else: self.system = str(Path(__file__).parent / "NiO/data/data_0") self.type_map = ["Ni", "O"] @@ -238,6 +251,14 @@ def setUp(self) -> None: "drdq": torch.from_numpy(drdq), "atom_ener_coeff": torch.from_numpy(atom_ener_coeff), } + if self.hess: + l_hessian = np_batch["hessian"] + p_hessian = np.ones_like(l_hessian) + self.model_pred["hessian"] = torch.from_numpy(p_hessian) + self.label["hessian"] = torch.from_numpy(l_hessian) + self.label["find_hessian"] = 1.0 + self.label_absent["hessian"] = torch.from_numpy(l_hessian) + else: self.model_pred = { "energy": torch.from_numpy(p_energy), @@ -310,6 +331,7 @@ def setUp(self) -> None: self.limit_pref_v, ) self.spin = False + self.hess = False super().setUp() def test_consistency(self) -> None: @@ -399,6 +421,7 @@ def setUp(self) -> None: numb_generalized_coord=self.numb_generalized_coord, ) self.spin = False + self.hess = False super().setUp() def test_consistency(self) -> None: @@ -469,6 +492,7 @@ def setUp(self) -> None: enable_atom_ener_coeff=True, ) self.spin = False + self.hess = False super().setUp() def test_consistency(self) -> None: @@ -539,6 +563,7 @@ def setUp(self) -> None: relative_f=0.1, ) self.spin = False + self.hess = False super().setUp() def test_consistency(self) -> None: @@ -577,6 +602,112 @@ def fake_model(): self.assertTrue(np.isnan(pt_more_loss_absent[f"l2_{key}_loss"])) +class TestEnerHessStdLoss(LossCommonTest): + def setUp(self): + self.start_lr = 1.1 + self.start_pref_e = 0.02 + self.limit_pref_e = 1.0 + self.start_pref_f = 1000.0 + self.limit_pref_f = 1.0 + self.start_pref_v = 0.02 + self.limit_pref_v = 1.0 + self.start_pref_h = 10.0 + self.limit_pref_h = 1.0 + # tf + self.tf_loss = EnerStdLoss( + self.start_lr, + self.start_pref_e, + self.limit_pref_e, + self.start_pref_f, + self.limit_pref_f, + self.start_pref_v, + self.limit_pref_v, + ) + # pt + self.pt_loss = EnergyStdLoss( + self.start_lr, + self.start_pref_e, + self.limit_pref_e, + self.start_pref_f, + self.limit_pref_f, + self.start_pref_v, + self.limit_pref_v, + ) + # pt-hess + self.pt_loss_h = EnergyHessianStdLoss( + starter_learning_rate=self.start_lr, + start_pref_e=self.start_pref_e, + limit_pref_e=self.limit_pref_e, + start_pref_f=self.start_pref_f, + limit_pref_f=self.limit_pref_f, + start_pref_v=self.start_pref_v, + limit_pref_v=self.limit_pref_v, + start_pref_h=self.start_pref_h, + limit_pref_h=self.limit_pref_h, + ) + self.spin = False + self.hess = True + super().setUp() + + def test_consistency(self): + with tf.Session(graph=self.g) as sess: + tf_loss, tf_more_loss = sess.run( + self.tf_loss_sess, feed_dict=self.feed_dict + ) + + def fake_model(): + return self.model_pred + + _, pt_loss, pt_more_loss = self.pt_loss( + {}, + fake_model, + self.label, + self.nloc, + self.cur_lr, + ) + _, pt_loss_absent, pt_more_loss_absent = self.pt_loss( + {}, + fake_model, + self.label_absent, + self.nloc, + self.cur_lr, + ) + pt_loss = pt_loss.detach().cpu() + pt_loss_absent = pt_loss_absent.detach().cpu() + _, pt_loss_h, pt_more_loss_h = self.pt_loss_h( + {}, + fake_model, + self.label, + self.nloc, + self.cur_lr, + ) + _, pt_loss_h_absent, pt_more_loss_h_absent = self.pt_loss_h( + {}, + fake_model, + self.label_absent, + self.nloc, + self.cur_lr, + ) + pt_loss_h_absent = pt_loss_h_absent.detach().cpu() + self.assertTrue(np.allclose(tf_loss, pt_loss.numpy())) + self.assertTrue(np.allclose(0.0, pt_loss_absent.numpy())) + self.assertTrue(np.allclose(0.0, pt_loss_h_absent.numpy())) + for key in ["ener", "force", "virial"]: + self.assertTrue( + np.allclose( + tf_more_loss[f"l2_{key}_loss"], pt_more_loss[f"l2_{key}_loss"] + ) + ) + self.assertTrue( + np.allclose( + pt_more_loss[f"l2_{key}_loss"], pt_more_loss_h[f"l2_{key}_loss"] + ) + ) + self.assertTrue(np.isnan(pt_more_loss_absent[f"l2_{key}_loss"])) + for key in ["ener", "force", "virial", "hessian"]: + self.assertTrue(np.isnan(pt_more_loss_h_absent[f"l2_{key}_loss"])) + + class TestEnerSpinLoss(LossCommonTest): def setUp(self) -> None: self.start_lr = 1.1 @@ -610,6 +741,7 @@ def setUp(self) -> None: self.limit_pref_fm, ) self.spin = True + self.hess = False super().setUp() def test_consistency(self) -> None: @@ -687,6 +819,7 @@ def setUp(self) -> None: limit_pref_ae=self.limit_pref_ae, ) self.spin = True + self.hess = False super().setUp() def test_consistency(self) -> None: @@ -760,6 +893,7 @@ def setUp(self) -> None: enable_atom_ener_coeff=True, ) self.spin = True + self.hess = False super().setUp() def test_consistency(self) -> None: