From f1be21e1fbd742ab53448b3851b5e1d50abfc04f Mon Sep 17 00:00:00 2001 From: Ilyes Batatia <48651863+ilyes319@users.noreply.github.com> Date: Thu, 16 May 2024 23:04:33 +0100 Subject: [PATCH] improving the pt selection and fix E0 bug --- mace/cli/fine_tuning_select.py | 23 +++++++++++++++++++++-- mace/cli/run_train.py | 2 +- mace/tools/arg_parser.py | 6 ++++++ mace/tools/scripts_utils.py | 9 ++++++++- 4 files changed, 36 insertions(+), 4 deletions(-) diff --git a/mace/cli/fine_tuning_select.py b/mace/cli/fine_tuning_select.py index a9349a6f..19e44d7d 100644 --- a/mace/cli/fine_tuning_select.py +++ b/mace/cli/fine_tuning_select.py @@ -92,6 +92,7 @@ def parse_args() -> argparse.Namespace: type=float, default=1.0, ) + parser.add_argument("--seed", help="random seed", type=int, default=42) return parser.parse_args() @@ -197,6 +198,8 @@ def assemble_descriptors(self) -> np.ndarray: def select_samples( args: argparse.Namespace, ) -> None: + np.random.seed(args.seed) + torch.manual_seed(args.seed) if args.model in ["small", "medium", "large"]: calc = mace_mp(args.model, device=args.device, default_dtype=args.default_dtype) else: @@ -217,18 +220,34 @@ def select_samples( atoms_list_pt = ase.io.read(args.configs_pt, index=":") for i, atoms in enumerate(atoms_list_pt): atoms.info["mace_descriptors"] = descriptors[i] - atoms_list_pt = [ + atoms_list_pt_filtered = [ x for x in atoms_list_pt if filter_atoms(x, all_species_ft, "combinations") ] else: atoms_list_pt = ase.io.read(args.configs_pt, index=":") - atoms_list_pt = [ + atoms_list_pt_filtered = [ x for x in atoms_list_pt if filter_atoms(x, all_species_ft, "combinations") ] + if len(atoms_list_pt_filtered) <= args.num_samples: + logging.info( + "Number of configurations after filtering is less than the number of samples, " + "selecting random configurations, for the rest." + ) + atoms_list_pt_minus_filtered = [ + x for x in atoms_list_pt if x not in atoms_list_pt_filtered + ] + atoms_list_pt_random = np.random.choice( + atoms_list_pt_minus_filtered, + args.num_samples - len(atoms_list_pt_filtered), + ).tolist() + atoms_list_pt = atoms_list_pt_filtered + atoms_list_pt_random + else: + atoms_list_pt = atoms_list_pt_filtered + else: atoms_list_pt = ase.io.read(args.configs_pt, index=":") if args.descriptors is not None: diff --git a/mace/cli/run_train.py b/mace/cli/run_train.py index 6497f7cf..462966d7 100644 --- a/mace/cli/run_train.py +++ b/mace/cli/run_train.py @@ -238,7 +238,7 @@ def main() -> None: args_samples = { "configs_pt": dataset_mp, "configs_ft": args.train_file, - "num_samples": 1000, + "num_samples": args.num_samples_pt, "seed": args.seed, "model": args.foundation_model, "head_pt": "pbe_mp", diff --git a/mace/tools/arg_parser.py b/mace/tools/arg_parser.py index b05816dd..abb550e1 100644 --- a/mace/tools/arg_parser.py +++ b/mace/tools/arg_parser.py @@ -350,6 +350,12 @@ def build_default_arg_parser() -> argparse.ArgumentParser: type=float, default=1.0, ) + parser.add_argument( + "--num_samples_pt", + help="Number of samples in the pretrained head", + type=int, + default=1000, + ) parser.add_argument( "--keep_isolated_atoms", help="Keep isolated atoms in the dataset, useful for transfer learning", diff --git a/mace/tools/scripts_utils.py b/mace/tools/scripts_utils.py index e03d754b..9254abb3 100644 --- a/mace/tools/scripts_utils.py +++ b/mace/tools/scripts_utils.py @@ -153,7 +153,14 @@ def get_atomic_energies(E0s, train_collection, z_table, heads) -> dict: atomic_energies_dict = json.load(open(E0s, "r")) else: try: - atomic_energies_dict = ast.literal_eval(E0s) + atomic_energies_eval = ast.literal_eval(E0s) + if not all( + isinstance(value, dict) + for value in atomic_energies_eval.values() + ): + atomic_energies_dict = {"Default": atomic_energies_eval} + else: + atomic_energies_dict = atomic_energies_eval assert isinstance(atomic_energies_dict, dict) except Exception as e: raise RuntimeError(