Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: compute output stat for a dict of labels. #3628

Merged
merged 5 commits into from
Apr 1, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 5 additions & 7 deletions deepmd/pt/model/atomic_model/base_atomic_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -192,10 +192,6 @@

def get_forward_wrapper_func(self) -> Callable[..., torch.Tensor]:
"""Get a forward wrapper of the atomic model for output bias calculation."""
model_output_type = list(self.atomic_output_def().keys())
if "mask" in model_output_type:
model_output_type.pop(model_output_type.index("mask"))
out_name = model_output_type[0]

def model_forward(coord, atype, box, fparam=None, aparam=None):
with torch.no_grad(): # it's essential for pure torch forward function to use auto_batchsize
Expand All @@ -220,7 +216,7 @@
fparam=fparam,
aparam=aparam,
)
return atomic_ret[out_name].detach()
return {kk: vv.detach() for kk, vv in atomic_ret.items()}

Check warning on line 219 in deepmd/pt/model/atomic_model/base_atomic_model.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/model/atomic_model/base_atomic_model.py#L219

Added line #L219 was not covered by tests

return model_forward

Expand Down Expand Up @@ -287,14 +283,16 @@
delta_bias = compute_output_stats(
merged,
self.get_ntypes(),
keys=["energy"],
model_forward=self.get_forward_wrapper_func(),
)
)["energy"]
self.set_out_bias(delta_bias, add=True)
elif bias_adjust_mode == "set-by-statistic":
bias_atom = compute_output_stats(
merged,
self.get_ntypes(),
)
keys=["energy"],
)["energy"]
self.set_out_bias(bias_atom)
else:
raise RuntimeError("Unknown bias_adjust_mode mode: " + bias_adjust_mode)
Expand Down
9 changes: 7 additions & 2 deletions deepmd/pt/model/atomic_model/pairtab_atomic_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -228,8 +228,13 @@ def compute_or_load_stat(

"""
bias_atom_e = compute_output_stats(
merged, self.ntypes, stat_file_path, self.rcond, self.atom_ener
)
merged,
self.ntypes,
keys=["energy"],
stat_file_path=stat_file_path,
rcond=self.rcond,
atom_ener=self.atom_ener,
)["energy"]
self.bias_atom_e.copy_(
torch.tensor(bias_atom_e, device=env.DEVICE).view([self.ntypes, 1])
)
Expand Down
9 changes: 7 additions & 2 deletions deepmd/pt/model/task/invar_fitting.py
Original file line number Diff line number Diff line change
Expand Up @@ -165,8 +165,13 @@ def compute_output_stats(

"""
bias_atom_e = compute_output_stats(
merged, self.ntypes, stat_file_path, self.rcond, self.atom_ener
)
merged,
self.ntypes,
keys=["energy"],
stat_file_path=stat_file_path,
rcond=self.rcond,
atom_ener=self.atom_ener,
)["energy"]
self.bias_atom_e.copy_(bias_atom_e.view([self.ntypes, self.dim_out]))

def output_def(self) -> FittingOutputDef:
Expand Down
129 changes: 88 additions & 41 deletions deepmd/pt/utils/stat.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,9 +78,39 @@
return lst


def restore_from_file(

Check warning on line 81 in deepmd/pt/utils/stat.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/utils/stat.py#L81

Added line #L81 was not covered by tests
stat_file_path: DPPath,
keys: List[str] = ["energy"],
) -> Optional[dict]:
if stat_file_path is None:
return None
stat_files = [stat_file_path / f"bias_atom_{kk}" for kk in keys]
if any(not (ii.is_file()) for ii in stat_files):
return None
anyangml marked this conversation as resolved.
Show resolved Hide resolved
ret = {}

Check warning on line 90 in deepmd/pt/utils/stat.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/utils/stat.py#L85-L90

Added lines #L85 - L90 were not covered by tests

