-
Notifications
You must be signed in to change notification settings - Fork 520
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
feat(pt): add universal test for loss (#4354)
<!-- This is an auto-generated comment: release notes by coderabbit.ai --> ## Summary by CodeRabbit ## Release Notes - **New Features** - Introduced a new `LossTest` class for enhanced testing of loss functions. - Added multiple parameterized test functions for various loss functions in the new `test_loss.py` file. - **Bug Fixes** - Corrected tensor operations in the `DOSLoss` class to ensure accurate cumulative sum calculations. - **Documentation** - Added SPDX license identifiers to multiple files for clarity on licensing terms. - **Chores** - Refactored data conversion methods in the `PTTestCase` class for improved handling of tensors and arrays. <!-- end of auto-generated comment: release notes by coderabbit.ai --> --------- Signed-off-by: Duo <[email protected]> Co-authored-by: Jinzhe Zeng <[email protected]>
- Loading branch information
Showing
10 changed files
with
388 additions
and
11 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1 @@ | ||
# SPDX-License-Identifier: LGPL-3.0-or-later |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,11 @@ | ||
# SPDX-License-Identifier: LGPL-3.0-or-later | ||
|
||
|
||
from .utils import ( | ||
LossTestCase, | ||
) | ||
|
||
|
||
class LossTest(LossTestCase): | ||
def setUp(self) -> None: | ||
LossTestCase.setUp(self) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,79 @@ | ||
# SPDX-License-Identifier: LGPL-3.0-or-later | ||
|
||
import numpy as np | ||
|
||
from deepmd.utils.data import ( | ||
DataRequirementItem, | ||
) | ||
|
||
from .....seed import ( | ||
GLOBAL_SEED, | ||
) | ||
|
||
|
||
class LossTestCase: | ||
"""Common test case for loss function.""" | ||
|
||
def setUp(self): | ||
pass | ||
|
||
def test_label_keys(self): | ||
module = self.forward_wrapper(self.module) | ||
label_requirement = self.module.label_requirement | ||
label_dict = {item.key: item for item in label_requirement} | ||
label_keys = sorted(label_dict.keys()) | ||
label_keys_expected = sorted( | ||
[key for key in self.key_to_pref_map if self.key_to_pref_map[key] > 0] | ||
) | ||
np.testing.assert_equal(label_keys_expected, label_keys) | ||
|
||
def test_forward(self): | ||
module = self.forward_wrapper(self.module) | ||
label_requirement = self.module.label_requirement | ||
label_dict = {item.key: item for item in label_requirement} | ||
label_keys = sorted(label_dict.keys()) | ||
natoms = 5 | ||
nframes = 2 | ||
|
||
def fake_model(): | ||
model_predict = { | ||
data_key: fake_input( | ||
label_dict[data_key], natoms=natoms, nframes=nframes | ||
) | ||
for data_key in label_keys | ||
} | ||
if "atom_ener" in model_predict: | ||
model_predict["atom_energy"] = model_predict.pop("atom_ener") | ||
model_predict.update( | ||
{"mask_mag": np.ones([nframes, natoms, 1], dtype=np.bool_)} | ||
) | ||
return model_predict | ||
|
||
labels = { | ||
data_key: fake_input(label_dict[data_key], natoms=natoms, nframes=nframes) | ||
for data_key in label_keys | ||
} | ||
labels.update({"find_" + data_key: 1.0 for data_key in label_keys}) | ||
|
||
_, loss, more_loss = module( | ||
{}, | ||
fake_model, | ||
labels, | ||
natoms, | ||
1.0, | ||
) | ||
|
||
|
||
def fake_input(data_item: DataRequirementItem, natoms=5, nframes=2) -> np.ndarray: | ||
ndof = data_item.ndof | ||
atomic = data_item.atomic | ||
repeat = data_item.repeat | ||
rng = np.random.default_rng(seed=GLOBAL_SEED) | ||
dtype = data_item.dtype if data_item.dtype is not None else np.float64 | ||
if atomic: | ||
data = rng.random([nframes, natoms, ndof], dtype) | ||
else: | ||
data = rng.random([nframes, ndof], dtype) | ||
if repeat != 1: | ||
data = np.repeat(data, repeat).reshape([nframes, -1]) | ||
return data |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1 @@ | ||
# SPDX-License-Identifier: LGPL-3.0-or-later |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,203 @@ | ||
# SPDX-License-Identifier: LGPL-3.0-or-later | ||
from collections import ( | ||
OrderedDict, | ||
) | ||
|
||
from ....consistent.common import ( | ||
parameterize_func, | ||
) | ||
|
||
|
||
def LossParamEnergy( | ||
starter_learning_rate=1.0, | ||
pref_e=1.0, | ||
pref_f=1.0, | ||
pref_v=1.0, | ||
pref_ae=1.0, | ||
): | ||
key_to_pref_map = { | ||
"energy": pref_e, | ||
"force": pref_f, | ||
"virial": pref_v, | ||
"atom_ener": pref_ae, | ||
} | ||
input_dict = { | ||
"key_to_pref_map": key_to_pref_map, | ||
"starter_learning_rate": starter_learning_rate, | ||
"start_pref_e": pref_e, | ||
"limit_pref_e": pref_e / 2, | ||
"start_pref_f": pref_f, | ||
"limit_pref_f": pref_f / 2, | ||
"start_pref_v": pref_v, | ||
"limit_pref_v": pref_v / 2, | ||
"start_pref_ae": pref_ae, | ||
"limit_pref_ae": pref_ae / 2, | ||
} | ||
return input_dict | ||
|
||
|
||
LossParamEnergyList = parameterize_func( | ||
LossParamEnergy, | ||
OrderedDict( | ||
{ | ||
"pref_e": (1.0, 0.0), | ||
"pref_f": (1.0, 0.0), | ||
"pref_v": (1.0, 0.0), | ||
"pref_ae": (1.0, 0.0), | ||
} | ||
), | ||
) | ||
# to get name for the default function | ||
LossParamEnergy = LossParamEnergyList[0] | ||
|
||
|
||
def LossParamEnergySpin( | ||
starter_learning_rate=1.0, | ||
pref_e=1.0, | ||
pref_fr=1.0, | ||
pref_fm=1.0, | ||
pref_v=1.0, | ||
pref_ae=1.0, | ||
): | ||
key_to_pref_map = { | ||
"energy": pref_e, | ||
"force": pref_fr, | ||
"force_mag": pref_fm, | ||
"virial": pref_v, | ||
"atom_ener": pref_ae, | ||
} | ||
input_dict = { | ||
"key_to_pref_map": key_to_pref_map, | ||
"starter_learning_rate": starter_learning_rate, | ||
"start_pref_e": pref_e, | ||
"limit_pref_e": pref_e / 2, | ||
"start_pref_fr": pref_fr, | ||
"limit_pref_fr": pref_fr / 2, | ||
"start_pref_fm": pref_fm, | ||
"limit_pref_fm": pref_fm / 2, | ||
"start_pref_v": pref_v, | ||
"limit_pref_v": pref_v / 2, | ||
"start_pref_ae": pref_ae, | ||
"limit_pref_ae": pref_ae / 2, | ||
} | ||
return input_dict | ||
|
||
|
||
LossParamEnergySpinList = parameterize_func( | ||
LossParamEnergySpin, | ||
OrderedDict( | ||
{ | ||
"pref_e": (1.0, 0.0), | ||
"pref_fr": (1.0, 0.0), | ||
"pref_fm": (1.0, 0.0), | ||
"pref_v": (1.0, 0.0), | ||
"pref_ae": (1.0, 0.0), | ||
} | ||
), | ||
) | ||
# to get name for the default function | ||
LossParamEnergySpin = LossParamEnergySpinList[0] | ||
|
||
|
||
def LossParamDos( | ||
starter_learning_rate=1.0, | ||
pref_dos=1.0, | ||
pref_ados=1.0, | ||
): | ||
key_to_pref_map = { | ||
"dos": pref_dos, | ||
"atom_dos": pref_ados, | ||
} | ||
input_dict = { | ||
"key_to_pref_map": key_to_pref_map, | ||
"starter_learning_rate": starter_learning_rate, | ||
"numb_dos": 2, | ||
"start_pref_dos": pref_dos, | ||
"limit_pref_dos": pref_dos / 2, | ||
"start_pref_ados": pref_ados, | ||
"limit_pref_ados": pref_ados / 2, | ||
"start_pref_cdf": 0.0, | ||
"limit_pref_cdf": 0.0, | ||
"start_pref_acdf": 0.0, | ||
"limit_pref_acdf": 0.0, | ||
} | ||
return input_dict | ||
|
||
|
||
LossParamDosList = parameterize_func( | ||
LossParamDos, | ||
OrderedDict( | ||
{ | ||
"pref_dos": (1.0,), | ||
"pref_ados": (1.0, 0.0), | ||
} | ||
), | ||
) + parameterize_func( | ||
LossParamDos, | ||
OrderedDict( | ||
{ | ||
"pref_dos": (0.0,), | ||
"pref_ados": (1.0,), | ||
} | ||
), | ||
) | ||
|
||
# to get name for the default function | ||
LossParamDos = LossParamDosList[0] | ||
|
||
|
||
def LossParamTensor( | ||
pref=1.0, | ||
pref_atomic=1.0, | ||
): | ||
tensor_name = "test_tensor" | ||
key_to_pref_map = { | ||
tensor_name: pref, | ||
f"atomic_{tensor_name}": pref_atomic, | ||
} | ||
input_dict = { | ||
"key_to_pref_map": key_to_pref_map, | ||
"tensor_name": tensor_name, | ||
"tensor_size": 2, | ||
"label_name": tensor_name, | ||
"pref": pref, | ||
"pref_atomic": pref_atomic, | ||
} | ||
return input_dict | ||
|
||
|
||
LossParamTensorList = parameterize_func( | ||
LossParamTensor, | ||
OrderedDict( | ||
{ | ||
"pref": (1.0,), | ||
"pref_atomic": (1.0, 0.0), | ||
} | ||
), | ||
) + parameterize_func( | ||
LossParamTensor, | ||
OrderedDict( | ||
{ | ||
"pref": (0.0,), | ||
"pref_atomic": (1.0,), | ||
} | ||
), | ||
) | ||
# to get name for the default function | ||
LossParamTensor = LossParamTensorList[0] | ||
|
||
|
||
def LossParamProperty(): | ||
key_to_pref_map = { | ||
"property": 1.0, | ||
} | ||
input_dict = { | ||
"key_to_pref_map": key_to_pref_map, | ||
"task_dim": 2, | ||
} | ||
return input_dict | ||
|
||
|
||
LossParamPropertyList = [LossParamProperty] | ||
# to get name for the default function | ||
LossParamProperty = LossParamPropertyList[0] |
Oops, something went wrong.