Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add default argparser to main #671

Merged
merged 5 commits into from
Nov 11, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion mace/cli/active_learning_md.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,9 @@


def parse_args() -> argparse.Namespace:
parser = argparse.ArgumentParser()
parser = argparse.ArgumentParser(
formatter_class=argparse.ArgumentDefaultsHelpFormatter,
)
parser.add_argument("--config", help="path to XYZ configurations", required=True)
parser.add_argument(
"--config_index", help="index of configuration", type=int, default=-1
Expand Down
4 changes: 3 additions & 1 deletion mace/cli/create_lammps_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,9 @@


def parse_args():
parser = argparse.ArgumentParser()
parser = argparse.ArgumentParser(
formatter_class=argparse.ArgumentDefaultsHelpFormatter,
)
parser.add_argument(
"model_path",
type=str,
Expand Down
4 changes: 3 additions & 1 deletion mace/cli/eval_configs.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,9 @@


def parse_args() -> argparse.Namespace:
parser = argparse.ArgumentParser()
parser = argparse.ArgumentParser(
formatter_class=argparse.ArgumentDefaultsHelpFormatter,
)
parser.add_argument("--configs", help="path to XYZ configurations", required=True)
parser.add_argument("--model", help="path to model", required=True)
parser.add_argument("--output", help="output path", required=True)
Expand Down
4 changes: 3 additions & 1 deletion mace/cli/fine_tuning_select.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,9 @@


def parse_args() -> argparse.Namespace:
parser = argparse.ArgumentParser()
parser = argparse.ArgumentParser(
formatter_class=argparse.ArgumentDefaultsHelpFormatter,
)
parser.add_argument(
"--configs_pt",
help="path to XYZ configurations for the pretraining",
Expand Down
5 changes: 4 additions & 1 deletion mace/cli/plot_train.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,10 @@ def parse_training_results(path: str) -> List[dict]:


def parse_args() -> argparse.Namespace:
parser = argparse.ArgumentParser(description="Plot mace training statistics")
parser = argparse.ArgumentParser(
description="Plot mace training statistics",
formatter_class=argparse.ArgumentDefaultsHelpFormatter,
)
parser.add_argument(
"--path", help="path to results file or directory", required=True
)
Expand Down
9 changes: 7 additions & 2 deletions mace/tools/arg_parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ def build_default_arg_parser() -> argparse.ArgumentParser:

parser = configargparse.ArgumentParser(
config_file_parser_class=configargparse.YAMLConfigFileParser,
formatter_class=argparse.ArgumentDefaultsHelpFormatter,
)
parser.add(
"--config",
Expand All @@ -23,7 +24,9 @@ def build_default_arg_parser() -> argparse.ArgumentParser:
help="config file to agregate options",
)
except ImportError:
parser = argparse.ArgumentParser()
parser = argparse.ArgumentParser(
formatter_class=argparse.ArgumentDefaultsHelpFormatter,
)

# Name and seed
parser.add_argument("--name", help="experiment name", required=True)
Expand Down Expand Up @@ -704,7 +707,9 @@ def build_default_arg_parser() -> argparse.ArgumentParser:


def build_preprocess_arg_parser() -> argparse.ArgumentParser:
parser = argparse.ArgumentParser()
parser = argparse.ArgumentParser(
formatter_class=argparse.ArgumentDefaultsHelpFormatter,
)
parser.add_argument(
"--train_file",
help="Training set h5 file",
Expand Down
18 changes: 16 additions & 2 deletions mace/tools/finetuning_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,7 @@ def load_foundations_elements(
)
if (
model.interactions[i].__class__.__name__
== "RealAgnosticResidualInteractionBlock"
in ["RealAgnosticResidualInteractionBlock", "RealAgnosticDensityResidualInteractionBlock"]
):
model.interactions[i].skip_tp.weight = torch.nn.Parameter(
model_foundations.interactions[i]
Expand All @@ -101,7 +101,21 @@ def load_foundations_elements(
.clone()
/ (num_species_foundations / num_species) ** 0.5
)

if (
model.interactions[i].__class__.__name__
in ["RealAgnosticDensityInteractionBlock", "RealAgnosticDensityResidualInteractionBlock"]
):
# Assuming only 1 layer in density_fn
getattr(model.interactions[i].density_fn, "layer0").weight = (
torch.nn.Parameter(
getattr(
model_foundations.interactions[i].density_fn,
"layer0",
)
.weight
.clone()
)
)
# Transferring products
for i in range(2): # Assuming 2 products modules
max_range = max_L + 1 if i == 0 else 1
Expand Down
Loading