Skip to content

Commit

Permalink
Merge pull request #32 from khaledkah/fixing-neighbor-lists
Browse files Browse the repository at this point in the history
updating neghbor list
  • Loading branch information
khaledkah authored Jul 9, 2024
2 parents ebde24d + c6b1b5c commit 9f7df9e
Show file tree
Hide file tree
Showing 5 changed files with 40 additions and 9 deletions.
3 changes: 1 addition & 2 deletions src/morered/configs/experiment/vp_gauss_ddpm.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -48,8 +48,7 @@ data:
diffusion_process: ${globals.diffusion_process}
time_key: ${globals.time_target_key}

- _target_: schnetpack.transform.MatScipyNeighborList
cutoff: ${globals.cutoff}
- _target_: morered.transform.AllToAllNeighborList
- _target_: schnetpack.transform.CastTo32

model:
Expand Down
2 changes: 1 addition & 1 deletion src/morered/sampling/ddpm.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,7 +118,7 @@ def denoise(

# set all atoms as neighbors and compute neighbors only once before starting.
if not self.recompute_neighbors:
batch = compute_neighbors(batch, cutoff=9999999.0, device=self.device)
batch = compute_neighbors(batch, fully_connected=True, device=self.device)

# history of the reverse steps
hist = []
Expand Down
2 changes: 1 addition & 1 deletion src/morered/sampling/morered.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,7 @@ def denoise(

# set all atoms as neighbors and compute neighbors only once before starting.
if not self.recompute_neighbors:
batch = compute_neighbors(batch, cutoff=9999999.0, device=self.device)
batch = compute_neighbors(batch, fully_connected=True, device=self.device)

# initialize convergence flag for each molecule
converged = torch.zeros_like(
Expand Down
25 changes: 24 additions & 1 deletion src/morered/transform/transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,30 @@
from morered.processes.base import DiffusionProcess
from morered.utils import batch_center_systems

__all__ = ["BatchSubtractCenterOfMass", "Diffuse"]
__all__ = ["AllToAllNeighborList", "BatchSubtractCenterOfMass", "Diffuse"]


class AllToAllNeighborList(trn.NeighborListTransform):
"""
Calculate a full neighbor list for all atoms in the system.
Faster than other methods and useful for small systems.
"""

def __init__(self):
# pass dummy large cutoff as all neighbors are connceted
super().__init__(cutoff=1e8)

def _build_neighbor_list(self, Z, positions, cell, pbc, cutoff):
n_atoms = Z.shape[0]
idx_i = torch.arange(n_atoms).repeat_interleave(n_atoms)
idx_j = torch.arange(n_atoms).repeat(n_atoms)

mask = idx_i != idx_j
idx_i = idx_i[mask]
idx_j = idx_j[mask]

offset = torch.zeros(n_atoms * (n_atoms - 1), 3, dtype=positions.dtype)
return idx_i, idx_j, offset


class BatchSubtractCenterOfMass(trn.Transform):
Expand Down
17 changes: 13 additions & 4 deletions src/morered/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,14 +4,15 @@
from typing import Dict, Optional

import numpy as np
import schnetpack.transform as trn
import torch
from ase import Atoms, build
from ase.data import chemical_symbols, covalent_radii
from tqdm import tqdm

import schnetpack.transform as trn
from schnetpack import properties
from schnetpack.data.loader import _atoms_collate_fn
from tqdm import tqdm

import morered as mrd
from morered.bonds import allowed_bonds_dict, bonds1, bonds2, bonds3


Expand Down Expand Up @@ -70,6 +71,7 @@ def compute_neighbors(
old_batch,
neighbor_list_trn: Optional[trn.Transform] = None,
cutoff=5.0,
fully_connected=False,
additional_keys=[],
device=None,
):
Expand All @@ -80,6 +82,8 @@ def compute_neighbors(
old_batch: batch of systems to compute the neighbors for
neighbor_list_trn: transform to compute the neighbors
cutoff: cutoff radius for the neighbor list
fully_connected: if True, all atoms are connected to each other.
Ignores the cutoff.
additional_keys: additional keys to be included in the new batch
device: Pytorch device
"""
Expand All @@ -90,7 +94,12 @@ def compute_neighbors(
f_dtype = old_batch[properties.R].dtype

# initialize the neighbor list transform
neighbors_calculator = neighbor_list_trn or trn.MatScipyNeighborList(cutoff=cutoff)
if fully_connected:
neighbors_calculator = mrd.transform.AllToAllNeighborList()
else:
neighbors_calculator = neighbor_list_trn or trn.MatScipyNeighborList(
cutoff=cutoff
)

batch = []

Expand Down

0 comments on commit 9f7df9e

Please sign in to comment.