Skip to content

Commit

Permalink
Fix bug that overwrote REF_* keys when those were the explicitly
Browse files Browse the repository at this point in the history
specified keys for the training reference quantities
  • Loading branch information
bernstei committed May 1, 2024
1 parent 7697261 commit 7cd71d8
Showing 1 changed file with 22 additions and 15 deletions.
37 changes: 22 additions & 15 deletions mace/data/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -212,45 +212,52 @@ def load_from_xyz(
) -> Tuple[Dict[int, float], Configurations]:
atoms_list = ase.io.read(file_path, index=":")

energy_from_calc = False
forces_from_calc = False
stress_from_calc = False

# Perform initial checks and log warnings
if energy_key == "energy":
logging.info(
"Using energy_key 'energy' is unsafe, consider using a different key, rewriting energies to 'REF_energy'"
"Using energy_key 'energy' is unsafe, consider using a different key, rewriting energies to '_REF_energy'"
)
energy_key = "REF_energy"
energy_from_calc = True
energy_key = "_REF_energy"

if forces_key == "forces":
logging.info(
"Using forces_key 'forces' is unsafe, consider using a different key, rewriting forces to 'REF_forces'"
"Using forces_key 'forces' is unsafe, consider using a different key, rewriting forces to '_REF_forces'"
)
forces_key = "REF_forces"
forces_from_calc = True
forces_key = "_REF_forces"

if stress_key == "stress":
logging.info(
"Using stress_key 'stress' is unsafe, consider using a different key, rewriting stress to 'REF_stress'"
"Using stress_key 'stress' is unsafe, consider using a different key, rewriting stress to '_REF_stress'"
)
stress_key = "REF_stress"
stress_from_calc = True
stress_key = "_REF_stress"

for atoms in atoms_list:
if energy_key == "REF_energy":
if energy_from_calc:
try:
atoms.info["REF_energy"] = atoms.get_potential_energy()
atoms.info["_REF_energy"] = atoms.get_potential_energy()
except Exception as e: # pylint: disable=W0703
logging.warning(f"Failed to extract energy: {e}")
atoms.info["REF_energy"] = None
atoms.info["_REF_energy"] = None

if forces_key == "REF_forces":
if forces_from_calc:
try:
atoms.info["REF_forces"] = atoms.get_forces()
atoms.info["_REF_forces"] = atoms.get_forces()
except Exception as e: # pylint: disable=W0703
logging.warning(f"Failed to extract forces: {e}")
atoms.info["REF_forces"] = None
atoms.info["_REF_forces"] = None

if stress_key == "REF_stress":
if stress_from_calc:
try:
atoms.info["REF_stress"] = atoms.get_stress()
atoms.info["_REF_stress"] = atoms.get_stress()
except Exception as e: # pylint: disable=W0703
atoms.info["REF_stress"] = None
atoms.info["_REF_stress"] = None

if not isinstance(atoms_list, list):
atoms_list = [atoms_list]
Expand Down

0 comments on commit 7cd71d8

Please sign in to comment.