Skip to content

Commit

Permalink
made consistent seqsep for buildDistMatrix and iterNeighbors
Browse files Browse the repository at this point in the history
  • Loading branch information
jamesmkrieger committed Nov 5, 2023
1 parent f411813 commit 7c89ef4
Show file tree
Hide file tree
Showing 3 changed files with 51 additions and 24 deletions.
39 changes: 24 additions & 15 deletions prody/measure/contacts.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@

from prody.atomic import Atomic, Atom, AtomGroup, AtomSubset, Selection
from prody.kdtree import KDTree
from prody.utilities import rangeString
from prody.utilities import rangeString, LOGGER

__all__ = ['Contacts', 'iterNeighbors', 'findNeighbors']

Expand Down Expand Up @@ -131,7 +131,7 @@ def getUnitcell(self):
return self._unitcell.copy()


def iterNeighbors(atoms, radius, atoms2=None, unitcell=None, seqdist=None):
def iterNeighbors(atoms, radius, atoms2=None, unitcell=None, seqsep=None):
"""Yield pairs of *atoms* that are within *radius* of each other and the
distance between them.
Expand All @@ -145,8 +145,8 @@ def iterNeighbors(atoms, radius, atoms2=None, unitcell=None, seqdist=None):
and *unitcell* is not provided, unitcell information from frame will be
used if available.
If *seqdist* is provided, neighbors will be filtered out if the
sequence distance is greater than *seqdist*.
If *seqsep* is provided, neighbors will be kept if the sequence separation >= *seqsep*.
Note that *seqsep* will be ignored if atoms are not provided.
"""

radius = float(radius)
Expand Down Expand Up @@ -197,6 +197,9 @@ def iterNeighbors(atoms, radius, atoms2=None, unitcell=None, seqdist=None):
if coords.ndim == 1:
coords = array([coords])

warned = False
SEQDIST_COORDS_WARNING = 'seqsep is ignored when not using Atomic objects'

if atoms2 is None:
if len(coords) <= 1:
raise ValueError('atoms must be more than 1')
Expand All @@ -206,8 +209,10 @@ def iterNeighbors(atoms, radius, atoms2=None, unitcell=None, seqdist=None):
_dict = {}
if ag is None:
for (i, j), r in zip(*kdtree(radius)):
if seqdist is None or abs(i.getResnum() - j.getResnum()) > seqdist:
yield (i, j, r)
if seqsep is not None and not warned:
LOGGER.warn(SEQDIST_COORDS_WARNING)
warned = True
yield (i, j, r)
else:
for (i, j), r in zip(*kdtree(radius)):
a1 = _dict.get(i)
Expand All @@ -218,7 +223,7 @@ def iterNeighbors(atoms, radius, atoms2=None, unitcell=None, seqdist=None):
if a2 is None:
a2 = Atom(ag, index(j), acsi)
_dict[j] = a2
if seqdist is None or abs(a1.getResnum() - a2.getResnum()) > seqdist:
if seqsep is None or abs(a1.getResnum() - a2.getResnum()) >= seqsep:
yield (a1, a2, r)
else:
try:
Expand Down Expand Up @@ -260,38 +265,42 @@ def iterNeighbors(atoms, radius, atoms2=None, unitcell=None, seqdist=None):
if ag is None or ag2 is None:
for j, xyz in enumerate(coords2):
for i, r in zip(*kdtree(radius, xyz)):
if seqdist is None or abs(i.getResnum() - j.getResnum()) > seqdist:
yield (i, j, r)
if seqsep is not None and not warned:
LOGGER.warn(SEQDIST_COORDS_WARNING)
warned = True
yield (i, j, r)
else:
for a2 in atoms2.iterAtoms():
for i, r in zip(*kdtree(radius, a2._getCoords())):
a1 = _dict.get(i)
if a1 is None:
a1 = Atom(ag, index(i), acsi)
_dict[i] = a1
if seqdist is None or abs(a1.getResnum() - a2.getResnum()) > seqdist:
if seqsep is None or abs(a1.getResnum() - a2.getResnum()) < seqsep:
yield (a1, a2, r)
else:
kdtree = KDTree(coords2, unitcell=unitcell, none=list)
_dict = {}
if ag is None or ag2 is None:
for i, xyz in enumerate(coords):
for j, r in zip(*kdtree(radius, xyz)):
if seqdist is None or abs(i.getResnum() - j.getResnum()) > seqdist:
yield (i, j, r)
if seqsep is not None and not warned:
LOGGER.warn(SEQDIST_COORDS_WARNING)
warned = True
yield (i, j, r)
else:
for a1 in atoms.iterAtoms():
for i, r in zip(*kdtree(radius, a1._getCoords())):
a2 = _dict.get(i)
if a2 is None:
a2 = Atom(ag2, index2(i), acsi2)
_dict[i] = a2
if seqdist is None or abs(a1.getResnum() - a2.getResnum()) > seqdist:
if seqsep is None or abs(a1.getResnum() - a2.getResnum()) < seqsep:
yield (a1, a2, r)


def findNeighbors(atoms, radius, atoms2=None, unitcell=None, seqdist=None):
def findNeighbors(atoms, radius, atoms2=None, unitcell=None, seqsep=None):
"""Returns list of neighbors that are within *radius* of each other and the
distance between them. See :func:`iterNeighbors` for more details."""

