diff --git a/mace/cli/active_learning_md.py b/mace/cli/active_learning_md.py index a26be698..9cf4f4a8 100644 --- a/mace/cli/active_learning_md.py +++ b/mace/cli/active_learning_md.py @@ -14,7 +14,9 @@ def parse_args() -> argparse.Namespace: - parser = argparse.ArgumentParser() + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter, + ) parser.add_argument("--config", help="path to XYZ configurations", required=True) parser.add_argument( "--config_index", help="index of configuration", type=int, default=-1 diff --git a/mace/cli/convert_dev.py b/mace/cli/convert_dev.py new file mode 100644 index 00000000..9dd8c61d --- /dev/null +++ b/mace/cli/convert_dev.py @@ -0,0 +1,31 @@ +from argparse import ArgumentParser + +import torch + + +def main(): + parser = ArgumentParser() + parser.add_argument( + "--target_device", + "-t", + help="device to convert to, usually 'cpu' or 'cuda'", + default="cpu", + ) + parser.add_argument( + "--output_file", + "-o", + help="name for output model, defaults to model_file.target_device", + ) + parser.add_argument("model_file", help="input model file path") + args = parser.parse_args() + + if args.output_file is None: + args.output_file = args.model_file + "." + args.target_device + + model = torch.load(args.model_file) + model.to(args.target_device) + torch.save(model, args.output_file) + + +if __name__ == "__main__": + main() diff --git a/mace/cli/create_lammps_model.py b/mace/cli/create_lammps_model.py index 1917ab8e..507a2cd0 100644 --- a/mace/cli/create_lammps_model.py +++ b/mace/cli/create_lammps_model.py @@ -7,7 +7,9 @@ def parse_args(): - parser = argparse.ArgumentParser() + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter, + ) parser.add_argument( "model_path", type=str, diff --git a/mace/cli/eval_configs.py b/mace/cli/eval_configs.py index 86470f8d..d00c54c6 100644 --- a/mace/cli/eval_configs.py +++ b/mace/cli/eval_configs.py @@ -16,7 +16,9 @@ def parse_args() -> argparse.Namespace: - parser = argparse.ArgumentParser() + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter, + ) parser.add_argument("--configs", help="path to XYZ configurations", required=True) parser.add_argument("--model", help="path to model", required=True) parser.add_argument("--output", help="output path", required=True) diff --git a/mace/cli/fine_tuning_select.py b/mace/cli/fine_tuning_select.py index 2fa5f644..94baf0dd 100644 --- a/mace/cli/fine_tuning_select.py +++ b/mace/cli/fine_tuning_select.py @@ -20,7 +20,9 @@ def parse_args() -> argparse.Namespace: - parser = argparse.ArgumentParser() + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter, + ) parser.add_argument( "--configs_pt", help="path to XYZ configurations for the pretraining", diff --git a/mace/cli/plot_train.py b/mace/cli/plot_train.py index c249d76a..a1c424df 100644 --- a/mace/cli/plot_train.py +++ b/mace/cli/plot_train.py @@ -60,7 +60,10 @@ def parse_training_results(path: str) -> List[dict]: def parse_args() -> argparse.Namespace: - parser = argparse.ArgumentParser(description="Plot mace training statistics") + parser = argparse.ArgumentParser( + description="Plot mace training statistics", + formatter_class=argparse.ArgumentDefaultsHelpFormatter, + ) parser.add_argument( "--path", help="path to results file or directory", required=True ) diff --git a/mace/cli/run_train.py b/mace/cli/run_train.py index 654cc358..9a754ffc 100644 --- a/mace/cli/run_train.py +++ b/mace/cli/run_train.py @@ -363,8 +363,14 @@ def run(args: argparse.Namespace) -> None: z_table_foundation = AtomicNumberTable( [int(z) for z in model_foundation.atomic_numbers] ) + foundation_atomic_energies = model_foundation.atomic_energies_fn.atomic_energies + if foundation_atomic_energies.ndim > 1: + foundation_atomic_energies = foundation_atomic_energies.squeeze() + if foundation_atomic_energies.ndim == 2: + foundation_atomic_energies = foundation_atomic_energies[0] + logging.info("Foundation model has multiple heads, using the first head as foundation E0s.") atomic_energies_dict[head_config.head_name] = { - z: model_foundation.atomic_energies_fn.atomic_energies[ + z: foundation_atomic_energies[ z_table_foundation.z_to_index(z) ].item() for z in z_table.zs @@ -382,8 +388,14 @@ def run(args: argparse.Namespace) -> None: z_table_foundation = AtomicNumberTable( [int(z) for z in model_foundation.atomic_numbers] ) + foundation_atomic_energies = model_foundation.atomic_energies_fn.atomic_energies + if foundation_atomic_energies.ndim > 1: + foundation_atomic_energies = foundation_atomic_energies.squeeze() + if foundation_atomic_energies.ndim == 2: + foundation_atomic_energies = foundation_atomic_energies[0] + logging.info("Foundation model has multiple heads, using the first head as foundation E0s.") atomic_energies_dict["pt_head"] = { - z: model_foundation.atomic_energies_fn.atomic_energies[ + z: foundation_atomic_energies[ z_table_foundation.z_to_index(z) ].item() for z in z_table.zs @@ -663,7 +675,6 @@ def run(args: argparse.Namespace) -> None: folder, r_max=args.r_max, z_table=z_table, heads=heads, head=head_config.head_name ) for test_name, test_set in test_sets.items(): - logging.info("test_name", test_name) test_sampler = None if args.distributed: test_sampler = torch.utils.data.distributed.DistributedSampler( diff --git a/mace/modules/__init__.py b/mace/modules/__init__.py index 9278130f..e48e0b23 100644 --- a/mace/modules/__init__.py +++ b/mace/modules/__init__.py @@ -15,6 +15,8 @@ NonLinearReadoutBlock, RadialEmbeddingBlock, RealAgnosticAttResidualInteractionBlock, + RealAgnosticDensityInteractionBlock, + RealAgnosticDensityResidualInteractionBlock, RealAgnosticInteractionBlock, RealAgnosticResidualInteractionBlock, ResidualElementDependentInteractionBlock, @@ -56,6 +58,8 @@ "RealAgnosticResidualInteractionBlock": RealAgnosticResidualInteractionBlock, "RealAgnosticAttResidualInteractionBlock": RealAgnosticAttResidualInteractionBlock, "RealAgnosticInteractionBlock": RealAgnosticInteractionBlock, + "RealAgnosticDensityInteractionBlock": RealAgnosticDensityInteractionBlock, + "RealAgnosticDensityResidualInteractionBlock": RealAgnosticDensityResidualInteractionBlock, } scaling_classes: Dict[str, Callable] = { diff --git a/mace/modules/blocks.py b/mace/modules/blocks.py index 34539b0b..0db3b02e 100644 --- a/mace/modules/blocks.py +++ b/mace/modules/blocks.py @@ -664,6 +664,181 @@ def forward( ) # [n_nodes, channels, (lmax + 1)**2] +@compile_mode("script") +class RealAgnosticDensityInteractionBlock(InteractionBlock): + def _setup(self) -> None: + # First linear + self.linear_up = o3.Linear( + self.node_feats_irreps, + self.node_feats_irreps, + internal_weights=True, + shared_weights=True, + ) + # TensorProduct + irreps_mid, instructions = tp_out_irreps_with_instructions( + self.node_feats_irreps, + self.edge_attrs_irreps, + self.target_irreps, + ) + self.conv_tp = o3.TensorProduct( + self.node_feats_irreps, + self.edge_attrs_irreps, + irreps_mid, + instructions=instructions, + shared_weights=False, + internal_weights=False, + ) + + # Convolution weights + input_dim = self.edge_feats_irreps.num_irreps + self.conv_tp_weights = nn.FullyConnectedNet( + [input_dim] + self.radial_MLP + [self.conv_tp.weight_numel], + torch.nn.functional.silu, + ) + + # Linear + irreps_mid = irreps_mid.simplify() + self.irreps_out = self.target_irreps + self.linear = o3.Linear( + irreps_mid, self.irreps_out, internal_weights=True, shared_weights=True + ) + + # Selector TensorProduct + self.skip_tp = o3.FullyConnectedTensorProduct( + self.irreps_out, self.node_attrs_irreps, self.irreps_out + ) + self.reshape = reshape_irreps(self.irreps_out) + + # Density normalization + self.density_fn = nn.FullyConnectedNet( + [input_dim] + + [ + 1, + ], + torch.nn.functional.silu, + ) + # Reshape + self.reshape = reshape_irreps(self.irreps_out) + + def forward( + self, + node_attrs: torch.Tensor, + node_feats: torch.Tensor, + edge_attrs: torch.Tensor, + edge_feats: torch.Tensor, + edge_index: torch.Tensor, + ) -> Tuple[torch.Tensor, None]: + sender = edge_index[0] + receiver = edge_index[1] + num_nodes = node_feats.shape[0] + node_feats = self.linear_up(node_feats) + tp_weights = self.conv_tp_weights(edge_feats) + edge_density = torch.tanh(self.density_fn(edge_feats) ** 2) + mji = self.conv_tp( + node_feats[sender], edge_attrs, tp_weights + ) # [n_edges, irreps] + density = scatter_sum( + src=edge_density, index=receiver, dim=0, dim_size=num_nodes + ) # [n_nodes, 1] + message = scatter_sum( + src=mji, index=receiver, dim=0, dim_size=num_nodes + ) # [n_nodes, irreps] + message = self.linear(message) / (density + 1) + message = self.skip_tp(message, node_attrs) + return ( + self.reshape(message), + None, + ) # [n_nodes, channels, (lmax + 1)**2] + + +@compile_mode("script") +class RealAgnosticDensityResidualInteractionBlock(InteractionBlock): + def _setup(self) -> None: + # First linear + self.linear_up = o3.Linear( + self.node_feats_irreps, + self.node_feats_irreps, + internal_weights=True, + shared_weights=True, + ) + # TensorProduct + irreps_mid, instructions = tp_out_irreps_with_instructions( + self.node_feats_irreps, + self.edge_attrs_irreps, + self.target_irreps, + ) + self.conv_tp = o3.TensorProduct( + self.node_feats_irreps, + self.edge_attrs_irreps, + irreps_mid, + instructions=instructions, + shared_weights=False, + internal_weights=False, + ) + + # Convolution weights + input_dim = self.edge_feats_irreps.num_irreps + self.conv_tp_weights = nn.FullyConnectedNet( + [input_dim] + self.radial_MLP + [self.conv_tp.weight_numel], + torch.nn.functional.silu, # gate + ) + + # Linear + irreps_mid = irreps_mid.simplify() + self.irreps_out = self.target_irreps + self.linear = o3.Linear( + irreps_mid, self.irreps_out, internal_weights=True, shared_weights=True + ) + + # Selector TensorProduct + self.skip_tp = o3.FullyConnectedTensorProduct( + self.node_feats_irreps, self.node_attrs_irreps, self.hidden_irreps + ) + self.reshape = reshape_irreps(self.irreps_out) + + # Density normalization + self.density_fn = nn.FullyConnectedNet( + [input_dim] + + [ + 1, + ], + torch.nn.functional.silu, + ) + + # Reshape + self.reshape = reshape_irreps(self.irreps_out) + + def forward( + self, + node_attrs: torch.Tensor, + node_feats: torch.Tensor, + edge_attrs: torch.Tensor, + edge_feats: torch.Tensor, + edge_index: torch.Tensor, + ) -> Tuple[torch.Tensor, torch.Tensor]: + sender = edge_index[0] + receiver = edge_index[1] + num_nodes = node_feats.shape[0] + sc = self.skip_tp(node_feats, node_attrs) + node_feats = self.linear_up(node_feats) + tp_weights = self.conv_tp_weights(edge_feats) + edge_density = torch.tanh(self.density_fn(edge_feats) ** 2) + mji = self.conv_tp( + node_feats[sender], edge_attrs, tp_weights + ) # [n_edges, irreps] + density = scatter_sum( + src=edge_density, index=receiver, dim=0, dim_size=num_nodes + ) # [n_nodes, 1] + message = scatter_sum( + src=mji, index=receiver, dim=0, dim_size=num_nodes + ) # [n_nodes, irreps] + message = self.linear(message) / (density + 1) + return ( + self.reshape(message), + sc, + ) # [n_nodes, channels, (lmax + 1)**2] + + @compile_mode("script") class RealAgnosticAttResidualInteractionBlock(InteractionBlock): def _setup(self) -> None: diff --git a/mace/tools/arg_parser.py b/mace/tools/arg_parser.py index f64fe41b..9db0ccd5 100644 --- a/mace/tools/arg_parser.py +++ b/mace/tools/arg_parser.py @@ -15,6 +15,7 @@ def build_default_arg_parser() -> argparse.ArgumentParser: parser = configargparse.ArgumentParser( config_file_parser_class=configargparse.YAMLConfigFileParser, + formatter_class=argparse.ArgumentDefaultsHelpFormatter, ) parser.add( "--config", @@ -23,7 +24,9 @@ def build_default_arg_parser() -> argparse.ArgumentParser: help="config file to agregate options", ) except ImportError: - parser = argparse.ArgumentParser() + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter, + ) # Name and seed parser.add_argument("--name", help="experiment name", required=True) @@ -153,6 +156,8 @@ def build_default_arg_parser() -> argparse.ArgumentParser: "RealAgnosticResidualInteractionBlock", "RealAgnosticAttResidualInteractionBlock", "RealAgnosticInteractionBlock", + "RealAgnosticDensityInteractionBlock", + "RealAgnosticDensityResidualInteractionBlock", ], ) parser.add_argument( @@ -163,6 +168,8 @@ def build_default_arg_parser() -> argparse.ArgumentParser: choices=[ "RealAgnosticResidualInteractionBlock", "RealAgnosticInteractionBlock", + "RealAgnosticDensityInteractionBlock", + "RealAgnosticDensityResidualInteractionBlock", ], ) parser.add_argument( @@ -706,7 +713,9 @@ def build_default_arg_parser() -> argparse.ArgumentParser: def build_preprocess_arg_parser() -> argparse.ArgumentParser: - parser = argparse.ArgumentParser() + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter, + ) parser.add_argument( "--train_file", help="Training set h5 file", diff --git a/mace/tools/scripts_utils.py b/mace/tools/scripts_utils.py index b2777d49..9f9fe4c2 100644 --- a/mace/tools/scripts_utils.py +++ b/mace/tools/scripts_utils.py @@ -617,8 +617,10 @@ def custom_key(key): def dict_to_array(input_data, heads): + if all(isinstance(value, np.ndarray) for value in input_data.values()): + return np.array([input_data[head] for head in heads]) if not all(isinstance(value, dict) for value in input_data.values()): - return np.array(list(input_data.values())) + return np.array([[input_data[head]] for head in heads]) unique_keys = set() for inner_dict in input_data.values(): unique_keys.update(inner_dict.keys()) @@ -630,7 +632,7 @@ def dict_to_array(input_data, heads): key_index = sorted_keys.index(int(key)) head_index = heads.index(head_name) result_array[head_index][key_index] = value - return np.squeeze(result_array) + return result_array class LRScheduler: diff --git a/setup.cfg b/setup.cfg index 6751b12d..139f914e 100644 --- a/setup.cfg +++ b/setup.cfg @@ -42,6 +42,7 @@ console_scripts = mace_run_train = mace.cli.run_train:main mace_prepare_data = mace.cli.preprocess_data:main mace_finetuning = mace.cli.fine_tuning_select:main + mace_convert_dev = mace.cli.convert_dev:main [options.extras_require] wandb = wandb @@ -54,4 +55,4 @@ dev = pytest pytest-benchmark pylint -schedulefree = schedulefree \ No newline at end of file +schedulefree = schedulefree diff --git a/tests/test_run_train.py b/tests/test_run_train.py index 80d7eaa7..153acfdf 100644 --- a/tests/test_run_train.py +++ b/tests/test_run_train.py @@ -1,3 +1,4 @@ +import json import os import subprocess import sys @@ -7,7 +8,6 @@ import numpy as np import pytest from ase.atoms import Atoms -import json from mace.calculators.mace import MACECalculator @@ -614,7 +614,7 @@ def test_run_train_foundation_multihead_json(tmp_path, fitting_configs): if i in (0, 1): continue # skip isolated atoms, as energies specified by json files below - elif i % 2 == 0: + if i % 2 == 0: c.info["head"] = "DFT" fitting_configs_dft.append(c) else: @@ -625,9 +625,9 @@ def test_run_train_foundation_multihead_json(tmp_path, fitting_configs): # write E0s to json files E0s = {1: 0.0, 8: 0.0} - with open(tmp_path / "fit_multihead_dft.json", "w") as f: + with open(tmp_path / "fit_multihead_dft.json", "w", encoding="utf-8") as f: json.dump(E0s, f) - with open(tmp_path / "fit_multihead_mp2.json", "w") as f: + with open(tmp_path / "fit_multihead_mp2.json", "w", encoding="utf-8") as f: json.dump(E0s, f) heads = {