Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement hessian autodiff calculation #3262

Merged
merged 10 commits into from
Feb 14, 2024
Merged
Show file tree
Hide file tree
Changes from 8 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions deepmd/dpmodel/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
OutputVariableDef,
fitting_check_output,
get_deriv_name,
get_hessian_name,
get_reduce_name,
model_check_output,
)
Expand All @@ -31,4 +32,5 @@
"fitting_check_output",
"get_reduce_name",
"get_deriv_name",
"get_hessian_name",
]
4 changes: 4 additions & 0 deletions deepmd/dpmodel/output_def.py
Original file line number Diff line number Diff line change
Expand Up @@ -319,6 +319,10 @@
return name + "_derv_r", name + "_derv_c"


def get_hessian_name(name: str) -> str:
return name + "_derv_r_derv_r"

Check warning on line 323 in deepmd/dpmodel/output_def.py

View check run for this annotation

Codecov / codecov/patch

deepmd/dpmodel/output_def.py#L323

Added line #L323 was not covered by tests


def apply_operation(var_def: OutputVariableDef, op: OutputVariableOperation) -> int:
"""Apply a operation to the category of a variable definition.

Expand Down
3 changes: 2 additions & 1 deletion deepmd/dpmodel/utils/env_mat.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,8 @@
t0 = 1 / length
t1 = diff / length**2
weight = compute_smooth_weight(length, ruct_smth, rcut)
env_mat_se_a = np.concatenate([t0, t1], axis=-1) * weight * np.expand_dims(mask, -1)
weight = weight * np.expand_dims(mask, -1)
env_mat_se_a = np.concatenate([t0, t1], axis=-1) * weight

Check warning on line 57 in deepmd/dpmodel/utils/env_mat.py

View check run for this annotation

Codecov / codecov/patch

deepmd/dpmodel/utils/env_mat.py#L56-L57

Added lines #L56 - L57 were not covered by tests
return env_mat_se_a, diff * np.expand_dims(mask, -1), weight


Expand Down
7 changes: 5 additions & 2 deletions deepmd/pt/model/descriptor/env_mat.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,10 @@
"""Make smooth environment matrix."""
bsz, natoms, nnei = nlist.shape
coord = coord.view(bsz, -1, 3)
nall = coord.shape[1]

Check warning on line 13 in deepmd/pt/model/descriptor/env_mat.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/model/descriptor/env_mat.py#L13

Added line #L13 was not covered by tests
mask = nlist >= 0
nlist = nlist * mask
# nlist = nlist * mask ## this impl will contribute nans in Hessian calculation.
nlist = torch.where(mask, nlist, nall - 1)

Check warning on line 16 in deepmd/pt/model/descriptor/env_mat.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/model/descriptor/env_mat.py#L16

Added line #L16 was not covered by tests
coord_l = coord[:, :natoms].view(bsz, -1, 1, 3)
index = nlist.view(bsz, -1).unsqueeze(-1).expand(-1, -1, 3)
coord_r = torch.gather(coord, 1, index)
Expand All @@ -23,7 +25,8 @@
t0 = 1 / length
t1 = diff / length**2
weight = compute_smooth_weight(length, ruct_smth, rcut)
env_mat_se_a = torch.cat([t0, t1], dim=-1) * weight * mask.unsqueeze(-1)
weight = weight * mask.unsqueeze(-1)
env_mat_se_a = torch.cat([t0, t1], dim=-1) * weight

Check warning on line 29 in deepmd/pt/model/descriptor/env_mat.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/model/descriptor/env_mat.py#L28-L29

Added lines #L28 - L29 were not covered by tests
return env_mat_se_a, diff * mask.unsqueeze(-1), weight


Expand Down
3 changes: 3 additions & 0 deletions deepmd/pt/model/model/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,9 @@
EnergyModel,
ZBLModel,
)
from .make_hessian_model import (

Check warning on line 21 in deepmd/pt/model/model/__init__.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/model/model/__init__.py#L21

Added line #L21 was not covered by tests
make_hessian_model,
)
from .model import (
BaseModel,
)
Expand Down
21 changes: 19 additions & 2 deletions deepmd/pt/model/model/ener.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,9 @@ def forward(
coord,
atype,
box,
fparam=fparam,
aparam=aparam,
do_atomic_virial=do_atomic_virial,
)

model_predict = {}
Expand All @@ -63,13 +66,18 @@ def forward_lower(
extended_atype,
nlist,
mapping: Optional[torch.Tensor] = None,
fparam: Optional[torch.Tensor] = None,
aparam: Optional[torch.Tensor] = None,
do_atomic_virial: bool = False,
):
model_ret = self.forward_common_lower(
extended_coord,
extended_atype,
nlist,
mapping,
mapping=mapping,
fparam=fparam,
aparam=aparam,
do_atomic_virial=do_atomic_virial,
)

model_predict = {}
Expand Down Expand Up @@ -109,7 +117,12 @@ def forward(
do_atomic_virial: bool = False,
) -> Dict[str, torch.Tensor]:
model_ret = self.forward_common(
coord, atype, box, do_atomic_virial=do_atomic_virial
coord,
atype,
box,
fparam=fparam,
aparam=aparam,
do_atomic_virial=do_atomic_virial,
)
if self.fitting_net is not None:
model_predict = {}
Expand All @@ -135,13 +148,17 @@ def forward_lower(
extended_atype,
nlist,
mapping: Optional[torch.Tensor] = None,
fparam: Optional[torch.Tensor] = None,
aparam: Optional[torch.Tensor] = None,
do_atomic_virial: bool = False,
):
model_ret = self.forward_common_lower(
extended_coord,
extended_atype,
nlist,
mapping,
fparam=fparam,
aparam=aparam,
do_atomic_virial=do_atomic_virial,
)
if self.fitting_net is not None:
Expand Down
217 changes: 217 additions & 0 deletions deepmd/pt/model/model/make_hessian_model.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,217 @@
# SPDX-License-Identifier: LGPL-3.0-or-later
import copy
from typing import (

Check warning on line 3 in deepmd/pt/model/model/make_hessian_model.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/model/model/make_hessian_model.py#L2-L3

Added lines #L2 - L3 were not covered by tests
Dict,
List,
Optional,
Union,
)

import torch

Check warning on line 10 in deepmd/pt/model/model/make_hessian_model.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/model/model/make_hessian_model.py#L10

Added line #L10 was not covered by tests

from deepmd.dpmodel import (

Check warning on line 12 in deepmd/pt/model/model/make_hessian_model.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/model/model/make_hessian_model.py#L12

Added line #L12 was not covered by tests
get_hessian_name,
)


def make_hessian_model(T_Model):

Check warning on line 17 in deepmd/pt/model/model/make_hessian_model.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/model/model/make_hessian_model.py#L17

Added line #L17 was not covered by tests
"""Make a model that can compute Hessian.

