Skip to content

Commit

Permalink
code review
Browse files Browse the repository at this point in the history
  • Loading branch information
sblackburn-mila committed Oct 11, 2024
1 parent 1698440 commit c6f8f96
Showing 1 changed file with 5 additions and 3 deletions.
8 changes: 5 additions & 3 deletions crystal_diffusion/models/egnn_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import torch

from crystal_diffusion.models.mace_utils import get_adj_matrix
from crystal_diffusion.utils.basis_transformations import get_positions_from_coordinates


def unsorted_segment_sum(data: torch.Tensor, segment_ids: torch.Tensor, num_segments: int) -> torch.Tensor:
Expand Down Expand Up @@ -103,16 +104,17 @@ def get_edges_with_radial_cutoff(relative_coordinates: torch.Tensor, unit_cell:
Returns:
long tensor of size [number of edges, 2] with edge indices
"""
cartesian_coordinates = relative_coordinates @ unit_cell # convert back to cartesian coordinates
# get cartesian coordinates from relative coordinates
cartesian_coordinates = get_positions_from_coordinates(relative_coordinates, unit_cell)
adj_matrix, _, _, _ = get_adj_matrix(cartesian_coordinates, unit_cell, radial_cutoff)
# adj_matrix is a n_edges x 2 tensor with duplicates with different shifts.
# the uplifting in 2 x spatial_dimension manages the shifts in a natural way. This means we can ignore the shifts
# and possibly ignore the multiplicities i.e. no need to sum twice the contribution of a neighbor that we see
# in the unitcell and in a shifted unit cell.
# in the unit cell and in a shifted unit cell.
# TODO check this statement - test with and without multiplicities - just remove the duplicate drop that follows to
# test the w/o multiplicities case
if drop_duplicate_edges:
adj_matrix = torch.unique(adj_matrix, dim=1) # compare x[0,:] to x[1, :] to find duplicates and drop them
adj_matrix = torch.unique(adj_matrix, dim=1)
# MACE adj calculations returns a (2, n_edges) tensor and EGNN expects a (n_edges, 2) tensor
adj_matrix = adj_matrix.transpose(0, 1)

Expand Down

0 comments on commit c6f8f96

Please sign in to comment.