Skip to content

Commit

Permalink
Merge branch 'develop' into refactor-data
Browse files Browse the repository at this point in the history
  • Loading branch information
ilyes319 authored Nov 6, 2024
2 parents 7a19ed6 + c1184eb commit b8ef3ab
Show file tree
Hide file tree
Showing 13 changed files with 261 additions and 17 deletions.
4 changes: 3 additions & 1 deletion mace/cli/active_learning_md.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,9 @@


def parse_args() -> argparse.Namespace:
parser = argparse.ArgumentParser()
parser = argparse.ArgumentParser(
formatter_class=argparse.ArgumentDefaultsHelpFormatter,
)
parser.add_argument("--config", help="path to XYZ configurations", required=True)
parser.add_argument(
"--config_index", help="index of configuration", type=int, default=-1
Expand Down
31 changes: 31 additions & 0 deletions mace/cli/convert_dev.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
from argparse import ArgumentParser

import torch


def main():
parser = ArgumentParser()
parser.add_argument(
"--target_device",
"-t",
help="device to convert to, usually 'cpu' or 'cuda'",
default="cpu",
)
parser.add_argument(
"--output_file",
"-o",
help="name for output model, defaults to model_file.target_device",
)
parser.add_argument("model_file", help="input model file path")
args = parser.parse_args()

if args.output_file is None:
args.output_file = args.model_file + "." + args.target_device

model = torch.load(args.model_file)
model.to(args.target_device)
torch.save(model, args.output_file)


if __name__ == "__main__":
main()
4 changes: 3 additions & 1 deletion mace/cli/create_lammps_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,9 @@


def parse_args():
parser = argparse.ArgumentParser()
parser = argparse.ArgumentParser(
formatter_class=argparse.ArgumentDefaultsHelpFormatter,
)
parser.add_argument(
"model_path",
type=str,
Expand Down
4 changes: 3 additions & 1 deletion mace/cli/eval_configs.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,9 @@


def parse_args() -> argparse.Namespace:
parser = argparse.ArgumentParser()
parser = argparse.ArgumentParser(
formatter_class=argparse.ArgumentDefaultsHelpFormatter,
)
parser.add_argument("--configs", help="path to XYZ configurations", required=True)
parser.add_argument("--model", help="path to model", required=True)
parser.add_argument("--output", help="output path", required=True)
Expand Down
4 changes: 3 additions & 1 deletion mace/cli/fine_tuning_select.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,9 @@


def parse_args() -> argparse.Namespace:
parser = argparse.ArgumentParser()
parser = argparse.ArgumentParser(
formatter_class=argparse.ArgumentDefaultsHelpFormatter,
)
parser.add_argument(
"--configs_pt",
help="path to XYZ configurations for the pretraining",
Expand Down
5 changes: 4 additions & 1 deletion mace/cli/plot_train.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,10 @@ def parse_training_results(path: str) -> List[dict]:


def parse_args() -> argparse.Namespace:
parser = argparse.ArgumentParser(description="Plot mace training statistics")
parser = argparse.ArgumentParser(
description="Plot mace training statistics",
formatter_class=argparse.ArgumentDefaultsHelpFormatter,
)
parser.add_argument(
"--path", help="path to results file or directory", required=True
)
Expand Down
17 changes: 14 additions & 3 deletions mace/cli/run_train.py
Original file line number Diff line number Diff line change
Expand Up @@ -363,8 +363,14 @@ def run(args: argparse.Namespace) -> None:
z_table_foundation = AtomicNumberTable(
[int(z) for z in model_foundation.atomic_numbers]
)
foundation_atomic_energies = model_foundation.atomic_energies_fn.atomic_energies
if foundation_atomic_energies.ndim > 1:
foundation_atomic_energies = foundation_atomic_energies.squeeze()
if foundation_atomic_energies.ndim == 2:
foundation_atomic_energies = foundation_atomic_energies[0]
logging.info("Foundation model has multiple heads, using the first head as foundation E0s.")
atomic_energies_dict[head_config.head_name] = {
z: model_foundation.atomic_energies_fn.atomic_energies[
z: foundation_atomic_energies[
z_table_foundation.z_to_index(z)
].item()
for z in z_table.zs
Expand All @@ -382,8 +388,14 @@ def run(args: argparse.Namespace) -> None:
z_table_foundation = AtomicNumberTable(
[int(z) for z in model_foundation.atomic_numbers]
)
foundation_atomic_energies = model_foundation.atomic_energies_fn.atomic_energies
if foundation_atomic_energies.ndim > 1:
foundation_atomic_energies = foundation_atomic_energies.squeeze()
if foundation_atomic_energies.ndim == 2:
foundation_atomic_energies = foundation_atomic_energies[0]
logging.info("Foundation model has multiple heads, using the first head as foundation E0s.")
atomic_energies_dict["pt_head"] = {
z: model_foundation.atomic_energies_fn.atomic_energies[
z: foundation_atomic_energies[
z_table_foundation.z_to_index(z)
].item()
for z in z_table.zs
Expand Down Expand Up @@ -663,7 +675,6 @@ def run(args: argparse.Namespace) -> None:
folder, r_max=args.r_max, z_table=z_table, heads=heads, head=head_config.head_name
)
for test_name, test_set in test_sets.items():
logging.info("test_name", test_name)
test_sampler = None
if args.distributed:
test_sampler = torch.utils.data.distributed.DistributedSampler(
Expand Down
4 changes: 4 additions & 0 deletions mace/modules/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,8 @@
NonLinearReadoutBlock,
RadialEmbeddingBlock,
RealAgnosticAttResidualInteractionBlock,
RealAgnosticDensityInteractionBlock,
RealAgnosticDensityResidualInteractionBlock,
RealAgnosticInteractionBlock,
RealAgnosticResidualInteractionBlock,
ResidualElementDependentInteractionBlock,
Expand Down Expand Up @@ -56,6 +58,8 @@
"RealAgnosticResidualInteractionBlock": RealAgnosticResidualInteractionBlock,
"RealAgnosticAttResidualInteractionBlock": RealAgnosticAttResidualInteractionBlock,
"RealAgnosticInteractionBlock": RealAgnosticInteractionBlock,
"RealAgnosticDensityInteractionBlock": RealAgnosticDensityInteractionBlock,
"RealAgnosticDensityResidualInteractionBlock": RealAgnosticDensityResidualInteractionBlock,
}

scaling_classes: Dict[str, Callable] = {
Expand Down
175 changes: 175 additions & 0 deletions mace/modules/blocks.py
Original file line number Diff line number Diff line change
Expand Up @@ -664,6 +664,181 @@ def forward(
) # [n_nodes, channels, (lmax + 1)**2]


@compile_mode("script")
class RealAgnosticDensityInteractionBlock(InteractionBlock):
def _setup(self) -> None:
# First linear
self.linear_up = o3.Linear(
self.node_feats_irreps,
self.node_feats_irreps,
internal_weights=True,
shared_weights=True,
)
# TensorProduct
irreps_mid, instructions = tp_out_irreps_with_instructions(
self.node_feats_irreps,
self.edge_attrs_irreps,
self.target_irreps,
)
self.conv_tp = o3.TensorProduct(
self.node_feats_irreps,
self.edge_attrs_irreps,
irreps_mid,
instructions=instructions,
shared_weights=False,
internal_weights=False,
)

# Convolution weights
input_dim = self.edge_feats_irreps.num_irreps
self.conv_tp_weights = nn.FullyConnectedNet(
[input_dim] + self.radial_MLP + [self.conv_tp.weight_numel],
torch.nn.functional.silu,
)

# Linear
irreps_mid = irreps_mid.simplify()
self.irreps_out = self.target_irreps
self.linear = o3.Linear(
irreps_mid, self.irreps_out, internal_weights=True, shared_weights=True
)

# Selector TensorProduct
self.skip_tp = o3.FullyConnectedTensorProduct(
self.irreps_out, self.node_attrs_irreps, self.irreps_out
)
self.reshape = reshape_irreps(self.irreps_out)

# Density normalization
self.density_fn = nn.FullyConnectedNet(
[input_dim]
+ [
1,
],
torch.nn.functional.silu,
)
# Reshape
self.reshape = reshape_irreps(self.irreps_out)

def forward(
self,
node_attrs: torch.Tensor,
node_feats: torch.Tensor,
edge_attrs: torch.Tensor,
edge_feats: torch.Tensor,
edge_index: torch.Tensor,
) -> Tuple[torch.Tensor, None]:
sender = edge_index[0]
receiver = edge_index[1]
num_nodes = node_feats.shape[0]
node_feats = self.linear_up(node_feats)
tp_weights = self.conv_tp_weights(edge_feats)
edge_density = torch.tanh(self.density_fn(edge_feats) ** 2)
mji = self.conv_tp(
node_feats[sender], edge_attrs, tp_weights
) # [n_edges, irreps]
density = scatter_sum(
src=edge_density, index=receiver, dim=0, dim_size=num_nodes
) # [n_nodes, 1]
message = scatter_sum(
src=mji, index=receiver, dim=0, dim_size=num_nodes
) # [n_nodes, irreps]
message = self.linear(message) / (density + 1)
message = self.skip_tp(message, node_attrs)
return (
self.reshape(message),
None,
) # [n_nodes, channels, (lmax + 1)**2]


@compile_mode("script")
class RealAgnosticDensityResidualInteractionBlock(InteractionBlock):
def _setup(self) -> None:
# First linear
self.linear_up = o3.Linear(
self.node_feats_irreps,
self.node_feats_irreps,
internal_weights=True,
shared_weights=True,
)
# TensorProduct
irreps_mid, instructions = tp_out_irreps_with_instructions(
self.node_feats_irreps,
self.edge_attrs_irreps,
self.target_irreps,
)
self.conv_tp = o3.TensorProduct(
self.node_feats_irreps,
self.edge_attrs_irreps,
irreps_mid,
instructions=instructions,
shared_weights=False,
internal_weights=False,
)

# Convolution weights
input_dim = self.edge_feats_irreps.num_irreps
self.conv_tp_weights = nn.FullyConnectedNet(
[input_dim] + self.radial_MLP + [self.conv_tp.weight_numel],
torch.nn.functional.silu, # gate
)

# Linear
irreps_mid = irreps_mid.simplify()
self.irreps_out = self.target_irreps
self.linear = o3.Linear(
irreps_mid, self.irreps_out, internal_weights=True, shared_weights=True
)

# Selector TensorProduct
self.skip_tp = o3.FullyConnectedTensorProduct(
self.node_feats_irreps, self.node_attrs_irreps, self.hidden_irreps
)
self.reshape = reshape_irreps(self.irreps_out)

# Density normalization
self.density_fn = nn.FullyConnectedNet(
[input_dim]
+ [
1,
],
torch.nn.functional.silu,
)

# Reshape
self.reshape = reshape_irreps(self.irreps_out)

def forward(
self,
node_attrs: torch.Tensor,
node_feats: torch.Tensor,
edge_attrs: torch.Tensor,
edge_feats: torch.Tensor,
edge_index: torch.Tensor,
) -> Tuple[torch.Tensor, torch.Tensor]:
sender = edge_index[0]
receiver = edge_index[1]
num_nodes = node_feats.shape[0]
sc = self.skip_tp(node_feats, node_attrs)
node_feats = self.linear_up(node_feats)
tp_weights = self.conv_tp_weights(edge_feats)
edge_density = torch.tanh(self.density_fn(edge_feats) ** 2)
mji = self.conv_tp(
node_feats[sender], edge_attrs, tp_weights
) # [n_edges, irreps]
density = scatter_sum(
src=edge_density, index=receiver, dim=0, dim_size=num_nodes
) # [n_nodes, 1]
message = scatter_sum(
src=mji, index=receiver, dim=0, dim_size=num_nodes
) # [n_nodes, irreps]
message = self.linear(message) / (density + 1)
return (
self.reshape(message),
sc,
) # [n_nodes, channels, (lmax + 1)**2]


@compile_mode("script")
class RealAgnosticAttResidualInteractionBlock(InteractionBlock):
def _setup(self) -> None:
Expand Down
13 changes: 11 additions & 2 deletions mace/tools/arg_parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ def build_default_arg_parser() -> argparse.ArgumentParser:

parser = configargparse.ArgumentParser(
config_file_parser_class=configargparse.YAMLConfigFileParser,
formatter_class=argparse.ArgumentDefaultsHelpFormatter,
)
parser.add(
"--config",
Expand All @@ -23,7 +24,9 @@ def build_default_arg_parser() -> argparse.ArgumentParser:
help="config file to agregate options",
)
except ImportError:
parser = argparse.ArgumentParser()
parser = argparse.ArgumentParser(
formatter_class=argparse.ArgumentDefaultsHelpFormatter,
)

# Name and seed
parser.add_argument("--name", help="experiment name", required=True)
Expand Down Expand Up @@ -153,6 +156,8 @@ def build_default_arg_parser() -> argparse.ArgumentParser:
"RealAgnosticResidualInteractionBlock",
"RealAgnosticAttResidualInteractionBlock",
"RealAgnosticInteractionBlock",
"RealAgnosticDensityInteractionBlock",
"RealAgnosticDensityResidualInteractionBlock",
],
)
parser.add_argument(
Expand All @@ -163,6 +168,8 @@ def build_default_arg_parser() -> argparse.ArgumentParser:
choices=[
"RealAgnosticResidualInteractionBlock",
"RealAgnosticInteractionBlock",
"RealAgnosticDensityInteractionBlock",
"RealAgnosticDensityResidualInteractionBlock",
],
)
parser.add_argument(
Expand Down Expand Up @@ -706,7 +713,9 @@ def build_default_arg_parser() -> argparse.ArgumentParser:


def build_preprocess_arg_parser() -> argparse.ArgumentParser:
parser = argparse.ArgumentParser()
parser = argparse.ArgumentParser(
formatter_class=argparse.ArgumentDefaultsHelpFormatter,
)
parser.add_argument(
"--train_file",
help="Training set h5 file",
Expand Down
6 changes: 4 additions & 2 deletions mace/tools/scripts_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -617,8 +617,10 @@ def custom_key(key):


def dict_to_array(input_data, heads):
if all(isinstance(value, np.ndarray) for value in input_data.values()):
return np.array([input_data[head] for head in heads])
if not all(isinstance(value, dict) for value in input_data.values()):
return np.array(list(input_data.values()))
return np.array([[input_data[head]] for head in heads])
unique_keys = set()
for inner_dict in input_data.values():
unique_keys.update(inner_dict.keys())
Expand All @@ -630,7 +632,7 @@ def dict_to_array(input_data, heads):
key_index = sorted_keys.index(int(key))
head_index = heads.index(head_name)
result_array[head_index][key_index] = value
return np.squeeze(result_array)
return result_array


class LRScheduler:
Expand Down
Loading

0 comments on commit b8ef3ab

Please sign in to comment.