LIMITATION: this model is not jitable due to the restrictions of torch jit script.

LIMITATION: only the hessian of `forward_common` is available.

Parameters
----------
T_Model
The model. Should provide the `forward_common` and `fitting_output_def` methods

Returns
-------
The model computes hessian.

"""

class CM(T_Model):
def __init__(

Check warning on line 36 in deepmd/pt/model/model/make_hessian_model.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/model/model/make_hessian_model.py#L35-L36

Added lines #L35 - L36 were not covered by tests
self,
*args,
**kwargs,
):
super().__init__(

Check warning on line 41 in deepmd/pt/model/model/make_hessian_model.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/model/model/make_hessian_model.py#L41

Added line #L41 was not covered by tests
*args,
**kwargs,
)
self.hess_fitting_def = copy.deepcopy(super().fitting_output_def())

Check warning on line 45 in deepmd/pt/model/model/make_hessian_model.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/model/model/make_hessian_model.py#L45

Added line #L45 was not covered by tests

def requires_hessian(

Check warning on line 47 in deepmd/pt/model/model/make_hessian_model.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/model/model/make_hessian_model.py#L47

Added line #L47 was not covered by tests
self,
keys: Union[str, List[str]],
):
"""Set which output variable(s) requires hessian."""
if isinstance(keys, str):
keys = [keys]
for kk in self.hess_fitting_def.keys():
if kk in keys:
self.hess_fitting_def[kk].r_hessian = True

Check warning on line 56 in deepmd/pt/model/model/make_hessian_model.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/model/model/make_hessian_model.py#L52-L56

Added lines #L52 - L56 were not covered by tests

def fitting_output_def(self):

Check warning on line 58 in deepmd/pt/model/model/make_hessian_model.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/model/model/make_hessian_model.py#L58

Added line #L58 was not covered by tests
"""Get the fitting output def."""
return self.hess_fitting_def

Check warning on line 60 in deepmd/pt/model/model/make_hessian_model.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/model/model/make_hessian_model.py#L60

Added line #L60 was not covered by tests

def forward_common(

Check warning on line 62 in deepmd/pt/model/model/make_hessian_model.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/model/model/make_hessian_model.py#L62

Added line #L62 was not covered by tests
self,
coord,
atype,
box: Optional[torch.Tensor] = None,
fparam: Optional[torch.Tensor] = None,
aparam: Optional[torch.Tensor] = None,
do_atomic_virial: bool = False,
) -> Dict[str, torch.Tensor]:
"""Return model prediction.

