Skip to content

Commit

Permalink
Merge pull request prody#1849 from jamesmkrieger/logistic
Browse files Browse the repository at this point in the history
Logistic regression for differences between sub-ensembles
  • Loading branch information
jamesmkrieger authored Apr 16, 2024
2 parents e8d68bc + 0c6b105 commit 29610e7
Show file tree
Hide file tree
Showing 5 changed files with 244 additions and 41 deletions.
68 changes: 36 additions & 32 deletions prody/database/bioexcel.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@

from prody.atomic.atomgroup import AtomGroup
from prody.atomic.functions import extendAtomicData
from prody.proteins.pdbfile import parsePDB, _parsePDBLines
from prody.proteins.pdbfile import parsePDB
from prody.trajectory.psffile import parsePSF, writePSF
from prody.trajectory.dcdfile import parseDCD

Expand Down Expand Up @@ -58,14 +58,7 @@ def fetchBioexcelPDB(acc, **kwargs):
if selection is not None:
url += '?selection=' + selection.replace(" ","%20")

response = requestFromUrl(url, timeout, source='pdb')

if PY3K:
response = response.decode()

fo = open(filepath, 'w')
fo.write(response)
fo.close()
filepath = requestFromUrl(url, timeout, filepath, source='pdb')

return filepath

Expand Down Expand Up @@ -118,11 +111,7 @@ def fetchBioexcelTrajectory(acc, **kwargs):
if selection is not None:
url += '&selection=' + selection.replace(" ","%20")

response = requestFromUrl(url, timeout, source='xtc')

fo = open(filepath, 'wb')
fo.write(response)
fo.close()
filepath = requestFromUrl(url, timeout, filepath, source='xtc')

if convert:
filepath = convertXtcToDcd(filepath, **kwargs)
Expand Down Expand Up @@ -162,14 +151,7 @@ def fetchBioexcelTopology(acc, **kwargs):

if not isfile(filepath):
url = prefix + acc + "/topology"
response = requestFromUrl(url, timeout, source='json')

if PY3K:
response = response.decode()

fo = open(filepath, 'w')
fo.write(response)
fo.close()
filepath = requestFromUrl(url, timeout, filepath, source='json')

