Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: compute output stat for a dict of labels. #3628

Merged
merged 5 commits into from
Apr 1, 2024
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 5 additions & 7 deletions deepmd/pt/model/atomic_model/base_atomic_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -192,10 +192,6 @@

def get_forward_wrapper_func(self) -> Callable[..., torch.Tensor]:
"""Get a forward wrapper of the atomic model for output bias calculation."""
model_output_type = list(self.atomic_output_def().keys())
if "mask" in model_output_type:
model_output_type.pop(model_output_type.index("mask"))
out_name = model_output_type[0]

def model_forward(coord, atype, box, fparam=None, aparam=None):
with torch.no_grad(): # it's essential for pure torch forward function to use auto_batchsize
Expand All @@ -220,7 +216,7 @@
fparam=fparam,
aparam=aparam,
)
return atomic_ret[out_name].detach()
return {kk: vv.detach() for kk, vv in atomic_ret.items()}

Check warning on line 219 in deepmd/pt/model/atomic_model/base_atomic_model.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/model/atomic_model/base_atomic_model.py#L219

Added line #L219 was not covered by tests

return model_forward

Expand Down Expand Up @@ -287,14 +283,16 @@
delta_bias = compute_output_stats(
merged,
self.get_ntypes(),
keys=["energy"],
model_forward=self.get_forward_wrapper_func(),
)
)["energy"]
self.set_out_bias(delta_bias, add=True)
elif bias_adjust_mode == "set-by-statistic":
bias_atom = compute_output_stats(
merged,
self.get_ntypes(),
)
keys=["energy"],
)["energy"]
self.set_out_bias(bias_atom)
else:
raise RuntimeError("Unknown bias_adjust_mode mode: " + bias_adjust_mode)
Expand Down
9 changes: 7 additions & 2 deletions deepmd/pt/model/atomic_model/pairtab_atomic_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -228,8 +228,13 @@ def compute_or_load_stat(

"""
bias_atom_e = compute_output_stats(
merged, self.ntypes, stat_file_path, self.rcond, self.atom_ener
)
merged,
self.ntypes,
keys=["energy"],
stat_file_path=stat_file_path,
rcond=self.rcond,
atom_ener=self.atom_ener,
)["energy"]
self.bias_atom_e.copy_(
torch.tensor(bias_atom_e, device=env.DEVICE).view([self.ntypes, 1])
)
Expand Down
9 changes: 7 additions & 2 deletions deepmd/pt/model/task/invar_fitting.py
Original file line number Diff line number Diff line change
Expand Up @@ -165,8 +165,13 @@ def compute_output_stats(

"""
bias_atom_e = compute_output_stats(
merged, self.ntypes, stat_file_path, self.rcond, self.atom_ener
)
merged,
self.ntypes,
keys=["energy"],
stat_file_path=stat_file_path,
rcond=self.rcond,
atom_ener=self.atom_ener,
)["energy"]
self.bias_atom_e.copy_(bias_atom_e.view([self.ntypes, self.dim_out]))

