Skip to content

Commit

Permalink
improving the pt selection and fix E0 bug
Browse files Browse the repository at this point in the history
  • Loading branch information
ilyes319 committed May 16, 2024
1 parent 61781f4 commit f1be21e
Show file tree
Hide file tree
Showing 4 changed files with 36 additions and 4 deletions.
23 changes: 21 additions & 2 deletions mace/cli/fine_tuning_select.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,7 @@ def parse_args() -> argparse.Namespace:
type=float,
default=1.0,
)
parser.add_argument("--seed", help="random seed", type=int, default=42)
return parser.parse_args()


Expand Down Expand Up @@ -197,6 +198,8 @@ def assemble_descriptors(self) -> np.ndarray:
def select_samples(
args: argparse.Namespace,
) -> None:
np.random.seed(args.seed)
torch.manual_seed(args.seed)
if args.model in ["small", "medium", "large"]:
calc = mace_mp(args.model, device=args.device, default_dtype=args.default_dtype)
else:
Expand All @@ -217,18 +220,34 @@ def select_samples(
atoms_list_pt = ase.io.read(args.configs_pt, index=":")
for i, atoms in enumerate(atoms_list_pt):
atoms.info["mace_descriptors"] = descriptors[i]
atoms_list_pt = [
atoms_list_pt_filtered = [
x
for x in atoms_list_pt
if filter_atoms(x, all_species_ft, "combinations")
]
else:
atoms_list_pt = ase.io.read(args.configs_pt, index=":")
atoms_list_pt = [
atoms_list_pt_filtered = [
x
for x in atoms_list_pt
if filter_atoms(x, all_species_ft, "combinations")
]
if len(atoms_list_pt_filtered) <= args.num_samples:
logging.info(
"Number of configurations after filtering is less than the number of samples, "
"selecting random configurations, for the rest."
)
atoms_list_pt_minus_filtered = [
x for x in atoms_list_pt if x not in atoms_list_pt_filtered
]
atoms_list_pt_random = np.random.choice(
atoms_list_pt_minus_filtered,
args.num_samples - len(atoms_list_pt_filtered),
).tolist()
atoms_list_pt = atoms_list_pt_filtered + atoms_list_pt_random
else:
atoms_list_pt = atoms_list_pt_filtered

else:
atoms_list_pt = ase.io.read(args.configs_pt, index=":")
if args.descriptors is not None:
Expand Down
2 changes: 1 addition & 1 deletion mace/cli/run_train.py
Original file line number Diff line number Diff line change
Expand Up @@ -238,7 +238,7 @@ def main() -> None:
args_samples = {
"configs_pt": dataset_mp,
"configs_ft": args.train_file,
"num_samples": 1000,
"num_samples": args.num_samples_pt,
"seed": args.seed,
"model": args.foundation_model,
"head_pt": "pbe_mp",
Expand Down
6 changes: 6 additions & 0 deletions mace/tools/arg_parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -350,6 +350,12 @@ def build_default_arg_parser() -> argparse.ArgumentParser:
type=float,
default=1.0,
)
parser.add_argument(
"--num_samples_pt",
help="Number of samples in the pretrained head",
type=int,
default=1000,
)
parser.add_argument(
"--keep_isolated_atoms",
help="Keep isolated atoms in the dataset, useful for transfer learning",
Expand Down
9 changes: 8 additions & 1 deletion mace/tools/scripts_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -153,7 +153,14 @@ def get_atomic_energies(E0s, train_collection, z_table, heads) -> dict:
atomic_energies_dict = json.load(open(E0s, "r"))
else:
try:
atomic_energies_dict = ast.literal_eval(E0s)
atomic_energies_eval = ast.literal_eval(E0s)
if not all(
isinstance(value, dict)
for value in atomic_energies_eval.values()
):
atomic_energies_dict = {"Default": atomic_energies_eval}
else:
atomic_energies_dict = atomic_energies_eval
assert isinstance(atomic_energies_dict, dict)
except Exception as e:
raise RuntimeError(
Expand Down

0 comments on commit f1be21e

Please sign in to comment.