Skip to content

Commit

Permalink
Merge pull request #73 from mila-iqia/egnn_radial_cutoff
Browse files Browse the repository at this point in the history
radial cutoff for the adjacency matrix in egnn
  • Loading branch information
sblackburn86 authored Oct 11, 2024
2 parents 9feaa7f + 977f99b commit 91e3f7c
Show file tree
Hide file tree
Showing 5 changed files with 119 additions and 54 deletions.
16 changes: 11 additions & 5 deletions crystal_diffusion/models/egnn.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ def __init__(
attention: bool = False,
normalize: bool = False,
coords_agg: str = "mean",
message_agg: str = "mean",
tanh: bool = False,
):
"""E_GCL layer initialization.
Expand All @@ -52,7 +53,8 @@ def __init__(
attention: if True, multiply the message output by a gated value of the output. Defaults to False.
normalize: if True, use a normalized version of the coordinates update i.e. x_i^l - x_j^l would be a unit
vector in eq. 4 in https://arxiv.org/pdf/2102.09844. Defaults to False.
coords_agg: Use a mean or sum aggregation for the messages. Defaults to mean.
coords_agg: Use a mean or sum aggregation for the coordinates update. Defaults to mean.
message_agg: Use a mean or sum aggregation for the messages. Defaults to mean.
tanh: if True, add a tanh non-linearity after the coordinates update. Defaults to False.
"""
super(E_GCL, self).__init__()
Expand All @@ -65,6 +67,7 @@ def __init__(
if coords_agg not in ["mean", "sum"]:
raise ValueError(f"coords_agg should be mean or sum. Got {coords_agg}")
self.coords_agg_fn = unsorted_segment_sum if coords_agg == "sum" else unsorted_segment_mean
self.msg_agg_fn = unsorted_segment_sum if message_agg == "sum" else unsorted_segment_mean

# message update MLP i.e. message m_{ij} used in the graph neural network.
# \phi_e is eq. (3) in https://arxiv.org/pdf/2102.09844
Expand Down Expand Up @@ -151,7 +154,7 @@ def node_model(self, x: torch.Tensor, edge_index: torch.Tensor, messages: torch.
updated node features. size: number of nodes, output_size
"""
row = edge_index[:, 0]
agg = unsorted_segment_sum(messages, row, num_segments=x.size(0)) # sum messages m_i = \sum_j m_{ij}
agg = self.msg_agg_fn(messages, row, num_segments=x.size(0)) # sum messages m_i = \sum_j m_{ij}
agg = torch.cat([x, agg], dim=1) # concat h_i and m_i
out = self.node_mlp(agg)
if self.residual: # optional skip connection
Expand Down Expand Up @@ -245,7 +248,8 @@ def __init__(
normalize: bool = False,
tanh: bool = False,
coords_agg: str = "mean",
n_layers: int = 4
message_agg: str = "mean",
n_layers: int = 4,
):
"""EGNN model stacking multiple E_GCL layers.
Expand All @@ -262,7 +266,8 @@ def __init__(
attention: if True, multiply the message output by a gated value of the output. Defaults to False.
normalize: if True, use a normalized version of the coordinates update i.e. x_i^l - x_j^l would be a unit
vector in eq. 4 in https://arxiv.org/pdf/2102.09844. Defaults to False.
coords_agg: Use a mean or sum aggregation for the messages. Defaults to mean.
coords_agg: Use a mean or sum aggregation for the coordinates update. Defaults to mean.
message_agg: Use a mean or sum aggregation for the messages. Defaults to mean.
tanh: if True, add a tanh non-linearity after the coordinates update. Defaults to False.
n_layers: number of E_GCL layers. Defaults to 4.
"""
Expand All @@ -286,7 +291,8 @@ def __init__(
attention=attention,
normalize=normalize,
coords_agg=coords_agg,
tanh=tanh
message_agg=message_agg,
tanh=tanh,
)
)

Expand Down
35 changes: 35 additions & 0 deletions crystal_diffusion/models/egnn_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,10 @@

import torch

from crystal_diffusion.models.mace_utils import get_adj_matrix
from crystal_diffusion.utils.basis_transformations import \
get_positions_from_coordinates


def unsorted_segment_sum(data: torch.Tensor, segment_ids: torch.Tensor, num_segments: int) -> torch.Tensor:
"""Sum all the elements in data by their ids.
Expand Down Expand Up @@ -85,3 +89,34 @@ def get_edges_batch(n_nodes: int, batch_size: int) -> torch.Tensor:
all_edges.append(edges + n_nodes * i)
edges = torch.cat(all_edges)
return edges


def get_edges_with_radial_cutoff(relative_coordinates: torch.Tensor, unit_cell: torch.Tensor,
radial_cutoff: float = 4.0, drop_duplicate_edges: bool = True) -> torch.Tensor:
"""Get edges for a batch with a cutoff based on distance.
Args:
relative_coordinates: batch x n_atom x spatial dimension tensor with relative coordinates
unit_cell: batch x spatial dimension x spatial dimension tensor with the unit cell vectors
radial_cutoff (optional): cutoff distance in Angstrom. Defaults to 4.0
drop_duplicate_edges (optional): if True, return only 1 instance of each edge. If False, return each edge
multiple times, depending on the unit cell shift multiplicities. Defaults to True.
Returns:
long tensor of size [number of edges, 2] with edge indices
"""
# get cartesian coordinates from relative coordinates
cartesian_coordinates = get_positions_from_coordinates(relative_coordinates, unit_cell)
adj_matrix, _, _, _ = get_adj_matrix(cartesian_coordinates, unit_cell, radial_cutoff)
# adj_matrix is a n_edges x 2 tensor with duplicates with different shifts.
# the uplifting in 2 x spatial_dimension manages the shifts in a natural way. This means we can ignore the shifts
# and possibly ignore the multiplicities i.e. no need to sum twice the contribution of a neighbor that we see
# in the unit cell and in a shifted unit cell.
# TODO check this statement - test with and without multiplicities - just remove the duplicate drop that follows to
# test the w/o multiplicities case
if drop_duplicate_edges:
adj_matrix = torch.unique(adj_matrix, dim=1)
# MACE adj calculations returns a (2, n_edges) tensor and EGNN expects a (n_edges, 2) tensor
adj_matrix = adj_matrix.transpose(0, 1)

return adj_matrix
48 changes: 48 additions & 0 deletions crystal_diffusion/models/graph_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
from typing import Tuple

import torch

from crystal_diffusion.utils.neighbors import (
get_periodic_adjacency_information,
shift_adjacency_matrix_indices_for_graph_batching)


def get_adj_matrix(positions: torch.Tensor,
basis_vectors: torch.Tensor,
radial_cutoff: float = 4.0) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
"""Create the adjacency and shift matrices.
Args:
positions : atomic positions, assumed to be within the unit cell, in Euclidean coordinates.
Dimension [batch_size, max_number_of_atoms, 3]
basis_vectors : vectors that define the unit cell, (a1, a2, a3). The basis vectors are assumed
to be vertically stacked, namely
[-- a1 --]
[-- a2 --]
[-- a3 --]
Dimension [batch_size, 3, 3].
radial_cutoff : largest distance between neighbors.
Returns:
adjacency matrix: The (src, dst) node indices, as a [2, num_edge] tensor,
shift matrix: The lattice vector shifts between source and destination, as a [num_edge, 3] tensor
batch_indices: for each node, this indicates which batch item it originally belonged to.
number_of_edges: for each element in the batch, how many edges belong to it
"""
batch_size, number_of_atoms, spatial_dimensions = positions.shape

adjacency_info = get_periodic_adjacency_information(positions, basis_vectors, radial_cutoff)

# The indices in the adjacency matrix must be shifted to account for the batching
# of multiple distinct structures into a single disconnected graph.
adjacency_matrix = adjacency_info.adjacency_matrix
number_of_edges = adjacency_info.number_of_edges
shifted_adjacency_matrix = shift_adjacency_matrix_indices_for_graph_batching(adjacency_matrix,
number_of_edges,
number_of_atoms)
shifts = adjacency_info.shifts
batch_indices = adjacency_info.node_batch_indices

number_of_edges = adjacency_info.number_of_edges

return shifted_adjacency_matrix, shifts, batch_indices, number_of_edges
45 changes: 1 addition & 44 deletions crystal_diffusion/models/mace_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,51 +6,8 @@
from e3nn import o3
from torch_geometric.data import Data

from crystal_diffusion.models.graph_utils import get_adj_matrix
from crystal_diffusion.namespace import NOISY_CARTESIAN_POSITIONS, UNIT_CELL
from crystal_diffusion.utils.neighbors import (
get_periodic_adjacency_information,
shift_adjacency_matrix_indices_for_graph_batching)


def get_adj_matrix(positions: torch.Tensor,
basis_vectors: torch.Tensor,
radial_cutoff: float = 4.0) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
"""Create the adjacency and shift matrices.
Args:
positions : atomic positions, assumed to be within the unit cell, in Euclidean coordinates.
Dimension [batch_size, max_number_of_atoms, 3]
basis_vectors : vectors that define the unit cell, (a1, a2, a3). The basis vectors are assumed
to be vertically stacked, namely
[-- a1 --]
[-- a2 --]
[-- a3 --]
Dimension [batch_size, 3, 3].
radial_cutoff : largest distance between neighbors.
Returns:
adjacency matrix: The (src, dst) node indices, as a [2, num_edge] tensor,
shift matrix: The lattice vector shifts between source and destination, as a [num_edge, 3] tensor
batch_indices: for each node, this indicates which batch item it originally belonged to.
number_of_edges: for each element in the batch, how many edges belong to it
"""
batch_size, number_of_atoms, spatial_dimensions = positions.shape

adjacency_info = get_periodic_adjacency_information(positions, basis_vectors, radial_cutoff)

# The indices in the adjacency matrix must be shifted to account for the batching
# of multiple distinct structures into a single disconnected graph.
adjacency_matrix = adjacency_info.adjacency_matrix
number_of_edges = adjacency_info.number_of_edges
shifted_adjacency_matrix = shift_adjacency_matrix_indices_for_graph_batching(adjacency_matrix,
number_of_edges,
number_of_atoms)
shifts = adjacency_info.shifts
batch_indices = adjacency_info.node_batch_indices

number_of_edges = adjacency_info.number_of_edges

return shifted_adjacency_matrix, shifts, batch_indices, number_of_edges


def input_to_mace(x: Dict[AnyStr, torch.Tensor], radial_cutoff: float) -> Data:
Expand Down
29 changes: 24 additions & 5 deletions crystal_diffusion/models/score_networks/egnn_score_network.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,12 @@
import torch

from crystal_diffusion.models.egnn import EGNN
from crystal_diffusion.models.egnn_utils import get_edges_batch
from crystal_diffusion.models.egnn_utils import (get_edges_batch,
get_edges_with_radial_cutoff)
from crystal_diffusion.models.score_networks import ScoreNetworkParameters
from crystal_diffusion.models.score_networks.score_network import ScoreNetwork
from crystal_diffusion.namespace import NOISE, NOISY_RELATIVE_COORDINATES
from crystal_diffusion.namespace import (NOISE, NOISY_RELATIVE_COORDINATES,
UNIT_CELL)


@dataclass(kw_only=True)
Expand All @@ -26,7 +28,11 @@ class EGNNScoreNetworkParameters(ScoreNetworkParameters):
normalize: bool = False
tanh: bool = False
coords_agg: str = "mean"
message_agg: str = "mean"
n_layers: int = 4
edges: str = 'fully_connected'
radial_cutoff: float = 4.0
drop_duplicate_edges: bool = True


class EGNNScoreNetwork(ScoreNetwork):
Expand All @@ -51,6 +57,12 @@ def __init__(self, hyper_params: EGNNScoreNetworkParameters):
self.register_parameter('projection_matrices',
torch.nn.Parameter(projection_matrices, requires_grad=False))

self.edges = hyper_params.edges
assert self.edges in ["fully_connected", "radial_cutoff"], \
f'Edges type should be fully_connected or radial_cutoff. Got {self.edges}'
self.radial_cutoff = hyper_params.radial_cutoff
self.drop_duplicate_edges = hyper_params.drop_duplicate_edges

self.egnn = EGNN(
input_size=self.number_of_features_per_node,
message_n_hidden_dimensions=hyper_params.message_n_hidden_dimensions,
Expand All @@ -64,6 +76,7 @@ def __init__(self, hyper_params: EGNNScoreNetworkParameters):
normalize=hyper_params.normalize,
tanh=hyper_params.tanh,
coords_agg=hyper_params.coords_agg,
message_agg=hyper_params.message_agg,
n_layers=hyper_params.n_layers,
)

Expand Down Expand Up @@ -147,9 +160,15 @@ def _forward_unchecked(
relative_coordinates = batch[NOISY_RELATIVE_COORDINATES]
batch_size, number_of_atoms, spatial_dimension = relative_coordinates.shape

edges = get_edges_batch(
n_nodes=number_of_atoms, batch_size=batch_size
)
if self.edges == "fully_connected":
edges = get_edges_batch(
n_nodes=number_of_atoms, batch_size=batch_size
)
else:
edges = get_edges_with_radial_cutoff(
relative_coordinates, batch[UNIT_CELL], self.radial_cutoff,
drop_duplicate_edges=self.drop_duplicate_edges
)

edges = edges.to(relative_coordinates.device)

Expand Down

0 comments on commit 91e3f7c

Please sign in to comment.