diff --git a/deepmd/pt/model/atomic_model/base_atomic_model.py b/deepmd/pt/model/atomic_model/base_atomic_model.py index fa30655f8a..129b8dc11d 100644 --- a/deepmd/pt/model/atomic_model/base_atomic_model.py +++ b/deepmd/pt/model/atomic_model/base_atomic_model.py @@ -192,10 +192,6 @@ def serialize(self) -> dict: def get_forward_wrapper_func(self) -> Callable[..., torch.Tensor]: """Get a forward wrapper of the atomic model for output bias calculation.""" - model_output_type = list(self.atomic_output_def().keys()) - if "mask" in model_output_type: - model_output_type.pop(model_output_type.index("mask")) - out_name = model_output_type[0] def model_forward(coord, atype, box, fparam=None, aparam=None): with torch.no_grad(): # it's essential for pure torch forward function to use auto_batchsize @@ -220,7 +216,7 @@ def model_forward(coord, atype, box, fparam=None, aparam=None): fparam=fparam, aparam=aparam, ) - return atomic_ret[out_name].detach() + return {kk: vv.detach() for kk, vv in atomic_ret.items()} return model_forward @@ -287,14 +283,16 @@ def change_out_bias( delta_bias = compute_output_stats( merged, self.get_ntypes(), + keys=["energy"], model_forward=self.get_forward_wrapper_func(), - ) + )["energy"] self.set_out_bias(delta_bias, add=True) elif bias_adjust_mode == "set-by-statistic": bias_atom = compute_output_stats( merged, self.get_ntypes(), - ) + keys=["energy"], + )["energy"] self.set_out_bias(bias_atom) else: raise RuntimeError("Unknown bias_adjust_mode mode: " + bias_adjust_mode) diff --git a/deepmd/pt/model/atomic_model/pairtab_atomic_model.py b/deepmd/pt/model/atomic_model/pairtab_atomic_model.py index c20abf6a12..4db77790e9 100644 --- a/deepmd/pt/model/atomic_model/pairtab_atomic_model.py +++ b/deepmd/pt/model/atomic_model/pairtab_atomic_model.py @@ -228,8 +228,13 @@ def compute_or_load_stat( """ bias_atom_e = compute_output_stats( - merged, self.ntypes, stat_file_path, self.rcond, self.atom_ener - ) + merged, + self.ntypes, + keys=["energy"], + stat_file_path=stat_file_path, + rcond=self.rcond, + atom_ener=self.atom_ener, + )["energy"] self.bias_atom_e.copy_( torch.tensor(bias_atom_e, device=env.DEVICE).view([self.ntypes, 1]) ) diff --git a/deepmd/pt/model/task/invar_fitting.py b/deepmd/pt/model/task/invar_fitting.py index afb1d73658..585f697193 100644 --- a/deepmd/pt/model/task/invar_fitting.py +++ b/deepmd/pt/model/task/invar_fitting.py @@ -165,8 +165,13 @@ def compute_output_stats( """ bias_atom_e = compute_output_stats( - merged, self.ntypes, stat_file_path, self.rcond, self.atom_ener - ) + merged, + self.ntypes, + keys=["energy"], + stat_file_path=stat_file_path, + rcond=self.rcond, + atom_ener=self.atom_ener, + )["energy"] self.bias_atom_e.copy_(bias_atom_e.view([self.ntypes, self.dim_out])) def output_def(self) -> FittingOutputDef: diff --git a/deepmd/pt/utils/stat.py b/deepmd/pt/utils/stat.py index 328ec30908..bf5645f02f 100644 --- a/deepmd/pt/utils/stat.py +++ b/deepmd/pt/utils/stat.py @@ -1,5 +1,8 @@ # SPDX-License-Identifier: LGPL-3.0-or-later import logging +from pathlib import ( + Path, +) from typing import ( Callable, List, @@ -78,9 +81,39 @@ def make_stat_input(datasets, dataloaders, nbatches): return lst +def restore_from_file( + stat_file_path: Path, + keys: List[str] = ["energy"], +) -> Optional[dict]: + if stat_file_path is None: + return None + stat_files = [stat_file_path / f"bias_atom_{kk}.npy" for kk in keys] + if any(not (ii.is_file()) for ii in stat_files): + return None + ret = {} + + for kk in keys: + fp = stat_file_path / f"bias_atom_{kk}.npy" + assert fp.is_file() + ret[kk] = np.load(fp) + return ret + + +def save_to_file( + stat_file_path: Path, + results: dict, +): + assert stat_file_path is not None + stat_file_path.mkdir(exist_ok=True, parents=True) + for kk, vv in results.items(): + fp = stat_file_path / f"bias_atom_{kk}.npy" + np.save(fp, vv) + + def compute_output_stats( merged: Union[Callable[[], List[dict]], List[dict]], ntypes: int, + keys: List[str] = ["energy"], stat_file_path: Optional[DPPath] = None, rcond: Optional[float] = None, atom_ener: Optional[List[float]] = None, @@ -112,17 +145,15 @@ def compute_output_stats( which will be subtracted from the energy label of the data. The difference will then be used to calculate the delta complement energy bias for each type. """ - if stat_file_path is not None: - stat_file_path = stat_file_path / "bias_atom_e" - if stat_file_path is not None and stat_file_path.is_file(): - bias_atom_e = stat_file_path.load_numpy() - else: + bias_atom_e = restore_from_file(stat_file_path, keys) + + if bias_atom_e is None: if callable(merged): # only get data for once sampled = merged() else: sampled = merged - energy = [item["energy"] for item in sampled] + outputs = {kk: [item[kk] for item in sampled] for kk in keys} data_mixed_type = "real_natoms_vec" in sampled[0] natoms_key = "natoms" if not data_mixed_type else "real_natoms_vec" for system in sampled: @@ -133,7 +164,7 @@ def compute_output_stats( system[natoms_key][:, 2:] *= type_mask.unsqueeze(0) input_natoms = [item[natoms_key] for item in sampled] # shape: (nframes, ndim) - merged_energy = to_numpy_array(torch.cat(energy)) + merged_output = {kk: to_numpy_array(torch.cat(outputs[kk])) for kk in keys} # shape: (nframes, ntypes) merged_natoms = to_numpy_array(torch.cat(input_natoms)[:, 2:]) if atom_ener is not None and len(atom_ener) > 0: @@ -144,16 +175,20 @@ def compute_output_stats( assigned_atom_ener = None if model_forward is None: # only use statistics result - bias_atom_e, _ = compute_stats_from_redu( - merged_energy, - merged_natoms, - assigned_bias=assigned_atom_ener, - rcond=rcond, - ) + # [0]: take the first otuput (mean) of compute_stats_from_redu + bias_atom_e = { + kk: compute_stats_from_redu( + merged_output[kk], + merged_natoms, + assigned_bias=assigned_atom_ener, + rcond=rcond, + )[0] + for kk in keys + } else: # subtract the model bias and output the delta bias auto_batch_size = AutoBatchSize() - energy_predict = [] + model_predict = {kk: [] for kk in keys} for system in sampled: nframes = system["coord"].shape[0] coord, atype, box, natoms = ( @@ -174,34 +209,49 @@ def model_forward_auto_batch_size(*args, **kwargs): **kwargs, ) - energy = ( - model_forward_auto_batch_size( - coord, atype, box, fparam=fparam, aparam=aparam - ) - .reshape(nframes, -1) - .sum(-1) + sample_predict = model_forward_auto_batch_size( + coord, atype, box, fparam=fparam, aparam=aparam ) - energy_predict.append(to_numpy_array(energy).reshape([nframes, 1])) - - energy_predict = np.concatenate(energy_predict) - bias_diff = merged_energy - energy_predict - bias_atom_e, _ = compute_stats_from_redu( - bias_diff, - merged_natoms, - assigned_bias=assigned_atom_ener, - rcond=rcond, - ) - unbias_e = energy_predict + merged_natoms @ bias_atom_e + + for kk in keys: + model_predict[kk].append( + to_numpy_array( + torch.sum(sample_predict[kk], dim=1) # nf x nloc x odims + ) + ) + + model_predict = {kk: np.concatenate(model_predict[kk]) for kk in keys} + + bias_diff = {kk: merged_output[kk] - model_predict[kk] for kk in keys} + bias_atom_e = { + kk: compute_stats_from_redu( + bias_diff[kk], + merged_natoms, + assigned_bias=assigned_atom_ener, + rcond=rcond, + )[0] + for kk in keys + } + unbias_e = { + kk: model_predict[kk] + merged_natoms @ bias_atom_e[kk] for kk in keys + } atom_numbs = merged_natoms.sum(-1) - rmse_ae = np.sqrt( - np.mean( - np.square((unbias_e.ravel() - merged_energy.ravel()) / atom_numbs) + for kk in keys: + rmse_ae = np.sqrt( + np.mean( + np.square( + (unbias_e[kk].ravel() - merged_output[kk].ravel()) + / atom_numbs + ) + ) ) - ) - log.info( - f"RMSE of energy per atom after linear regression is: {rmse_ae} eV/atom." - ) + log.info( + f"RMSE of {kk} per atom after linear regression is: {rmse_ae} in the unit of {kk}." + ) + if stat_file_path is not None: - stat_file_path.save_numpy(bias_atom_e) - assert all(x is not None for x in [bias_atom_e]) - return to_torch_tensor(bias_atom_e) + save_to_file(stat_file_path, bias_atom_e) + + ret = {kk: to_torch_tensor(bias_atom_e[kk]) for kk in keys} + + return ret diff --git a/source/tests/pt/test_stat.py b/source/tests/pt/test_stat.py index 51ca903bc2..7d23ca5920 100644 --- a/source/tests/pt/test_stat.py +++ b/source/tests/pt/test_stat.py @@ -1,6 +1,7 @@ # SPDX-License-Identifier: LGPL-3.0-or-later import json import os +import shutil import unittest from abc import ( ABC, @@ -29,7 +30,14 @@ from deepmd.pt.utils.dataloader import ( DpLoaderSet, ) +from deepmd.pt.utils.stat import ( + compute_output_stats, +) +from deepmd.pt.utils.stat import make_stat_input from deepmd.pt.utils.stat import make_stat_input as my_make +from deepmd.pt.utils.utils import ( + to_numpy_array, +) from deepmd.tf.common import ( expand_sys_str, ) @@ -325,5 +333,83 @@ def tf_compute_input_stats(self): ) +class TestOutputStat(unittest.TestCase): + def test(self): + self.data_file = [str(Path(__file__).parent / "water/data/data_0")] + type_map = ["O", "H"] # by dataset + self.data = DpLoaderSet( + self.data_file, + batch_size=1, + type_map=type_map, + ) + self.data.add_data_requirement(energy_data_requirement) + self.sampled = make_stat_input( + self.data.systems, + self.data.dataloaders, + nbatches=1, + ) + stat_file_path = Path("my_output_stat") + stat_file_path.mkdir(exist_ok=True) + atom_ener = np.array([3.0, 5.0]).reshape(2, 1) + + if stat_file_path.is_dir(): + shutil.rmtree(stat_file_path) + # compute from sample + ret0 = compute_output_stats( + self.sampled, + len(type_map), + keys=["energy"], + stat_file_path=stat_file_path, + atom_ener=None, + model_forward=None, + ) + # ground truth + ntest = 1 + atom_nums = np.tile( + np.bincount(to_numpy_array(self.sampled[0]["atype"][0])), + (ntest, 1), + ) + energy_diff = to_numpy_array(self.sampled[0]["energy"][:ntest]) + ground_truth_shift = np.linalg.lstsq(atom_nums, energy_diff, rcond=None)[0] + + # check values + np.testing.assert_almost_equal( + to_numpy_array(ret0["energy"]), ground_truth_shift, decimal=10 + ) + self.assertTrue(stat_file_path.is_dir()) + + def raise_error(): + raise RuntimeError + + # hack!!! + # suppose to load stat from file, if from sample, an error will raise. + ret1 = compute_output_stats( + raise_error, + len(type_map), + keys=["energy"], + stat_file_path=stat_file_path, + atom_ener=None, + model_forward=None, + ) + np.testing.assert_almost_equal( + to_numpy_array(ret0["energy"]), to_numpy_array(ret1["energy"]), decimal=10 + ) + shutil.rmtree(stat_file_path) + + # from assigned atom_ener + ret2 = compute_output_stats( + self.sampled, + len(type_map), + keys=["energy"], + stat_file_path=stat_file_path, + atom_ener=atom_ener, + model_forward=None, + ) + np.testing.assert_almost_equal( + to_numpy_array(ret2["energy"]), atom_ener, decimal=10 + ) + shutil.rmtree(stat_file_path) + + if __name__ == "__main__": unittest.main()