Skip to content

Commit

Permalink
Merge branch 'cpu_gridify' into 'main'
Browse files Browse the repository at this point in the history
add cpu_gridify option, fix config import, remove openbabel dependency

See merge request jdurrant/deepfrag!1
  • Loading branch information
jdurrant committed May 29, 2021
2 parents 4d81a06 + a54e965 commit 44579ec
Show file tree
Hide file tree
Showing 3 changed files with 247 additions and 73 deletions.
211 changes: 199 additions & 12 deletions leadopt/grid_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -209,17 +209,194 @@ def gpu_gridify(grid, atom_num, atom_coords, atom_mask, layer_offset,
numba.cuda.atomic.max(grid, idx, val)


@numba.jit(nopython=True)
def cpu_gridify(grid, atom_num, atom_coords, atom_mask, layer_offset,
batch_idx, width, res, center, rot,
point_radius, point_type, acc_type
):
"""Adds atoms to the grid in a GPU kernel.
This kernel converts atom coordinate information to 3d voxel information.
Each GPU thread is responsible for one specific grid point. This function
receives a list of atomic coordinates and atom layers and simply iterates
over the list to find nearby atoms and add their effect.
Voxel information is stored in a 5D tensor of type: BxTxNxNxN where:
B = batch size
T = number of atom types (receptor + ligand)
N = grid width (in gridpoints)
Each invocation of this function will write information to a specific batch
index specified by batch_idx. Additionally, the layer_offset parameter can
be set to specify a fixed offset to add to each atom_layer item.
How it works:
1. Each GPU thread controls a single gridpoint. This gridpoint coordinate
is translated to a "real world" coordinate by applying rotation and
translation vectors.
2. Each thread iterates over the list of atoms and checks for atoms within
a threshold to add to the grid.
Args:
grid: DeviceNDArray tensor where grid information is stored
atom_num: number of atoms
atom_coords: array containing (x,y,z) atom coordinates
atom_mask: uint32 array of size atom_num containing a destination
layer bitmask (i.e. if bit k is set, write atom to index k)
layer_offset: a fixed ofset added to each atom layer index
batch_idx: index specifiying which batch to write information to
width: number of grid points in each dimension
res: distance between neighboring grid points in angstroms
(1 == gridpoint every angstrom)
(0.5 == gridpoint every half angstrom, e.g. tighter grid)
center: (x,y,z) coordinate of grid center
rot: (x,y,z,y) rotation quaternion
"""
# x,y,z = numba.cuda.grid(3)
for x in range(width):
for y in range(width):
for z in range(width):

# center around origin
tx = x - (width/2)
ty = y - (width/2)
tz = z - (width/2)

# scale by resolution
tx = tx * res
ty = ty * res
tz = tz * res

# apply rotation vector
aw = rot[0]
ax = rot[1]
ay = rot[2]
az = rot[3]

bw = 0
bx = tx
by = ty
bz = tz

# multiply by rotation vector
cw = (aw * bw) - (ax * bx) - (ay * by) - (az * bz)
cx = (aw * bx) + (ax * bw) + (ay * bz) - (az * by)
cy = (aw * by) + (ay * bw) + (az * bx) - (ax * bz)
cz = (aw * bz) + (az * bw) + (ax * by) - (ay * bx)

# multiply by conjugate
# dw = (cw * aw) - (cx * (-ax)) - (cy * (-ay)) - (cz * (-az))
dx = (cw * (-ax)) + (cx * aw) + (cy * (-az)) - (cz * (-ay))
dy = (cw * (-ay)) + (cy * aw) + (cz * (-ax)) - (cx * (-az))
dz = (cw * (-az)) + (cz * aw) + (cx * (-ay)) - (cy * (-ax))

# apply translation vector
tx = dx + center[0]
ty = dy + center[1]
tz = dz + center[2]

i = 0
while i < atom_num:
# fetch atom
fx, fy, fz = atom_coords[i]
mask = atom_mask[i]
i += 1

# invisible atoms
if mask == 0:
continue

# point radius squared
r = point_radius
r2 = point_radius * point_radius

# quick cube bounds check
if abs(fx-tx) > r2 or abs(fy-ty) > r2 or abs(fz-tz) > r2:
continue

# value to add to this gridpoint
val = 0

if point_type == 0: # POINT_TYPE.EXP
# exponential sphere fill
# compute squared distance to atom
d2 = (fx-tx)**2 + (fy-ty)**2 + (fz-tz)**2
if d2 > r2:
continue

# compute effect
val = math.exp((-2 * d2) / r2)
elif point_type == 1: # POINT_TYPE.SPHERE
# solid sphere fill
# compute squared distance to atom
d2 = (fx-tx)**2 + (fy-ty)**2 + (fz-tz)**2
if d2 > r2:
continue

val = 1
elif point_type == 2: # POINT_TYPE.CUBE
# solid cube fill
val = 1
elif point_type == 3: # POINT_TYPE.GAUSSIAN
# (Ragoza, 2016)
#
# piecewise gaussian sphere fill
# compute squared distance to atom
d2 = (fx-tx)**2 + (fy-ty)**2 + (fz-tz)**2
d = math.sqrt(d2)

if d > r * 1.5:
continue
elif d > r:
val = math.exp(-2.0) * ( (4*d2/r2) - (12*d/r) + 9 )
else:
val = math.exp((-2 * d2) / r2)
elif point_type == 4: # POINT_TYPE.LJ
# (Jimenez, 2017) - DeepSite
#
# LJ potential
# compute squared distance to atom
d2 = (fx-tx)**2 + (fy-ty)**2 + (fz-tz)**2
d = math.sqrt(d2)

if d > r * 1.5:
continue
else:
val = 1 - math.exp(-((r/d)**12))
elif point_type == 5: # POINT_TYPE.DISCRETE
# nearest-gridpoint
# L1 distance
if abs(fx-tx) < (res/2) and abs(fy-ty) < (res/2) and abs(fz-tz) < (res/2):
val = 1

# add value to layers
for k in range(32):
if (mask >> k) & 1:
idx = (batch_idx, layer_offset+k, x, y, z)
if acc_type == 0: # ACC_TYPE.SUM
grid[idx] += val
elif acc_type == 1: # ACC_TYPE.MAX
grid[idx] = max(grid[idx], val)


def mol_gridify(grid, atom_coords, atom_mask, layer_offset, batch_idx,
width, res, center, rot, point_radius, point_type, acc_type):
width, res, center, rot, point_radius, point_type, acc_type,
cpu=False):
"""Wrapper around gpu_gridify.
(See gpu_gridify() for details)
"""
dw = ((width - 1) // GPU_DIM) + 1
gpu_gridify[(dw,dw,dw), (GPU_DIM,GPU_DIM,GPU_DIM)](
grid, len(atom_coords), atom_coords, atom_mask, layer_offset,
batch_idx, width, res, center, rot, point_radius, point_type, acc_type
)
if cpu:
cpu_gridify(
grid, len(atom_coords), atom_coords, atom_mask, layer_offset,
batch_idx, width, res, center, rot, point_radius, point_type, acc_type
)
else:
dw = ((width - 1) // GPU_DIM) + 1
gpu_gridify[(dw,dw,dw), (GPU_DIM,GPU_DIM,GPU_DIM)](
grid, len(atom_coords), atom_coords, atom_mask, layer_offset,
batch_idx, width, res, center, rot, point_radius, point_type, acc_type
)


def make_tensor(shape):
Expand Down Expand Up @@ -371,7 +548,7 @@ def get_batch(data, batch_size=16, batch_set=None, width=48, res=0.5,

def get_raw_batch(r_coords, r_types, p_coords, p_types, rec_typer, lig_typer,
conn, num_samples=32, width=24, res=1, fixed_rot=None,
point_radius=1.5, point_type=0, acc_type=0):
point_radius=1.5, point_type=0, acc_type=0, cpu=False):
"""Sample a raw batch with provided atom coordinates.
Args:
Expand All @@ -383,14 +560,22 @@ def get_raw_batch(r_coords, r_types, p_coords, p_types, rec_typer, lig_typer,
num_samples: number of rotations to sample
width: grid width
res: grid resolution
rec_channels: number of receptor channels
parent_channels: number of parent chanels
fixed_rot: None or a fixed 4-element rotation vector
point_radius: atom radius in Angstroms
point_type: shape of the atom densities
acc_type: atom density accumulation type
cpu: if True, generate batches with cpu_gridify
"""
B = num_samples
T = rec_typer.size() + lig_typer.size()
N = width

torch_grid, cuda_grid = make_tensor((B,T,N,N,N))
if cpu:
t = np.zeros((B,T,N,N,N))
torch_grid = t
cuda_grid = t
else:
torch_grid, cuda_grid = make_tensor((B,T,N,N,N))

r_mask = np.zeros(len(r_types), dtype=np.uint32)
p_mask = np.zeros(len(p_types), dtype=np.uint32)
Expand Down Expand Up @@ -418,7 +603,8 @@ def get_raw_batch(r_coords, r_types, p_coords, p_types, rec_typer, lig_typer,
rot=rot,
point_radius=point_radius,
point_type=point_type,
acc_type=acc_type
acc_type=acc_type,
cpu=cpu
)

mol_gridify(
Expand All @@ -433,7 +619,8 @@ def get_raw_batch(r_coords, r_types, p_coords, p_types, rec_typer, lig_typer,
rot=rot,
point_radius=point_radius,
point_type=point_type,
acc_type=acc_type
acc_type=acc_type,
cpu=cpu
)

return torch_grid
38 changes: 6 additions & 32 deletions leadopt/model_conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,18 +19,22 @@
import torch
import torch.nn as nn
import torch.nn.functional as F
import wandb
import tqdm
import numpy as np

try:
import wandb
except:
pass

from leadopt.models.voxel import VoxelFingerprintNet
from leadopt.data_util import FragmentDataset, SharedFragmentDataset, FingerprintDataset, LIG_TYPER,\
REC_TYPER
from leadopt.grid_util import get_batch
from leadopt.metrics import mse, bce, tanimoto, cos, top_k_acc,\
average_support, inside_support

from config import partitions, moad_partitions
from config import moad_partitions


def get_bios(p):
Expand Down Expand Up @@ -309,36 +313,6 @@ def init_models(self):

def load_data(self):
print('[*] Loading data...', flush=True)
# train_dat = FragmentDataset(
# self._args['fragments'],
# rec_typer=REC_TYPER[self._args['rec_typer']],
# lig_typer=LIG_TYPER[self._args['lig_typer']],
# # filter_rec=(
# # partitions.TRAIN if not self._args['no_partitions'] else None),
# filter_rec=set(get_bios(moad_partitions.TRAIN)),
# filter_smi=set(moad_partitions.TRAIN_SMI),
# fdist_min=self._args['fdist_min'],
# fdist_max=self._args['fdist_max'],
# fmass_min=self._args['fmass_min'],
# fmass_max=self._args['fmass_max'],
# verbose=True
# )

# val_dat = FragmentDataset(
# self._args['fragments'],
# rec_typer=REC_TYPER[self._args['rec_typer']],
# lig_typer=LIG_TYPER[self._args['lig_typer']],
# # filter_rec=(
# # partitions.VAL if not self._args['no_partitions'] else None),
# filter_rec=set(get_bios(moad_partitions.VAL)),
# filter_smi=set(moad_partitions.VAL_SMI),
# fdist_min=self._args['fdist_min'],
# fdist_max=self._args['fdist_max'],
# fmass_min=self._args['fmass_min'],
# fmass_max=self._args['fmass_max'],
# verbose=True
# )

dat = FragmentDataset(
self._args['fragments'],
rec_typer=REC_TYPER[self._args['rec_typer']],
Expand Down
Loading

0 comments on commit 44579ec

Please sign in to comment.