Skip to content

Commit

Permalink
Implement hessian autodiff calculation (#3262)
Browse files Browse the repository at this point in the history
restrictions:
- cannot jit
- only the `forward_common` interface has its hessian calculation. not
for `forward_common_lower`.
- may give nan when nall == nloc. specifically when nloc==1

also fix bug in pt: transform_output. The output shape will be wrong
when the dimension of output variable is larger than 1.

---------

Co-authored-by: Han Wang <[email protected]>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
  • Loading branch information
3 people authored Feb 14, 2024
1 parent 977b430 commit 25bf37a
Show file tree
Hide file tree
Showing 8 changed files with 408 additions and 6 deletions.
2 changes: 2 additions & 0 deletions deepmd/dpmodel/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
OutputVariableDef,
fitting_check_output,
get_deriv_name,
get_hessian_name,
get_reduce_name,
model_check_output,
)
Expand All @@ -31,4 +32,5 @@
"fitting_check_output",
"get_reduce_name",
"get_deriv_name",
"get_hessian_name",
]
4 changes: 4 additions & 0 deletions deepmd/dpmodel/output_def.py
Original file line number Diff line number Diff line change
Expand Up @@ -319,6 +319,10 @@ def get_deriv_name(name: str) -> Tuple[str, str]:
return name + "_derv_r", name + "_derv_c"


def get_hessian_name(name: str) -> str:
return name + "_derv_r_derv_r"


