Skip to content

Commit

Permalink
use .4g as format gives better float representation for numbers in tr…
Browse files Browse the repository at this point in the history
…aining, especially loss
  • Loading branch information
Alin Marin Elena authored and alinelena committed Nov 14, 2024
1 parent bd41231 commit 98347f3
Showing 1 changed file with 10 additions and 10 deletions.
20 changes: 10 additions & 10 deletions mace/tools/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ def valid_err_log(
error_e = eval_metrics["rmse_e_per_atom"] * 1e3
error_f = eval_metrics["rmse_f"] * 1e3
logging.info(
f"{inintial_phrase}: head: {valid_loader_name}, loss={valid_loss:8.8f}, RMSE_E_per_atom={error_e:8.2f} meV, RMSE_F={error_f:8.2f} meV / A"
f"{inintial_phrase}: head: {valid_loader_name}, loss={valid_loss:.4g}, RMSE_E_per_atom={error_e:.4g} meV, RMSE_F={error_f:.4g} meV / A"
)
elif (
log_errors == "PerAtomRMSEstressvirials"
Expand All @@ -70,7 +70,7 @@ def valid_err_log(
error_f = eval_metrics["rmse_f"] * 1e3
error_stress = eval_metrics["rmse_stress"] * 1e3
logging.info(
f"{inintial_phrase}: head: {valid_loader_name}, loss={valid_loss:8.8f}, RMSE_E_per_atom={error_e:8.2f} meV, RMSE_F={error_f:8.2f} meV / A, RMSE_stress={error_stress:8.2f} meV / A^3",
f"{inintial_phrase}: head: {valid_loader_name}, loss={valid_loss:.4g}, RMSE_E_per_atom={error_e:.4g} meV, RMSE_F={error_f:.4g} meV / A, RMSE_stress={error_stress:.4g} meV / A^3",
)
elif (
log_errors == "PerAtomRMSEstressvirials"
Expand All @@ -80,7 +80,7 @@ def valid_err_log(
error_f = eval_metrics["rmse_f"] * 1e3
error_virials = eval_metrics["rmse_virials_per_atom"] * 1e3
logging.info(
f"{inintial_phrase}: head: {valid_loader_name}, loss={valid_loss:8.8f}, RMSE_E_per_atom={error_e:8.2f} meV, RMSE_F={error_f:8.2f} meV / A, RMSE_virials_per_atom={error_virials:8.2f} meV",
f"{inintial_phrase}: head: {valid_loader_name}, loss={valid_loss:.4g}, RMSE_E_per_atom={error_e:.4g} meV, RMSE_F={error_f:.4g} meV / A, RMSE_virials_per_atom={error_virials:.4g} meV",
)
elif (
log_errors == "PerAtomMAEstressvirials"
Expand All @@ -90,7 +90,7 @@ def valid_err_log(
error_f = eval_metrics["mae_f"] * 1e3
error_stress = eval_metrics["mae_stress"] * 1e3
logging.info(
f"{inintial_phrase}: loss={valid_loss:8.8f}, MAE_E_per_atom={error_e:8.2f} meV, MAE_F={error_f:8.2f} meV / A, MAE_stress={error_stress:8.2f} meV / A^3"
f"{inintial_phrase}: loss={valid_loss:.4g}, MAE_E_per_atom={error_e:.4g} meV, MAE_F={error_f:.4g} meV / A, MAE_stress={error_stress:.4g} meV / A^3"
)
elif (
log_errors == "PerAtomMAEstressvirials"
Expand All @@ -100,37 +100,37 @@ def valid_err_log(
error_f = eval_metrics["mae_f"] * 1e3
error_virials = eval_metrics["mae_virials"] * 1e3
logging.info(
f"{inintial_phrase}: loss={valid_loss:8.8f}, MAE_E_per_atom={error_e:8.2f} meV, MAE_F={error_f:8.2f} meV / A, MAE_virials={error_virials:8.2f} meV"
f"{inintial_phrase}: loss={valid_loss:.4g}, MAE_E_per_atom={error_e:.4g} meV, MAE_F={error_f:.4g} meV / A, MAE_virials={error_virials:.4g} meV"
)
elif log_errors == "TotalRMSE":
error_e = eval_metrics["rmse_e"] * 1e3
error_f = eval_metrics["rmse_f"] * 1e3
logging.info(
f"{inintial_phrase}: head: {valid_loader_name}, loss={valid_loss:8.8f}, RMSE_E={error_e:8.2f} meV, RMSE_F={error_f:8.2f} meV / A",
f"{inintial_phrase}: head: {valid_loader_name}, loss={valid_loss:.4g}, RMSE_E={error_e:.4g} meV, RMSE_F={error_f:.4g} meV / A",
)
elif log_errors == "PerAtomMAE":
error_e = eval_metrics["mae_e_per_atom"] * 1e3
error_f = eval_metrics["mae_f"] * 1e3
logging.info(
f"{inintial_phrase}: head: {valid_loader_name}, loss={valid_loss:8.8f}, MAE_E_per_atom={error_e:8.2f} meV, MAE_F={error_f:8.2f} meV / A",
f"{inintial_phrase}: head: {valid_loader_name}, loss={valid_loss:.4g}, MAE_E_per_atom={error_e:.4g} meV, MAE_F={error_f:.4g} meV / A",
)
elif log_errors == "TotalMAE":
error_e = eval_metrics["mae_e"] * 1e3
error_f = eval_metrics["mae_f"] * 1e3
logging.info(
f"{inintial_phrase}: head: {valid_loader_name}, loss={valid_loss:8.8f}, MAE_E={error_e:8.2f} meV, MAE_F={error_f:8.2f} meV / A",
f"{inintial_phrase}: head: {valid_loader_name}, loss={valid_loss:.4g}, MAE_E={error_e:.4g} meV, MAE_F={error_f:.4g} meV / A",
)
elif log_errors == "DipoleRMSE":
error_mu = eval_metrics["rmse_mu_per_atom"] * 1e3
logging.info(
f"{inintial_phrase}: head: {valid_loader_name}, loss={valid_loss:8.8f}, RMSE_MU_per_atom={error_mu:8.2f} mDebye",
f"{inintial_phrase}: head: {valid_loader_name}, loss={valid_loss:.4g}, RMSE_MU_per_atom={error_mu:.4g} mDebye",
)
elif log_errors == "EnergyDipoleRMSE":
error_e = eval_metrics["rmse_e_per_atom"] * 1e3
error_f = eval_metrics["rmse_f"] * 1e3
error_mu = eval_metrics["rmse_mu_per_atom"] * 1e3
logging.info(
f"{inintial_phrase}: head: {valid_loader_name}, loss={valid_loss:8.8f}, RMSE_E_per_atom={error_e:8.2f} meV, RMSE_F={error_f:8.2f} meV / A, RMSE_Mu_per_atom={error_mu:8.2f} mDebye",
f"{inintial_phrase}: head: {valid_loader_name}, loss={valid_loss:.4g}, RMSE_E_per_atom={error_e:.4g} meV, RMSE_F={error_f:.4g} meV / A, RMSE_Mu_per_atom={error_mu:.4g} mDebye",
)


Expand Down

0 comments on commit 98347f3

Please sign in to comment.