diff --git a/crystal_diffusion/models/egnn_utils.py b/crystal_diffusion/models/egnn_utils.py index 9188a8e0..d7ad60a5 100644 --- a/crystal_diffusion/models/egnn_utils.py +++ b/crystal_diffusion/models/egnn_utils.py @@ -3,7 +3,8 @@ import torch from crystal_diffusion.models.mace_utils import get_adj_matrix -from crystal_diffusion.utils.basis_transformations import get_positions_from_coordinates +from crystal_diffusion.utils.basis_transformations import \ + get_positions_from_coordinates def unsorted_segment_sum(data: torch.Tensor, segment_ids: torch.Tensor, num_segments: int) -> torch.Tensor: