Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add definition for the output of fitting and model #3128

Merged
merged 9 commits into from
Jan 12, 2024
14 changes: 14 additions & 0 deletions deepmd_utils/model_format/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,14 @@
save_dp_model,
traverse_model_dict,
)
from .output_def import (
FittingOutputDef,
ModelOutputDef,
OutputVariableDef,
VariableDef,
fitting_check_output,
model_check_output,
)
from .se_e2_a import (
DescrptSeA,
)
Expand All @@ -31,4 +39,10 @@
"traverse_model_dict",
"PRECISION_DICT",
"DEFAULT_PRECISION",
"ModelOutputDef",
"FittingOutputDef",
"OutputVariableDef",
"VariableDef",
"model_check_output",
"fitting_check_output",
]
278 changes: 278 additions & 0 deletions deepmd_utils/model_format/output_def.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,278 @@
# SPDX-License-Identifier: LGPL-3.0-or-later
from typing import (
Dict,
List,
Tuple,
Union,
)


def check_var(var, var_def):
if var_def.atomic:
# var.shape == [nf, nloc, *var_def.shape]
if len(var.shape) != len(var_def.shape) + 2:
raise ValueError(f"{var.shape[2:]} length not matching def {var_def.shape}")
if list(var.shape[2:]) != var_def.shape:
raise ValueError(f"{var.shape[2:]} not matching def {var_def.shape}")
else:
# var.shape == [nf, *var_def.shape]
if len(var.shape) != len(var_def.shape) + 1:
raise ValueError(f"{var.shape[1:]} length not matching def {var_def.shape}")
if list(var.shape[1:]) != var_def.shape:
raise ValueError(f"{var.shape[1:]} not matching def {var_def.shape}")
njzjz marked this conversation as resolved.
Show resolved Hide resolved
njzjz marked this conversation as resolved.
Show resolved Hide resolved


def model_check_output(cls):
"""Check if the output of the Model is consistent with the definition.

Two methods are assumed to be provided by the Model:
1. Model.output_def that gives the output definition.
2. Model.forward that defines the forward path of the model.

"""

class wrapper(cls):
def __init__(
self,
*args,
**kwargs,
):
super().__init__(*args, **kwargs)
self.md = cls.output_def(self)

def forward(
self,
*args,
**kwargs,
):
ret = cls.forward(self, *args, **kwargs)
for kk in self.md.keys_outp():
dd = self.md[kk]
check_var(ret[kk], dd)
if dd.reduciable:
rk = get_reduce_name(kk)
check_var(ret[rk], self.md[rk])
if dd.differentiable:
dnr, dnc = get_deriv_name(kk)
check_var(ret[dnr], self.md[dnr])
check_var(ret[dnc], self.md[dnc])
return ret

return wrapper


def fitting_check_output(cls):
"""Check if the output of the Fitting is consistent with the definition.

Two methods are assumed to be provided by the Fitting:
1. Fitting.output_def that gives the output definition.
2. Fitting.forward defines the forward path of the fitting.

"""

class wrapper(cls):
def __init__(
self,
*args,
**kwargs,
):
super().__init__(*args, **kwargs)
self.md = cls.output_def(self)

def forward(
self,
*args,
**kwargs,
):
ret = cls.forward(self, *args, **kwargs)
for kk in self.md.keys():
dd = self.md[kk]
check_var(ret[kk], dd)
return ret

return wrapper
njzjz marked this conversation as resolved.
Show resolved Hide resolved
wanghan-iapcm marked this conversation as resolved.
Show resolved Hide resolved


class VariableDef:
"""Defines the shape and other properties of a variable.

Parameters
----------
name
Name of the output variable. Notice that the xxxx_redu,
xxxx_derv_c, xxxx_derv_r are reserved names that should
not be used to define variables.
shape
The shape of the variable. e.g. energy should be [1],
njzjz marked this conversation as resolved.
Show resolved Hide resolved
dipole should be [3], polarizabilty should be [3,3].
atomic
If the variable is defined for each atom.

"""

def __init__(
self,
name: str,
shape: Union[List[int], Tuple[int]],
atomic: bool = True,
):
self.name = name
self.shape = list(shape)
self.atomic = atomic


class OutputVariableDef(VariableDef):
"""Defines the shape and other properties of the one output variable.

It is assume that the fitting network output variables for each
local atom. This class defines one output variable, including its
name, shape, reducibility and differentiability.

Parameters
----------
name
Name of the output variable. Notice that the xxxx_redu,
xxxx_derv_c, xxxx_derv_r are reserved names that should
not be used to define variables.
shape
The shape of the variable. e.g. energy should be [1],
dipole should be [3], polarizabilty should be [3,3].
reduciable
If the variable is reduced.
differentiable
If the variable is differentiated with respect to coordinates
of atoms and cell tensor (pbc case). Only reduciable variable
are differentiable.

"""

