From c6f8f9607375fdbe653d597c6357a384f403bc8e Mon Sep 17 00:00:00 2001 From: Simon Blackburn Date: Fri, 11 Oct 2024 08:37:22 -0400 Subject: [PATCH] code review --- crystal_diffusion/models/egnn_utils.py | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/crystal_diffusion/models/egnn_utils.py b/crystal_diffusion/models/egnn_utils.py index c9dc3a3d..9188a8e0 100644 --- a/crystal_diffusion/models/egnn_utils.py +++ b/crystal_diffusion/models/egnn_utils.py @@ -3,6 +3,7 @@ import torch from crystal_diffusion.models.mace_utils import get_adj_matrix +from crystal_diffusion.utils.basis_transformations import get_positions_from_coordinates def unsorted_segment_sum(data: torch.Tensor, segment_ids: torch.Tensor, num_segments: int) -> torch.Tensor: @@ -103,16 +104,17 @@ def get_edges_with_radial_cutoff(relative_coordinates: torch.Tensor, unit_cell: Returns: long tensor of size [number of edges, 2] with edge indices """ - cartesian_coordinates = relative_coordinates @ unit_cell # convert back to cartesian coordinates + # get cartesian coordinates from relative coordinates + cartesian_coordinates = get_positions_from_coordinates(relative_coordinates, unit_cell) adj_matrix, _, _, _ = get_adj_matrix(cartesian_coordinates, unit_cell, radial_cutoff) # adj_matrix is a n_edges x 2 tensor with duplicates with different shifts. # the uplifting in 2 x spatial_dimension manages the shifts in a natural way. This means we can ignore the shifts # and possibly ignore the multiplicities i.e. no need to sum twice the contribution of a neighbor that we see - # in the unitcell and in a shifted unit cell. + # in the unit cell and in a shifted unit cell. # TODO check this statement - test with and without multiplicities - just remove the duplicate drop that follows to # test the w/o multiplicities case if drop_duplicate_edges: - adj_matrix = torch.unique(adj_matrix, dim=1) # compare x[0,:] to x[1, :] to find duplicates and drop them + adj_matrix = torch.unique(adj_matrix, dim=1) # MACE adj calculations returns a (2, n_edges) tensor and EGNN expects a (n_edges, 2) tensor adj_matrix = adj_matrix.transpose(0, 1)