diff --git a/src/morered/configs/experiment/vp_gauss_ddpm.yaml b/src/morered/configs/experiment/vp_gauss_ddpm.yaml index 5df000b..42c1055 100644 --- a/src/morered/configs/experiment/vp_gauss_ddpm.yaml +++ b/src/morered/configs/experiment/vp_gauss_ddpm.yaml @@ -48,8 +48,7 @@ data: diffusion_process: ${globals.diffusion_process} time_key: ${globals.time_target_key} - - _target_: schnetpack.transform.MatScipyNeighborList - cutoff: ${globals.cutoff} + - _target_: morered.transform.AllToAllNeighborList - _target_: schnetpack.transform.CastTo32 model: diff --git a/src/morered/sampling/ddpm.py b/src/morered/sampling/ddpm.py index e38c54a..f4658e6 100644 --- a/src/morered/sampling/ddpm.py +++ b/src/morered/sampling/ddpm.py @@ -118,7 +118,7 @@ def denoise( # set all atoms as neighbors and compute neighbors only once before starting. if not self.recompute_neighbors: - batch = compute_neighbors(batch, cutoff=9999999.0, device=self.device) + batch = compute_neighbors(batch, fully_connected=True, device=self.device) # history of the reverse steps hist = [] diff --git a/src/morered/sampling/morered.py b/src/morered/sampling/morered.py index d81cd05..530a885 100644 --- a/src/morered/sampling/morered.py +++ b/src/morered/sampling/morered.py @@ -93,7 +93,7 @@ def denoise( # set all atoms as neighbors and compute neighbors only once before starting. if not self.recompute_neighbors: - batch = compute_neighbors(batch, cutoff=9999999.0, device=self.device) + batch = compute_neighbors(batch, fully_connected=True, device=self.device) # initialize convergence flag for each molecule converged = torch.zeros_like( diff --git a/src/morered/transform/transforms.py b/src/morered/transform/transforms.py index b81960a..46c394e 100644 --- a/src/morered/transform/transforms.py +++ b/src/morered/transform/transforms.py @@ -8,7 +8,30 @@ from morered.processes.base import DiffusionProcess from morered.utils import batch_center_systems -__all__ = ["BatchSubtractCenterOfMass", "Diffuse"] +__all__ = ["AllToAllNeighborList", "BatchSubtractCenterOfMass", "Diffuse"] + + +class AllToAllNeighborList(trn.NeighborListTransform): + """ + Calculate a full neighbor list for all atoms in the system. + Faster than other methods and useful for small systems. + """ + + def __init__(self): + # pass dummy large cutoff as all neighbors are connceted + super().__init__(cutoff=1e8) + + def _build_neighbor_list(self, Z, positions, cell, pbc, cutoff): + n_atoms = Z.shape[0] + idx_i = torch.arange(n_atoms).repeat_interleave(n_atoms) + idx_j = torch.arange(n_atoms).repeat(n_atoms) + + mask = idx_i != idx_j + idx_i = idx_i[mask] + idx_j = idx_j[mask] + + offset = torch.zeros(n_atoms * (n_atoms - 1), 3, dtype=positions.dtype) + return idx_i, idx_j, offset class BatchSubtractCenterOfMass(trn.Transform): diff --git a/src/morered/utils.py b/src/morered/utils.py index 306d9d0..d426188 100644 --- a/src/morered/utils.py +++ b/src/morered/utils.py @@ -4,14 +4,15 @@ from typing import Dict, Optional import numpy as np +import schnetpack.transform as trn import torch from ase import Atoms, build from ase.data import chemical_symbols, covalent_radii -from tqdm import tqdm - -import schnetpack.transform as trn from schnetpack import properties from schnetpack.data.loader import _atoms_collate_fn +from tqdm import tqdm + +import morered as mrd from morered.bonds import allowed_bonds_dict, bonds1, bonds2, bonds3 @@ -70,6 +71,7 @@ def compute_neighbors( old_batch, neighbor_list_trn: Optional[trn.Transform] = None, cutoff=5.0, + fully_connected=False, additional_keys=[], device=None, ): @@ -80,6 +82,8 @@ def compute_neighbors( old_batch: batch of systems to compute the neighbors for neighbor_list_trn: transform to compute the neighbors cutoff: cutoff radius for the neighbor list + fully_connected: if True, all atoms are connected to each other. + Ignores the cutoff. additional_keys: additional keys to be included in the new batch device: Pytorch device """ @@ -90,7 +94,12 @@ def compute_neighbors( f_dtype = old_batch[properties.R].dtype # initialize the neighbor list transform - neighbors_calculator = neighbor_list_trn or trn.MatScipyNeighborList(cutoff=cutoff) + if fully_connected: + neighbors_calculator = mrd.transform.AllToAllNeighborList() + else: + neighbors_calculator = neighbor_list_trn or trn.MatScipyNeighborList( + cutoff=cutoff + ) batch = []