-
Notifications
You must be signed in to change notification settings - Fork 18
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #30 from dwhswenson/dask
Parallelization of ContactFrequency with Dask.distributed
- Loading branch information
Showing
10 changed files
with
588 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,73 @@ | ||
""" | ||
Implementation of ContactFrequency parallelization using dask.distributed | ||
""" | ||
|
||
from . import frequency_task | ||
from .contact_map import ContactFrequency, ContactObject | ||
import mdtraj as md | ||
|
||
|
||
def dask_run(trajectory, client, run_info): | ||
""" | ||
Runs dask version of ContactFrequency. Note that this API on this will | ||
definitely change before the release. | ||
Parameters | ||
---------- | ||
trajectory : mdtraj.trajectory | ||
client : dask.distributed.Client | ||
path to dask scheduler file | ||
run_info : dict | ||
keys are 'trajectory_file' (trajectory filename), 'load_kwargs' | ||
(additional kwargs passed to md.load), and 'parameters' (dict of | ||
kwargs for the ContactFrequency object) | ||
Returns | ||
------- | ||
:class:`.ContactFrequency` : | ||
total contact frequency for the trajectory | ||
""" | ||
slices = frequency_task.default_slices(n_total=len(trajectory), | ||
n_workers=len(client.ncores())) | ||
|
||
subtrajs = client.map(frequency_task.load_trajectory_task, slices, | ||
file_name=run_info['trajectory_file'], | ||
**run_info['load_kwargs']) | ||
maps = client.map(frequency_task.map_task, subtrajs, | ||
parameters=run_info['parameters']) | ||
freq = client.submit(frequency_task.reduce_all_results, maps) | ||
|
||
return freq.result() | ||
|
||
class DaskContactFrequency(ContactFrequency): | ||
def __init__(self, client, filename, query=None, haystack=None, | ||
cutoff=0.45, n_neighbors_ignored=2, **kwargs): | ||
self.client = client | ||
self.filename = filename | ||
trajectory = md.load(filename, **kwargs) | ||
|
||
self.frames = range(len(trajectory)) | ||
self.kwargs = kwargs | ||
|
||
ContactObject.__init__(self, trajectory.topology, query, haystack, | ||
cutoff, n_neighbors_ignored) | ||
|
||
freq = dask_run(trajectory, client, self.run_info) | ||
self._n_frames = freq.n_frames | ||
self._atom_contacts = freq._atom_contacts | ||
self._residue_contacts = freq._residue_contacts | ||
|
||
@property | ||
def parameters(self): | ||
return {'query': self.query, | ||
'haystack': self.haystack, | ||
'cutoff': self.cutoff, | ||
'n_neighbors_ignored': self.n_neighbors_ignored} | ||
|
||
@property | ||
def run_info(self): | ||
return {'parameters': self.parameters, | ||
'trajectory_file': self.filename, | ||
'load_kwargs': self.kwargs} | ||
|
||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,132 @@ | ||
""" | ||
Task-based implementation of :class:`.ContactFrequency`. | ||
The overall algorithm is: | ||
1. Identify how we're going to slice up the trajectory into task-based | ||
chunks (:meth:`block_slices`, :meth:`default_slices`) | ||
2. On each node | ||
a. Load the trajectory segment (:meth:`load_trajectory_task`) | ||
b. Run the analysis on the segment (:meth:`map_task`) | ||
3. Once all the results have been collected, combine them | ||
(:meth:`reduce_all_results`) | ||
Notes | ||
----- | ||
Includes versions where messages are Python objects and versions (labelled | ||
with _json) where messages have been JSON-serialized. However, we don't yet | ||
have a solution for JSON serialization of MDTraj objects, so if JSON | ||
serialization is the communication method, the loading of the trajectory and | ||
the calculation of the contacts must be combined into a single task. | ||
""" | ||
|
||
import mdtraj as md | ||
from contact_map import ContactFrequency | ||
|
||
def block_slices(n_total, n_per_block): | ||
"""Determine slices for splitting the input array. | ||
Parameters | ||
---------- | ||
n_total : int | ||
total length of array | ||
n_per_block : int | ||
maximum number of items per block | ||
Returns | ||
------- | ||
list of slice | ||
slices to be applied to the array | ||
""" | ||
n_full_blocks = n_total // n_per_block | ||
slices = [slice(i*n_per_block, (i+1)*n_per_block) | ||
for i in range(n_full_blocks)] | ||
if n_total % n_per_block: | ||
slices.append(slice(n_full_blocks*n_per_block, n_total)) | ||
return slices | ||
|
||
def default_slices(n_total, n_workers): | ||
"""Calculate default slices from number of workers. | ||
Default behavior is (approximately) one task per worker. | ||
Parameters | ||
---------- | ||
n_total : int | ||
total number of items in array | ||
n_workers : int | ||
number of workers | ||
Returns | ||
------- | ||
list of slice | ||
slices to be applied to the array | ||
""" | ||
n_frames_per_task = max(1, n_total // n_workers) | ||
return block_slices(n_total, n_frames_per_task) | ||
|
||
|
||
def load_trajectory_task(subslice, file_name, **kwargs): | ||
""" | ||
Task for loading file. Reordered for to take per-task variable first. | ||
Parameters | ||
---------- | ||
subslice : slice | ||
the slice of the trajectory to use | ||
file_name : str | ||
trajectory file name | ||
kwargs : | ||
other parameters to mdtraj.load | ||
Returns | ||
------- | ||
md.Trajectory : | ||
subtrajectory for this slice | ||
""" | ||
return md.load(file_name, **kwargs)[subslice] | ||
|
||
def map_task(subtrajectory, parameters): | ||
"""Task to be mapped to all subtrajectories. Run ContactFrequency | ||
Parameters | ||
---------- | ||
subtrajectory : mdtraj.Trajectory | ||
single trajectory segment to calculate ContactFrequency for | ||
parameters : dict | ||
kwargs-style dict for the :class:`.ContactFrequency` object | ||
Returns | ||
------- | ||
:class:`.ContactFrequency` : | ||
contact frequency for the subtrajectory | ||
""" | ||
return ContactFrequency(subtrajectory, **parameters) | ||
|
||
def reduce_all_results(contacts): | ||
"""Combine multiple :class:`.ContactFrequency` objects into one | ||
Parameters | ||
---------- | ||
contacts : iterable of :class:`.ContactFrequency` | ||
the individual (partial) contact frequencies | ||
Returns | ||
------- | ||
:class:`.ContactFrequency` : | ||
total of all input contact frequencies (summing them) | ||
""" | ||
accumulator = contacts[0] | ||
for contact in contacts[1:]: | ||
accumulator.add_contact_frequency(contact) | ||
return accumulator | ||
|
||
|
||
def map_task_json(subtrajectory, parameters): | ||
"""JSON-serialized version of :meth:`map_task`""" | ||
return map_task(subtrajectory, parameters).to_json() | ||
|
||
def reduce_all_results_json(results_of_map): | ||
"""JSON-serialized version of :meth:`reduce_all_results`""" | ||
contacts = [ContactFrequency.from_json(res) for res in results_of_map] | ||
return reduce_all_results(contacts) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,17 @@ | ||
from .utils import * | ||
|
||
from contact_map.dask_runner import * | ||
|
||
class TestDaskContactFrequency(object): | ||
def test_dask_integration(self): | ||
# this is an integration test to check that dask works | ||
dask = pytest.importorskip('dask') | ||
distributed = pytest.importorskip('dask.distributed') | ||
|
||
client = distributed.Client() | ||
filename = find_testfile("trajectory.pdb") | ||
|
||
dask_freq = DaskContactFrequency(client, filename, cutoff=0.075, | ||
n_neighbors_ignored=0) | ||
client.close() | ||
assert dask_freq.n_frames == 5 |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,93 @@ | ||
import os | ||
import collections | ||
|
||
from .utils import * | ||
from .test_contact_map import traj | ||
|
||
from contact_map.frequency_task import * | ||
from contact_map import ContactFrequency | ||
|
||
class TestSlicing(object): | ||
# tests for block_slices and default_slices | ||
@pytest.mark.parametrize("inputs, results", [ | ||
((100, 25), | ||
[slice(0, 25), slice(25, 50), slice(50, 75), slice(75, 100)]), | ||
((85, 25), | ||
[slice(0, 25), slice(25, 50), slice(50, 75), slice(75, 85)]) | ||
]) | ||
def test_block_slices(self, inputs, results): | ||
n_total, n_per_block = inputs | ||
assert block_slices(n_total, n_per_block) == results | ||
|
||
@pytest.mark.parametrize("inputs, results", [ | ||
((100, 4), | ||
[slice(0, 25), slice(25, 50), slice(50, 75), slice(75, 100)]), | ||
((77, 3), | ||
[slice(0, 25), slice(25, 50), slice(50, 75), slice(75, 77)]), | ||
((2, 20), | ||
[slice(0, 1), slice(1, 2)]) | ||
]) | ||
def test_default_slice_even_split(self, inputs, results): | ||
n_total, n_workers = inputs | ||
assert default_slices(n_total, n_workers) == results | ||
|
||
class TestTasks(object): | ||
def setup(self): | ||
self.contact_freq_0_4 = ContactFrequency(traj, cutoff=0.075, | ||
n_neighbors_ignored=0, | ||
frames=range(4)) | ||
self.contact_freq_4 = ContactFrequency(traj, cutoff=0.075, | ||
n_neighbors_ignored=0, | ||
frames=[4]) | ||
self.total_contact_freq = ContactFrequency(traj, cutoff=0.075, | ||
n_neighbors_ignored=0) | ||
self.parameters = {'cutoff': 0.075, 'n_neighbors_ignored': 0} | ||
|
||
def test_load_trajectory_task(self): | ||
subslice = slice(0, 4) | ||
file_name = find_testfile("trajectory.pdb") | ||
trajectory = load_trajectory_task(subslice, file_name) | ||
assert trajectory.xyz.shape == (4, 10, 3) | ||
|
||
def test_map_task(self): | ||
trajectory = traj[:4] | ||
mapped = map_task(trajectory, parameters=self.parameters) | ||
assert mapped == self.contact_freq_0_4 | ||
|
||
def test_reduce_task(self): | ||
reduced = reduce_all_results([self.contact_freq_0_4, | ||
self.contact_freq_4]) | ||
assert reduced == self.total_contact_freq | ||
|
||
def test_map_task_json(self): | ||
# check the json objects by converting them back to full objects | ||
trajectory = traj[:4] | ||
mapped = map_task_json(trajectory, parameters=self.parameters) | ||
assert ContactFrequency.from_json(mapped) == self.contact_freq_0_4 | ||
|
||
def test_reduce_all_results_json(self): | ||
reduced = reduce_all_results_json([self.contact_freq_0_4.to_json(), | ||
self.contact_freq_4.to_json()]) | ||
assert reduced == self.total_contact_freq | ||
|
||
def test_integration_object_based(self): | ||
file_name = find_testfile("trajectory.pdb") | ||
slices = default_slices(len(traj), n_workers=3) | ||
trajs = [load_trajectory_task(subslice=sl, | ||
file_name=file_name) | ||
for sl in slices] | ||
mapped = [map_task(subtraj, self.parameters) for subtraj in trajs] | ||
result = reduce_all_results(mapped) | ||
assert result == self.total_contact_freq | ||
|
||
def test_integration_json_based(self): | ||
file_name = find_testfile("trajectory.pdb") | ||
slices = default_slices(len(traj), n_workers=3) | ||
trajs = [load_trajectory_task(subslice=sl, | ||
file_name=file_name) | ||
for sl in slices] | ||
mapped = [map_task_json(subtraj, self.parameters) | ||
for subtraj in trajs] | ||
result = reduce_all_results_json(mapped) | ||
assert result == self.total_contact_freq | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.