From 7cd71d82611178901819cea7d87c5711cb3e1284 Mon Sep 17 00:00:00 2001 From: Noam Bernstein Date: Wed, 1 May 2024 16:48:36 -0400 Subject: [PATCH] Fix bug that overwrote REF_* keys when those were the explicitly specified keys for the training reference quantities --- mace/data/utils.py | 37 ++++++++++++++++++++++--------------- 1 file changed, 22 insertions(+), 15 deletions(-) diff --git a/mace/data/utils.py b/mace/data/utils.py index c55ad86b..4dd96287 100644 --- a/mace/data/utils.py +++ b/mace/data/utils.py @@ -212,45 +212,52 @@ def load_from_xyz( ) -> Tuple[Dict[int, float], Configurations]: atoms_list = ase.io.read(file_path, index=":") + energy_from_calc = False + forces_from_calc = False + stress_from_calc = False + # Perform initial checks and log warnings if energy_key == "energy": logging.info( - "Using energy_key 'energy' is unsafe, consider using a different key, rewriting energies to 'REF_energy'" + "Using energy_key 'energy' is unsafe, consider using a different key, rewriting energies to '_REF_energy'" ) - energy_key = "REF_energy" + energy_from_calc = True + energy_key = "_REF_energy" if forces_key == "forces": logging.info( - "Using forces_key 'forces' is unsafe, consider using a different key, rewriting forces to 'REF_forces'" + "Using forces_key 'forces' is unsafe, consider using a different key, rewriting forces to '_REF_forces'" ) - forces_key = "REF_forces" + forces_from_calc = True + forces_key = "_REF_forces" if stress_key == "stress": logging.info( - "Using stress_key 'stress' is unsafe, consider using a different key, rewriting stress to 'REF_stress'" + "Using stress_key 'stress' is unsafe, consider using a different key, rewriting stress to '_REF_stress'" ) - stress_key = "REF_stress" + stress_from_calc = True + stress_key = "_REF_stress" for atoms in atoms_list: - if energy_key == "REF_energy": + if energy_from_calc: try: - atoms.info["REF_energy"] = atoms.get_potential_energy() + atoms.info["_REF_energy"] = atoms.get_potential_energy() except Exception as e: # pylint: disable=W0703 logging.warning(f"Failed to extract energy: {e}") - atoms.info["REF_energy"] = None + atoms.info["_REF_energy"] = None - if forces_key == "REF_forces": + if forces_from_calc: try: - atoms.info["REF_forces"] = atoms.get_forces() + atoms.info["_REF_forces"] = atoms.get_forces() except Exception as e: # pylint: disable=W0703 logging.warning(f"Failed to extract forces: {e}") - atoms.info["REF_forces"] = None + atoms.info["_REF_forces"] = None - if stress_key == "REF_stress": + if stress_from_calc: try: - atoms.info["REF_stress"] = atoms.get_stress() + atoms.info["_REF_stress"] = atoms.get_stress() except Exception as e: # pylint: disable=W0703 - atoms.info["REF_stress"] = None + atoms.info["_REF_stress"] = None if not isinstance(atoms_list, list): atoms_list = [atoms_list]