def apply_operation(var_def: OutputVariableDef, op: OutputVariableOperation) -> int:
"""Apply a operation to the category of a variable definition.
Expand Down
3 changes: 2 additions & 1 deletion deepmd/dpmodel/utils/env_mat.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,8 @@ def _make_env_mat(
t0 = 1 / length
t1 = diff / length**2
weight = compute_smooth_weight(length, ruct_smth, rcut)
env_mat_se_a = np.concatenate([t0, t1], axis=-1) * weight * np.expand_dims(mask, -1)
weight = weight * np.expand_dims(mask, -1)
env_mat_se_a = np.concatenate([t0, t1], axis=-1) * weight
return env_mat_se_a, diff * np.expand_dims(mask, -1), weight


Expand Down
7 changes: 5 additions & 2 deletions deepmd/pt/model/descriptor/env_mat.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,10 @@ def _make_env_mat_se_a(nlist, coord, rcut: float, ruct_smth: float):
"""Make smooth environment matrix."""
bsz, natoms, nnei = nlist.shape
coord = coord.view(bsz, -1, 3)
nall = coord.shape[1]
mask = nlist >= 0
nlist = nlist * mask
# nlist = nlist * mask ## this impl will contribute nans in Hessian calculation.
nlist = torch.where(mask, nlist, nall - 1)
coord_l = coord[:, :natoms].view(bsz, -1, 1, 3)
index = nlist.view(bsz, -1).unsqueeze(-1).expand(-1, -1, 3)
coord_r = torch.gather(coord, 1, index)
Expand All @@ -23,7 +25,8 @@ def _make_env_mat_se_a(nlist, coord, rcut: float, ruct_smth: float):
t0 = 1 / length
t1 = diff / length**2
weight = compute_smooth_weight(length, ruct_smth, rcut)
env_mat_se_a = torch.cat([t0, t1], dim=-1) * weight * mask.unsqueeze(-1)
weight = weight * mask.unsqueeze(-1)
env_mat_se_a = torch.cat([t0, t1], dim=-1) * weight
return env_mat_se_a, diff * mask.unsqueeze(-1), weight


Expand Down
4 changes: 4 additions & 0 deletions deepmd/pt/model/model/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,9 @@
EnergyModel,
ZBLModel,
)
from .make_hessian_model import (
make_hessian_model,
)
from .model import (
BaseModel,
)
Expand Down Expand Up @@ -84,4 +87,5 @@ def get_model(model_params):
"BaseModel",
"EnergyModel",
"get_model",
"make_hessian_model",
]
216 changes: 216 additions & 0 deletions deepmd/pt/model/model/make_hessian_model.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,216 @@
# SPDX-License-Identifier: LGPL-3.0-or-later
import copy
import math
from typing import (
Dict,
List,
Optional,
Union,
)

import torch

from deepmd.dpmodel import (
get_hessian_name,
)


def make_hessian_model(T_Model):
"""Make a model that can compute Hessian.
LIMITATION: this model is not jitable due to the restrictions of torch jit script.
LIMITATION: only the hessian of `forward_common` is available.
Parameters
----------
T_Model
The model. Should provide the `forward_common` and `fitting_output_def` methods
Returns
-------
The model computes hessian.
"""

class CM(T_Model):
def __init__(
self,
*args,
**kwargs,
):
super().__init__(
*args,
**kwargs,
)
self.hess_fitting_def = copy.deepcopy(super().fitting_output_def())

def requires_hessian(
self,
keys: Union[str, List[str]],
):
"""Set which output variable(s) requires hessian."""
if isinstance(keys, str):
keys = [keys]
for kk in self.hess_fitting_def.keys():
if kk in keys:
self.hess_fitting_def[kk].r_hessian = True

def fitting_output_def(self):
"""Get the fitting output def."""
return self.hess_fitting_def

def forward_common(
self,
coord,
atype,
box: Optional[torch.Tensor] = None,
fparam: Optional[torch.Tensor] = None,
aparam: Optional[torch.Tensor] = None,
do_atomic_virial: bool = False,
) -> Dict[str, torch.Tensor]:
"""Return model prediction.
Parameters
----------
coord
The coordinates of the atoms.
shape: nf x (nloc x 3)
atype
The type of atoms. shape: nf x nloc
box
The simulation box. shape: nf x 9
fparam
frame parameter. nf x ndf
aparam
atomic parameter. nf x nloc x nda
do_atomic_virial
If calculate the atomic virial.
Returns
-------
ret_dict
The result dict of type Dict[str,torch.Tensor].
The keys are defined by the `ModelOutputDef`.
"""
ret = super().forward_common(
coord,
atype,
box=box,
fparam=fparam,
aparam=aparam,
do_atomic_virial=do_atomic_virial,
)
vdef = self.fitting_output_def()
hess_yes = [vdef[kk].r_hessian for kk in vdef.keys()]
if any(hess_yes):
hess = self._cal_hessian_all(
coord,
atype,
box=box,
fparam=fparam,
aparam=aparam,
)
ret.update(hess)
return ret

def _cal_hessian_all(
self,
coord: torch.Tensor,
atype: torch.Tensor,
box: Optional[torch.Tensor] = None,
fparam: Optional[torch.Tensor] = None,
aparam: Optional[torch.Tensor] = None,
) -> Dict[str, torch.Tensor]:
nf, nloc = atype.shape
coord = coord.view([nf, (nloc * 3)])
box = box.view([nf, 9]) if box is not None else None
fparam = fparam.view([nf, -1]) if fparam is not None else None
aparam = aparam.view([nf, nloc, -1]) if aparam is not None else None
fdef = self.fitting_output_def()
# keys of values that require hessian
hess_keys: List[str] = []
for kk in fdef.keys():
if fdef[kk].r_hessian:
hess_keys.append(kk)
# result dict init by empty lists
res = {get_hessian_name(kk): [] for kk in hess_keys}
# loop over variable
for kk in hess_keys:
vdef = fdef[kk]
vshape = vdef.shape
vsize = math.prod(vdef.shape)
# loop over frames
for ii in range(nf):
icoord = coord[ii]
iatype = atype[ii]
ibox = box[ii] if box is not None else None
ifparam = fparam[ii] if fparam is not None else None
iaparam = aparam[ii] if aparam is not None else None
# loop over all components
for idx in range(vsize):
hess = self._cal_hessian_one_component(
idx, icoord, iatype, ibox, ifparam, iaparam
)
res[get_hessian_name(kk)].append(hess)
res[get_hessian_name(kk)] = torch.stack(res[get_hessian_name(kk)]).view(
(nf, *vshape, nloc * 3, nloc * 3)
)
return res

def _cal_hessian_one_component(
self,
ci,
coord,
atype,
box: Optional[torch.Tensor] = None,
fparam: Optional[torch.Tensor] = None,
aparam: Optional[torch.Tensor] = None,
) -> torch.Tensor:
# coord, # (nloc x 3)
# atype, # nloc
# box: Optional[torch.Tensor] = None, # 9
# fparam: Optional[torch.Tensor] = None, # nfp
# aparam: Optional[torch.Tensor] = None, # (nloc x nap)
wc = wrapper_class_forward_energy(self, ci, atype, box, fparam, aparam)

hess = torch.autograd.functional.hessian(
wc,
coord,
create_graph=False,
)
return hess

class wrapper_class_forward_energy:
def __init__(
self,
obj: CM,
ci: int,
atype: torch.Tensor,
box: Optional[torch.Tensor],
fparam: Optional[torch.Tensor],
aparam: Optional[torch.Tensor],
):
self.atype, self.box, self.fparam, self.aparam = atype, box, fparam, aparam
self.ci = ci
self.obj = obj

def __call__(
self,
xx,
):
ci = self.ci
atype, box, fparam, aparam = self.atype, self.box, self.fparam, self.aparam
res = super(CM, self.obj).forward_common(
xx.unsqueeze(0),
atype.unsqueeze(0),
box.unsqueeze(0) if box is not None else None,
fparam.unsqueeze(0) if fparam is not None else None,
aparam.unsqueeze(0) if aparam is not None else None,
do_atomic_virial=False,
)
er = res["energy_redu"][0].view([-1])[ci]
return er

return CM
7 changes: 4 additions & 3 deletions deepmd/pt/model/model/transform_output.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,10 +128,11 @@ def take_deriv(
assert aviri is not None
aviri = aviri.unsqueeze(-2)
split_avir.append(aviri)
# nf x nloc x v_dim x 3, nf x nloc x v_dim x 9
ff = torch.concat(split_ff, dim=-2)
# nf x nall x v_dim x 3, nf x nall x v_dim x 9
out_lead_shape = list(coord_ext.shape[:-1]) + vdef.shape
ff = torch.concat(split_ff, dim=-2).view(out_lead_shape + [3]) # noqa: RUF005
if do_virial:
avir = torch.concat(split_avir, dim=-2)
avir = torch.concat(split_avir, dim=-2).view(out_lead_shape + [9]) # noqa: RUF005
else:
avir = None
return ff, avir
Expand Down
Loading

0 comments on commit 25bf37a

Please sign in to comment.