diff --git a/mace/cli/run_train.py b/mace/cli/run_train.py index e8319ac7..1c0898b7 100644 --- a/mace/cli/run_train.py +++ b/mace/cli/run_train.py @@ -335,8 +335,21 @@ def run(args: argparse.Namespace) -> None: ) head_config_pt.collections = collections head_configs.append(head_config_pt) + + ratio_pt_ft = size_collections_train / len(head_config_pt.collections.train) + if ratio_pt_ft < 0.1: + logging.warning( + f"Ratio of the number of configurations in the training set and the in the pt_train_file is {ratio_pt_ft}, " + f"increasing the number of configurations in the pt_train_file by a factor of {int(0.1 / ratio_pt_ft)}" + ) + for head_config in head_configs: + if head_config.head_name == "pt_head": + continue + head_config.collections.train += head_config.collections.train * int( + 0.1 / ratio_pt_ft + ) logging.info( - f"Total number of configurations: train={len(collections.train)}, valid={len(collections.valid)}" + f"Total number of configurations in pretraining: train={len(head_config_pt.collections.train)}, valid={len(head_config_pt.collections.valid)}" ) # Atomic number table