Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

compute output stat for atomic model #3642

Merged
merged 13 commits into from
Apr 7, 2024
6 changes: 6 additions & 0 deletions deepmd/dpmodel/atomic_model/base_atomic_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,13 +27,19 @@
class BaseAtomicModel(BaseAtomicModel_):
def __init__(
self,
type_map: List[str],
atom_exclude_types: List[int] = [],
pair_exclude_types: List[Tuple[int, int]] = [],
):
super().__init__()
self.type_map = type_map

Check warning

Code scanning / CodeQL

Overwriting attribute in super-class or sub-class Warning

Assignment overwrites attribute type_map, which was previously defined in subclass
DPAtomicModel
.
Assignment overwrites attribute type_map, which was previously defined in subclass
DPAtomicModel
.
Assignment overwrites attribute type_map, which was previously defined in subclass
LinearEnergyAtomicModel
.
self.reinit_atom_exclude(atom_exclude_types)
self.reinit_pair_exclude(pair_exclude_types)

def get_type_map(self) -> List[str]:
"""Get the type map."""
return self.type_map

def reinit_atom_exclude(
self,
exclude_types: List[int] = [],
Expand Down
6 changes: 1 addition & 5 deletions deepmd/dpmodel/atomic_model/dp_atomic_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ def __init__(
self.descriptor = descriptor
self.fitting = fitting
self.type_map = type_map
super().__init__(**kwargs)
super().__init__(type_map, **kwargs)

def fitting_output_def(self) -> FittingOutputDef:
"""Get the output def of the fitting net."""
Expand All @@ -67,10 +67,6 @@ def get_sel(self) -> List[int]:
"""Get the neighbor selection."""
return self.descriptor.get_sel()

def get_type_map(self) -> List[str]:
"""Get the type map."""
return self.type_map

def mixed_types(self) -> bool:
"""If true, the model
1. assumes total number of atoms aligned across frames;
Expand Down
2 changes: 1 addition & 1 deletion deepmd/dpmodel/atomic_model/linear_atomic_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,7 @@ def __init__(
self.mapping_list.append(self.remap_atype(tpmp, self.type_map))
assert len(err_msg) == 0, "\n".join(err_msg)
self.mixed_types_list = [model.mixed_types() for model in self.models]
super().__init__(**kwargs)
super().__init__(type_map, **kwargs)

def mixed_types(self) -> bool:
"""If true, the model
Expand Down
3 changes: 0 additions & 3 deletions deepmd/dpmodel/atomic_model/make_base_atomic_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,9 +51,6 @@ def atomic_output_def(self) -> FittingOutputDef:
"""
return self.fitting_output_def()

def get_output_keys(self) -> List[str]:
return list(self.atomic_output_def().keys())

@abstractmethod
def get_rcut(self) -> float:
"""Get the cut-off radius."""
Expand Down
6 changes: 5 additions & 1 deletion deepmd/dpmodel/atomic_model/pairtab_atomic_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,16 +59,20 @@
rcut: float,
sel: Union[int, List[int]],
type_map: List[str],
rcond: Optional[float] = None,
atom_ener: Optional[List[float]] = None,
**kwargs,
):
super().__init__()
super().__init__(type_map, **kwargs)
self.tab_file = tab_file
self.rcut = rcut
self.type_map = type_map

Check warning

Code scanning / CodeQL

Overwriting attribute in super-class or sub-class Warning

Assignment overwrites attribute type_map, which was previously defined in superclass
BaseAtomicModel
.

self.tab = PairTab(self.tab_file, rcut=rcut)
self.type_map = type_map

Check warning

Code scanning / CodeQL

Overwriting attribute in super-class or sub-class Warning

Assignment overwrites attribute type_map, which was previously defined in superclass
BaseAtomicModel
.
self.ntypes = len(type_map)
self.rcond = rcond
self.atom_ener = atom_ener

if self.tab_file is not None:
self.tab_info, self.tab_data = self.tab.get()
Expand Down
4 changes: 4 additions & 0 deletions deepmd/dpmodel/output_def.py
Original file line number Diff line number Diff line change
Expand Up @@ -224,6 +224,10 @@ def __init__(
if not self.r_differentiable:
raise ValueError("only r_differentiable variable can calculate hessian")

@property
def size(self):
return self.output_size


class FittingOutputDef:
"""Defines the shapes and other properties of the fitting network outputs.
Expand Down
186 changes: 176 additions & 10 deletions deepmd/pt/model/atomic_model/base_atomic_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
from deepmd.pt.utils import (
AtomExcludeMask,
PairExcludeMask,
env,
)
from deepmd.pt.utils.nlist import (
extend_input_and_build_neighbor_list,
Expand All @@ -35,19 +36,64 @@
)

log = logging.getLogger(__name__)
dtype = env.GLOBAL_PT_FLOAT_PRECISION
device = env.DEVICE

BaseAtomicModel_ = make_base_atomic_model(torch.Tensor)


class BaseAtomicModel(BaseAtomicModel_):
class BaseAtomicModel(torch.nn.Module, BaseAtomicModel_):
def __init__(
self,
type_map: List[str],
atom_exclude_types: List[int] = [],
pair_exclude_types: List[Tuple[int, int]] = [],
):
super().__init__()
torch.nn.Module.__init__(self)
BaseAtomicModel_.__init__(self)
self.type_map = type_map
self.reinit_atom_exclude(atom_exclude_types)
self.reinit_pair_exclude(pair_exclude_types)
self.rcond = None
self.atom_ener = None

def init_out_stat(self):
"""Initialize the output bias."""
ntypes = self.get_ntypes()
self.bias_keys: List[str] = list(self.fitting_output_def().keys())
self.max_out_size = max(
[self.atomic_output_def()[kk].size for kk in self.bias_keys]
)
self.n_out = len(self.bias_keys)
out_bias_data = torch.zeros(
[self.n_out, ntypes, self.max_out_size], dtype=dtype, device=device
)
out_std_data = torch.ones(
[self.n_out, ntypes, self.max_out_size], dtype=dtype, device=device
)
self.register_buffer("out_bias", out_bias_data)
self.register_buffer("out_std", out_std_data)

def __setitem__(self, key, value):
if key in ["out_bias"]:
self.out_bias = value
elif key in ["out_std"]:
self.out_std = value

Check warning on line 81 in deepmd/pt/model/atomic_model/base_atomic_model.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/model/atomic_model/base_atomic_model.py#L80-L81

Added lines #L80 - L81 were not covered by tests
else:
raise KeyError(key)

Check warning on line 83 in deepmd/pt/model/atomic_model/base_atomic_model.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/model/atomic_model/base_atomic_model.py#L83

Added line #L83 was not covered by tests

def __getitem__(self, key):
if key in ["out_bias"]:
return self.out_bias
elif key in ["out_std"]:
return self.out_std

Check warning on line 89 in deepmd/pt/model/atomic_model/base_atomic_model.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/model/atomic_model/base_atomic_model.py#L88-L89

Added lines #L88 - L89 were not covered by tests
else:
raise KeyError(key)

Check warning on line 92 in deepmd/pt/model/atomic_model/base_atomic_model.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/model/atomic_model/base_atomic_model.py#L91-L92

Added lines #L91 - L92 were not covered by tests
@torch.jit.export
def get_type_map(self) -> List[str]:
"""Get the type map."""
return self.type_map

def reinit_atom_exclude(
self,
Expand Down Expand Up @@ -165,6 +211,7 @@
fparam=fparam,
aparam=aparam,
)
ret_dict = self.apply_out_bias(ret_dict, atype)

# nf x nloc
atom_mask = ext_atom_mask[:, :nloc].to(torch.int32)
Expand Down Expand Up @@ -210,9 +257,60 @@
"""
raise NotImplementedError

def compute_or_load_out_stat(
self,
merged: Union[Callable[[], List[dict]], List[dict]],
stat_file_path: Optional[DPPath] = None,
):
"""
Compute the output statistics (e.g. energy bias) for the fitting net from packed data.

Parameters
----------
merged : Union[Callable[[], List[dict]], List[dict]]
- List[dict]: A list of data samples from various data systems.
Each element, `merged[i]`, is a data dictionary containing `keys`: `torch.Tensor`
originating from the `i`-th data system.
- Callable[[], List[dict]]: A lazy function that returns data samples in the above format
only when needed. Since the sampling process can be slow and memory-intensive,
the lazy function helps by only sampling once.
stat_file_path : Optional[DPPath]
The path to the stat file.

"""
self.change_out_bias(
merged,
stat_file_path=stat_file_path,
bias_adjust_mode="set-by-statistic",
)

def apply_out_bias(
self,
ret: Dict[str, torch.Tensor],
atype: torch.Tensor,
):
"""Apply the bias to each atomic output.
The developer may override the method to define how the bias is applied
to the atomic output of the model.

Parameters
----------
ret
The returned dict by the forward_atomic method
atype
The atom types. nf x nloc

"""
out_bias, out_std = self._fetch_out_stat(self.bias_keys)
for kk in self.bias_keys:
# nf x nloc x odims, out_bias: ntypes x odims
ret[kk] = ret[kk] + out_bias[kk][atype]
return ret

def change_out_bias(
self,
sample_merged,
stat_file_path: Optional[DPPath] = None,
iProzd marked this conversation as resolved.
Show resolved Hide resolved
bias_adjust_mode="change-by-statistic",
) -> None:
"""Change the output bias according to the input data and the pretrained model.
Expand All @@ -233,20 +331,28 @@
'set-by-statistic' : directly use the statistic output bias in the target dataset.
"""
if bias_adjust_mode == "change-by-statistic":
delta_bias = compute_output_stats(
delta_bias, out_std = compute_output_stats(
sample_merged,
self.get_ntypes(),
keys=self.get_output_keys(),
keys=list(self.atomic_output_def().keys()),
stat_file_path=stat_file_path,
model_forward=self._get_forward_wrapper_func(),
)["energy"]
self.set_out_bias(delta_bias, add=True)
rcond=self.rcond,
atom_ener=self.atom_ener,
)
# self.set_out_bias(delta_bias, add=True)
self._store_out_stat(delta_bias, out_std, add=True)
elif bias_adjust_mode == "set-by-statistic":
bias_atom = compute_output_stats(
bias_out, std_out = compute_output_stats(
sample_merged,
self.get_ntypes(),
keys=self.get_output_keys(),
)["energy"]
self.set_out_bias(bias_atom)
keys=list(self.atomic_output_def().keys()),
stat_file_path=stat_file_path,
rcond=self.rcond,
atom_ener=self.atom_ener,
)
# self.set_out_bias(bias_out)
self._store_out_stat(bias_out, std_out)
else:
raise RuntimeError("Unknown bias_adjust_mode mode: " + bias_adjust_mode)

Expand Down Expand Up @@ -279,3 +385,63 @@
return {kk: vv.detach() for kk, vv in atomic_ret.items()}

return model_forward

def _varsize(
self,
shape: List[int],
) -> int:
output_size = 1
len_shape = len(shape)
for i in range(len_shape):
output_size *= shape[i]
return output_size

def _get_bias_index(
self,
kk: str,
) -> int:
res: List[int] = []
for i, e in enumerate(self.bias_keys):
if e == kk:
res.append(i)
assert len(res) == 1
return res[0]

def _store_out_stat(
self,
out_bias: Dict[str, torch.Tensor],
out_std: Dict[str, torch.Tensor],
add: bool = False,
):
ntypes = self.get_ntypes()
out_bias_data = torch.clone(self.out_bias)
out_std_data = torch.clone(self.out_std)
for kk in out_bias.keys():
assert kk in out_std.keys()
idx = self._get_bias_index(kk)
size = self._varsize(self.atomic_output_def()[kk].shape)
if not add:
out_bias_data[idx, :, :size] = out_bias[kk].view(ntypes, size)
else:
out_bias_data[idx, :, :size] += out_bias[kk].view(ntypes, size)
out_std_data[idx, :, :size] = out_std[kk].view(ntypes, size)
self.out_bias.copy_(out_bias_data)
self.out_std.copy_(out_std_data)

def _fetch_out_stat(
self,
keys: List[str],
) -> Tuple[Dict[str, torch.Tensor], Dict[str, torch.Tensor]]:
ret_bias = {}
ret_std = {}
ntypes = self.get_ntypes()
for kk in keys:
idx = self._get_bias_index(kk)
isize = self._varsize(self.atomic_output_def()[kk].shape)
ret_bias[kk] = self.out_bias[idx, :, :isize].view(
[ntypes] + list(self.atomic_output_def()[kk].shape) # noqa: RUF005
)
ret_std[kk] = self.out_std[idx, :, :isize].view(
[ntypes] + list(self.atomic_output_def()[kk].shape) # noqa: RUF005
)
return ret_bias, ret_std
16 changes: 5 additions & 11 deletions deepmd/pt/model/atomic_model/dp_atomic_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@


@BaseAtomicModel.register("standard")
class DPAtomicModel(torch.nn.Module, BaseAtomicModel):
class DPAtomicModel(BaseAtomicModel):
"""Model give atomic prediction of some physical property.

Parameters
Expand All @@ -55,17 +55,17 @@
type_map: List[str],
**kwargs,
):
torch.nn.Module.__init__(self)
super().__init__(type_map, **kwargs)
ntypes = len(type_map)
self.type_map = type_map

Check warning

Code scanning / CodeQL

Overwriting attribute in super-class or sub-class Warning

Assignment overwrites attribute type_map, which was previously defined in superclass
BaseAtomicModel
.
self.ntypes = ntypes
self.descriptor = descriptor
self.rcut = self.descriptor.get_rcut()
self.sel = self.descriptor.get_sel()
self.fitting_net = fitting
# order matters ntypes and type_map should be initialized first.
BaseAtomicModel.__init__(self, **kwargs)
super().init_out_stat()

@torch.jit.export
def fitting_output_def(self) -> FittingOutputDef:
"""Get the output def of the fitting net."""
return (
Expand All @@ -79,11 +79,6 @@
"""Get the cut-off radius."""
return self.rcut

@torch.jit.export
def get_type_map(self) -> List[str]:
"""Get the type map."""
return self.type_map

def get_sel(self) -> List[int]:
"""Get the neighbor selection."""
return self.sel
Expand Down Expand Up @@ -220,8 +215,7 @@
return sampled

self.descriptor.compute_input_stats(wrapped_sampler, stat_file_path)
if self.fitting_net is not None:
self.fitting_net.compute_output_stats(wrapped_sampler, stat_file_path)
self.compute_or_load_out_stat(wrapped_sampler, stat_file_path)

def set_out_bias(self, out_bias: torch.Tensor, add=False) -> None:
"""
Expand Down
Loading