Parameters
----------
coord
The coordinates of the atoms.
shape: nf x (nloc x 3)
atype
The type of atoms. shape: nf x nloc
box
The simulation box. shape: nf x 9
fparam
frame parameter. nf x ndf
aparam
atomic parameter. nf x nloc x nda
do_atomic_virial
If calculate the atomic virial.

Returns
-------
ret_dict
The result dict of type Dict[str,torch.Tensor].
The keys are defined by the `ModelOutputDef`.

"""
ret = super().forward_common(

Check warning on line 96 in deepmd/pt/model/model/make_hessian_model.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/model/model/make_hessian_model.py#L96

Added line #L96 was not covered by tests
coord,
atype,
box=box,
fparam=fparam,
aparam=aparam,
do_atomic_virial=do_atomic_virial,
)
vdef = self.fitting_output_def()
hess_yes = [vdef[kk].r_hessian for kk in vdef.keys()]
if any(hess_yes):
hess = self._cal_hessian_all(

Check warning on line 107 in deepmd/pt/model/model/make_hessian_model.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/model/model/make_hessian_model.py#L104-L107

Added lines #L104 - L107 were not covered by tests
coord,
atype,
box=box,
fparam=fparam,
aparam=aparam,
)
ret.update(hess)
return ret

Check warning on line 115 in deepmd/pt/model/model/make_hessian_model.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/model/model/make_hessian_model.py#L114-L115

Added lines #L114 - L115 were not covered by tests

def _cal_hessian_all(

Check warning on line 117 in deepmd/pt/model/model/make_hessian_model.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/model/model/make_hessian_model.py#L117

Added line #L117 was not covered by tests
self,
coord: torch.Tensor,
atype: torch.Tensor,
box: Optional[torch.Tensor] = None,
fparam: Optional[torch.Tensor] = None,
aparam: Optional[torch.Tensor] = None,
) -> Dict[str, torch.Tensor]:
nf, nloc = atype.shape
coord = coord.view([nf, (nloc * 3)])
box = box.view([nf, 9]) if box is not None else None
fparam = fparam.view([nf, -1]) if fparam is not None else None
aparam = aparam.view([nf, nloc, -1]) if aparam is not None else None
fdef = self.fitting_output_def()

Check warning on line 130 in deepmd/pt/model/model/make_hessian_model.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/model/model/make_hessian_model.py#L125-L130

Added lines #L125 - L130 were not covered by tests
# keys of values that require hessian
hess_keys: List[str] = []
for kk in fdef.keys():
if fdef[kk].r_hessian:
hess_keys.append(kk)

Check warning on line 135 in deepmd/pt/model/model/make_hessian_model.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/model/model/make_hessian_model.py#L132-L135

Added lines #L132 - L135 were not covered by tests
# result dict init by empty lists
res = {get_hessian_name(kk): [] for kk in hess_keys}

Check warning on line 137 in deepmd/pt/model/model/make_hessian_model.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/model/model/make_hessian_model.py#L137

Added line #L137 was not covered by tests
# loop over variable
for kk in hess_keys:
vdef = fdef[kk]
vshape = vdef.shape
vsize = 1
for ii in vshape:
vsize *= ii

Check warning on line 144 in deepmd/pt/model/model/make_hessian_model.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/model/model/make_hessian_model.py#L139-L144

Added lines #L139 - L144 were not covered by tests
wanghan-iapcm marked this conversation as resolved.
Show resolved Hide resolved
# loop over frames
for ii in range(nf):
icoord = coord[ii]
iatype = atype[ii]
ibox = box[ii] if box is not None else None
ifparam = fparam[ii] if fparam is not None else None
iaparam = aparam[ii] if aparam is not None else None

Check warning on line 151 in deepmd/pt/model/model/make_hessian_model.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/model/model/make_hessian_model.py#L146-L151

Added lines #L146 - L151 were not covered by tests
# loop over all components
for idx in range(vsize):
wanghan-iapcm marked this conversation as resolved.
Show resolved Hide resolved
hess = self._cal_hessian_one_component(

Check warning on line 154 in deepmd/pt/model/model/make_hessian_model.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/model/model/make_hessian_model.py#L153-L154

Added lines #L153 - L154 were not covered by tests
idx, icoord, iatype, ibox, ifparam, iaparam
)
res[get_hessian_name(kk)].append(hess)
res[get_hessian_name(kk)] = torch.stack(res[get_hessian_name(kk)]).view(

Check warning on line 158 in deepmd/pt/model/model/make_hessian_model.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/model/model/make_hessian_model.py#L157-L158

Added lines #L157 - L158 were not covered by tests
(nf, *vshape, nloc * 3, nloc * 3)
)
return res

Check warning on line 161 in deepmd/pt/model/model/make_hessian_model.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/model/model/make_hessian_model.py#L161

Added line #L161 was not covered by tests

def _cal_hessian_one_component(

Check warning on line 163 in deepmd/pt/model/model/make_hessian_model.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/model/model/make_hessian_model.py#L163

Added line #L163 was not covered by tests
self,
ci,
coord,
atype,
box: Optional[torch.Tensor] = None,
fparam: Optional[torch.Tensor] = None,
aparam: Optional[torch.Tensor] = None,
) -> torch.Tensor:
# coord, # (nloc x 3)
# atype, # nloc
# box: Optional[torch.Tensor] = None, # 9
# fparam: Optional[torch.Tensor] = None, # nfp
# aparam: Optional[torch.Tensor] = None, # (nloc x nap)
wc = wrapper_class_forward_energy(self, ci, atype, box, fparam, aparam)

Check warning on line 177 in deepmd/pt/model/model/make_hessian_model.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/model/model/make_hessian_model.py#L177

Added line #L177 was not covered by tests

hess = torch.autograd.functional.hessian(

Check warning on line 179 in deepmd/pt/model/model/make_hessian_model.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/model/model/make_hessian_model.py#L179

Added line #L179 was not covered by tests
wc,
coord,
create_graph=False,
)
return hess

Check warning on line 184 in deepmd/pt/model/model/make_hessian_model.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/model/model/make_hessian_model.py#L184

Added line #L184 was not covered by tests

class wrapper_class_forward_energy:
def __init__(

Check warning on line 187 in deepmd/pt/model/model/make_hessian_model.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/model/model/make_hessian_model.py#L186-L187

Added lines #L186 - L187 were not covered by tests
self,
obj: CM,
ci: int,
atype: torch.Tensor,
box: Optional[torch.Tensor],
fparam: Optional[torch.Tensor],
aparam: Optional[torch.Tensor],
):
self.atype, self.box, self.fparam, self.aparam = atype, box, fparam, aparam
self.ci = ci
self.obj = obj

Check warning on line 198 in deepmd/pt/model/model/make_hessian_model.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/model/model/make_hessian_model.py#L196-L198

Added lines #L196 - L198 were not covered by tests

def __call__(

Check warning on line 200 in deepmd/pt/model/model/make_hessian_model.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/model/model/make_hessian_model.py#L200

Added line #L200 was not covered by tests
self,
xx,
):
ci = self.ci
atype, box, fparam, aparam = self.atype, self.box, self.fparam, self.aparam
res = super(CM, self.obj).forward_common(

Check warning on line 206 in deepmd/pt/model/model/make_hessian_model.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/model/model/make_hessian_model.py#L204-L206

Added lines #L204 - L206 were not covered by tests
xx.unsqueeze(0),
atype.unsqueeze(0),
box.unsqueeze(0) if box is not None else None,
fparam.unsqueeze(0) if fparam is not None else None,
aparam.unsqueeze(0) if aparam is not None else None,
do_atomic_virial=False,
)
er = res["energy_redu"][0].view([-1])[ci]
return er

Check warning on line 215 in deepmd/pt/model/model/make_hessian_model.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/model/model/make_hessian_model.py#L214-L215

Added lines #L214 - L215 were not covered by tests

return CM

Check warning on line 217 in deepmd/pt/model/model/make_hessian_model.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/model/model/make_hessian_model.py#L217

Added line #L217 was not covered by tests
8 changes: 8 additions & 0 deletions deepmd/pt/model/model/make_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,10 @@ def forward_common(
The type of atoms. shape: nf x nloc
box
The simulation box. shape: nf x 9
fparam
frame parameter. nf x ndf
aparam
atomic parameter. nf x nloc x nda
do_atomic_virial
If calculate the atomic virial.

Expand Down Expand Up @@ -155,6 +159,10 @@ def forward_common_lower(
neighbor list. nf x nloc x nsel.
mapping
mapps the extended indices to local indices. nf x nall.
fparam
frame parameter. nf x ndf
aparam
atomic parameter. nf x nloc x nda
do_atomic_virial
whether calculate atomic virial.

Expand Down
7 changes: 4 additions & 3 deletions deepmd/pt/model/model/transform_output.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,10 +128,11 @@
assert aviri is not None
aviri = aviri.unsqueeze(-2)
split_avir.append(aviri)
# nf x nloc x v_dim x 3, nf x nloc x v_dim x 9
ff = torch.concat(split_ff, dim=-2)
# nf x nall x v_dim x 3, nf x nall x v_dim x 9
out_lead_shape = list(coord_ext.shape[:-1]) + vdef.shape
ff = torch.concat(split_ff, dim=-2).view(out_lead_shape + [3]) # noqa: RUF005

Check warning on line 133 in deepmd/pt/model/model/transform_output.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/model/model/transform_output.py#L132-L133

Added lines #L132 - L133 were not covered by tests
if do_virial:
avir = torch.concat(split_avir, dim=-2)
avir = torch.concat(split_avir, dim=-2).view(out_lead_shape + [9]) # noqa: RUF005

Check warning on line 135 in deepmd/pt/model/model/transform_output.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/model/model/transform_output.py#L135

Added line #L135 was not covered by tests
else:
avir = None
return ff, avir
Expand Down
Loading