Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Simplify cache #67

Merged
merged 10 commits into from
Oct 3, 2024
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
216 changes: 50 additions & 166 deletions zarrtraj/ZARR.py
Original file line number Diff line number Diff line change
Expand Up @@ -233,7 +233,9 @@ class ZARRH5MDReader(base.ReaderBase):
@due.dcite(
Doi("10.1002/jcc.21787"), description="MDAnalysis 2011", path=__name__
)
@due.dcite(Doi("10.5281/zenodo.3773449"), description="Zarr", path=__name__)
@due.dcite(
Doi("10.5281/zenodo.3773449"), description="Zarr", path=__name__
)
@store_init_arguments
def __init__(
self,
Expand Down Expand Up @@ -313,8 +315,7 @@ def __init__(
# Set to none so close() can be called
self._file = None
self._cache = None
# Read first timestep
self._frame_seq = collections.deque([0])

if not HAS_ZARR:
raise RuntimeError("Please install zarr")
super(ZARRH5MDReader, self).__init__(filename, **kwargs)
Expand Down Expand Up @@ -397,7 +398,7 @@ def __init__(
self._global_steparray,
self._stepmaps,
)
self._cache.update_frame_seq(self._frame_seq)

self._read_next_timestep()

def _set_translated_units(self):
Expand Down Expand Up @@ -628,22 +629,14 @@ def _read_next_timestep(self):

def _read_frame(self, frame):
"""reads data from h5md-formatted file and copies to current timestep"""
# frame seq update case 1: read called from iterator-like context
if not self._frame_seq:
self._frame_seq = None
self._cache.update_frame_seq(self._frame_seq)
raise StopIteration
if frame < 0 or frame >= self.n_frames:
raise IOError("Frame index out of range")

self._frame = self._cache.load_frame()
self._frame = self._cache.load_frame(frame)

if self.convert_units:
self._convert_units()

# frame seq update case 2: read called from __getitem__-like context
if len(self._frame_seq) == 0:
self._frame_seq = None
self._cache.update_frame_seq(self._frame_seq)

return self.ts

def _convert_units(self):
Expand All @@ -667,7 +660,6 @@ def _convert_units(self):

def close(self):
"""close reader"""
self._frame_seq = None
if self._cache is not None:
self._cache.cleanup()
if self._file is not None:
Expand All @@ -687,150 +679,6 @@ def Writer(self, filename, n_atoms=None, **kwargs):
kwargs.setdefault("forces", ("force" in self._elements))
return ZARRMDWriter(filename, n_atoms, **kwargs)

def __getitem__(self, frame):
"""Return the Timestep corresponding to *frame*.

If `frame` is a integer then the corresponding frame is
returned. Negative numbers are counted from the end.

If frame is a :class:`slice` then an iterator is returned that
allows iteration over that part of the trajectory.

Note
----
*frame* is a 0-based frame index.

Note
----
ZARRH5MDReader overrides this method to get
access to the the sequence of frames
the user wants.
"""
if isinstance(frame, numbers.Integral):
frame = self._apply_limits(frame)
if self._frame_seq is None:
self._frame_seq = collections.deque([frame])
self._cache.update_frame_seq(self._frame_seq)
return self._read_frame_with_aux(frame)
elif isinstance(frame, (list, np.ndarray)):
if len(frame) != 0 and isinstance(frame[0], (bool, np.bool_)):
# Avoid having list of bools
frame = np.asarray(frame, dtype=bool)
# Convert bool array to int array
frame = np.arange(len(self))[frame]
if isinstance(frame, np.ndarray):
frame = frame.tolist()
if self._frame_seq is None:
self._frame_seq = collections.deque(frame)
self._cache.update_frame_seq(self._frame_seq)
return base.FrameIteratorIndices(self, frame)
elif isinstance(frame, slice):
start, stop, step = self.check_slice_indices(
frame.start, frame.stop, frame.step
)
if self._frame_seq is None:
self._frame_seq = collections.deque(range(start, stop, step))
self._cache.update_frame_seq(self._frame_seq)
if start == 0 and stop == len(self) and step == 1:
return base.FrameIteratorAll(self)
else:
return base.FrameIteratorSliced(self, frame)
else:
raise TypeError(
"Trajectories must be an indexed using an integer,"
" slice or list of indices"
)

def __iter__(self):
"""Iterate over all frames in the trajectory

Note
----
ZARRH5MDReader overrides this method to get
access to the the sequence of frames
the user wants.
"""
self._reopen()
self._frame_seq = collections.deque(range(0, self.n_frames))
self._cache.update_frame_seq(self._frame_seq)
return self

def next(self):
if self._frame_seq is None and self._frame + 1 < self.n_frames:
self._frame_seq = collections.deque([self._frame + 1])
self._cache.update_frame_seq(self._frame_seq)
elif self._frame_seq is None:
self.rewind()
raise StopIteration from None
try:
ts = self._read_next_timestep()
except (EOFError, IOError):
self.rewind()
raise StopIteration from None
else:
for auxname, reader in self._auxs.items():
ts = self._auxs[auxname].update_ts(ts)

ts = self._apply_transformations(ts)

return ts

def iter_as_aux(self, auxname):
"""Iterate over the trajectory with an auxiliary reader

Note
----
ZARRH5MDReader overrides this method to get
access to the the sequence of frames
the user wants.
"""
aux = self._check_for_aux(auxname)
self._reopen()
self._frame_seq = collections.deque(range(0, self.n_frames))
self._cache.update_frame_seq(self._frame_seq)
aux._restart()
while True:
try:
yield self.next_as_aux(auxname)
except StopIteration:
return

def copy(self):
"""Return independent copy of this Reader.

New Reader will have its own file handle and can seek/iterate
independently of the original.

Will also copy the current state of the Timestep held in the original
Reader.

Note
----
ZARRH5MDReader overrides this method to get
access to the the sequence of frames
the user wants.

.. versionchanged:: 2.2.0
Arguments used to construct the reader are correctly captured and
passed to the creation of the new class. Previously the only
``n_atoms`` was passed to class copies, leading to a class created
with default parameters which may differ from the original class.
"""

new = self.__class__(**self._kwargs)

if self.transformations:
new.add_transformations(*self.transformations)
# seek the new reader to the same frame we started with
new[self.ts.frame]
# then copy over the current Timestep in case it has
# been modified since initial load
new.ts = self.ts.copy()
new._cache._timestep = new.ts
for auxname, auxread in self._auxs.items():
new.add_auxiliary(auxname, auxread.copy())
return new

@property
def n_frames(self):
"""number of frames in trajectory"""
Expand Down Expand Up @@ -870,6 +718,41 @@ def parse_n_atoms(filename, group=None, so=None):
"You must include a topology file."
)

def copy(self):
"""Return independent copy of this Reader.

New Reader will have its own file handle and can seek/iterate
independently of the original.

Will also copy the current state of the Timestep held in the original
Reader.

Note
----
ZARRH5MDReader overrides this method to copy
the copied reader's timestep to the cache's timestep

.. versionchanged:: 2.2.0
Arguments used to construct the reader are correctly captured and
passed to the creation of the new class. Previously the only
``n_atoms`` was passed to class copies, leading to a class created
with default parameters which may differ from the original class.
"""

new = self.__class__(**self._kwargs)

if self.transformations:
new.add_transformations(*self.transformations)
# seek the new reader to the same frame we started with
new[self.ts.frame]
# then copy over the current Timestep in case it has
# been modified since initial load
new.ts = self.ts.copy()
new._cache._timestep = new.ts
for auxname, auxread in self._auxs.items():
new.add_auxiliary(auxname, auxread.copy())
return new


class H5MDElementBuffer:
def __init__(
Expand Down Expand Up @@ -996,9 +879,9 @@ def flush(self):
if num_v_frames == 0:
num_v_frames = self._val_frames_per_chunk

self._val[self._val_idx - num_v_frames : self._val_idx] = self._val_buf[
:num_v_frames
]
self._val[self._val_idx - num_v_frames : self._val_idx] = (
self._val_buf[:num_v_frames]
)
self._val.resize(self._val_idx, *self._val_chunks[1:])

num_t_frames = self._t_idx % self._t_frames_per_chunk
Expand Down Expand Up @@ -1248,7 +1131,9 @@ def __init__(

protocol = get_protocol(filename)
if protocol not in ZARRTRAJ_NETWORK_PROTOCOLS and protocol != "file":
raise ValueError(f"Unsupported protocol '{protocol}' for Zarrtraj.")
raise ValueError(
f"Unsupported protocol '{protocol}' for Zarrtraj."
)
if protocol in ZARRTRAJ_EXPERIMENTAL_PROTOCOLS:
warnings.warn(
f"Zarrtraj is using the experimental protocol '{protocol}' "
Expand Down Expand Up @@ -1649,9 +1534,8 @@ def update_desired_dsets(
self._global_steparray = global_steparray
self._stepmaps = stepmaps

def load_frame(self):
def load_frame(self, frame):
"""Reader responsible for raising StopIteration when no more frames"""
frame = self._frame_seq.popleft()
self._load_timestep_frame(frame)
return frame

Expand Down
Loading