From 25bf37a9f9a56c9561112fab9224e2e5e6b21024 Mon Sep 17 00:00:00 2001 From: Han Wang <92130845+wanghan-iapcm@users.noreply.github.com> Date: Wed, 14 Feb 2024 16:47:34 +0800 Subject: [PATCH] Implement hessian autodiff calculation (#3262) restrictions: - cannot jit - only the `forward_common` interface has its hessian calculation. not for `forward_common_lower`. - may give nan when nall == nloc. specifically when nloc==1 also fix bug in pt: transform_output. The output shape will be wrong when the dimension of output variable is larger than 1. --------- Co-authored-by: Han Wang Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> --- deepmd/dpmodel/__init__.py | 2 + deepmd/dpmodel/output_def.py | 4 + deepmd/dpmodel/utils/env_mat.py | 3 +- deepmd/pt/model/descriptor/env_mat.py | 7 +- deepmd/pt/model/model/__init__.py | 4 + deepmd/pt/model/model/make_hessian_model.py | 216 ++++++++++++++++++ deepmd/pt/model/model/transform_output.py | 7 +- .../tests/pt/model/test_make_hessian_model.py | 171 ++++++++++++++ 8 files changed, 408 insertions(+), 6 deletions(-) create mode 100644 deepmd/pt/model/model/make_hessian_model.py create mode 100644 source/tests/pt/model/test_make_hessian_model.py diff --git a/deepmd/dpmodel/__init__.py b/deepmd/dpmodel/__init__.py index 5a83bb7bd4..906aac662a 100644 --- a/deepmd/dpmodel/__init__.py +++ b/deepmd/dpmodel/__init__.py @@ -14,6 +14,7 @@ OutputVariableDef, fitting_check_output, get_deriv_name, + get_hessian_name, get_reduce_name, model_check_output, ) @@ -31,4 +32,5 @@ "fitting_check_output", "get_reduce_name", "get_deriv_name", + "get_hessian_name", ] diff --git a/deepmd/dpmodel/output_def.py b/deepmd/dpmodel/output_def.py index fac24534eb..d816ed4e84 100644 --- a/deepmd/dpmodel/output_def.py +++ b/deepmd/dpmodel/output_def.py @@ -319,6 +319,10 @@ def get_deriv_name(name: str) -> Tuple[str, str]: return name + "_derv_r", name + "_derv_c" +def get_hessian_name(name: str) -> str: + return name + "_derv_r_derv_r" + + def apply_operation(var_def: OutputVariableDef, op: OutputVariableOperation) -> int: """Apply a operation to the category of a variable definition. diff --git a/deepmd/dpmodel/utils/env_mat.py b/deepmd/dpmodel/utils/env_mat.py index 739b06208c..070b0e1549 100644 --- a/deepmd/dpmodel/utils/env_mat.py +++ b/deepmd/dpmodel/utils/env_mat.py @@ -53,7 +53,8 @@ def _make_env_mat( t0 = 1 / length t1 = diff / length**2 weight = compute_smooth_weight(length, ruct_smth, rcut) - env_mat_se_a = np.concatenate([t0, t1], axis=-1) * weight * np.expand_dims(mask, -1) + weight = weight * np.expand_dims(mask, -1) + env_mat_se_a = np.concatenate([t0, t1], axis=-1) * weight return env_mat_se_a, diff * np.expand_dims(mask, -1), weight diff --git a/deepmd/pt/model/descriptor/env_mat.py b/deepmd/pt/model/descriptor/env_mat.py index 63181388df..b3235de175 100644 --- a/deepmd/pt/model/descriptor/env_mat.py +++ b/deepmd/pt/model/descriptor/env_mat.py @@ -10,8 +10,10 @@ def _make_env_mat_se_a(nlist, coord, rcut: float, ruct_smth: float): """Make smooth environment matrix.""" bsz, natoms, nnei = nlist.shape coord = coord.view(bsz, -1, 3) + nall = coord.shape[1] mask = nlist >= 0 - nlist = nlist * mask + # nlist = nlist * mask ## this impl will contribute nans in Hessian calculation. + nlist = torch.where(mask, nlist, nall - 1) coord_l = coord[:, :natoms].view(bsz, -1, 1, 3) index = nlist.view(bsz, -1).unsqueeze(-1).expand(-1, -1, 3) coord_r = torch.gather(coord, 1, index) @@ -23,7 +25,8 @@ def _make_env_mat_se_a(nlist, coord, rcut: float, ruct_smth: float): t0 = 1 / length t1 = diff / length**2 weight = compute_smooth_weight(length, ruct_smth, rcut) - env_mat_se_a = torch.cat([t0, t1], dim=-1) * weight * mask.unsqueeze(-1) + weight = weight * mask.unsqueeze(-1) + env_mat_se_a = torch.cat([t0, t1], dim=-1) * weight return env_mat_se_a, diff * mask.unsqueeze(-1), weight diff --git a/deepmd/pt/model/model/__init__.py b/deepmd/pt/model/model/__init__.py index 1948acd003..25db37a3d7 100644 --- a/deepmd/pt/model/model/__init__.py +++ b/deepmd/pt/model/model/__init__.py @@ -18,6 +18,9 @@ EnergyModel, ZBLModel, ) +from .make_hessian_model import ( + make_hessian_model, +) from .model import ( BaseModel, ) @@ -84,4 +87,5 @@ def get_model(model_params): "BaseModel", "EnergyModel", "get_model", + "make_hessian_model", ] diff --git a/deepmd/pt/model/model/make_hessian_model.py b/deepmd/pt/model/model/make_hessian_model.py new file mode 100644 index 0000000000..0ed14b1931 --- /dev/null +++ b/deepmd/pt/model/model/make_hessian_model.py @@ -0,0 +1,216 @@ +# SPDX-License-Identifier: LGPL-3.0-or-later +import copy +import math +from typing import ( + Dict, + List, + Optional, + Union, +) + +import torch + +from deepmd.dpmodel import ( + get_hessian_name, +) + + +def make_hessian_model(T_Model): + """Make a model that can compute Hessian. + + LIMITATION: this model is not jitable due to the restrictions of torch jit script. + + LIMITATION: only the hessian of `forward_common` is available. + + Parameters + ---------- + T_Model + The model. Should provide the `forward_common` and `fitting_output_def` methods + + Returns + ------- + The model computes hessian. + + """ + + class CM(T_Model): + def __init__( + self, + *args, + **kwargs, + ): + super().__init__( + *args, + **kwargs, + ) + self.hess_fitting_def = copy.deepcopy(super().fitting_output_def()) + + def requires_hessian( + self, + keys: Union[str, List[str]], + ): + """Set which output variable(s) requires hessian.""" + if isinstance(keys, str): + keys = [keys] + for kk in self.hess_fitting_def.keys(): + if kk in keys: + self.hess_fitting_def[kk].r_hessian = True + + def fitting_output_def(self): + """Get the fitting output def.""" + return self.hess_fitting_def + + def forward_common( + self, + coord, + atype, + box: Optional[torch.Tensor] = None, + fparam: Optional[torch.Tensor] = None, + aparam: Optional[torch.Tensor] = None, + do_atomic_virial: bool = False, + ) -> Dict[str, torch.Tensor]: + """Return model prediction. + + Parameters + ---------- + coord + The coordinates of the atoms. + shape: nf x (nloc x 3) + atype + The type of atoms. shape: nf x nloc + box + The simulation box. shape: nf x 9 + fparam + frame parameter. nf x ndf + aparam + atomic parameter. nf x nloc x nda + do_atomic_virial + If calculate the atomic virial. + + Returns + ------- + ret_dict + The result dict of type Dict[str,torch.Tensor]. + The keys are defined by the `ModelOutputDef`. + + """ + ret = super().forward_common( + coord, + atype, + box=box, + fparam=fparam, + aparam=aparam, + do_atomic_virial=do_atomic_virial, + ) + vdef = self.fitting_output_def() + hess_yes = [vdef[kk].r_hessian for kk in vdef.keys()] + if any(hess_yes): + hess = self._cal_hessian_all( + coord, + atype, + box=box, + fparam=fparam, + aparam=aparam, + ) + ret.update(hess) + return ret + + def _cal_hessian_all( + self, + coord: torch.Tensor, + atype: torch.Tensor, + box: Optional[torch.Tensor] = None, + fparam: Optional[torch.Tensor] = None, + aparam: Optional[torch.Tensor] = None, + ) -> Dict[str, torch.Tensor]: + nf, nloc = atype.shape + coord = coord.view([nf, (nloc * 3)]) + box = box.view([nf, 9]) if box is not None else None + fparam = fparam.view([nf, -1]) if fparam is not None else None + aparam = aparam.view([nf, nloc, -1]) if aparam is not None else None + fdef = self.fitting_output_def() + # keys of values that require hessian + hess_keys: List[str] = [] + for kk in fdef.keys(): + if fdef[kk].r_hessian: + hess_keys.append(kk) + # result dict init by empty lists + res = {get_hessian_name(kk): [] for kk in hess_keys} + # loop over variable + for kk in hess_keys: + vdef = fdef[kk] + vshape = vdef.shape + vsize = math.prod(vdef.shape) + # loop over frames + for ii in range(nf): + icoord = coord[ii] + iatype = atype[ii] + ibox = box[ii] if box is not None else None + ifparam = fparam[ii] if fparam is not None else None + iaparam = aparam[ii] if aparam is not None else None + # loop over all components + for idx in range(vsize): + hess = self._cal_hessian_one_component( + idx, icoord, iatype, ibox, ifparam, iaparam + ) + res[get_hessian_name(kk)].append(hess) + res[get_hessian_name(kk)] = torch.stack(res[get_hessian_name(kk)]).view( + (nf, *vshape, nloc * 3, nloc * 3) + ) + return res + + def _cal_hessian_one_component( + self, + ci, + coord, + atype, + box: Optional[torch.Tensor] = None, + fparam: Optional[torch.Tensor] = None, + aparam: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + # coord, # (nloc x 3) + # atype, # nloc + # box: Optional[torch.Tensor] = None, # 9 + # fparam: Optional[torch.Tensor] = None, # nfp + # aparam: Optional[torch.Tensor] = None, # (nloc x nap) + wc = wrapper_class_forward_energy(self, ci, atype, box, fparam, aparam) + + hess = torch.autograd.functional.hessian( + wc, + coord, + create_graph=False, + ) + return hess + + class wrapper_class_forward_energy: + def __init__( + self, + obj: CM, + ci: int, + atype: torch.Tensor, + box: Optional[torch.Tensor], + fparam: Optional[torch.Tensor], + aparam: Optional[torch.Tensor], + ): + self.atype, self.box, self.fparam, self.aparam = atype, box, fparam, aparam + self.ci = ci + self.obj = obj + + def __call__( + self, + xx, + ): + ci = self.ci + atype, box, fparam, aparam = self.atype, self.box, self.fparam, self.aparam + res = super(CM, self.obj).forward_common( + xx.unsqueeze(0), + atype.unsqueeze(0), + box.unsqueeze(0) if box is not None else None, + fparam.unsqueeze(0) if fparam is not None else None, + aparam.unsqueeze(0) if aparam is not None else None, + do_atomic_virial=False, + ) + er = res["energy_redu"][0].view([-1])[ci] + return er + + return CM diff --git a/deepmd/pt/model/model/transform_output.py b/deepmd/pt/model/model/transform_output.py index 27e014640d..312bb952b5 100644 --- a/deepmd/pt/model/model/transform_output.py +++ b/deepmd/pt/model/model/transform_output.py @@ -128,10 +128,11 @@ def take_deriv( assert aviri is not None aviri = aviri.unsqueeze(-2) split_avir.append(aviri) - # nf x nloc x v_dim x 3, nf x nloc x v_dim x 9 - ff = torch.concat(split_ff, dim=-2) + # nf x nall x v_dim x 3, nf x nall x v_dim x 9 + out_lead_shape = list(coord_ext.shape[:-1]) + vdef.shape + ff = torch.concat(split_ff, dim=-2).view(out_lead_shape + [3]) # noqa: RUF005 if do_virial: - avir = torch.concat(split_avir, dim=-2) + avir = torch.concat(split_avir, dim=-2).view(out_lead_shape + [9]) # noqa: RUF005 else: avir = None return ff, avir diff --git a/source/tests/pt/model/test_make_hessian_model.py b/source/tests/pt/model/test_make_hessian_model.py new file mode 100644 index 0000000000..650d35e019 --- /dev/null +++ b/source/tests/pt/model/test_make_hessian_model.py @@ -0,0 +1,171 @@ +# SPDX-License-Identifier: LGPL-3.0-or-later +import unittest + +import numpy as np +import torch + +from deepmd.dpmodel.output_def import ( + OutputVariableCategory, +) +from deepmd.pt.model.descriptor.se_a import ( + DescrptSeA, +) +from deepmd.pt.model.model import ( + make_hessian_model, +) +from deepmd.pt.model.model.ener import ( + DPModel, +) +from deepmd.pt.model.task.ener import ( + InvarFitting, +) +from deepmd.pt.utils import ( + env, +) +from deepmd.pt.utils.utils import ( + to_numpy_array, + to_torch_tensor, +) + +dtype = torch.float64 + + +def finite_hessian(f, x, delta=1e-6): + in_shape = x.shape + assert len(in_shape) == 1 + y0 = f(x) + out_shape = y0.shape + res = np.empty(out_shape + in_shape + in_shape) + for iidx in np.ndindex(*in_shape): + for jidx in np.ndindex(*in_shape): + i0 = np.zeros(in_shape) + i1 = np.zeros(in_shape) + i2 = np.zeros(in_shape) + i3 = np.zeros(in_shape) + i0[iidx] += delta + i2[iidx] += delta + i1[iidx] -= delta + i3[iidx] -= delta + i0[jidx] += delta + i1[jidx] += delta + i2[jidx] -= delta + i3[jidx] -= delta + y0 = f(x + i0) + y1 = f(x + i1) + y2 = f(x + i2) + y3 = f(x + i3) + res[(Ellipsis, *iidx, *jidx)] = (y0 + y3 - y1 - y2) / (4 * delta**2.0) + return res + + +class HessianTest: + def test( + self, + ): + # setup test case + places = 6 + delta = 1e-3 + natoms = self.nloc + nf = self.nf + nv = self.nv + cell0 = torch.rand([3, 3], dtype=dtype) + cell0 = 1.0 * (cell0 + cell0.T) + 5.0 * torch.eye(3) + cell1 = torch.rand([3, 3], dtype=dtype) + cell1 = 1.0 * (cell1 + cell1.T) + 5.0 * torch.eye(3) + cell = torch.stack([cell0, cell1]) + coord = torch.rand([nf, natoms, 3], dtype=dtype) + coord = torch.matmul(coord, cell) + cell = cell.view([nf, 9]) + coord = coord.view([nf, natoms * 3]) + atype = torch.stack( + [ + torch.IntTensor([0, 0, 1]), + torch.IntTensor([1, 0, 1]), + ] + ).view([nf, natoms]) + nfp, nap = 2, 3 + fparam = torch.rand([nf, nfp], dtype=dtype) + aparam = torch.rand([nf, natoms * nap], dtype=dtype) + # forward hess and valu models + ret_dict0 = self.model_hess.forward_common( + coord, atype, box=cell, fparam=fparam, aparam=aparam + ) + ret_dict1 = self.model_valu.forward_common( + coord, atype, box=cell, fparam=fparam, aparam=aparam + ) + # compare hess and value models + torch.testing.assert_close(ret_dict0["energy"], ret_dict1["energy"]) + ana_hess = ret_dict0["energy_derv_r_derv_r"] + + # compute finite difference + fnt_hess = [] + for ii in range(nf): + + def np_infer( + xx, + ): + ret = self.model_valu.forward_common( + to_torch_tensor(xx).unsqueeze(0), + atype[ii].unsqueeze(0), + box=cell[ii].unsqueeze(0), + fparam=fparam[ii].unsqueeze(0), + aparam=aparam[ii].unsqueeze(0), + ) + # detach + ret = {kk: to_numpy_array(ret[kk]) for kk in ret} + return ret + + def ff(xx): + return np_infer(xx)["energy_redu"] + + xx = to_numpy_array(coord[ii]) + fnt_hess.append(finite_hessian(ff, xx, delta=delta).squeeze()) + + # compare finite difference with autodiff + fnt_hess = np.stack(fnt_hess).reshape([nf, nv, natoms * 3, natoms * 3]) + np.testing.assert_almost_equal( + fnt_hess, to_numpy_array(ana_hess), decimal=places + ) + + +class TestDPModel(unittest.TestCase, HessianTest): + def setUp(self): + torch.manual_seed(2) + self.nf = 2 + self.nloc = 3 + self.rcut = 4.0 + self.rcut_smth = 3.0 + self.sel = [10, 10] + self.nt = 2 + self.nv = 2 + ds = DescrptSeA( + self.rcut, + self.rcut_smth, + self.sel, + neuron=[2, 4, 8], + axis_neuron=2, + ).to(env.DEVICE) + ft0 = InvarFitting( + "energy", + self.nt, + ds.get_dim_out(), + self.nv, + distinguish_types=ds.distinguish_types(), + do_hessian=True, + neuron=[4, 4, 4], + ).to(env.DEVICE) + type_map = ["foo", "bar"] + self.model_hess = make_hessian_model(DPModel)(ds, ft0, type_map=type_map).to( + env.DEVICE + ) + self.model_valu = DPModel.deserialize(self.model_hess.serialize()) + self.model_hess.requires_hessian("energy") + + def test_output_def(self): + self.assertTrue(self.model_hess.fitting_output_def()["energy"].r_hessian) + self.assertFalse(self.model_valu.fitting_output_def()["energy"].r_hessian) + self.assertTrue(self.model_hess.model_output_def()["energy"].r_hessian) + self.assertEqual( + self.model_hess.model_output_def()["energy_derv_r_derv_r"].category, + OutputVariableCategory.DERV_R_DERV_R, + )