def output_def(self) -> FittingOutputDef:
Expand Down
132 changes: 91 additions & 41 deletions deepmd/pt/utils/stat.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,8 @@
# SPDX-License-Identifier: LGPL-3.0-or-later
import logging
from pathlib import (

Check warning on line 3 in deepmd/pt/utils/stat.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/utils/stat.py#L3

Added line #L3 was not covered by tests
Path,
)
from typing import (
Callable,
List,
Expand Down Expand Up @@ -78,9 +81,39 @@
return lst


def restore_from_file(

Check warning on line 84 in deepmd/pt/utils/stat.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/utils/stat.py#L84

Added line #L84 was not covered by tests
stat_file_path: Path,
keys: List[str] = ["energy"],
) -> Optional[dict]:
if stat_file_path is None:
return None
stat_files = [stat_file_path / f"bias_atom_{kk}.npy" for kk in keys]
if any(not (ii.is_file()) for ii in stat_files):
return None
anyangml marked this conversation as resolved.
Show resolved Hide resolved
ret = {}

Check warning on line 93 in deepmd/pt/utils/stat.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/utils/stat.py#L88-L93

Added lines #L88 - L93 were not covered by tests

for kk in keys:
fp = stat_file_path / f"bias_atom_{kk}.npy"
assert fp.is_file()
anyangml marked this conversation as resolved.
Show resolved Hide resolved
ret[kk] = np.load(fp)
return ret

Check warning on line 99 in deepmd/pt/utils/stat.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/utils/stat.py#L95-L99

Added lines #L95 - L99 were not covered by tests


def save_to_file(

Check warning on line 102 in deepmd/pt/utils/stat.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/utils/stat.py#L102

Added line #L102 was not covered by tests
stat_file_path: Path,
results: dict,
):
assert stat_file_path is not None
stat_file_path.mkdir(exist_ok=True, parents=True)
for kk, vv in results.items():
fp = stat_file_path / f"bias_atom_{kk}.npy"
np.save(fp, vv)

Check warning on line 110 in deepmd/pt/utils/stat.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/utils/stat.py#L106-L110

Added lines #L106 - L110 were not covered by tests


def compute_output_stats(
merged: Union[Callable[[], List[dict]], List[dict]],
ntypes: int,
keys: List[str] = ["energy"],
stat_file_path: Optional[DPPath] = None,
rcond: Optional[float] = None,
atom_ener: Optional[List[float]] = None,
Expand Down Expand Up @@ -112,17 +145,15 @@
which will be subtracted from the energy label of the data.
The difference will then be used to calculate the delta complement energy bias for each type.
"""
if stat_file_path is not None:
stat_file_path = stat_file_path / "bias_atom_e"
if stat_file_path is not None and stat_file_path.is_file():
bias_atom_e = stat_file_path.load_numpy()
else:
bias_atom_e = restore_from_file(stat_file_path, keys)

Check warning on line 148 in deepmd/pt/utils/stat.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/utils/stat.py#L148

Added line #L148 was not covered by tests

if bias_atom_e is None:

Check warning on line 150 in deepmd/pt/utils/stat.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/utils/stat.py#L150

Added line #L150 was not covered by tests
if callable(merged):
# only get data for once
sampled = merged()
else:
sampled = merged
energy = [item["energy"] for item in sampled]
outputs = {kk: [item[kk] for item in sampled] for kk in keys}

Check warning on line 156 in deepmd/pt/utils/stat.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/utils/stat.py#L156

Added line #L156 was not covered by tests
data_mixed_type = "real_natoms_vec" in sampled[0]
natoms_key = "natoms" if not data_mixed_type else "real_natoms_vec"
for system in sampled:
Expand All @@ -133,7 +164,7 @@
system[natoms_key][:, 2:] *= type_mask.unsqueeze(0)
input_natoms = [item[natoms_key] for item in sampled]
# shape: (nframes, ndim)
merged_energy = to_numpy_array(torch.cat(energy))
merged_output = {kk: to_numpy_array(torch.cat(outputs[kk])) for kk in keys}

Check warning on line 167 in deepmd/pt/utils/stat.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/utils/stat.py#L167

Added line #L167 was not covered by tests
# shape: (nframes, ntypes)
merged_natoms = to_numpy_array(torch.cat(input_natoms)[:, 2:])
if atom_ener is not None and len(atom_ener) > 0:
Expand All @@ -144,16 +175,20 @@
assigned_atom_ener = None
if model_forward is None:
# only use statistics result
bias_atom_e, _ = compute_stats_from_redu(
merged_energy,
merged_natoms,
assigned_bias=assigned_atom_ener,
rcond=rcond,
)
# [0]: take the first otuput (mean) of compute_stats_from_redu
bias_atom_e = {

Check warning on line 179 in deepmd/pt/utils/stat.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/utils/stat.py#L179

Added line #L179 was not covered by tests
kk: compute_stats_from_redu(
merged_output[kk],
merged_natoms,
assigned_bias=assigned_atom_ener,
rcond=rcond,
)[0]
for kk in keys
}
else:
# subtract the model bias and output the delta bias
auto_batch_size = AutoBatchSize()
energy_predict = []
model_predict = {kk: [] for kk in keys}

Check warning on line 191 in deepmd/pt/utils/stat.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/utils/stat.py#L191

Added line #L191 was not covered by tests
for system in sampled:
nframes = system["coord"].shape[0]
coord, atype, box, natoms = (
Expand All @@ -174,34 +209,49 @@
**kwargs,
)

energy = (
model_forward_auto_batch_size(
coord, atype, box, fparam=fparam, aparam=aparam
)
.reshape(nframes, -1)
.sum(-1)
sample_predict = model_forward_auto_batch_size(

Check warning on line 212 in deepmd/pt/utils/stat.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/utils/stat.py#L212

Added line #L212 was not covered by tests
coord, atype, box, fparam=fparam, aparam=aparam
)
energy_predict.append(to_numpy_array(energy).reshape([nframes, 1]))

energy_predict = np.concatenate(energy_predict)
bias_diff = merged_energy - energy_predict
bias_atom_e, _ = compute_stats_from_redu(
bias_diff,
merged_natoms,
assigned_bias=assigned_atom_ener,
rcond=rcond,
)
unbias_e = energy_predict + merged_natoms @ bias_atom_e

for kk in keys:
model_predict[kk].append(

Check warning on line 217 in deepmd/pt/utils/stat.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/utils/stat.py#L216-L217

Added lines #L216 - L217 were not covered by tests
to_numpy_array(
torch.sum(sample_predict[kk], dim=1) # nf x nloc x odims
)
)

model_predict = {kk: np.concatenate(model_predict[kk]) for kk in keys}

Check warning on line 223 in deepmd/pt/utils/stat.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/utils/stat.py#L223

Added line #L223 was not covered by tests

bias_diff = {kk: merged_output[kk] - model_predict[kk] for kk in keys}
bias_atom_e = {

Check warning on line 226 in deepmd/pt/utils/stat.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/utils/stat.py#L225-L226

Added lines #L225 - L226 were not covered by tests
kk: compute_stats_from_redu(
bias_diff[kk],
merged_natoms,
assigned_bias=assigned_atom_ener,
rcond=rcond,
)[0]
for kk in keys
}
unbias_e = {

Check warning on line 235 in deepmd/pt/utils/stat.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/utils/stat.py#L235

Added line #L235 was not covered by tests
kk: model_predict[kk] + merged_natoms @ bias_atom_e[kk] for kk in keys
}
atom_numbs = merged_natoms.sum(-1)
rmse_ae = np.sqrt(
np.mean(
np.square((unbias_e.ravel() - merged_energy.ravel()) / atom_numbs)
for kk in keys:
rmse_ae = np.sqrt(

Check warning on line 240 in deepmd/pt/utils/stat.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/utils/stat.py#L239-L240

Added lines #L239 - L240 were not covered by tests
np.mean(
np.square(
(unbias_e[kk].ravel() - merged_output[kk].ravel())
/ atom_numbs
)
)
)
)
log.info(
f"RMSE of energy per atom after linear regression is: {rmse_ae} eV/atom."
)
log.info(

Check warning on line 248 in deepmd/pt/utils/stat.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/utils/stat.py#L248

Added line #L248 was not covered by tests
f"RMSE of {kk} per atom after linear regression is: {rmse_ae} in the unit of {kk}."
)

if stat_file_path is not None:
stat_file_path.save_numpy(bias_atom_e)
assert all(x is not None for x in [bias_atom_e])
return to_torch_tensor(bias_atom_e)
save_to_file(stat_file_path, bias_atom_e)

Check warning on line 253 in deepmd/pt/utils/stat.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/utils/stat.py#L253

Added line #L253 was not covered by tests

ret = {kk: to_torch_tensor(bias_atom_e[kk]) for kk in keys}

Check warning on line 255 in deepmd/pt/utils/stat.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/utils/stat.py#L255

Added line #L255 was not covered by tests

return ret

Check warning on line 257 in deepmd/pt/utils/stat.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/utils/stat.py#L257

Added line #L257 was not covered by tests
86 changes: 86 additions & 0 deletions source/tests/pt/test_stat.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
# SPDX-License-Identifier: LGPL-3.0-or-later
import json
import os
import shutil
import unittest
from abc import (
ABC,
Expand Down Expand Up @@ -29,7 +30,14 @@
from deepmd.pt.utils.dataloader import (
DpLoaderSet,
)
from deepmd.pt.utils.stat import (
compute_output_stats,
)
from deepmd.pt.utils.stat import make_stat_input
from deepmd.pt.utils.stat import make_stat_input as my_make
from deepmd.pt.utils.utils import (
to_numpy_array,
)
from deepmd.tf.common import (
expand_sys_str,
)
Expand Down Expand Up @@ -325,5 +333,83 @@ def tf_compute_input_stats(self):
)


class TestOutputStat(unittest.TestCase):
def test(self):
self.data_file = [str(Path(__file__).parent / "water/data/data_0")]
type_map = ["O", "H"] # by dataset
self.data = DpLoaderSet(
self.data_file,
batch_size=1,
type_map=type_map,
)
self.data.add_data_requirement(energy_data_requirement)
self.sampled = make_stat_input(
self.data.systems,
self.data.dataloaders,
nbatches=1,
)
stat_file_path = Path("my_output_stat")
stat_file_path.mkdir(exist_ok=True)
atom_ener = np.array([3.0, 5.0]).reshape(2, 1)

if stat_file_path.is_dir():
shutil.rmtree(stat_file_path)
# compute from sample
ret0 = compute_output_stats(
self.sampled,
len(type_map),
keys=["energy"],
stat_file_path=stat_file_path,
atom_ener=None,
model_forward=None,
)
# ground truth
ntest = 1
atom_nums = np.tile(
np.bincount(to_numpy_array(self.sampled[0]["atype"][0])),
(ntest, 1),
)
energy_diff = to_numpy_array(self.sampled[0]["energy"][:ntest])
ground_truth_shift = np.linalg.lstsq(atom_nums, energy_diff, rcond=None)[0]

# check values
np.testing.assert_almost_equal(
to_numpy_array(ret0["energy"]), ground_truth_shift, decimal=10
)
self.assertTrue(stat_file_path.is_dir())

def raise_error():
raise RuntimeError

# hack!!!
# suppose to load stat from file, if from sample, an error will raise.
ret1 = compute_output_stats(
raise_error,
len(type_map),
keys=["energy"],
stat_file_path=stat_file_path,
atom_ener=None,
model_forward=None,
)
np.testing.assert_almost_equal(
to_numpy_array(ret0["energy"]), to_numpy_array(ret1["energy"]), decimal=10
)
shutil.rmtree(stat_file_path)

# from assigned atom_ener
ret2 = compute_output_stats(
self.sampled,
len(type_map),
keys=["energy"],
stat_file_path=stat_file_path,
atom_ener=atom_ener,
model_forward=None,
)
np.testing.assert_almost_equal(
to_numpy_array(ret2["energy"]), atom_ener, decimal=10
)
shutil.rmtree(stat_file_path)


if __name__ == "__main__":
unittest.main()
Loading