if convert:
ag = parseBioexcelTopology(filepath, **kwargs)
Expand Down Expand Up @@ -269,10 +251,16 @@ def parseBioexcelPDB(query, **kwargs):
fetching it if needed using **kwargs
"""
kwargs['convert'] = True
if not isfile(query):
filename = fetchBioexcelPDB(query, **kwargs)
else:
if isfile(query):
filename = query
elif isfile(query + '.pdb'):
filename = query + '.pdb'
else:
filename = fetchBioexcelPDB(query, **kwargs)

ag = parsePDB(filename)
if ag is None:
filename = fetchBioexcelPDB(query, **kwargs)

return parsePDB(filename)

Expand Down Expand Up @@ -301,7 +289,7 @@ def convertXtcToDcd(filepath, **kwargs):

return filepath

def requestFromUrl(url, timeout, source=None):
def requestFromUrl(url, timeout, filepath, source=None):
"""Helper function to make a request from a url and return the response"""
import requests
import json
Expand All @@ -320,15 +308,31 @@ def requestFromUrl(url, timeout, source=None):
if source == 'json':
json.loads(response)

if PY3K:
response = response.decode()

fo = open(filepath, 'w')
fo.write(response)
fo.close()

elif source == 'xtc':
ftmp = tempfile.NamedTemporaryFile()
ftmp.write(response, 'wb')
ftmp.close()
fo = open(filepath, 'wb')
fo.write(response)
fo.close()

top = mdtraj.load_psf(fetchBioexcelTopology(acc))
mdtraj.load_xtc(ftmp.name, top=top)
mdtraj.load_xtc(filepath, top=top)

elif source == 'pdb':
_parsePDBLines(response)
if PY3K:
response = response.decode()

fo = open(filepath, 'w')
fo.write(response)
fo.close()

ag = parsePDB(filepath)
numAtoms = ag.numAtoms()

except Exception:
pass
Expand All @@ -338,7 +342,7 @@ def requestFromUrl(url, timeout, source=None):
sleep = 20 if int(sleep * 1.5) >= 20 else int(sleep * 1.5)
LOGGER.sleep(int(sleep), '. Trying to reconnect...')

return response
return filepath

def checkSelection(**kwargs):
"""Helper function to check selection"""
Expand Down
4 changes: 4 additions & 0 deletions prody/dynamics/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -363,3 +363,7 @@
from . import lda
from .lda import *
__all__.extend(lda.__all__)

from . import logistic
from .logistic import *
__all__.extend(logistic.__all__)
26 changes: 20 additions & 6 deletions prody/dynamics/functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
from .rtb import RTB
from .pca import PCA, EDA
from .lda import LDA
from .logistic import LRA
from .imanm import imANM
from .exanm import exANM
from .mode import Vector, Mode, VectorBase
Expand Down Expand Up @@ -86,8 +87,14 @@ def saveModel(nma, filename=None, matrices=False, **kwargs):
type_ = 'PCA'
elif isinstance(nma, LDA):
type_ = 'LDA'
attr_list.append('_lda')
attr_list.append('_labels')
attr_list.append('_shuffled_ldas')
elif isinstance(nma, LRA):
type_ = 'LRA'
attr_list.append('_lra')
attr_list.append('_labels')
attr_list.append('_shuffled_lras')
else:
type_ = 'NMA'

Expand Down Expand Up @@ -190,6 +197,8 @@ def loadModel(filename, **kwargs):
nma = RTB(title)
elif type_ == 'LDA':
nma = LDA(title)
elif type_ == 'LRA':
nma = LRA(title)
else:
raise IOError('NMA model type is not recognized: {0}'.format(type_))

Expand All @@ -211,6 +220,11 @@ def loadModel(filename, **kwargs):
else:
dict_[attr] = attr_dict[attr]

if '_shuffled_ldas' in nma.__dict__:
nma._shuffled_ldas = [arr[0].getModel() for arr in nma._shuffled_ldas]
elif '_shuffled_lras' in nma.__dict__:
nma._shuffled_lras = [arr[0].getModel() for arr in nma._shuffled_lras]

return nma


Expand Down Expand Up @@ -940,8 +954,8 @@ def parseGromacsModes(run_path, title="", model='nma', **kwargs):

if isfile(eigval_fname):
vals_fname = eigval_fname
elif isfile(run_path + eigval_fname):
vals_fname = run_path + eigval_fname
elif isfile(join(run_path, eigval_fname)):
vals_fname = join(run_path, eigval_fname)
else:
raise ValueError('eigval_fname should point be a path to a file '
'either relative to run_path or an absolute one')
Expand All @@ -953,8 +967,8 @@ def parseGromacsModes(run_path, title="", model='nma', **kwargs):

if isfile(eigvec_fname):
vecs_fname = eigval_fname
elif isfile(run_path + eigvec_fname):
vecs_fname = run_path + eigvec_fname
elif isfile(join(run_path, eigvec_fname)):
vecs_fname = join(run_path, eigvec_fname)
else:
raise ValueError('eigvec_fname should point be a path to a file '
'either relative to run_path or an absolute one')
Expand All @@ -966,8 +980,8 @@ def parseGromacsModes(run_path, title="", model='nma', **kwargs):

if isfile(pdb_fname):
pdb = eigval_fname
elif isfile(run_path + pdb_fname):
pdb = run_path + pdb_fname
elif isfile(join(run_path, pdb_fname)):
pdb = join(run_path, pdb_fname)
else:
raise ValueError('pdb_fname should point be a path to a file '
'either relative to run_path or an absolute one')
Expand Down
180 changes: 180 additions & 0 deletions prody/dynamics/logistic.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,180 @@
# -*- coding: utf-8 -*-
"""This module defines a class for logistic regression classification calculations."""

import time

import numpy as np

from prody import LOGGER
from prody.atomic import Atomic
from prody.ensemble import Ensemble
from prody.utilities import isListLike

from .nma import NMA

__all__ = ['LRA']


class LRA(NMA):

"""A class for logistic regression classification of conformational
ensembles. See examples in :ref:`pca`."""

def __init__(self, name='Unknown'):
NMA.__init__(self, name)

def calcModes(self, coordsets, labels, lasso=True, **kwargs):
"""Calculate logistic regression classification modes between classes.
This method uses :class:`sklearn.linear_model.LogisticRegression`
on coordsets with class labels.
*coordsets* argument may be one of the following:
* :class:`.Atomic`
* :class:`.Ensemble`
* :class:`.TrajBase`
* :class:`numpy.ndarray` with shape ``(n_csets, n_atoms, 3)``
:arg labels: a set of labels for discriminating classes
:type labels: :class:`~numpy.ndarray`
:arg lasso: whether to use lasso regression (sets penalty='l1', solver='liblinear')
Default **True**
:type lasso: bool
:arg n_shuffles: number of random shuffles of labels to assess variability
:type n_shuffles: int
Other kwargs for the LogisticRegression class can also be used
"""
try:
from sklearn.linear_model import LogisticRegression
except ImportError:
raise ImportError("Please install sklearn to use LogisticRegression")

start = time.time()
self._clear()

if isinstance(coordsets, np.ndarray):
if (coordsets.ndim != 3 or coordsets.shape[2] != 3 or
coordsets.dtype not in (np.float32, float)):
raise ValueError('coordsets is not a valid coordinate array')
self._coordsets = coordsets
elif isinstance(coordsets, Atomic):
self._coordsets = coordsets._getCoordsets()
elif isinstance(coordsets, Ensemble):
self._coordsets = coordsets._getCoordsets()
else:
raise TypeError('coordsets should be Atomic, Ensemble or numpy.ndarray, not {0}'
.format(type(coordsets)))

nconfs = self._coordsets.shape[0]
if not isListLike(labels):
raise TypeError('labels must be either a list or a numpy.ndarray, not {0}'
.format(type(labels)))
if not isinstance(labels, np.ndarray):
labels = np.array(labels)
if labels.ndim != 1 or len(labels) != nconfs:
raise ValueError('labels should have same number as conformers')

self._n_atoms = self._coordsets.shape[1]

self._coordsets = self._coordsets.reshape(nconfs, -1)
self._labels = labels

quiet = kwargs.pop('quiet', False)

self._n_shuffles = kwargs.pop('n_shuffles', 0)

if lasso:
if 'penalty' not in kwargs:
kwargs['penalty'] ='l1'
else:
LOGGER.warn('using provided penalty kwarg instead of l1 from lasso')

if 'solver' not in kwargs:
kwargs['solver'] ='liblinear'
else:
LOGGER.warn('using provided solver kwarg instead of liblinear from lasso')

self._lra = LogisticRegression(**kwargs)
self._projection = self._lra.fit(self._coordsets, self._labels)
self._array = self._lra.coef_.T/np.linalg.norm(self._lra.coef_)
self._eigvals = np.ones(1)
self._vars = np.ones(1)

self._n_modes = 1

if not quiet:
if self._n_modes > 1:
LOGGER.debug('{0} modes were calculated in {1:.2f}s.'
.format(self._n_modes, time.time()-start))
else:
LOGGER.debug('{0} mode was calculated in {1:.2f}s.'
.format(self._n_modes, time.time()-start))

if self._n_shuffles > 0:
if self._n_modes > 1:
LOGGER.debug('Calculating {0} modes for {1} shuffles.'
.format(self._n_modes, self._n_shuffles))
else:
LOGGER.debug('Calculating {0} mode for {1} shuffles.'
.format(self._n_modes, self._n_shuffles))

self._shuffled_lras = [LRA('shuffle '+str(n)) for n in range(self._n_shuffles)]
self._coordsets_reshaped = self._coordsets.reshape(self._coordsets.shape[0], self._n_atoms, -1)

n = 0
while n < self._n_shuffles:
labelsNew = self._labels.copy()
# use random generator with None,
# then fresh, unpredictable entropy will be pulled from the OS
rng = np.random.default_rng()
rng.shuffle(labelsNew) # in place

self._shuffled_lras[n].calcModes(self._coordsets_reshaped,
labelsNew, quiet=True)

if np.allclose(abs(np.dot(self._shuffled_lras[n].getEigvecs()[0],
self.getEigvecs()[0])),
1):
# LDA has flipped direction as labels match or are exactly flipped
continue

n += 1

if self._n_shuffles > 0 and not quiet:
if self._n_modes > 1:
LOGGER.debug('{0} modes were calculated with {1} shuffles in {2:.2f}s.'
.format(self._n_modes, self._n_shuffles, time.time()-start))
else:
LOGGER.debug('{0} mode was calculated with {1} shuffles in {2:.2f}s.'
.format(self._n_modes, self._n_shuffles, time.time()-start))

def addEigenpair(self, eigenvector, eigenvalue=None):
"""Add eigen *vector* and eigen *value* pair(s) to the instance.
If eigen *value* is omitted, it will be set to 1. Eigenvalues
are set as variances."""

NMA.addEigenpair(self, eigenvector, eigenvalue)
self._vars = self._eigvals

def setEigens(self, vectors, values=None):
"""Set eigen *vectors* and eigen *values*. If eigen *values* are
omitted, they will be set to 1. Eigenvalues are set as variances."""

self._clear()
NMA.setEigens(self, vectors, values)
self._vars = self._eigvals

def getShuffledModes(self):
return self._shuffled_lras.copy()

def getShuffledEigvecs(self):
return np.array([lda.getEigvecs() for lda in self._shuffled_lras])

def getShuffledPercentile(self, percentile, take_abs=True):
shuffles = self.getShuffledEigvecs()
if take_abs:
shuffles = abs(shuffles)
return np.percentile(shuffles, percentile)
Loading

0 comments on commit 29610e7

Please sign in to comment.