diff --git a/crystal_diffusion/models/egnn.py b/crystal_diffusion/models/egnn.py index a42dca50..5170d2d8 100644 --- a/crystal_diffusion/models/egnn.py +++ b/crystal_diffusion/models/egnn.py @@ -34,6 +34,7 @@ def __init__( attention: bool = False, normalize: bool = False, coords_agg: str = "mean", + message_agg: str = "mean", tanh: bool = False, ): """E_GCL layer initialization. @@ -52,7 +53,8 @@ def __init__( attention: if True, multiply the message output by a gated value of the output. Defaults to False. normalize: if True, use a normalized version of the coordinates update i.e. x_i^l - x_j^l would be a unit vector in eq. 4 in https://arxiv.org/pdf/2102.09844. Defaults to False. - coords_agg: Use a mean or sum aggregation for the messages. Defaults to mean. + coords_agg: Use a mean or sum aggregation for the coordinates update. Defaults to mean. + message_agg: Use a mean or sum aggregation for the messages. Defaults to mean. tanh: if True, add a tanh non-linearity after the coordinates update. Defaults to False. """ super(E_GCL, self).__init__() @@ -65,6 +67,7 @@ def __init__( if coords_agg not in ["mean", "sum"]: raise ValueError(f"coords_agg should be mean or sum. Got {coords_agg}") self.coords_agg_fn = unsorted_segment_sum if coords_agg == "sum" else unsorted_segment_mean + self.msg_agg_fn = unsorted_segment_sum if message_agg == "sum" else unsorted_segment_mean # message update MLP i.e. message m_{ij} used in the graph neural network. # \phi_e is eq. (3) in https://arxiv.org/pdf/2102.09844 @@ -151,7 +154,7 @@ def node_model(self, x: torch.Tensor, edge_index: torch.Tensor, messages: torch. updated node features. size: number of nodes, output_size """ row = edge_index[:, 0] - agg = unsorted_segment_sum(messages, row, num_segments=x.size(0)) # sum messages m_i = \sum_j m_{ij} + agg = self.msg_agg_fn(messages, row, num_segments=x.size(0)) # sum messages m_i = \sum_j m_{ij} agg = torch.cat([x, agg], dim=1) # concat h_i and m_i out = self.node_mlp(agg) if self.residual: # optional skip connection @@ -245,7 +248,8 @@ def __init__( normalize: bool = False, tanh: bool = False, coords_agg: str = "mean", - n_layers: int = 4 + message_agg: str = "mean", + n_layers: int = 4, ): """EGNN model stacking multiple E_GCL layers. @@ -262,7 +266,8 @@ def __init__( attention: if True, multiply the message output by a gated value of the output. Defaults to False. normalize: if True, use a normalized version of the coordinates update i.e. x_i^l - x_j^l would be a unit vector in eq. 4 in https://arxiv.org/pdf/2102.09844. Defaults to False. - coords_agg: Use a mean or sum aggregation for the messages. Defaults to mean. + coords_agg: Use a mean or sum aggregation for the coordinates update. Defaults to mean. + message_agg: Use a mean or sum aggregation for the messages. Defaults to mean. tanh: if True, add a tanh non-linearity after the coordinates update. Defaults to False. n_layers: number of E_GCL layers. Defaults to 4. """ @@ -286,7 +291,8 @@ def __init__( attention=attention, normalize=normalize, coords_agg=coords_agg, - tanh=tanh + message_agg=message_agg, + tanh=tanh, ) ) diff --git a/crystal_diffusion/models/egnn_utils.py b/crystal_diffusion/models/egnn_utils.py index 283090ad..d7ad60a5 100644 --- a/crystal_diffusion/models/egnn_utils.py +++ b/crystal_diffusion/models/egnn_utils.py @@ -2,6 +2,10 @@ import torch +from crystal_diffusion.models.mace_utils import get_adj_matrix +from crystal_diffusion.utils.basis_transformations import \ + get_positions_from_coordinates + def unsorted_segment_sum(data: torch.Tensor, segment_ids: torch.Tensor, num_segments: int) -> torch.Tensor: """Sum all the elements in data by their ids. @@ -85,3 +89,34 @@ def get_edges_batch(n_nodes: int, batch_size: int) -> torch.Tensor: all_edges.append(edges + n_nodes * i) edges = torch.cat(all_edges) return edges + + +def get_edges_with_radial_cutoff(relative_coordinates: torch.Tensor, unit_cell: torch.Tensor, + radial_cutoff: float = 4.0, drop_duplicate_edges: bool = True) -> torch.Tensor: + """Get edges for a batch with a cutoff based on distance. + + Args: + relative_coordinates: batch x n_atom x spatial dimension tensor with relative coordinates + unit_cell: batch x spatial dimension x spatial dimension tensor with the unit cell vectors + radial_cutoff (optional): cutoff distance in Angstrom. Defaults to 4.0 + drop_duplicate_edges (optional): if True, return only 1 instance of each edge. If False, return each edge + multiple times, depending on the unit cell shift multiplicities. Defaults to True. + + Returns: + long tensor of size [number of edges, 2] with edge indices + """ + # get cartesian coordinates from relative coordinates + cartesian_coordinates = get_positions_from_coordinates(relative_coordinates, unit_cell) + adj_matrix, _, _, _ = get_adj_matrix(cartesian_coordinates, unit_cell, radial_cutoff) + # adj_matrix is a n_edges x 2 tensor with duplicates with different shifts. + # the uplifting in 2 x spatial_dimension manages the shifts in a natural way. This means we can ignore the shifts + # and possibly ignore the multiplicities i.e. no need to sum twice the contribution of a neighbor that we see + # in the unit cell and in a shifted unit cell. + # TODO check this statement - test with and without multiplicities - just remove the duplicate drop that follows to + # test the w/o multiplicities case + if drop_duplicate_edges: + adj_matrix = torch.unique(adj_matrix, dim=1) + # MACE adj calculations returns a (2, n_edges) tensor and EGNN expects a (n_edges, 2) tensor + adj_matrix = adj_matrix.transpose(0, 1) + + return adj_matrix diff --git a/crystal_diffusion/models/graph_utils.py b/crystal_diffusion/models/graph_utils.py new file mode 100644 index 00000000..602e09ad --- /dev/null +++ b/crystal_diffusion/models/graph_utils.py @@ -0,0 +1,48 @@ +from typing import Tuple + +import torch + +from crystal_diffusion.utils.neighbors import ( + get_periodic_adjacency_information, + shift_adjacency_matrix_indices_for_graph_batching) + + +def get_adj_matrix(positions: torch.Tensor, + basis_vectors: torch.Tensor, + radial_cutoff: float = 4.0) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: + """Create the adjacency and shift matrices. + + Args: + positions : atomic positions, assumed to be within the unit cell, in Euclidean coordinates. + Dimension [batch_size, max_number_of_atoms, 3] + basis_vectors : vectors that define the unit cell, (a1, a2, a3). The basis vectors are assumed + to be vertically stacked, namely + [-- a1 --] + [-- a2 --] + [-- a3 --] + Dimension [batch_size, 3, 3]. + radial_cutoff : largest distance between neighbors. + + Returns: + adjacency matrix: The (src, dst) node indices, as a [2, num_edge] tensor, + shift matrix: The lattice vector shifts between source and destination, as a [num_edge, 3] tensor + batch_indices: for each node, this indicates which batch item it originally belonged to. + number_of_edges: for each element in the batch, how many edges belong to it + """ + batch_size, number_of_atoms, spatial_dimensions = positions.shape + + adjacency_info = get_periodic_adjacency_information(positions, basis_vectors, radial_cutoff) + + # The indices in the adjacency matrix must be shifted to account for the batching + # of multiple distinct structures into a single disconnected graph. + adjacency_matrix = adjacency_info.adjacency_matrix + number_of_edges = adjacency_info.number_of_edges + shifted_adjacency_matrix = shift_adjacency_matrix_indices_for_graph_batching(adjacency_matrix, + number_of_edges, + number_of_atoms) + shifts = adjacency_info.shifts + batch_indices = adjacency_info.node_batch_indices + + number_of_edges = adjacency_info.number_of_edges + + return shifted_adjacency_matrix, shifts, batch_indices, number_of_edges diff --git a/crystal_diffusion/models/mace_utils.py b/crystal_diffusion/models/mace_utils.py index 467f2142..3b45ef4d 100644 --- a/crystal_diffusion/models/mace_utils.py +++ b/crystal_diffusion/models/mace_utils.py @@ -6,51 +6,8 @@ from e3nn import o3 from torch_geometric.data import Data +from crystal_diffusion.models.graph_utils import get_adj_matrix from crystal_diffusion.namespace import NOISY_CARTESIAN_POSITIONS, UNIT_CELL -from crystal_diffusion.utils.neighbors import ( - get_periodic_adjacency_information, - shift_adjacency_matrix_indices_for_graph_batching) - - -def get_adj_matrix(positions: torch.Tensor, - basis_vectors: torch.Tensor, - radial_cutoff: float = 4.0) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: - """Create the adjacency and shift matrices. - - Args: - positions : atomic positions, assumed to be within the unit cell, in Euclidean coordinates. - Dimension [batch_size, max_number_of_atoms, 3] - basis_vectors : vectors that define the unit cell, (a1, a2, a3). The basis vectors are assumed - to be vertically stacked, namely - [-- a1 --] - [-- a2 --] - [-- a3 --] - Dimension [batch_size, 3, 3]. - radial_cutoff : largest distance between neighbors. - - Returns: - adjacency matrix: The (src, dst) node indices, as a [2, num_edge] tensor, - shift matrix: The lattice vector shifts between source and destination, as a [num_edge, 3] tensor - batch_indices: for each node, this indicates which batch item it originally belonged to. - number_of_edges: for each element in the batch, how many edges belong to it - """ - batch_size, number_of_atoms, spatial_dimensions = positions.shape - - adjacency_info = get_periodic_adjacency_information(positions, basis_vectors, radial_cutoff) - - # The indices in the adjacency matrix must be shifted to account for the batching - # of multiple distinct structures into a single disconnected graph. - adjacency_matrix = adjacency_info.adjacency_matrix - number_of_edges = adjacency_info.number_of_edges - shifted_adjacency_matrix = shift_adjacency_matrix_indices_for_graph_batching(adjacency_matrix, - number_of_edges, - number_of_atoms) - shifts = adjacency_info.shifts - batch_indices = adjacency_info.node_batch_indices - - number_of_edges = adjacency_info.number_of_edges - - return shifted_adjacency_matrix, shifts, batch_indices, number_of_edges def input_to_mace(x: Dict[AnyStr, torch.Tensor], radial_cutoff: float) -> Data: diff --git a/crystal_diffusion/models/score_networks/egnn_score_network.py b/crystal_diffusion/models/score_networks/egnn_score_network.py index b8611774..24f30d7a 100644 --- a/crystal_diffusion/models/score_networks/egnn_score_network.py +++ b/crystal_diffusion/models/score_networks/egnn_score_network.py @@ -5,10 +5,12 @@ import torch from crystal_diffusion.models.egnn import EGNN -from crystal_diffusion.models.egnn_utils import get_edges_batch +from crystal_diffusion.models.egnn_utils import (get_edges_batch, + get_edges_with_radial_cutoff) from crystal_diffusion.models.score_networks import ScoreNetworkParameters from crystal_diffusion.models.score_networks.score_network import ScoreNetwork -from crystal_diffusion.namespace import NOISE, NOISY_RELATIVE_COORDINATES +from crystal_diffusion.namespace import (NOISE, NOISY_RELATIVE_COORDINATES, + UNIT_CELL) @dataclass(kw_only=True) @@ -26,7 +28,11 @@ class EGNNScoreNetworkParameters(ScoreNetworkParameters): normalize: bool = False tanh: bool = False coords_agg: str = "mean" + message_agg: str = "mean" n_layers: int = 4 + edges: str = 'fully_connected' + radial_cutoff: float = 4.0 + drop_duplicate_edges: bool = True class EGNNScoreNetwork(ScoreNetwork): @@ -51,6 +57,12 @@ def __init__(self, hyper_params: EGNNScoreNetworkParameters): self.register_parameter('projection_matrices', torch.nn.Parameter(projection_matrices, requires_grad=False)) + self.edges = hyper_params.edges + assert self.edges in ["fully_connected", "radial_cutoff"], \ + f'Edges type should be fully_connected or radial_cutoff. Got {self.edges}' + self.radial_cutoff = hyper_params.radial_cutoff + self.drop_duplicate_edges = hyper_params.drop_duplicate_edges + self.egnn = EGNN( input_size=self.number_of_features_per_node, message_n_hidden_dimensions=hyper_params.message_n_hidden_dimensions, @@ -64,6 +76,7 @@ def __init__(self, hyper_params: EGNNScoreNetworkParameters): normalize=hyper_params.normalize, tanh=hyper_params.tanh, coords_agg=hyper_params.coords_agg, + message_agg=hyper_params.message_agg, n_layers=hyper_params.n_layers, ) @@ -147,9 +160,15 @@ def _forward_unchecked( relative_coordinates = batch[NOISY_RELATIVE_COORDINATES] batch_size, number_of_atoms, spatial_dimension = relative_coordinates.shape - edges = get_edges_batch( - n_nodes=number_of_atoms, batch_size=batch_size - ) + if self.edges == "fully_connected": + edges = get_edges_batch( + n_nodes=number_of_atoms, batch_size=batch_size + ) + else: + edges = get_edges_with_radial_cutoff( + relative_coordinates, batch[UNIT_CELL], self.radial_cutoff, + drop_duplicate_edges=self.drop_duplicate_edges + ) edges = edges.to(relative_coordinates.device)