Skip to content

Commit

Permalink
Merge pull request #30 from dwhswenson/dask
Browse files Browse the repository at this point in the history
Parallelization of ContactFrequency with Dask.distributed
  • Loading branch information
dwhswenson authored Jan 20, 2018
2 parents 7e03cb4 + 4250e6b commit 991451c
Show file tree
Hide file tree
Showing 10 changed files with 588 additions and 0 deletions.
2 changes: 2 additions & 0 deletions contact_map/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,4 +11,6 @@

from .min_dist import NearestAtoms, MinimumDistanceCounter

from .dask_runner import DaskContactFrequency

# import concurrence
17 changes: 17 additions & 0 deletions contact_map/contact_map.py
Original file line number Diff line number Diff line change
Expand Up @@ -724,6 +724,23 @@ def __init__(self, trajectory, query=None, haystack=None, cutoff=0.45,
contacts = self._build_contact_map(trajectory)
(self._atom_contacts, self._residue_contacts) = contacts

def __hash__(self):
return hash((super(ContactFrequency, self).__hash__(),
tuple(self._atom_contacts.items()),
tuple(self._residue_contacts.items()),
self.n_frames))

def __eq__(self, other):
is_equal = (super(ContactFrequency, self).__eq__(other)
and self._atom_contacts == other._atom_contacts
and self._residue_contacts == other._residue_contacts
and self.n_frames == other.n_frames)
return is_equal

def to_dict(self):
dct = super(ContactFrequency, self).to_dict()
dct.update({'n_frames': self.n_frames})
return dct

def _build_contact_map(self, trajectory):
# We actually build the contact map on a per-residue basis, although
Expand Down
73 changes: 73 additions & 0 deletions contact_map/dask_runner.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,73 @@
"""
Implementation of ContactFrequency parallelization using dask.distributed
"""

from . import frequency_task
from .contact_map import ContactFrequency, ContactObject
import mdtraj as md


def dask_run(trajectory, client, run_info):
"""
Runs dask version of ContactFrequency. Note that this API on this will
definitely change before the release.
Parameters
----------
trajectory : mdtraj.trajectory
client : dask.distributed.Client
path to dask scheduler file
run_info : dict
keys are 'trajectory_file' (trajectory filename), 'load_kwargs'
(additional kwargs passed to md.load), and 'parameters' (dict of
kwargs for the ContactFrequency object)
Returns
-------
:class:`.ContactFrequency` :
total contact frequency for the trajectory
"""
slices = frequency_task.default_slices(n_total=len(trajectory),
n_workers=len(client.ncores()))

subtrajs = client.map(frequency_task.load_trajectory_task, slices,
file_name=run_info['trajectory_file'],
**run_info['load_kwargs'])
maps = client.map(frequency_task.map_task, subtrajs,
parameters=run_info['parameters'])
freq = client.submit(frequency_task.reduce_all_results, maps)

return freq.result()

class DaskContactFrequency(ContactFrequency):
def __init__(self, client, filename, query=None, haystack=None,
cutoff=0.45, n_neighbors_ignored=2, **kwargs):
self.client = client
self.filename = filename
trajectory = md.load(filename, **kwargs)

self.frames = range(len(trajectory))
self.kwargs = kwargs

ContactObject.__init__(self, trajectory.topology, query, haystack,
cutoff, n_neighbors_ignored)

freq = dask_run(trajectory, client, self.run_info)
self._n_frames = freq.n_frames
self._atom_contacts = freq._atom_contacts
self._residue_contacts = freq._residue_contacts

@property
def parameters(self):
return {'query': self.query,
'haystack': self.haystack,
'cutoff': self.cutoff,
'n_neighbors_ignored': self.n_neighbors_ignored}

@property
def run_info(self):
return {'parameters': self.parameters,
'trajectory_file': self.filename,
'load_kwargs': self.kwargs}


132 changes: 132 additions & 0 deletions contact_map/frequency_task.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,132 @@
"""
Task-based implementation of :class:`.ContactFrequency`.
The overall algorithm is:
1. Identify how we're going to slice up the trajectory into task-based
chunks (:meth:`block_slices`, :meth:`default_slices`)
2. On each node
a. Load the trajectory segment (:meth:`load_trajectory_task`)
b. Run the analysis on the segment (:meth:`map_task`)
3. Once all the results have been collected, combine them
(:meth:`reduce_all_results`)
Notes
-----
Includes versions where messages are Python objects and versions (labelled
with _json) where messages have been JSON-serialized. However, we don't yet
have a solution for JSON serialization of MDTraj objects, so if JSON
serialization is the communication method, the loading of the trajectory and
the calculation of the contacts must be combined into a single task.
"""

import mdtraj as md
from contact_map import ContactFrequency

def block_slices(n_total, n_per_block):
"""Determine slices for splitting the input array.
Parameters
----------
n_total : int
total length of array
n_per_block : int
maximum number of items per block
Returns
-------
list of slice
slices to be applied to the array
"""
n_full_blocks = n_total // n_per_block
slices = [slice(i*n_per_block, (i+1)*n_per_block)
for i in range(n_full_blocks)]
if n_total % n_per_block:
slices.append(slice(n_full_blocks*n_per_block, n_total))
return slices

def default_slices(n_total, n_workers):
"""Calculate default slices from number of workers.
Default behavior is (approximately) one task per worker.
Parameters
----------
n_total : int
total number of items in array
n_workers : int
number of workers
Returns
-------
list of slice
slices to be applied to the array
"""
n_frames_per_task = max(1, n_total // n_workers)
return block_slices(n_total, n_frames_per_task)


def load_trajectory_task(subslice, file_name, **kwargs):
"""
Task for loading file. Reordered for to take per-task variable first.
Parameters
----------
subslice : slice
the slice of the trajectory to use
file_name : str
trajectory file name
kwargs :
other parameters to mdtraj.load
Returns
-------
md.Trajectory :
subtrajectory for this slice
"""
return md.load(file_name, **kwargs)[subslice]

def map_task(subtrajectory, parameters):
"""Task to be mapped to all subtrajectories. Run ContactFrequency
Parameters
----------
subtrajectory : mdtraj.Trajectory
single trajectory segment to calculate ContactFrequency for
parameters : dict
kwargs-style dict for the :class:`.ContactFrequency` object
Returns
-------
:class:`.ContactFrequency` :
contact frequency for the subtrajectory
"""
return ContactFrequency(subtrajectory, **parameters)

def reduce_all_results(contacts):
"""Combine multiple :class:`.ContactFrequency` objects into one
Parameters
----------
contacts : iterable of :class:`.ContactFrequency`
the individual (partial) contact frequencies
Returns
-------
:class:`.ContactFrequency` :
total of all input contact frequencies (summing them)
"""
accumulator = contacts[0]
for contact in contacts[1:]:
accumulator.add_contact_frequency(contact)
return accumulator


def map_task_json(subtrajectory, parameters):
"""JSON-serialized version of :meth:`map_task`"""
return map_task(subtrajectory, parameters).to_json()

def reduce_all_results_json(results_of_map):
"""JSON-serialized version of :meth:`reduce_all_results`"""
contacts = [ContactFrequency.from_json(res) for res in results_of_map]
return reduce_all_results(contacts)
11 changes: 11 additions & 0 deletions contact_map/tests/test_contact_map.py
Original file line number Diff line number Diff line change
Expand Up @@ -329,6 +329,17 @@ def test_frames_parameter(self):
}
assert contacts.residue_contacts.counter == expected_residue_count

def test_hash(self):
map2 = ContactFrequency(trajectory=traj,
cutoff=0.075,
n_neighbors_ignored=0)
map3 = ContactFrequency(trajectory=traj[:2],
cutoff=0.075,
n_neighbors_ignored=0)

assert hash(self.map) == hash(map2)
assert hash(self.map) != hash(map3)

def test_saving(self):
m = self.map
m.save_to_file(test_file)
Expand Down
17 changes: 17 additions & 0 deletions contact_map/tests/test_dask_runner.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
from .utils import *

from contact_map.dask_runner import *

class TestDaskContactFrequency(object):
def test_dask_integration(self):
# this is an integration test to check that dask works
dask = pytest.importorskip('dask')
distributed = pytest.importorskip('dask.distributed')

client = distributed.Client()
filename = find_testfile("trajectory.pdb")

dask_freq = DaskContactFrequency(client, filename, cutoff=0.075,
n_neighbors_ignored=0)
client.close()
assert dask_freq.n_frames == 5
93 changes: 93 additions & 0 deletions contact_map/tests/test_frequency_task.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,93 @@
import os
import collections

from .utils import *
from .test_contact_map import traj

from contact_map.frequency_task import *
from contact_map import ContactFrequency

class TestSlicing(object):
# tests for block_slices and default_slices
@pytest.mark.parametrize("inputs, results", [
((100, 25),
[slice(0, 25), slice(25, 50), slice(50, 75), slice(75, 100)]),
((85, 25),
[slice(0, 25), slice(25, 50), slice(50, 75), slice(75, 85)])
])
def test_block_slices(self, inputs, results):
n_total, n_per_block = inputs
assert block_slices(n_total, n_per_block) == results

@pytest.mark.parametrize("inputs, results", [
((100, 4),
[slice(0, 25), slice(25, 50), slice(50, 75), slice(75, 100)]),
((77, 3),
[slice(0, 25), slice(25, 50), slice(50, 75), slice(75, 77)]),
((2, 20),
[slice(0, 1), slice(1, 2)])
])
def test_default_slice_even_split(self, inputs, results):
n_total, n_workers = inputs
assert default_slices(n_total, n_workers) == results

class TestTasks(object):
def setup(self):
self.contact_freq_0_4 = ContactFrequency(traj, cutoff=0.075,
n_neighbors_ignored=0,
frames=range(4))
self.contact_freq_4 = ContactFrequency(traj, cutoff=0.075,
n_neighbors_ignored=0,
frames=[4])
self.total_contact_freq = ContactFrequency(traj, cutoff=0.075,
n_neighbors_ignored=0)
self.parameters = {'cutoff': 0.075, 'n_neighbors_ignored': 0}

def test_load_trajectory_task(self):
subslice = slice(0, 4)
file_name = find_testfile("trajectory.pdb")
trajectory = load_trajectory_task(subslice, file_name)
assert trajectory.xyz.shape == (4, 10, 3)

def test_map_task(self):
trajectory = traj[:4]
mapped = map_task(trajectory, parameters=self.parameters)
assert mapped == self.contact_freq_0_4

def test_reduce_task(self):
reduced = reduce_all_results([self.contact_freq_0_4,
self.contact_freq_4])
assert reduced == self.total_contact_freq

def test_map_task_json(self):
# check the json objects by converting them back to full objects
trajectory = traj[:4]
mapped = map_task_json(trajectory, parameters=self.parameters)
assert ContactFrequency.from_json(mapped) == self.contact_freq_0_4

def test_reduce_all_results_json(self):
reduced = reduce_all_results_json([self.contact_freq_0_4.to_json(),
self.contact_freq_4.to_json()])
assert reduced == self.total_contact_freq

def test_integration_object_based(self):
file_name = find_testfile("trajectory.pdb")
slices = default_slices(len(traj), n_workers=3)
trajs = [load_trajectory_task(subslice=sl,
file_name=file_name)
for sl in slices]
mapped = [map_task(subtraj, self.parameters) for subtraj in trajs]
result = reduce_all_results(mapped)
assert result == self.total_contact_freq

def test_integration_json_based(self):
file_name = find_testfile("trajectory.pdb")
slices = default_slices(len(traj), n_workers=3)
trajs = [load_trajectory_task(subslice=sl,
file_name=file_name)
for sl in slices]
mapped = [map_task_json(subtraj, self.parameters)
for subtraj in trajs]
result = reduce_all_results_json(mapped)
assert result == self.total_contact_freq

9 changes: 9 additions & 0 deletions docs/api.rst
Original file line number Diff line number Diff line change
Expand Up @@ -28,3 +28,12 @@ Minimum Distance (and related)

MinimumDistanceCounter
NearestAtoms

Parallelization of ``ContactFrequency``
---------------------------------------

.. autosummary::
:toctree: api/generated/

frequency_task
dask_runner
Loading

0 comments on commit 991451c

Please sign in to comment.