Skip to content

Commit

Permalink
refact: compute_output_stats and change_out_bias (#3639)
Browse files Browse the repository at this point in the history
-    clean up the interface of `change_out_bias`
-    refact `compute_output_stats`, so the code is more readable.

---------

Co-authored-by: Han Wang <[email protected]>
  • Loading branch information
wanghan-iapcm and Han Wang authored Apr 3, 2024
1 parent 4c546d0 commit 073f559
Show file tree
Hide file tree
Showing 8 changed files with 203 additions and 169 deletions.
3 changes: 3 additions & 0 deletions deepmd/dpmodel/atomic_model/make_base_atomic_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,9 @@ def atomic_output_def(self) -> FittingOutputDef:
"""
return self.fitting_output_def()

def get_output_keys(self) -> List[str]:
return list(self.atomic_output_def().keys())

@abstractmethod
def get_rcut(self) -> float:
"""Get the cut-off radius."""
Expand Down
123 changes: 50 additions & 73 deletions deepmd/pt/model/atomic_model/base_atomic_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,9 +8,9 @@
List,
Optional,
Tuple,
Union,
)

import numpy as np
import torch

from deepmd.dpmodel.atomic_model import (
Expand All @@ -30,9 +30,6 @@
from deepmd.pt.utils.stat import (
compute_output_stats,
)
from deepmd.pt.utils.utils import (
to_numpy_array,
)
from deepmd.utils.path import (
DPPath,
)
Expand Down Expand Up @@ -190,115 +187,95 @@ def serialize(self) -> dict:
"pair_exclude_types": self.pair_exclude_types,
}

def get_forward_wrapper_func(self) -> Callable[..., torch.Tensor]:
"""Get a forward wrapper of the atomic model for output bias calculation."""

def model_forward(coord, atype, box, fparam=None, aparam=None):
with torch.no_grad(): # it's essential for pure torch forward function to use auto_batchsize
(
extended_coord,
extended_atype,
mapping,
nlist,
) = extend_input_and_build_neighbor_list(
coord,
atype,
self.get_rcut(),
self.get_sel(),
mixed_types=self.mixed_types(),
box=box,
)
atomic_ret = self.forward_common_atomic(
extended_coord,
extended_atype,
nlist,
mapping=mapping,
fparam=fparam,
aparam=aparam,
)
return {kk: vv.detach() for kk, vv in atomic_ret.items()}

return model_forward

def compute_or_load_stat(
self,
sampled_func,
merged: Union[Callable[[], List[dict]], List[dict]],
stat_file_path: Optional[DPPath] = None,
):
"""
Compute or load the statistics parameters of the model,
such as mean and standard deviation of descriptors or the energy bias of the fitting net.
When `sampled` is provided, all the statistics parameters will be calculated (or re-calculated for update),
and saved in the `stat_file_path`(s).
When `sampled` is not provided, it will check the existence of `stat_file_path`(s)
and load the calculated statistics parameters.
Compute the output statistics (e.g. energy bias) for the fitting net from packed data.
Parameters
----------
sampled_func
The sampled data frames from different data systems.
stat_file_path
The path to the statistics files.
merged : Union[Callable[[], List[dict]], List[dict]]
- List[dict]: A list of data samples from various data systems.
Each element, `merged[i]`, is a data dictionary containing `keys`: `torch.Tensor`
originating from the `i`-th data system.
- Callable[[], List[dict]]: A lazy function that returns data samples in the above format
only when needed. Since the sampling process can be slow and memory-intensive,
the lazy function helps by only sampling once.
stat_file_path : Optional[DPPath]
The path to the stat file.
"""
raise NotImplementedError

def change_out_bias(
self,
merged,
origin_type_map,
full_type_map,
sample_merged,
bias_adjust_mode="change-by-statistic",
) -> None:
"""Change the output bias according to the input data and the pretrained model.
Parameters
----------
merged : Union[Callable[[], List[dict]], List[dict]]
sample_merged : Union[Callable[[], List[dict]], List[dict]]
- List[dict]: A list of data samples from various data systems.
Each element, `merged[i]`, is a data dictionary containing `keys`: `torch.Tensor`
originating from the `i`-th data system.
- Callable[[], List[dict]]: A lazy function that returns data samples in the above format
only when needed. Since the sampling process can be slow and memory-intensive,
the lazy function helps by only sampling once.
origin_type_map : List[str]
The original type_map in dataset, they are targets to change the output bias.
full_type_map : List[str]
The full type_map in pre-trained model
bias_adjust_mode : str
The mode for changing output bias : ['change-by-statistic', 'set-by-statistic']
'change-by-statistic' : perform predictions on labels of target dataset,
and do least square on the errors to obtain the target shift as bias.
'set-by-statistic' : directly use the statistic output bias in the target dataset.
"""
sorter = np.argsort(full_type_map)
missing_types = [t for t in origin_type_map if t not in full_type_map]
assert (
not missing_types
), f"Some types are not in the pre-trained model: {list(missing_types)} !"
idx_type_map = sorter[
np.searchsorted(full_type_map, origin_type_map, sorter=sorter)
]
original_bias = self.get_out_bias()
if bias_adjust_mode == "change-by-statistic":
delta_bias = compute_output_stats(
merged,
sample_merged,
self.get_ntypes(),
keys=["energy"],
model_forward=self.get_forward_wrapper_func(),
keys=self.get_output_keys(),
model_forward=self._get_forward_wrapper_func(),
)["energy"]
self.set_out_bias(delta_bias, add=True)
elif bias_adjust_mode == "set-by-statistic":
bias_atom = compute_output_stats(
merged,
sample_merged,
self.get_ntypes(),
keys=["energy"],
keys=self.get_output_keys(),
)["energy"]
self.set_out_bias(bias_atom)
else:
raise RuntimeError("Unknown bias_adjust_mode mode: " + bias_adjust_mode)
bias_atom = self.get_out_bias()
log.info(
f"Change output bias of {origin_type_map!s} "
f"from {to_numpy_array(original_bias[idx_type_map]).reshape(-1)!s} "
f"to {to_numpy_array(bias_atom[idx_type_map]).reshape(-1)!s}."
)

def _get_forward_wrapper_func(self) -> Callable[..., torch.Tensor]:
"""Get a forward wrapper of the atomic model for output bias calculation."""

def model_forward(coord, atype, box, fparam=None, aparam=None):
with torch.no_grad(): # it's essential for pure torch forward function to use auto_batchsize
(
extended_coord,
extended_atype,
mapping,
nlist,
) = extend_input_and_build_neighbor_list(
coord,
atype,
self.get_rcut(),
self.get_sel(),
mixed_types=self.mixed_types(),
box=box,
)
atomic_ret = self.forward_common_atomic(
extended_coord,
extended_atype,
nlist,
mapping=mapping,
fparam=fparam,
aparam=aparam,
)
return {kk: vv.detach() for kk, vv in atomic_ret.items()}

return model_forward
11 changes: 3 additions & 8 deletions deepmd/pt/model/model/make_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -172,11 +172,12 @@ def forward_common(
model_predict = self.output_type_cast(model_predict, input_prec)
return model_predict

def get_out_bias(self) -> torch.Tensor:
return self.atomic_model.get_out_bias()

def change_out_bias(
self,
merged,
origin_type_map,
full_type_map,
bias_adjust_mode="change-by-statistic",
) -> None:
"""Change the output bias of atomic model according to the input data and the pretrained model.
Expand All @@ -190,10 +191,6 @@ def change_out_bias(
- Callable[[], List[dict]]: A lazy function that returns data samples in the above format
only when needed. Since the sampling process can be slow and memory-intensive,
the lazy function helps by only sampling once.
origin_type_map : List[str]
The original type_map in dataset, they are targets to change the output bias.
full_type_map : List[str]
The full type_map in pre-trained model
bias_adjust_mode : str
The mode for changing output bias : ['change-by-statistic', 'set-by-statistic']
'change-by-statistic' : perform predictions on labels of target dataset,
Expand All @@ -202,8 +199,6 @@ def change_out_bias(
"""
self.atomic_model.change_out_bias(
merged,
origin_type_map,
full_type_map,
bias_adjust_mode=bias_adjust_mode,
)

Expand Down
4 changes: 2 additions & 2 deletions deepmd/pt/model/task/invar_fitting.py
Original file line number Diff line number Diff line change
Expand Up @@ -167,11 +167,11 @@ def compute_output_stats(
bias_atom_e = compute_output_stats(
merged,
self.ntypes,
keys=["energy"],
keys=[self.var_name],
stat_file_path=stat_file_path,
rcond=self.rcond,
atom_ener=self.atom_ener,
)["energy"]
)[self.var_name]
self.bias_atom_e.copy_(bias_atom_e.view([self.ntypes, self.dim_out]))

def output_def(self) -> FittingOutputDef:
Expand Down
42 changes: 33 additions & 9 deletions deepmd/pt/train/training.py
Original file line number Diff line number Diff line change
Expand Up @@ -570,21 +570,17 @@ def single_model_finetune(
_model_params["new_type_map"],
)
if isinstance(_model, EnergyModel):
_model.change_out_bias(
_sample_func,
bias_adjust_mode=_model_params.get(
"bias_adjust_mode", "change-by-statistic"
),
origin_type_map=new_type_map,
full_type_map=old_type_map,
_model = _model_change_out_bias(
_model, new_type_map, _sample_func, _model_params
)
else:
# need to updated
pass
return _model

# finetune
if not self.multi_task:
single_model_finetune(
self.model = single_model_finetune(
self.model, model_params, self.get_sample_func
)
else:
Expand All @@ -593,7 +589,7 @@ def single_model_finetune(
log.info(
f"Model branch {model_key} will be fine-tuned. This may take a long time..."
)
single_model_finetune(
self.model[model_key] = single_model_finetune(
self.model[model_key],
model_params["model_dict"][model_key],
self.get_sample_func[model_key],
Expand Down Expand Up @@ -1148,3 +1144,31 @@ def print_on_training(self, fout, step_id, cur_lr, train_results, valid_results)
print_str += " %8.1e\n" % cur_lr
fout.write(print_str)
fout.flush()


def _model_change_out_bias(
_model,
new_type_map,
_sample_func,
_model_params,
):
old_bias = _model.get_out_bias()
_model.change_out_bias(
_sample_func,
bias_adjust_mode=_model_params.get("bias_adjust_mode", "change-by-statistic"),
)
new_bias = _model.get_out_bias()

model_type_map = _model.get_type_map()
sorter = np.argsort(model_type_map)
missing_types = [t for t in new_type_map if t not in model_type_map]
assert (
not missing_types
), f"Some types are not in the pre-trained model: {list(missing_types)} !"
idx_type_map = sorter[np.searchsorted(model_type_map, new_type_map, sorter=sorter)]
log.info(
f"Change output bias of {new_type_map!s} "
f"from {to_numpy_array(old_bias[idx_type_map]).reshape(-1)!s} "
f"to {to_numpy_array(new_bias[idx_type_map]).reshape(-1)!s}."
)
return _model
Loading

0 comments on commit 073f559

Please sign in to comment.