Skip to content

Commit

Permalink
Merge pull request #298 from choderalab/slice-sampler-state
Browse files Browse the repository at this point in the history
Support slice SamplerState with list of indices
  • Loading branch information
andrrizzi authored Oct 4, 2017
2 parents 1e998c9 + fe847d0 commit 7717e42
Show file tree
Hide file tree
Showing 2 changed files with 17 additions and 11 deletions.
16 changes: 10 additions & 6 deletions openmmtools/states.py
Original file line number Diff line number Diff line change
Expand Up @@ -1738,12 +1738,9 @@ def has_nan(self):

def __getitem__(self, item):
sampler_state = self.__class__([])
if isinstance(item, slice):
# Copy original values to avoid side effects.
sampler_state._positions = copy.deepcopy(self._positions[item])
if self._velocities is not None:
sampler_state._velocities = copy.deepcopy(self._velocities[item].copy())
else: # Single index.

# Handle single index.
if np.issubdtype(type(item), np.integer):
# Here we don't need to copy since we instantiate a new array.
pos_value = self._positions[item].value_in_unit(self._positions.unit)
sampler_state._positions = unit.Quantity(np.array([pos_value]),
Expand All @@ -1752,6 +1749,13 @@ def __getitem__(self, item):
vel_value = self._velocities[item].value_in_unit(self._velocities.unit)
sampler_state._velocities = unit.Quantity(np.array([vel_value]),
self._velocities.unit)
else: # Assume slice or sequence.
# Copy original values to avoid side effects.
sampler_state._positions = copy.deepcopy(self._positions[item])
if self._velocities is not None:
sampler_state._velocities = copy.deepcopy(self._velocities[item].copy())

# Copy box vectors.
sampler_state.box_vectors = copy.deepcopy(self.box_vectors)

# Energies for only a subset of atoms is undefined.
Expand Down
12 changes: 7 additions & 5 deletions openmmtools/tests/test_states.py
Original file line number Diff line number Diff line change
Expand Up @@ -857,11 +857,13 @@ def test_operator_getitem(self):
sliced_sampler_state.positions[0][0] += 1 * unit.angstrom
assert sliced_sampler_state.positions[0][0] == sampler_state.positions[0][0] + 1 * unit.angstrom

sliced_sampler_state = sampler_state[2:10]
assert sliced_sampler_state.n_particles == 8
assert len(sliced_sampler_state.velocities) == 8
assert np.allclose(sliced_sampler_state.positions,
self.alanine_explicit_positions[2:10])
# SamplerState.__getitem__ should work for both slices and lists.
for sliced_sampler_state in [sampler_state[2:10],
sampler_state[list(range(2, 10))]]:
assert sliced_sampler_state.n_particles == 8
assert len(sliced_sampler_state.velocities) == 8
assert np.allclose(sliced_sampler_state.positions,
self.alanine_explicit_positions[2:10])

sliced_sampler_state = sampler_state[2:10:2]
assert sliced_sampler_state.n_particles == 4
Expand Down

0 comments on commit 7717e42

Please sign in to comment.