From 4ceef5926632acd7d96ea25a28d8d2b789491902 Mon Sep 17 00:00:00 2001 From: "David W.H. Swenson" Date: Mon, 22 Jan 2018 19:17:10 +0100 Subject: [PATCH] [codeclim] Refactor to separate ContactCount --- contact_map/__init__.py | 4 +- contact_map/contact_count.py | 211 ++++++++++++++++++++++++ contact_map/contact_map.py | 209 +---------------------- contact_map/tests/test_contact_count.py | 141 ++++++++++++++++ contact_map/tests/test_contact_map.py | 137 +-------------- docs/api.rst | 4 +- 6 files changed, 362 insertions(+), 344 deletions(-) create mode 100644 contact_map/contact_count.py create mode 100644 contact_map/tests/test_contact_count.py diff --git a/contact_map/__init__.py b/contact_map/__init__.py index 76c6042..96a24f9 100644 --- a/contact_map/__init__.py +++ b/contact_map/__init__.py @@ -6,9 +6,11 @@ __version__ = version.version from .contact_map import ( - ContactMap, ContactFrequency, ContactDifference, ContactCount + ContactMap, ContactFrequency, ContactDifference ) +from .contact_count import ContactCount + from .min_dist import NearestAtoms, MinimumDistanceCounter from .dask_runner import DaskContactFrequency diff --git a/contact_map/contact_count.py b/contact_map/contact_count.py new file mode 100644 index 0000000..d0876cf --- /dev/null +++ b/contact_map/contact_count.py @@ -0,0 +1,211 @@ +import scipy +import numpy as np +import pandas as pd + +from .plot_utils import ranged_colorbar + +# matplotlib is technically optional, but required for plotting +try: + import matplotlib + import matplotlib.pyplot as plt +except ImportError: + HAS_MATPLOTLIB = False +else: + HAS_MATPLOTLIB = True + +def _colorbar(with_colorbar, cmap_f, norm, min_val): + if with_colorbar is False: + return None + elif with_colorbar is True: + cbmin = np.floor(min_val) # [-1.0..0.0] => -1; [0.0..1.0] => 0 + cbmax = 1.0 + cb = ranged_colorbar(cmap_f, norm, cbmin, cbmax) + # leave open other inputs to be parsed later (like tuples) + return cb + +class ContactCount(object): + """Return object when dealing with contacts (residue or atom). + + This contains all the information about the contacts of a given type. + This information can be represented several ways. One is as a list of + contact pairs, each associated with the fraction of time the contact + occurs. Another is as a matrix, where the rows and columns label the + pair number, and the value is the fraction of time. This class provides + several methods to get different representations of this data for + further analysis. + + In general, instances of this class shouldn't be created by a user using + ``__init__``; instead, they will be returned by other methods. So users + will often need to use this object for analysis. + + Parameters + ---------- + counter : :class:`collections.Counter` + the counter describing the count of how often the contact occurred; + key is a frozenset of a pair of numbers (identifying the + atoms/residues); value is the raw count of the number of times it + occurred + object_f : callable + method to obtain the object associated with the number used in + ``counter``; typically :meth:`mdtraj.Topology.residue` or + :meth:`mdtraj.Topology.atom`. + n_x : int + number of objects in the x direction (used in plotting) + n_y : int + number of objects in the y direction (used in plotting) + """ + def __init__(self, counter, object_f, n_x, n_y): + self._counter = counter + self._object_f = object_f + self.n_x = n_x + self.n_y = n_y + + @property + def counter(self): + """ + :class:`collections.Counter` : + keys use index number; count is contact occurrences + """ + return self._counter + + @property + def sparse_matrix(self): + """ + :class:`scipy.sparse.dok.dok_matrix` : + sparse matrix representation of contacts + + Rows/columns correspond to indices and the values correspond to + the count + """ + mtx = scipy.sparse.dok_matrix((self.n_x, self.n_y)) + for (k, v) in self._counter.items(): + key = list(k) + mtx[key[0], key[1]] = v + mtx[key[1], key[0]] = v + return mtx + + @property + def df(self): + """ + :class:`pandas.SparseDataFrame` : + DataFrame representation of the contact matrix + + Rows/columns correspond to indices and the values correspond to + the count + """ + mtx = self.sparse_matrix.tocoo() + index = list(range(self.n_x)) + columns = list(range(self.n_y)) + return pd.SparseDataFrame(mtx, index=index, columns=columns) + + def plot(self, cmap='seismic', vmin=-1.0, vmax=1.0, with_colorbar=True): + """ + Plot contact matrix (requires matplotlib) + + Parameters + ---------- + cmap : str + color map name, default 'seismic' + vmin : float + minimum value for color map interpolation; default -1.0 + vmax : float + maximum value for color map interpolation; default 1.0 + + Returns + ------- + fig : :class:`matplotlib.Figure` + matplotlib figure object for this plot + ax : :class:`matplotlib.Axes` + matplotlib axes object for this plot + """ + if not HAS_MATPLOTLIB: # pragma: no cover + raise RuntimeError("Error importing matplotlib") + norm = matplotlib.colors.Normalize(vmin=vmin, vmax=vmax) + cmap_f = plt.get_cmap(cmap) + + fig, ax = plt.subplots() + ax.axis([0, self.n_x, 0, self.n_y]) + ax.set_facecolor(cmap_f(norm(0.0))) + + min_val = 0.0 + + for (pair, value) in self.counter.items(): + if value < min_val: + min_val = value + pair_list = list(pair) + patch_0 = matplotlib.patches.Rectangle( + pair_list, 1, 1, + facecolor=cmap_f(norm(value)), + linewidth=0 + ) + patch_1 = matplotlib.patches.Rectangle( + (pair_list[1], pair_list[0]), 1, 1, + facecolor=cmap_f(norm(value)), + linewidth=0 + ) + ax.add_patch(patch_0) + ax.add_patch(patch_1) + + _colorbar(with_colorbar, cmap_f, norm, min_val) + + return (fig, ax) + + def most_common(self, obj=None): + """ + Most common values (ordered) with object as keys. + + This uses the objects for the contact pair (typically MDTraj + ``Atom`` or ``Residue`` objects), instead of numeric indices. This + is more readable and can be easily used for further manipulation. + + Parameters + ---------- + obj : MDTraj Atom or Residue + if given, the return value only has entries including this + object (allowing one to, for example, get the most common + contacts with a specific residue) + + Returns + ------- + list : + the most common contacts in order. If the list is ``l``, then + each element ``l[e]`` is a tuple with two parts: ``l[e][0]`` is + the key, which is a pair of Atom or Residue objects, and + ``l[e][1]`` is the count of how often that contact occurred. + + See also + -------- + most_common_idx : same thing, using index numbers as key + """ + if obj is None: + result = [ + ([self._object_f(idx) for idx in common[0]], common[1]) + for common in self.most_common_idx() + ] + else: + obj_idx = obj.index + result = [ + ([self._object_f(idx) for idx in common[0]], common[1]) + for common in self.most_common_idx() + if obj_idx in common[0] + ] + return result + + def most_common_idx(self): + """ + Most common values (ordered) with indices as keys. + + Returns + ------- + list : + the most common contacts in order. The if the list is ``l``, + then each element ``l[e]`` consists of two parts: ``l[e][0]`` is + a pair of integers, representing the indices of the objects + associated with the contact, and ``l[e][1]`` is the count of how + often that contact occurred + + See also + -------- + most_common : same thing, using objects as key + """ + return self._counter.most_common() diff --git a/contact_map/contact_map.py b/contact_map/contact_map.py index 247fcfc..005d938 100644 --- a/contact_map/contact_map.py +++ b/contact_map/contact_map.py @@ -7,23 +7,14 @@ import itertools import pickle import json -import scipy import numpy as np import pandas as pd import mdtraj as md +from .contact_count import ContactCount from .plot_utils import ranged_colorbar from .py_2_3 import inspect_method_arguments -# matplotlib is technically optional, but required for plotting -try: - import matplotlib - import matplotlib.pyplot as plt -except ImportError: - HAS_MATPLOTLIB = False -else: - HAS_MATPLOTLIB = True - # TODO: # * switch to something where you can define the haystack -- the trick is to # replace the current mdtraj._compute_neighbors with something that @@ -61,204 +52,6 @@ def _residue_and_index(residue, topology): res = topology.residue(res_idx) return (res, res_idx) -def _colorbar(with_colorbar, cmap_f, norm, min_val): - if with_colorbar is False: - return None - elif with_colorbar is True: - cbmin = np.floor(min_val) # [-1.0..0.0] => -1; [0.0..1.0] => 0 - cbmax = 1.0 - cb = ranged_colorbar(cmap_f, norm, cbmin, cbmax) - # leave open other inputs to be parsed later (like tuples) - return cb - - -class ContactCount(object): - """Return object when dealing with contacts (residue or atom). - - This contains all the information about the contacts of a given type. - This information can be represented several ways. One is as a list of - contact pairs, each associated with the fraction of time the contact - occurs. Another is as a matrix, where the rows and columns label the - pair number, and the value is the fraction of time. This class provides - several methods to get different representations of this data for - further analysis. - - In general, instances of this class shouldn't be created by a user using - ``__init__``; instead, they will be returned by other methods. So users - will often need to use this object for analysis. - - Parameters - ---------- - counter : :class:`collections.Counter` - the counter describing the count of how often the contact occurred; - key is a frozenset of a pair of numbers (identifying the - atoms/residues); value is the raw count of the number of times it - occurred - object_f : callable - method to obtain the object associated with the number used in - ``counter``; typically :meth:`mdtraj.Topology.residue` or - :meth:`mdtraj.Topology.atom`. - n_x : int - number of objects in the x direction (used in plotting) - n_y : int - number of objects in the y direction (used in plotting) - """ - def __init__(self, counter, object_f, n_x, n_y): - self._counter = counter - self._object_f = object_f - self.n_x = n_x - self.n_y = n_y - - @property - def counter(self): - """ - :class:`collections.Counter` : - keys use index number; count is contact occurrences - """ - return self._counter - - @property - def sparse_matrix(self): - """ - :class:`scipy.sparse.dok.dok_matrix` : - sparse matrix representation of contacts - - Rows/columns correspond to indices and the values correspond to - the count - """ - mtx = scipy.sparse.dok_matrix((self.n_x, self.n_y)) - for (k, v) in self._counter.items(): - key = list(k) - mtx[key[0], key[1]] = v - mtx[key[1], key[0]] = v - return mtx - - @property - def df(self): - """ - :class:`pandas.SparseDataFrame` : - DataFrame representation of the contact matrix - - Rows/columns correspond to indices and the values correspond to - the count - """ - mtx = self.sparse_matrix.tocoo() - index = list(range(self.n_x)) - columns = list(range(self.n_y)) - return pd.SparseDataFrame(mtx, index=index, columns=columns) - - def plot(self, cmap='seismic', vmin=-1.0, vmax=1.0, with_colorbar=True): - """ - Plot contact matrix (requires matplotlib) - - Parameters - ---------- - cmap : str - color map name, default 'seismic' - vmin : float - minimum value for color map interpolation; default -1.0 - vmax : float - maximum value for color map interpolation; default 1.0 - - Returns - ------- - fig : :class:`matplotlib.Figure` - matplotlib figure object for this plot - ax : :class:`matplotlib.Axes` - matplotlib axes object for this plot - """ - if not HAS_MATPLOTLIB: # pragma: no cover - raise RuntimeError("Error importing matplotlib") - norm = matplotlib.colors.Normalize(vmin=vmin, vmax=vmax) - cmap_f = plt.get_cmap(cmap) - - fig, ax = plt.subplots() - ax.axis([0, self.n_x, 0, self.n_y]) - ax.set_facecolor(cmap_f(norm(0.0))) - - min_val = 0.0 - - for (pair, value) in self.counter.items(): - if value < min_val: - min_val = value - pair_list = list(pair) - patch_0 = matplotlib.patches.Rectangle( - pair_list, 1, 1, - facecolor=cmap_f(norm(value)), - linewidth=0 - ) - patch_1 = matplotlib.patches.Rectangle( - (pair_list[1], pair_list[0]), 1, 1, - facecolor=cmap_f(norm(value)), - linewidth=0 - ) - ax.add_patch(patch_0) - ax.add_patch(patch_1) - - _colorbar(with_colorbar, cmap_f, norm, min_val) - - return (fig, ax) - - def most_common(self, obj=None): - """ - Most common values (ordered) with object as keys. - - This uses the objects for the contact pair (typically MDTraj - ``Atom`` or ``Residue`` objects), instead of numeric indices. This - is more readable and can be easily used for further manipulation. - - Parameters - ---------- - obj : MDTraj Atom or Residue - if given, the return value only has entries including this - object (allowing one to, for example, get the most common - contacts with a specific residue) - - Returns - ------- - list : - the most common contacts in order. If the list is ``l``, then - each element ``l[e]`` is a tuple with two parts: ``l[e][0]`` is - the key, which is a pair of Atom or Residue objects, and - ``l[e][1]`` is the count of how often that contact occurred. - - See also - -------- - most_common_idx : same thing, using index numbers as key - """ - if obj is None: - result = [ - ([self._object_f(idx) for idx in common[0]], common[1]) - for common in self.most_common_idx() - ] - else: - obj_idx = obj.index - result = [ - ([self._object_f(idx) for idx in common[0]], common[1]) - for common in self.most_common_idx() - if obj_idx in common[0] - ] - return result - - def most_common_idx(self): - """ - Most common values (ordered) with indices as keys. - - Returns - ------- - list : - the most common contacts in order. The if the list is ``l``, - then each element ``l[e]`` consists of two parts: ``l[e][0]`` is - a pair of integers, representing the indices of the objects - associated with the contact, and ``l[e][1]`` is the count of how - often that contact occurred - - See also - -------- - most_common : same thing, using objects as key - """ - return self._counter.most_common() - class ContactObject(object): """ diff --git a/contact_map/tests/test_contact_count.py b/contact_map/tests/test_contact_count.py new file mode 100644 index 0000000..12d75cb --- /dev/null +++ b/contact_map/tests/test_contact_count.py @@ -0,0 +1,141 @@ +import numpy as np + +# pylint: disable=wildcard-import, missing-docstring, protected-access +# pylint: disable=attribute-defined-outside-init, invalid-name, no-self-use +# pylint: disable=wrong-import-order, unused-wildcard-import + +# includes pytest +from .utils import * +from contact_map.contact_map import ContactFrequency +from .test_contact_map import (traj, traj_atom_contact_count, + traj_residue_contact_count, + check_most_common_order) + +from contact_map.contact_count import * + +class TestContactCount(object): + def setup(self): + self.map = ContactFrequency(traj, cutoff=0.075, + n_neighbors_ignored=0) + self.topology = self.map.topology + self.atom_contacts = self.map.atom_contacts + self.residue_contacts = self.map.residue_contacts + + self.atom_matrix = np.array([ + # 0 1 2 3 4 5 6 7 8 9 + [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.2, 0.2], # 0 + [0.0, 0.0, 0.0, 0.0, 0.8, 0.2, 0.0, 0.0, 0.2, 0.2], # 1 + [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], # 2 + [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], # 3 + [0.0, 0.8, 0.0, 0.0, 0.0, 0.0, 1.0, 0.4, 0.2, 0.0], # 4 + [0.0, 0.2, 0.0, 0.0, 0.0, 0.0, 1.0, 0.4, 0.2, 0.0], # 5 + [0.0, 0.0, 0.0, 0.0, 1.0, 1.0, 0.0, 0.0, 0.0, 0.0], # 6 + [0.0, 0.0, 0.0, 0.0, 0.4, 0.4, 0.0, 0.0, 0.0, 0.0], # 7 + [0.2, 0.2, 0.0, 0.0, 0.2, 0.2, 0.0, 0.0, 0.0, 0.0], # 8 + [0.2, 0.2, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0] # 9 + ]) + self.residue_matrix = np.array([ + # 0 1 2 3 4 + [0.0, 0.0, 1.0, 0.0, 0.2], # 0 + [0.0, 0.0, 0.0, 0.0, 0.0], # 1 + [1.0, 0.0, 0.0, 1.0, 0.2], # 2 + [0.0, 0.0, 1.0, 0.0, 0.0], # 3 + [0.2, 0.0, 0.2, 0.0, 0.0] # 4 + ]) + + # HAS_MATPLOTLIB imported by contact_map wildcard + @pytest.mark.skipif(not HAS_MATPLOTLIB, reason="Missing matplotlib") + def test_plot(self): + # purely smoke tests + self.residue_contacts.plot() + self.atom_contacts.plot() + self.residue_contacts.plot(with_colorbar=False) + + def test_initialization(self): + assert self.atom_contacts._object_f == self.topology.atom + assert self.atom_contacts.n_x == self.topology.n_atoms + assert self.atom_contacts.n_y == self.topology.n_atoms + assert self.residue_contacts._object_f == self.topology.residue + assert self.residue_contacts.n_x == self.topology.n_residues + assert self.residue_contacts.n_y == self.topology.n_residues + + def test_sparse_matrix(self): + assert_array_equal(self.map.atom_contacts.sparse_matrix.todense(), + self.atom_matrix) + assert_array_equal(self.map.residue_contacts.sparse_matrix.todense(), + self.residue_matrix) + + def test_df(self): + atom_df = self.map.atom_contacts.df + residue_df = self.map.residue_contacts.df + assert isinstance(atom_df, pd.SparseDataFrame) + assert isinstance(residue_df, pd.SparseDataFrame) + + assert_array_equal(atom_df.to_dense().as_matrix(), + zero_to_nan(self.atom_matrix)) + assert_array_equal(residue_df.to_dense().as_matrix(), + zero_to_nan(self.residue_matrix)) + + @pytest.mark.parametrize("obj_type", ['atom', 'res']) + def test_most_common(self, obj_type): + if obj_type == 'atom': + source_expected = traj_atom_contact_count + contacts = self.map.atom_contacts + obj_func = self.topology.atom + elif obj_type == 'res': + source_expected = traj_residue_contact_count + contacts = self.map.residue_contacts + obj_func = self.topology.residue + else: + raise RuntimeError("This shouldn't happen") + + expected = [ + (frozenset([obj_func(idx) for idx in ll[0]]), float(ll[1]) / 5.0) + for ll in source_expected.items() + ] + + most_common = contacts.most_common() + cleaned = [(frozenset(ll[0]), ll[1]) for ll in most_common] + + check_most_common_order(most_common) + assert set(cleaned) == set(expected) + + @pytest.mark.parametrize("obj_type", ['atom', 'res']) + def test_most_common_with_object(self, obj_type): + top = self.topology + if obj_type == 'atom': + contacts = self.map.atom_contacts + obj = top.atom(4) + expected = [(frozenset([obj, top.atom(6)]), 1.0), + (frozenset([obj, top.atom(1)]), 0.8), + (frozenset([obj, top.atom(7)]), 0.4), + (frozenset([obj, top.atom(8)]), 0.2)] + elif obj_type == 'res': + contacts = self.map.residue_contacts + obj = self.topology.residue(2) + expected = [(frozenset([obj, top.residue(0)]), 1.0), + (frozenset([obj, top.residue(3)]), 1.0), + (frozenset([obj, top.residue(4)]), 0.2)] + else: + raise RuntimeError("This shouldn't happen") + + most_common = contacts.most_common(obj) + cleaned = [(frozenset(ll[0]), ll[1]) for ll in most_common] + + check_most_common_order(most_common) + assert set(cleaned) == set(expected) + + @pytest.mark.parametrize("obj_type", ['atom', 'res']) + def test_most_common_idx(self, obj_type): + if obj_type == 'atom': + source_expected = traj_atom_contact_count + contacts = self.map.atom_contacts + elif obj_type == 'res': + source_expected = traj_residue_contact_count + contacts = self.map.residue_contacts + else: + raise RuntimeError("This shouldn't happen") + + expected_count = [(ll[0], float(ll[1]) / 5.0) + for ll in source_expected.items()] + assert set(contacts.most_common_idx()) == set(expected_count) diff --git a/contact_map/tests/test_contact_map.py b/contact_map/tests/test_contact_map.py index 013f010..99af6cd 100644 --- a/contact_map/tests/test_contact_map.py +++ b/contact_map/tests/test_contact_map.py @@ -1,8 +1,7 @@ import os import collections -import numpy as np -import mdtraj as md import json +import mdtraj as md # pylint: disable=wildcard-import, missing-docstring, protected-access # pylint: disable=attribute-defined-outside-init, invalid-name, no-self-use @@ -13,6 +12,7 @@ # stuff to be testing in this file from contact_map.contact_map import * +from contact_map.contact_count import ContactCount, HAS_MATPLOTLIB traj = md.load(find_testfile("trajectory.pdb")) @@ -169,8 +169,8 @@ def test_dict_serialization_cycle(self, idx): def test_json_serialization_cycle(self, idx): m = self.maps[idx] - json = m.to_json() - m2 = ContactMap.from_json(json) + json_str = m.to_json() + m2 = ContactMap.from_json(json_str) _contact_object_compare(m, m2) assert m == m2 @@ -436,134 +436,6 @@ def test_subtract_contact_frequency(self): last_frame.residue_contacts.counter -class TestContactCount(object): - def setup(self): - self.map = ContactFrequency(traj, cutoff=0.075, - n_neighbors_ignored=0) - self.topology = self.map.topology - self.atom_contacts = self.map.atom_contacts - self.residue_contacts = self.map.residue_contacts - - self.atom_matrix = np.array([ - # 0 1 2 3 4 5 6 7 8 9 - [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.2, 0.2], # 0 - [0.0, 0.0, 0.0, 0.0, 0.8, 0.2, 0.0, 0.0, 0.2, 0.2], # 1 - [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], # 2 - [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], # 3 - [0.0, 0.8, 0.0, 0.0, 0.0, 0.0, 1.0, 0.4, 0.2, 0.0], # 4 - [0.0, 0.2, 0.0, 0.0, 0.0, 0.0, 1.0, 0.4, 0.2, 0.0], # 5 - [0.0, 0.0, 0.0, 0.0, 1.0, 1.0, 0.0, 0.0, 0.0, 0.0], # 6 - [0.0, 0.0, 0.0, 0.0, 0.4, 0.4, 0.0, 0.0, 0.0, 0.0], # 7 - [0.2, 0.2, 0.0, 0.0, 0.2, 0.2, 0.0, 0.0, 0.0, 0.0], # 8 - [0.2, 0.2, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0] # 9 - ]) - self.residue_matrix = np.array([ - # 0 1 2 3 4 - [0.0, 0.0, 1.0, 0.0, 0.2], # 0 - [0.0, 0.0, 0.0, 0.0, 0.0], # 1 - [1.0, 0.0, 0.0, 1.0, 0.2], # 2 - [0.0, 0.0, 1.0, 0.0, 0.0], # 3 - [0.2, 0.0, 0.2, 0.0, 0.0] # 4 - ]) - - # HAS_MATPLOTLIB imported by contact_map wildcard - @pytest.mark.skipif(not HAS_MATPLOTLIB, reason="Missing matplotlib") - def test_plot(self): - # purely smoke tests - self.residue_contacts.plot() - self.atom_contacts.plot() - self.residue_contacts.plot(with_colorbar=False) - - def test_initialization(self): - assert self.atom_contacts._object_f == self.topology.atom - assert self.atom_contacts.n_x == self.topology.n_atoms - assert self.atom_contacts.n_y == self.topology.n_atoms - assert self.residue_contacts._object_f == self.topology.residue - assert self.residue_contacts.n_x == self.topology.n_residues - assert self.residue_contacts.n_y == self.topology.n_residues - - def test_sparse_matrix(self): - assert_array_equal(self.map.atom_contacts.sparse_matrix.todense(), - self.atom_matrix) - assert_array_equal(self.map.residue_contacts.sparse_matrix.todense(), - self.residue_matrix) - - def test_df(self): - atom_df = self.map.atom_contacts.df - residue_df = self.map.residue_contacts.df - assert isinstance(atom_df, pd.SparseDataFrame) - assert isinstance(residue_df, pd.SparseDataFrame) - - assert_array_equal(atom_df.to_dense().as_matrix(), - zero_to_nan(self.atom_matrix)) - assert_array_equal(residue_df.to_dense().as_matrix(), - zero_to_nan(self.residue_matrix)) - - @pytest.mark.parametrize("obj_type", ['atom', 'res']) - def test_most_common(self, obj_type): - if obj_type == 'atom': - source_expected = traj_atom_contact_count - contacts = self.map.atom_contacts - obj_func = self.topology.atom - elif obj_type == 'res': - source_expected = traj_residue_contact_count - contacts = self.map.residue_contacts - obj_func = self.topology.residue - else: - raise RuntimeError("This shouldn't happen") - - expected = [ - (frozenset([obj_func(idx) for idx in ll[0]]), float(ll[1]) / 5.0) - for ll in source_expected.items() - ] - - most_common = contacts.most_common() - cleaned = [(frozenset(ll[0]), ll[1]) for ll in most_common] - - check_most_common_order(most_common) - assert set(cleaned) == set(expected) - - @pytest.mark.parametrize("obj_type", ['atom', 'res']) - def test_most_common_with_object(self, obj_type): - top = self.topology - if obj_type == 'atom': - contacts = self.map.atom_contacts - obj = top.atom(4) - expected = [(frozenset([obj, top.atom(6)]), 1.0), - (frozenset([obj, top.atom(1)]), 0.8), - (frozenset([obj, top.atom(7)]), 0.4), - (frozenset([obj, top.atom(8)]), 0.2)] - elif obj_type == 'res': - contacts = self.map.residue_contacts - obj = self.topology.residue(2) - expected = [(frozenset([obj, top.residue(0)]), 1.0), - (frozenset([obj, top.residue(3)]), 1.0), - (frozenset([obj, top.residue(4)]), 0.2)] - else: - raise RuntimeError("This shouldn't happen") - - most_common = contacts.most_common(obj) - cleaned = [(frozenset(ll[0]), ll[1]) for ll in most_common] - - check_most_common_order(most_common) - assert set(cleaned) == set(expected) - - @pytest.mark.parametrize("obj_type", ['atom', 'res']) - def test_most_common_idx(self, obj_type): - if obj_type == 'atom': - source_expected = traj_atom_contact_count - contacts = self.map.atom_contacts - elif obj_type == 'res': - source_expected = traj_residue_contact_count - contacts = self.map.residue_contacts - else: - raise RuntimeError("This shouldn't happen") - - expected_count = [(ll[0], float(ll[1]) / 5.0) - for ll in source_expected.items()] - assert set(contacts.most_common_idx()) == set(expected_count) - - class TestContactDifference(object): def test_diff_traj_frame(self): ttraj = ContactFrequency(traj[0:4], cutoff=0.075, @@ -706,4 +578,3 @@ def test_plot(self): frame = ContactMap(traj[4], cutoff=0.075, n_neighbors_ignored=0) diff = ttraj - frame diff.residue_contacts.plot() - diff --git a/docs/api.rst b/docs/api.rst index d0c5df2..61376fb 100644 --- a/docs/api.rst +++ b/docs/api.rst @@ -37,7 +37,7 @@ Parallelization of ``ContactFrequency`` :toctree: api/generated/ frequency_task - dask_runner + DaskContactFrequency ----- @@ -81,7 +81,7 @@ Most common ~~~~~~~~~~~ Several methods begin with ``most_common``. The behavior for this is -inspired by the behavior of :method:`collections.Counter.most_common`, which +inspired by the behavior of :meth:`collections.Counter.most_common`, which returns elements and there counts ordered from most to least. Note that, unlike the original, we usually do not implement a way to only return the first ``n`` results (although this may be added later).