return list(iterNeighbors(atoms, radius, atoms2, unitcell, seqdist))
return list(iterNeighbors(atoms, radius, atoms2, unitcell, seqsep))
27 changes: 18 additions & 9 deletions prody/measure/measure.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@
DISTMAT_FORMATS = set(['mat', 'rcd', 'arr'])


def buildDistMatrix(atoms1, atoms2=None, unitcell=None, format='mat'):
def buildDistMatrix(atoms1, atoms2=None, unitcell=None, format='mat', seqsep=None):
"""Returns distance matrix. When *atoms2* is given, a distance matrix
with shape ``(len(atoms1), len(atoms2))`` is built. When *atoms2* is
**None**, a symmetric matrix with shape ``(len(atoms1), len(atoms1))``
Expand All @@ -50,7 +50,16 @@ def buildDistMatrix(atoms1, atoms2=None, unitcell=None, format='mat'):
:arg format: format of the resulting array, one of ``'mat'`` (matrix,
default), ``'rcd'`` (arrays of row indices, column indices, and
distances), or ``'arr'`` (only array of distances)
:type format: bool"""
:type format: bool
:arg seqsep: if provided, distances will only be measured between atoms
with resnum differences that are greater than or equal to seqsep.
:type seqsep: int
"""

spacing = 1
if seqsep is not None:
spacing = seqsep - 1

if not isinstance(atoms1, ndarray):
try:
Expand Down Expand Up @@ -86,10 +95,10 @@ def buildDistMatrix(atoms1, atoms2=None, unitcell=None, format='mat'):
raise ValueError('format must be one of mat, rcd, or arr')
if format == 'mat':
for i, xyz in enumerate(atoms1[:-1]):
dist[i, i+1:] = dist[i+1:, i] = getDistance(xyz, atoms2[i+1:],
dist[i, i+spacing:] = dist[i+spacing:, i] = getDistance(xyz, atoms2[i+spacing:],
unitcell)
else:
dist = concatenate([getDistance(xyz, atoms2[i+1:], unitcell)
dist = concatenate([getDistance(xyz, atoms2[i+spacing:], unitcell)
for i, xyz in enumerate(atoms1)])
if format == 'rcd':
n_atoms = len(atoms1)
Expand Down Expand Up @@ -720,7 +729,7 @@ def calcADPAxes(atoms, **kwargs):
# Make sure the direction that correlates with the previous atom
# is selected
vals = vals * sign((vecs * axes[(i-1)*3:(i)*3, :]).sum(0))
axes[i*3:(i+1)*3, :] = vals * vecs
axes[i*3:(i+spacing)*3, :] = vals * vecs
# Resort the columns before returning array
axes = axes[:, [2, 1, 0]]
torf = None
Expand Down Expand Up @@ -802,7 +811,7 @@ def buildADPMatrix(atoms):
element[0, 1] = element[1, 0] = anisou[3]
element[0, 2] = element[2, 0] = anisou[4]
element[1, 2] = element[2, 1] = anisou[5]
adp[i*3:(i+1)*3, i*3:(i+1)*3] = element
adp[i*3:(i+spacing)*3, i*3:(i+spacing)*3] = element
return adp


Expand Down Expand Up @@ -860,7 +869,7 @@ def calcDistanceMatrix(coords, cutoff=None):
r += 1

for i in range(n_atoms):
for j in range(i+1, n_atoms):
for j in range(i+spacing, n_atoms):
if dist_mat[i, j] == 0.:
dist_mat[i, j] = dist_mat[j, i] = max(dists)

Expand Down Expand Up @@ -1015,7 +1024,7 @@ def assignBlocks(atoms, res_per_block=None, secstr=False, **kwargs):
block = where(blocks == i)[0]
if len(block) < shortest_block:
block_im1 = where(blocks == i-1)[0]
block_ip1 = where(blocks == i+1)[0]
block_ip1 = where(blocks == i+spacing)[0]

dist_back = calcDistance(atoms[block_im1][-1], atoms[block][0])
dist_fwd = calcDistance(atoms[block][-1], atoms[block_ip1][0])
Expand All @@ -1025,7 +1034,7 @@ def assignBlocks(atoms, res_per_block=None, secstr=False, **kwargs):
blocks[where(blocks == i)[0]] = i-1
elif dist_fwd < min_dist_cutoff:
# join onto next block
blocks[where(blocks == i)[0]] = i+1
blocks[where(blocks == i)[0]] = i+spacing

blocks, amap = extendAtomicData(blocks, sel_ca, atoms)

Expand Down
9 changes: 9 additions & 0 deletions prody/tests/measure/test_contacts.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,8 @@
UBI_CONTACTS = Contacts(UBI_XYZ)
UBI_CONTACTS_PBC = Contacts(UBI_XYZ, UBI_UC)

UBI_ADDH = parseDatafile('1ubi_addH')


class TestContacts(unittest.TestCase):

Expand Down Expand Up @@ -92,6 +94,13 @@ def testAtomicArgumentSwitching(self):
neighbors2.sort()
self.assertEqual(neighbors1, neighbors2)

def testSeqsep(self):

neighbors = findNeighbors(UCA, UCA_RADIUS, seqsep=3)
n_neighbors = (buildDistMatrix(UCA, seqsep=3,
format='arr') <= UCA_RADIUS).sum()
assert_equal(len(neighbors), n_neighbors)

def testPBCCoordArgumentSwitching(self):

dist = 12.
Expand Down

0 comments on commit 7c89ef4

Please sign in to comment.