diff --git a/tests/test_error.py b/tests/test_error.py index 5de66250..0de974be 100644 --- a/tests/test_error.py +++ b/tests/test_error.py @@ -5,6 +5,7 @@ import pytest from ase.atoms import Atoms from ase.calculators.lj import LennardJones +from ase.stress import voigt_6_to_full_3x3_stress from pytest import approx from pprint import pprint @@ -103,6 +104,17 @@ def test_err_from_calc(ref_atoms): assert ref_err_dict['virial/atom/comp']['_ALL_']["count"] == 10 * 6 +def test_err_stress_shape(ref_atoms): + ref_atoms_calc = generic_calc(ref_atoms, OutputSpec(), LennardJones(sigma=0.75), output_prefix='calc_') + ref_err_dict, _, _ = ref_err_calc(ref_atoms_calc, ref_property_prefix='REF_', calc_property_prefix='calc_') + + for at in ref_atoms_calc: + at.info["REF_stress"] = voigt_6_to_full_3x3_stress(at.info["REF_stress"]) + ref_err_dict_shape, _, _ = ref_err_calc(ref_atoms_calc, ref_property_prefix='REF_', calc_property_prefix='calc_') + + assert ref_err_dict == ref_err_dict_shape + + def test_error_properties(ref_atoms): ref_atoms_calc = generic_calc(ref_atoms, OutputSpec(), LennardJones(sigma=0.75), output_prefix='calc_') # both energy and per atom diff --git a/wfl/fit/error.py b/wfl/fit/error.py index e2986d64..14d97285 100755 --- a/wfl/fit/error.py +++ b/wfl/fit/error.py @@ -85,8 +85,6 @@ def _reshape_normalize(quant, prop, atoms, per_atom): quant: 2-d array containing reshaped quantity, with leading dimension 1 for per-config or len(atoms) for per-atom """ - # convert scalars or lists into arrays - quant = np.asarray(quant) # fix shape of stress/virial if prop.startswith("stress") or prop.startswith("virial"): @@ -156,9 +154,9 @@ def _reshape_normalize(quant, prop, atoms, per_atom): raise ValueError("/atom only possible in config_properties") data = at.arrays - # grab data - ref_quant = data.get(ref_property_prefix + prop_use) - calc_quant = data.get(calc_property_prefix + prop_use) + # grab data, make a copy so normalization doesn't affect original + ref_quant = np.asarray(data.get(ref_property_prefix + prop_use)).copy() + calc_quant = np.asarray(data.get(calc_property_prefix + prop_use)).copy() if ref_quant is None or calc_quant is None: # warn if data is missing by reporting summary at the very end if prop not in missed_prop_counter: