Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Try to fix sign of scalar diffs and parity quantities in error analysis #336

Merged
merged 3 commits into from
Aug 29, 2024
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 9 additions & 9 deletions wfl/fit/error.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,8 @@ def calc(inputs, calc_property_prefix, ref_property_prefix,
Returns
-------
errors: dict of RMSE and MAE for each category and property
diffs: dict with list of differences for each category and property
diffs: dict with list of differences for each category and property (signed for scalar
properties, norms for vectors)
parity: dict with "ref" and "calc" keys, each containing list of property values for
each category and property, for parity plots
"""
Expand Down Expand Up @@ -188,13 +189,13 @@ def _reshape_normalize(quant, prop, atoms, per_atom):

if len(diff.shape) != 2:
raise RuntimeError(f"Should never have diff.shape={diff.shape} with dim != 2 (prop {prop + atom_split_index_label})")
# compute norm along vector components
diff = np.linalg.norm(diff, axis=1)
if not per_component:
if diff.shape[1] > 1:
# compute norm along vector components
diff = np.linalg.norm(diff, axis=1)
if not per_component and selected_ref_quant.shape[1] > 1:
selected_ref_quant = np.linalg.norm(selected_ref_quant, axis=1)
selected_calc_quant = np.linalg.norm(selected_calc_quant, axis=1)


_dict_add([all_diffs, all_weights, all_parity["ref"], all_parity["calc"]],
[diff, _promote(weight, diff), selected_ref_quant, selected_calc_quant],
at_category, prop + atom_split_index_label)
Expand Down Expand Up @@ -380,16 +381,15 @@ def select_units(prop, plt_type, units_dict=None):
"energy/atom": {"parity": ("eV/at", 1.0), "error": ("meV/at", 1.0e3)},
"forces": {"parity": ("eV/Å", 1.0), "error": ("meV/Å", 1.0e3)},
"virial": {"parity": ("eV", 1.0), "error": ("meV", 1.0e3)},
"virial/atom": {"parity": ("eV/at", 1.0), "error": ("meV/at", 1.0e3)}
"virial/atom": {"parity": ("eV/at", 1.0), "error": ("meV/at", 1.0e3)},
"stress": {"parity": ("GPa", 1.0), "error": ("MPa", 1.0e3)},
}
if units_dict is None:
units_dict = {}
use_units_dict.update(units_dict)

if "virial" in prop:
prop = re.sub(r"/comp\b", "", prop)
prop = re.sub(r"/comp\b", "", prop)
if "forces" in prop:
prop = re.sub(r"/comp\b", "", prop)
prop = re.sub(r"/Z_\d+\b", "", prop)

if "energy" in prop:
Expand Down
Loading