diff --git a/mace/cli/run_train.py b/mace/cli/run_train.py index 725c5d21..a4b2aebe 100644 --- a/mace/cli/run_train.py +++ b/mace/cli/run_train.py @@ -357,7 +357,10 @@ def run(args: argparse.Namespace) -> None: if args.loss in ("stress", "virials", "huber", "universal"): compute_virials = True args.compute_stress = True - args.error_table = "PerAtomRMSEstressvirials" + if "MAE" in args.error_table: + args.error_table = "PerAtomMAEstressvirials" + else: + args.error_table = "PerAtomRMSEstressvirials" output_args = { "energy": compute_energy, @@ -821,7 +824,9 @@ def run(args: argparse.Namespace) -> None: ), } if swa_eval: - torch.save(model, Path(args.model_dir) / (args.name + "_stagetwo.model")) + torch.save( + model, Path(args.model_dir) / (args.name + "_stagetwo.model") + ) try: path_complied = Path(args.model_dir) / ( args.name + "_stagetwo_compiled.model" diff --git a/mace/tools/arg_parser.py b/mace/tools/arg_parser.py index 37d0ce8d..38034335 100644 --- a/mace/tools/arg_parser.py +++ b/mace/tools/arg_parser.py @@ -80,6 +80,7 @@ def build_default_arg_parser() -> argparse.ArgumentParser: "PerAtomRMSE", "TotalRMSE", "PerAtomRMSEstressvirials", + "PerAtomMAEstressvirials", "PerAtomMAE", "TotalMAE", "DipoleRMSE", @@ -388,7 +389,8 @@ def build_default_arg_parser() -> argparse.ArgumentParser: "--forces_weight", help="weight of forces loss", type=float, default=100.0 ) parser.add_argument( - "--swa_forces_weight","--stage_two_forces_weight", + "--swa_forces_weight", + "--stage_two_forces_weight", help="weight of forces loss after starting Stage Two (previously called swa)", type=float, default=100.0, @@ -398,7 +400,8 @@ def build_default_arg_parser() -> argparse.ArgumentParser: "--energy_weight", help="weight of energy loss", type=float, default=1.0 ) parser.add_argument( - "--swa_energy_weight","--stage_two_energy_weight", + "--swa_energy_weight", + "--stage_two_energy_weight", help="weight of energy loss after starting Stage Two (previously called swa)", type=float, default=1000.0, @@ -408,7 +411,8 @@ def build_default_arg_parser() -> argparse.ArgumentParser: "--virials_weight", help="weight of virials loss", type=float, default=1.0 ) parser.add_argument( - "--swa_virials_weight", "--stage_two_virials_weight", + "--swa_virials_weight", + "--stage_two_virials_weight", help="weight of virials loss after starting Stage Two (previously called swa)", type=float, default=10.0, @@ -418,7 +422,8 @@ def build_default_arg_parser() -> argparse.ArgumentParser: "--stress_weight", help="weight of virials loss", type=float, default=1.0 ) parser.add_argument( - "--swa_stress_weight", "--stage_two_stress_weight", + "--swa_stress_weight", + "--stage_two_stress_weight", help="weight of stress loss after starting Stage Two (previously called swa)", type=float, default=10.0, @@ -428,7 +433,8 @@ def build_default_arg_parser() -> argparse.ArgumentParser: "--dipole_weight", help="weight of dipoles loss", type=float, default=1.0 ) parser.add_argument( - "--swa_dipole_weight","--stage_two_dipole_weight", + "--swa_dipole_weight", + "--stage_two_dipole_weight", help="weight of dipoles after starting Stage Two (previously called swa)", type=float, default=1.0, @@ -467,7 +473,12 @@ def build_default_arg_parser() -> argparse.ArgumentParser: "--lr", help="Learning rate of optimizer", type=float, default=0.01 ) parser.add_argument( - "--swa_lr", "--stage_two_lr", help="Learning rate of optimizer in Stage Two (previously called swa)", type=float, default=1e-3, dest="swa_lr" + "--swa_lr", + "--stage_two_lr", + help="Learning rate of optimizer in Stage Two (previously called swa)", + type=float, + default=1e-3, + dest="swa_lr", ) parser.add_argument( "--weight_decay", help="weight decay (L2 penalty)", type=float, default=5e-7 @@ -494,14 +505,16 @@ def build_default_arg_parser() -> argparse.ArgumentParser: default=0.9993, ) parser.add_argument( - "--swa", "--stage_two", + "--swa", + "--stage_two", help="use Stage Two loss weight, which decreases the learning rate and increases the energy weight at the end of the training to help converge them", action="store_true", default=False, dest="swa", ) parser.add_argument( - "--start_swa","--start_stage_two", + "--start_swa", + "--start_stage_two", help="Number of epochs before changing to Stage Two loss weights", type=int, default=None, diff --git a/mace/tools/scripts_utils.py b/mace/tools/scripts_utils.py index cc7b3929..106cb9b0 100644 --- a/mace/tools/scripts_utils.py +++ b/mace/tools/scripts_utils.py @@ -304,11 +304,18 @@ def get_atomic_energies(E0s, train_collection, z_table) -> dict: f"Could not compute average E0s if no training xyz given, error {e} occured" ) from e else: - try: - atomic_energies_dict = ast.literal_eval(E0s) - assert isinstance(atomic_energies_dict, dict) - except Exception as e: - raise RuntimeError(f"E0s specified invalidly, error {e} occured") from e + if E0s.endswith(".json"): + logging.info(f"Loading atomic energies from {E0s}") + with open(E0s, "r", encoding="utf-8") as f: + atomic_energies_dict = json.load(f) + else: + try: + atomic_energies_dict = ast.literal_eval(E0s) + assert isinstance(atomic_energies_dict, dict) + except Exception as e: + raise RuntimeError( + f"E0s specified invalidly, error {e} occured" + ) from e else: raise RuntimeError( "E0s not found in training file and not specified in command line" @@ -454,6 +461,14 @@ def create_error_table( "relative F RMSE %", "RMSE Stress (Virials) / meV / A (A^3)", ] + elif table_type == "PerAtomMAEstressvirials": + table.field_names = [ + "config_type", + "MAE E / meV / atom", + "MAE F / meV / A", + "relative F MAE %", + "MAE Stress (Virials) / meV / A (A^3)", + ] elif table_type == "TotalMAE": table.field_names = [ "config_type", @@ -558,6 +573,32 @@ def create_error_table( f"{metrics['rmse_virials'] * 1000:.1f}", ] ) + elif ( + table_type == "PerAtomMAEstressvirials" + and metrics["mae_stress"] is not None + ): + table.add_row( + [ + name, + f"{metrics['mae_e_per_atom'] * 1000:.1f}", + f"{metrics['mae_f'] * 1000:.1f}", + f"{metrics['rel_mae_f']:.2f}", + f"{metrics['mae_stress'] * 1000:.1f}", + ] + ) + elif ( + table_type == "PerAtomMAEstressvirials" + and metrics["mae_virials"] is not None + ): + table.add_row( + [ + name, + f"{metrics['mae_e_per_atom'] * 1000:.1f}", + f"{metrics['mae_f'] * 1000:.1f}", + f"{metrics['rel_mae_f']:.2f}", + f"{metrics['mae_virials'] * 1000:.1f}", + ] + ) elif table_type == "TotalMAE": table.add_row( [ diff --git a/mace/tools/train.py b/mace/tools/train.py index 575fb02d..20256cec 100644 --- a/mace/tools/train.py +++ b/mace/tools/train.py @@ -71,6 +71,26 @@ def valid_err_log(valid_loss, eval_metrics, logger, log_errors, epoch=None): logging.info( f"Epoch {epoch}: loss={valid_loss:.4f}, RMSE_E_per_atom={error_e:.1f} meV, RMSE_F={error_f:.1f} meV / A, RMSE_virials_per_atom={error_virials:.1f} meV" ) + elif ( + log_errors == "PerAtomMAEstressvirials" + and eval_metrics["mae_stress_per_atom"] is not None + ): + error_e = eval_metrics["mae_e_per_atom"] * 1e3 + error_f = eval_metrics["mae_f"] * 1e3 + error_stress = eval_metrics["mae_stress"] * 1e3 + logging.info( + f"Epoch {epoch}: loss={valid_loss:.4f}, MAE_E_per_atom={error_e:.1f} meV, MAE_F={error_f:.1f} meV / A, MAE_stress={error_stress:.1f} meV / A^3" + ) + elif ( + log_errors == "PerAtomMAEstressvirials" + and eval_metrics["mae_virials_per_atom"] is not None + ): + error_e = eval_metrics["mae_e_per_atom"] * 1e3 + error_f = eval_metrics["mae_f"] * 1e3 + error_virials = eval_metrics["mae_virials"] * 1e3 + logging.info( + f"Epoch {epoch}: loss={valid_loss:.4f}, MAE_E_per_atom={error_e:.1f} meV, MAE_F={error_f:.1f} meV / A, MAE_virials={error_virials:.1f} meV" + ) elif log_errors == "TotalRMSE": error_e = eval_metrics["rmse_e"] * 1e3 error_f = eval_metrics["rmse_f"] * 1e3