Skip to content

Commit

Permalink
Cluster Force: flatten indicies
Browse files Browse the repository at this point in the history
  • Loading branch information
LarsSchaaf committed Nov 7, 2024
1 parent 25709d9 commit 5063f5e
Show file tree
Hide file tree
Showing 3 changed files with 41 additions and 6 deletions.
12 changes: 9 additions & 3 deletions mace/modules/loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
import torch

from mace.tools import TensorDict
from mace.tools.scatter import scatter_sum
from mace.tools.scatter import scatter_sum, compute_effective_index
from mace.tools.torch_geometric import Batch


Expand All @@ -31,12 +31,18 @@ def weighted_mean_squared_error_energy(ref: Batch, pred: TensorDict) -> torch.Te
def weighted_mean_square_error_force_cluster(
ref: Batch, pred: TensorDict
) -> torch.Tensor:
effective_inicies, _ = compute_effective_index([ref.batch, ref.cluster])
cluster_forces_ref = scatter_sum(
ref["forces"], torch.unique(ref.cluster, return_inverse=True)[1], dim=0
ref["forces"],
effective_inicies,
dim=0,
)
cluster_forces_pred = scatter_sum(
pred["forces"], torch.unique(ref.cluster, return_inverse=True)[1], dim=0
pred["forces"],
effective_inicies,
dim=0,
)

return torch.mean(torch.square(cluster_forces_ref - cluster_forces_pred))


Expand Down
28 changes: 28 additions & 0 deletions mace/tools/scatter.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from typing import Optional

import torch
from typing import List, Tuple


def _broadcast(src: torch.Tensor, other: torch.Tensor, dim: int):
Expand Down Expand Up @@ -110,3 +111,30 @@ def scatter_mean(
else:
out.div_(count, rounding_mode="floor")
return out


def compute_effective_index(indices: List[torch.Tensor]) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Computes an effective index from multiple index tensors. Useful for multi index scatter operations.
Args:
indices (List[torch.Tensor]): List of index tensors, each of shape (N,).
Returns:
effective_index (torch.Tensor): Tensor of shape (N,), where each element
is a unique integer representing the combination of indices.
unique_combinations (torch.Tensor): Tensor containing unique combinations
of indices, shape (num_unique_combinations, num_indices).
"""
# Stack indices to shape (num_indices, N)
indices_stack = torch.stack(indices, dim=0) # Shape: (num_indices, N)

# Transpose to get combinations per element
index_combinations = indices_stack.t() # Shape: (N, num_indices)

# Find unique combinations and get inverse indices
unique_combinations, inverse_indices = torch.unique(
index_combinations, dim=0, return_inverse=True
)

return inverse_indices, unique_combinations
7 changes: 4 additions & 3 deletions mace/tools/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@

from . import torch_geometric
from .checkpoint import CheckpointHandler, CheckpointState
from .scatter import scatter_sum
from .scatter import scatter_sum, compute_effective_index
from .torch_tools import to_numpy
from .utils import (
MetricsLogger,
Expand Down Expand Up @@ -491,14 +491,15 @@ def update(self, batch, output): # pylint: disable=arguments-differ
self.delta_fs.append(batch.forces - output["forces"])
if output.get("forces") is not None and batch.cluster is not None:
self.clusterFs_computed += 1.0
effective_inicies, _ = compute_effective_index([batch.batch, batch.cluster])
cluster_forces_ref = scatter_sum(
batch["forces"],
torch.unique(batch.cluster, return_inverse=True)[1],
effective_inicies,
dim=0,
)
cluster_forces_pred = scatter_sum(
output["forces"],
torch.unique(batch.cluster, return_inverse=True)[1],
effective_inicies,
dim=0,
)
self.delta_cluster_forces.append(cluster_forces_ref - cluster_forces_pred)
Expand Down

0 comments on commit 5063f5e

Please sign in to comment.