From 5063f5e4b961824f3812412f809e901bbaeec995 Mon Sep 17 00:00:00 2001 From: Lars Leon Schaaf <41149633+LarsSchaaf@users.noreply.github.com> Date: Thu, 7 Nov 2024 19:47:25 +0000 Subject: [PATCH] Cluster Force: flatten indicies --- mace/modules/loss.py | 12 +++++++++--- mace/tools/scatter.py | 28 ++++++++++++++++++++++++++++ mace/tools/train.py | 7 ++++--- 3 files changed, 41 insertions(+), 6 deletions(-) diff --git a/mace/modules/loss.py b/mace/modules/loss.py index bd541434..1c47ae85 100644 --- a/mace/modules/loss.py +++ b/mace/modules/loss.py @@ -7,7 +7,7 @@ import torch from mace.tools import TensorDict -from mace.tools.scatter import scatter_sum +from mace.tools.scatter import scatter_sum, compute_effective_index from mace.tools.torch_geometric import Batch @@ -31,12 +31,18 @@ def weighted_mean_squared_error_energy(ref: Batch, pred: TensorDict) -> torch.Te def weighted_mean_square_error_force_cluster( ref: Batch, pred: TensorDict ) -> torch.Tensor: + effective_inicies, _ = compute_effective_index([ref.batch, ref.cluster]) cluster_forces_ref = scatter_sum( - ref["forces"], torch.unique(ref.cluster, return_inverse=True)[1], dim=0 + ref["forces"], + effective_inicies, + dim=0, ) cluster_forces_pred = scatter_sum( - pred["forces"], torch.unique(ref.cluster, return_inverse=True)[1], dim=0 + pred["forces"], + effective_inicies, + dim=0, ) + return torch.mean(torch.square(cluster_forces_ref - cluster_forces_pred)) diff --git a/mace/tools/scatter.py b/mace/tools/scatter.py index 7e1139a9..d17cb7cb 100644 --- a/mace/tools/scatter.py +++ b/mace/tools/scatter.py @@ -10,6 +10,7 @@ from typing import Optional import torch +from typing import List, Tuple def _broadcast(src: torch.Tensor, other: torch.Tensor, dim: int): @@ -110,3 +111,30 @@ def scatter_mean( else: out.div_(count, rounding_mode="floor") return out + + +def compute_effective_index(indices: List[torch.Tensor]) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Computes an effective index from multiple index tensors. Useful for multi index scatter operations. + + Args: + indices (List[torch.Tensor]): List of index tensors, each of shape (N,). + + Returns: + effective_index (torch.Tensor): Tensor of shape (N,), where each element + is a unique integer representing the combination of indices. + unique_combinations (torch.Tensor): Tensor containing unique combinations + of indices, shape (num_unique_combinations, num_indices). + """ + # Stack indices to shape (num_indices, N) + indices_stack = torch.stack(indices, dim=0) # Shape: (num_indices, N) + + # Transpose to get combinations per element + index_combinations = indices_stack.t() # Shape: (N, num_indices) + + # Find unique combinations and get inverse indices + unique_combinations, inverse_indices = torch.unique( + index_combinations, dim=0, return_inverse=True + ) + + return inverse_indices, unique_combinations diff --git a/mace/tools/train.py b/mace/tools/train.py index dfdb0d9e..0f0f74bd 100644 --- a/mace/tools/train.py +++ b/mace/tools/train.py @@ -22,7 +22,7 @@ from . import torch_geometric from .checkpoint import CheckpointHandler, CheckpointState -from .scatter import scatter_sum +from .scatter import scatter_sum, compute_effective_index from .torch_tools import to_numpy from .utils import ( MetricsLogger, @@ -491,14 +491,15 @@ def update(self, batch, output): # pylint: disable=arguments-differ self.delta_fs.append(batch.forces - output["forces"]) if output.get("forces") is not None and batch.cluster is not None: self.clusterFs_computed += 1.0 + effective_inicies, _ = compute_effective_index([batch.batch, batch.cluster]) cluster_forces_ref = scatter_sum( batch["forces"], - torch.unique(batch.cluster, return_inverse=True)[1], + effective_inicies, dim=0, ) cluster_forces_pred = scatter_sum( output["forces"], - torch.unique(batch.cluster, return_inverse=True)[1], + effective_inicies, dim=0, ) self.delta_cluster_forces.append(cluster_forces_ref - cluster_forces_pred)