Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

updating neghbor list #32

Merged
merged 1 commit into from
Jul 9, 2024
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 1 addition & 2 deletions src/morered/configs/experiment/vp_gauss_ddpm.yaml
Original file line number Diff line number Diff line change
@@ -48,8 +48,7 @@ data:
diffusion_process: ${globals.diffusion_process}
time_key: ${globals.time_target_key}

- _target_: schnetpack.transform.MatScipyNeighborList
cutoff: ${globals.cutoff}
- _target_: morered.transform.AllToAllNeighborList
- _target_: schnetpack.transform.CastTo32

model:
2 changes: 1 addition & 1 deletion src/morered/sampling/ddpm.py
Original file line number Diff line number Diff line change
@@ -118,7 +118,7 @@ def denoise(

# set all atoms as neighbors and compute neighbors only once before starting.
if not self.recompute_neighbors:
batch = compute_neighbors(batch, cutoff=9999999.0, device=self.device)
batch = compute_neighbors(batch, fully_connected=True, device=self.device)

# history of the reverse steps
hist = []
2 changes: 1 addition & 1 deletion src/morered/sampling/morered.py
Original file line number Diff line number Diff line change
@@ -93,7 +93,7 @@ def denoise(

# set all atoms as neighbors and compute neighbors only once before starting.
if not self.recompute_neighbors:
batch = compute_neighbors(batch, cutoff=9999999.0, device=self.device)
batch = compute_neighbors(batch, fully_connected=True, device=self.device)

# initialize convergence flag for each molecule
converged = torch.zeros_like(
25 changes: 24 additions & 1 deletion src/morered/transform/transforms.py
Original file line number Diff line number Diff line change
@@ -8,7 +8,30 @@
from morered.processes.base import DiffusionProcess
from morered.utils import batch_center_systems

__all__ = ["BatchSubtractCenterOfMass", "Diffuse"]
__all__ = ["AllToAllNeighborList", "BatchSubtractCenterOfMass", "Diffuse"]


class AllToAllNeighborList(trn.NeighborListTransform):
"""
Calculate a full neighbor list for all atoms in the system.
Faster than other methods and useful for small systems.
"""

def __init__(self):
# pass dummy large cutoff as all neighbors are connceted
super().__init__(cutoff=1e8)

def _build_neighbor_list(self, Z, positions, cell, pbc, cutoff):
n_atoms = Z.shape[0]
idx_i = torch.arange(n_atoms).repeat_interleave(n_atoms)
idx_j = torch.arange(n_atoms).repeat(n_atoms)

mask = idx_i != idx_j
idx_i = idx_i[mask]
idx_j = idx_j[mask]

offset = torch.zeros(n_atoms * (n_atoms - 1), 3, dtype=positions.dtype)
return idx_i, idx_j, offset


class BatchSubtractCenterOfMass(trn.Transform):
17 changes: 13 additions & 4 deletions src/morered/utils.py
Original file line number Diff line number Diff line change
@@ -4,14 +4,15 @@
from typing import Dict, Optional

import numpy as np
import schnetpack.transform as trn
import torch
from ase import Atoms, build
from ase.data import chemical_symbols, covalent_radii
from tqdm import tqdm

import schnetpack.transform as trn
from schnetpack import properties
from schnetpack.data.loader import _atoms_collate_fn
from tqdm import tqdm

import morered as mrd
from morered.bonds import allowed_bonds_dict, bonds1, bonds2, bonds3


@@ -70,6 +71,7 @@ def compute_neighbors(
old_batch,
neighbor_list_trn: Optional[trn.Transform] = None,
cutoff=5.0,
fully_connected=False,
additional_keys=[],
device=None,
):
@@ -80,6 +82,8 @@ def compute_neighbors(
old_batch: batch of systems to compute the neighbors for
neighbor_list_trn: transform to compute the neighbors
cutoff: cutoff radius for the neighbor list
fully_connected: if True, all atoms are connected to each other.
Ignores the cutoff.
additional_keys: additional keys to be included in the new batch
device: Pytorch device
"""
@@ -90,7 +94,12 @@ def compute_neighbors(
f_dtype = old_batch[properties.R].dtype

# initialize the neighbor list transform
neighbors_calculator = neighbor_list_trn or trn.MatScipyNeighborList(cutoff=cutoff)
if fully_connected:
neighbors_calculator = mrd.transform.AllToAllNeighborList()
else:
neighbors_calculator = neighbor_list_trn or trn.MatScipyNeighborList(
cutoff=cutoff
)

batch = []