Skip to content

Commit

Permalink
modify common/nblist.py, add the class NeighborListDp
Browse files Browse the repository at this point in the history
  • Loading branch information
okihane committed Dec 25, 2024
1 parent 12967a3 commit 7d4336a
Showing 1 changed file with 73 additions and 6 deletions.
79 changes: 73 additions & 6 deletions dmff/common/nblist.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,73 @@
freud = None
import warnings
warnings.warn("WARNING: freud not installed, users need to create neighbor list by themselves.")
try:
import dpnblist
except ImportError:
dpnblist = None
import warnings
warnings.warn("WARNING: dpdpnblist not installed, users need to create neighbor list by themselves.")

class NeighborListDp:
def __init__(self, alg_type, box, rcut, cov_map, padding=True):
if dpnblist is None:
raise ImportError("dpnblist not installed.")
self.box = dpnblist.Box([box[0][0], box[1][1], box[2][2]], [90.0, 90.0, 90.0])
self.nb = dpnblist.NeighborList(alg_type)
self.flag = False
self.rcut = rcut
self.capacity_multiplier = None
self.padding = padding
self.cov_map = cov_map

def _do_cov_map(self, pairs):
nbond = self.cov_map[pairs[:, 0], pairs[:, 1]]
pairs = jnp.concatenate([pairs, nbond[:, None]], axis=1)
return pairs

def allocate(self, coords, box=None):
self._positions = coords # cache it
dbox = dpnblist.Box([box[0][0], box[1][1], box[2][2]], [90.0, 90.0, 90.0]) if box is not None else self.box
self.nb.build(dbox, coords, self.rcut)
pair = self.nb.get_neighbor_pair()
nlist = np.vstack((pair[:, 0], pair[:, 1])).T
nlist = nlist.astype(np.int32)
msk = (nlist[:, 0] - nlist[:, 1]) < 0
nlist = nlist[msk]
if self.capacity_multiplier is None:
self.capacity_multiplier = int(nlist.shape[0] * 1.3)

if not self.padding:
self._pairs = self._do_cov_map(nlist)
return self._pairs

self.capacity_multiplier = max(self.capacity_multiplier, nlist.shape[0])
padding_width = self.capacity_multiplier - nlist.shape[0]
if padding_width == 0:
self._pairs = self._do_cov_map(nlist)
return self._pairs
elif padding_width > 0:
padding = np.ones((self.capacity_multiplier - nlist.shape[0], 2), dtype=np.int32) * coords.shape[0]
nlist = np.vstack((nlist, padding))
self._pairs = self._do_cov_map(nlist)
return self._pairs
else:
raise ValueError("padding width < 0")

def update(self, positions, box=None):
self.allocate(positions, box)

@property
def pairs(self):
return self._pairs

@property
def scaled_pairs(self):
return self._pairs

@property
def positions(self):
return self._positions


class NeighborListFreud:
Expand Down Expand Up @@ -69,10 +136,10 @@ def scaled_pairs(self):
def positions(self):
return self._positions


class NeighborList(NeighborListFreud):
...


class NoCutoffNeighborList:

def __init__(self, cov_map, padding=True):
Expand All @@ -88,8 +155,8 @@ def _do_cov_map(self, pairs):
def allocate(self, coords, box=None):
self._positions = coords # cache it
natoms = coords.shape[0]
nblist = np.fromiter(permutations(range(natoms), 2), dtype=np.dtype((int, 2)))
nlist = nblist[nblist[:, 0] < nblist[:, 1]]
dpnblist = np.fromiter(permutations(range(natoms), 2), dtype=np.dtype(int, 2))
nlist = dpnblist[dpnblist[:, 0] < dpnblist[:, 1]]
if self.capacity_multiplier is None:
self.capacity_multiplier = int(nlist.shape[0] * 1.3)

Expand Down Expand Up @@ -135,8 +202,8 @@ def __init__(self, rcut, cov_map, padding=True):
def allocate(self, coords):
self._positions = coords # cache it
natoms = coords.shape[0]
nblist = np.fromiter(permutations(range(natoms), 2), dtype=np.dtype(int, 2))
nlist = nblist[nblist[:, 0] < nblist[:, 1]]
dpnblist = np.fromiter(permutations(range(natoms), 2), dtype=np.dtype(int, 2))
nlist = dpnblist[dpnblist[:, 0] < dpnblist[:, 1]]
distances = np.linalg.norm(coords[nlist[:, 0]] - coords[nlist[:, 1]], axis=1)
nlist = nlist[distances < self.rcut]
if self.capacity_multiplier is None:
Expand All @@ -157,4 +224,4 @@ def allocate(self, coords):
self._pairs = self._do_cov_map(nlist)
return self._pairs
else:
raise ValueError("padding width < 0")
raise ValueError("padding width < 0")

0 comments on commit 7d4336a

Please sign in to comment.