Skip to content

Commit

Permalink
Merge pull request #671 from ACEsuit/develop
Browse files Browse the repository at this point in the history
Add default argparser to main
  • Loading branch information
ilyes319 authored Nov 11, 2024
2 parents 4081abd + abd7e5e commit f1e671d
Show file tree
Hide file tree
Showing 7 changed files with 39 additions and 9 deletions.
4 changes: 3 additions & 1 deletion mace/cli/active_learning_md.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,9 @@


def parse_args() -> argparse.Namespace:
parser = argparse.ArgumentParser()
parser = argparse.ArgumentParser(
formatter_class=argparse.ArgumentDefaultsHelpFormatter,
)
parser.add_argument("--config", help="path to XYZ configurations", required=True)
parser.add_argument(
"--config_index", help="index of configuration", type=int, default=-1
Expand Down
4 changes: 3 additions & 1 deletion mace/cli/create_lammps_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,9 @@


def parse_args():
parser = argparse.ArgumentParser()
parser = argparse.ArgumentParser(
formatter_class=argparse.ArgumentDefaultsHelpFormatter,
)
parser.add_argument(
"model_path",
type=str,
Expand Down
4 changes: 3 additions & 1 deletion mace/cli/eval_configs.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,9 @@


def parse_args() -> argparse.Namespace:
parser = argparse.ArgumentParser()
parser = argparse.ArgumentParser(
formatter_class=argparse.ArgumentDefaultsHelpFormatter,
)
parser.add_argument("--configs", help="path to XYZ configurations", required=True)
parser.add_argument("--model", help="path to model", required=True)
parser.add_argument("--output", help="output path", required=True)
Expand Down
4 changes: 3 additions & 1 deletion mace/cli/fine_tuning_select.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,9 @@


def parse_args() -> argparse.Namespace:
parser = argparse.ArgumentParser()
parser = argparse.ArgumentParser(
formatter_class=argparse.ArgumentDefaultsHelpFormatter,
)
parser.add_argument(
"--configs_pt",
help="path to XYZ configurations for the pretraining",
Expand Down
5 changes: 4 additions & 1 deletion mace/cli/plot_train.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,10 @@ def parse_training_results(path: str) -> List[dict]:


def parse_args() -> argparse.Namespace:
parser = argparse.ArgumentParser(description="Plot mace training statistics")
parser = argparse.ArgumentParser(
description="Plot mace training statistics",
formatter_class=argparse.ArgumentDefaultsHelpFormatter,
)
parser.add_argument(
"--path", help="path to results file or directory", required=True
)
Expand Down
9 changes: 7 additions & 2 deletions mace/tools/arg_parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ def build_default_arg_parser() -> argparse.ArgumentParser:

parser = configargparse.ArgumentParser(
config_file_parser_class=configargparse.YAMLConfigFileParser,
formatter_class=argparse.ArgumentDefaultsHelpFormatter,
)
parser.add(
"--config",
Expand All @@ -23,7 +24,9 @@ def build_default_arg_parser() -> argparse.ArgumentParser:
help="config file to agregate options",
)
except ImportError:
parser = argparse.ArgumentParser()
parser = argparse.ArgumentParser(
formatter_class=argparse.ArgumentDefaultsHelpFormatter,
)

# Name and seed
parser.add_argument("--name", help="experiment name", required=True)
Expand Down Expand Up @@ -704,7 +707,9 @@ def build_default_arg_parser() -> argparse.ArgumentParser:


def build_preprocess_arg_parser() -> argparse.ArgumentParser:
parser = argparse.ArgumentParser()
parser = argparse.ArgumentParser(
formatter_class=argparse.ArgumentDefaultsHelpFormatter,
)
parser.add_argument(
"--train_file",
help="Training set h5 file",
Expand Down
18 changes: 16 additions & 2 deletions mace/tools/finetuning_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,7 @@ def load_foundations_elements(
)
if (
model.interactions[i].__class__.__name__
== "RealAgnosticResidualInteractionBlock"
in ["RealAgnosticResidualInteractionBlock", "RealAgnosticDensityResidualInteractionBlock"]
):
model.interactions[i].skip_tp.weight = torch.nn.Parameter(
model_foundations.interactions[i]
Expand All @@ -101,7 +101,21 @@ def load_foundations_elements(
.clone()
/ (num_species_foundations / num_species) ** 0.5
)

if (
model.interactions[i].__class__.__name__
in ["RealAgnosticDensityInteractionBlock", "RealAgnosticDensityResidualInteractionBlock"]
):
# Assuming only 1 layer in density_fn
getattr(model.interactions[i].density_fn, "layer0").weight = (
torch.nn.Parameter(
getattr(
model_foundations.interactions[i].density_fn,
"layer0",
)
.weight
.clone()
)
)
# Transferring products
for i in range(2): # Assuming 2 products modules
max_range = max_L + 1 if i == 0 else 1
Expand Down

0 comments on commit f1e671d

Please sign in to comment.