for kk in keys:
fp = stat_file_path / f"bias_atom_{kk}"
assert fp.is_file()
ret[kk] = fp.load_numpy()
return ret

Check warning on line 96 in deepmd/pt/utils/stat.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/utils/stat.py#L92-L96

Added lines #L92 - L96 were not covered by tests


def save_to_file(

Check warning on line 99 in deepmd/pt/utils/stat.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/utils/stat.py#L99

Added line #L99 was not covered by tests
stat_file_path: DPPath,
results: dict,
):
assert stat_file_path is not None
stat_file_path.mkdir(exist_ok=True, parents=True)
for kk, vv in results.items():
fp = stat_file_path / f"bias_atom_{kk}"
fp.save_numpy(vv)

Check warning on line 107 in deepmd/pt/utils/stat.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/utils/stat.py#L103-L107

Added lines #L103 - L107 were not covered by tests


def compute_output_stats(
merged: Union[Callable[[], List[dict]], List[dict]],
ntypes: int,
keys: List[str] = ["energy"],
stat_file_path: Optional[DPPath] = None,
rcond: Optional[float] = None,
atom_ener: Optional[List[float]] = None,
Expand Down Expand Up @@ -112,17 +142,15 @@
which will be subtracted from the energy label of the data.
The difference will then be used to calculate the delta complement energy bias for each type.
"""
if stat_file_path is not None:
stat_file_path = stat_file_path / "bias_atom_e"
if stat_file_path is not None and stat_file_path.is_file():
bias_atom_e = stat_file_path.load_numpy()
else:
bias_atom_e = restore_from_file(stat_file_path, keys)

Check warning on line 145 in deepmd/pt/utils/stat.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/utils/stat.py#L145

Added line #L145 was not covered by tests

if bias_atom_e is None:

Check warning on line 147 in deepmd/pt/utils/stat.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/utils/stat.py#L147

Added line #L147 was not covered by tests
if callable(merged):
# only get data for once
sampled = merged()
else:
sampled = merged
energy = [item["energy"] for item in sampled]
outputs = {kk: [item[kk] for item in sampled] for kk in keys}

Check warning on line 153 in deepmd/pt/utils/stat.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/utils/stat.py#L153

Added line #L153 was not covered by tests
data_mixed_type = "real_natoms_vec" in sampled[0]
natoms_key = "natoms" if not data_mixed_type else "real_natoms_vec"
for system in sampled:
Expand All @@ -133,7 +161,7 @@
system[natoms_key][:, 2:] *= type_mask.unsqueeze(0)
input_natoms = [item[natoms_key] for item in sampled]
# shape: (nframes, ndim)
merged_energy = to_numpy_array(torch.cat(energy))
merged_output = {kk: to_numpy_array(torch.cat(outputs[kk])) for kk in keys}

Check warning on line 164 in deepmd/pt/utils/stat.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/utils/stat.py#L164

Added line #L164 was not covered by tests
# shape: (nframes, ntypes)
merged_natoms = to_numpy_array(torch.cat(input_natoms)[:, 2:])
if atom_ener is not None and len(atom_ener) > 0:
Expand All @@ -144,16 +172,20 @@
assigned_atom_ener = None
if model_forward is None:
# only use statistics result
bias_atom_e, _ = compute_stats_from_redu(
merged_energy,
merged_natoms,
assigned_bias=assigned_atom_ener,
rcond=rcond,
)
# [0]: take the first otuput (mean) of compute_stats_from_redu
bias_atom_e = {

Check warning on line 176 in deepmd/pt/utils/stat.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/utils/stat.py#L176

Added line #L176 was not covered by tests
kk: compute_stats_from_redu(
merged_output[kk],
merged_natoms,
assigned_bias=assigned_atom_ener,
rcond=rcond,
)[0]
for kk in keys
}
else:
# subtract the model bias and output the delta bias
auto_batch_size = AutoBatchSize()
energy_predict = []
model_predict = {kk: [] for kk in keys}

Check warning on line 188 in deepmd/pt/utils/stat.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/utils/stat.py#L188

Added line #L188 was not covered by tests
for system in sampled:
nframes = system["coord"].shape[0]
coord, atype, box, natoms = (
Expand All @@ -174,34 +206,49 @@
**kwargs,
)

energy = (
model_forward_auto_batch_size(
coord, atype, box, fparam=fparam, aparam=aparam
)
.reshape(nframes, -1)
.sum(-1)
sample_predict = model_forward_auto_batch_size(

Check warning on line 209 in deepmd/pt/utils/stat.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/utils/stat.py#L209

Added line #L209 was not covered by tests
coord, atype, box, fparam=fparam, aparam=aparam
)
energy_predict.append(to_numpy_array(energy).reshape([nframes, 1]))

energy_predict = np.concatenate(energy_predict)
bias_diff = merged_energy - energy_predict
bias_atom_e, _ = compute_stats_from_redu(
bias_diff,
merged_natoms,
assigned_bias=assigned_atom_ener,
rcond=rcond,
)
unbias_e = energy_predict + merged_natoms @ bias_atom_e

for kk in keys:
model_predict[kk].append(

Check warning on line 214 in deepmd/pt/utils/stat.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/utils/stat.py#L213-L214

Added lines #L213 - L214 were not covered by tests
to_numpy_array(
torch.sum(sample_predict[kk], dim=1) # nf x nloc x odims
)
)

model_predict = {kk: np.concatenate(model_predict[kk]) for kk in keys}

Check warning on line 220 in deepmd/pt/utils/stat.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/utils/stat.py#L220

Added line #L220 was not covered by tests

bias_diff = {kk: merged_output[kk] - model_predict[kk] for kk in keys}
bias_atom_e = {

Check warning on line 223 in deepmd/pt/utils/stat.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/utils/stat.py#L222-L223

Added lines #L222 - L223 were not covered by tests
kk: compute_stats_from_redu(
bias_diff[kk],
merged_natoms,
assigned_bias=assigned_atom_ener,
rcond=rcond,
)[0]
for kk in keys
}
unbias_e = {

Check warning on line 232 in deepmd/pt/utils/stat.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/utils/stat.py#L232

Added line #L232 was not covered by tests
kk: model_predict[kk] + merged_natoms @ bias_atom_e[kk] for kk in keys
}
atom_numbs = merged_natoms.sum(-1)
rmse_ae = np.sqrt(
np.mean(
np.square((unbias_e.ravel() - merged_energy.ravel()) / atom_numbs)
for kk in keys:
rmse_ae = np.sqrt(

Check warning on line 237 in deepmd/pt/utils/stat.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/utils/stat.py#L236-L237

Added lines #L236 - L237 were not covered by tests
np.mean(
np.square(
(unbias_e[kk].ravel() - merged_output[kk].ravel())
/ atom_numbs
)
)
)
)
log.info(
f"RMSE of energy per atom after linear regression is: {rmse_ae} eV/atom."
)
log.info(

Check warning on line 245 in deepmd/pt/utils/stat.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/utils/stat.py#L245

Added line #L245 was not covered by tests
f"RMSE of {kk} per atom after linear regression is: {rmse_ae} in the unit of {kk}."
)

if stat_file_path is not None:
stat_file_path.save_numpy(bias_atom_e)
assert all(x is not None for x in [bias_atom_e])
return to_torch_tensor(bias_atom_e)
save_to_file(stat_file_path, bias_atom_e)

Check warning on line 250 in deepmd/pt/utils/stat.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/utils/stat.py#L250

Added line #L250 was not covered by tests

ret = {kk: to_torch_tensor(bias_atom_e[kk]) for kk in keys}

Check warning on line 252 in deepmd/pt/utils/stat.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/utils/stat.py#L252

Added line #L252 was not covered by tests

return ret

Check warning on line 254 in deepmd/pt/utils/stat.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/utils/stat.py#L254

Added line #L254 was not covered by tests
100 changes: 100 additions & 0 deletions source/tests/pt/test_stat.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
# SPDX-License-Identifier: LGPL-3.0-or-later
import json
import os
import tempfile
import unittest
from abc import (
ABC,
Expand All @@ -11,6 +12,7 @@
)

import dpdata
import h5py
import numpy as np
import torch

Expand All @@ -29,7 +31,14 @@
from deepmd.pt.utils.dataloader import (
DpLoaderSet,
)
from deepmd.pt.utils.stat import (
compute_output_stats,
)
from deepmd.pt.utils.stat import make_stat_input
from deepmd.pt.utils.stat import make_stat_input as my_make
from deepmd.pt.utils.utils import (
to_numpy_array,
)
from deepmd.tf.common import (
expand_sys_str,
)
Expand All @@ -47,6 +56,9 @@
from deepmd.utils.data import (
DataRequirementItem,
)
from deepmd.utils.path import (
DPPath,
)

CUR_DIR = os.path.dirname(__file__)

Expand Down Expand Up @@ -325,5 +337,93 @@ def tf_compute_input_stats(self):
)


class TestOutputStat(unittest.TestCase):
def setUp(self):
self.data_file = [str(Path(__file__).parent / "water/data/data_0")]
self.type_map = ["O", "H"] # by dataset
self.data = DpLoaderSet(
self.data_file,
batch_size=1,
type_map=self.type_map,
)
self.data.add_data_requirement(energy_data_requirement)
self.sampled = make_stat_input(
self.data.systems,
self.data.dataloaders,
nbatches=1,
)
self.tempdir = tempfile.TemporaryDirectory()
h5file = str((Path(self.tempdir.name) / "testcase.h5").resolve())
with h5py.File(h5file, "w") as f:
pass
self.stat_file_path = DPPath(h5file, "a")

def tearDown(self):
self.tempdir.cleanup()

def test_calc_and_load(self):
stat_file_path = self.stat_file_path
type_map = self.type_map

# compute from sample
ret0 = compute_output_stats(
self.sampled,
len(type_map),
keys=["energy"],
stat_file_path=stat_file_path,
atom_ener=None,
model_forward=None,
)
# ground truth
ntest = 1
atom_nums = np.tile(
np.bincount(to_numpy_array(self.sampled[0]["atype"][0])),
(ntest, 1),
)
energy_diff = to_numpy_array(self.sampled[0]["energy"][:ntest])
ground_truth_shift = np.linalg.lstsq(atom_nums, energy_diff, rcond=None)[0]

# check values
np.testing.assert_almost_equal(
to_numpy_array(ret0["energy"]), ground_truth_shift, decimal=10
)
# self.assertTrue(stat_file_path.is_dir())

def raise_error():
raise RuntimeError

# hack!!!
# suppose to load stat from file, if from sample, an error will raise.
ret1 = compute_output_stats(
raise_error,
len(type_map),
keys=["energy"],
stat_file_path=stat_file_path,
atom_ener=None,
model_forward=None,
)
np.testing.assert_almost_equal(
to_numpy_array(ret0["energy"]), to_numpy_array(ret1["energy"]), decimal=10
)

def test_assigned(self):
atom_ener = np.array([3.0, 5.0]).reshape(2, 1)
stat_file_path = self.stat_file_path
type_map = self.type_map

# from assigned atom_ener
ret2 = compute_output_stats(
self.sampled,
len(type_map),
keys=["energy"],
stat_file_path=stat_file_path,
atom_ener=atom_ener,
model_forward=None,
)
np.testing.assert_almost_equal(
to_numpy_array(ret2["energy"]), atom_ener, decimal=10
)


if __name__ == "__main__":
unittest.main()