def __init__(
self,
name: str,
shape: Union[List[int], Tuple[int]],
reduciable: bool = False,
differentiable: bool = False,
):
# fitting output must be atomic
super().__init__(name, shape, atomic=True)
self.reduciable = reduciable
self.differentiable = differentiable
if not self.reduciable and self.differentiable:
raise ValueError("only reduciable variable are differentiable")


class FittingOutputDef:
"""Defines the shapes and other properties of the fitting network outputs.

It is assume that the fitting network output variables for each
local atom. This class defines all the outputs.

Parameters
----------
var_defs
List of output variable definitions.

"""

def __init__(
self,
var_defs: List[OutputVariableDef] = [],
):
self.var_defs = {vv.name: vv for vv in var_defs}

def __getitem__(
self,
key,
) -> OutputVariableDef:
return self.var_defs[key]

def get_data(self) -> Dict[str, OutputVariableDef]:
return self.var_defs

def keys(self):
return self.var_defs.keys()


class ModelOutputDef:
"""Defines the shapes and other properties of the model outputs.

The model reduce and differentiate fitting outputs if applicable.
If a variable is named by foo, then the reduced variable is called
foo_redu, the derivative w.r.t. coordinates is called foo_derv_r
and the derivative w.r.t. cell is called foo_derv_c.

Parameters
----------
fit_defs
Definition for the fitting net output

"""

def __init__(
self,
fit_defs: FittingOutputDef,
):
self.def_outp = fit_defs
self.def_redu = do_reduce(self.def_outp)
self.def_derv_r, self.def_derv_c = do_derivative(self.def_outp)
self.var_defs = {}
for ii in [
self.def_outp.get_data(),
self.def_redu,
self.def_derv_c,
self.def_derv_r,
]:
self.var_defs.update(ii)

def __getitem__(self, key) -> VariableDef:
return self.var_defs[key]

def get_data(self, key) -> Dict[str, VariableDef]:
return self.var_defs

Check warning on line 231 in deepmd_utils/model_format/output_def.py

View check run for this annotation

Codecov / codecov/patch

deepmd_utils/model_format/output_def.py#L231

Added line #L231 was not covered by tests

def keys(self):
return self.var_defs.keys()

wanghan-iapcm marked this conversation as resolved.
Show resolved Hide resolved
def keys_outp(self):
return self.def_outp.keys()

def keys_redu(self):
return self.def_redu.keys()

Check warning on line 240 in deepmd_utils/model_format/output_def.py

View check run for this annotation

Codecov / codecov/patch

deepmd_utils/model_format/output_def.py#L240

Added line #L240 was not covered by tests

def keys_derv_r(self):
return self.def_derv_r.keys()

Check warning on line 243 in deepmd_utils/model_format/output_def.py

View check run for this annotation

Codecov / codecov/patch

deepmd_utils/model_format/output_def.py#L243

Added line #L243 was not covered by tests

def keys_derv_c(self):
return self.def_derv_c.keys()

Check warning on line 246 in deepmd_utils/model_format/output_def.py

View check run for this annotation

Codecov / codecov/patch

deepmd_utils/model_format/output_def.py#L246

Added line #L246 was not covered by tests


def get_reduce_name(name):
return name + "_redu"


def get_deriv_name(name):
return name + "_derv_r", name + "_derv_c"
njzjz marked this conversation as resolved.
Show resolved Hide resolved


def do_reduce(
def_outp,
):
def_redu = {}
for kk, vv in def_outp.get_data().items():
if vv.reduciable:
rk = get_reduce_name(kk)
def_redu[rk] = VariableDef(rk, vv.shape, atomic=False)
return def_redu


def do_derivative(
def_outp,
):
def_derv_r = {}
def_derv_c = {}
for kk, vv in def_outp.get_data().items():
if vv.differentiable:
rkr, rkc = get_deriv_name(kk)
def_derv_r[rkr] = VariableDef(rkr, [*vv.shape, 3], atomic=True)
def_derv_c[rkc] = VariableDef(rkc, [*vv.shape, 3, 3], atomic=False)
return def_derv_r, def_derv_c
wanghan-iapcm marked this conversation as resolved.
Show resolved Hide resolved
Loading
Loading