Skip to content

Commit

Permalink
introduce some helper functions for typed get of pages
Browse files Browse the repository at this point in the history
  • Loading branch information
codingchipmunk committed Mar 10, 2023
1 parent 709e218 commit 10c2671
Showing 1 changed file with 32 additions and 8 deletions.
40 changes: 32 additions & 8 deletions pydapsys/neo_convert/abstract_converter.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
import numpy.typing as npt
import quantities as pq

from pydapsys.page import DataPage, WaveformPage
from pydapsys.page import DataPage, WaveformPage, TextPage, PageType
from pydapsys.toc.entry import Root, Stream, StreamType
from pydapsys.util.floats import float_comp

Expand All @@ -33,6 +33,30 @@ def to_neo(self) -> neo.Block:
"""
...

def _get_datapage_typechecked(self, pid: int, ptype: PageType) -> DataPage:
"""
gets the page with the given id and checks if it is of the requested type.
Throws an exception if the type doesn't match
"""
page = self.pages[pid]
if page.type != ptype:
raise Exception(f"page {pid} is not of type {ptype.Text}, but {page.type.Text}")
return page

def get_textpage(self, pid: int) -> TextPage:
"""
Gets the page with the given id and checks if it is a text page. If it isn't, an exception is thrown.
Primarily serves as a convenience function for type hinting.
"""
return self._get_datapage_typechecked(pid, PageType.Text)

def get_waveformpage(self, pid: int) -> WaveformPage:
"""
Gets the page with the given id and checks if it is a waveform page. If it isn't, an exception is thrown.
Primarily serves as a convenience function for type hinting.
"""
return self._get_datapage_typechecked(pid, PageType.Waveform)

def _pageids_to_event(self, page_ids: Union[Sequence[int], npt.NDArray[np.uint32]], name: str = "") -> neo.Event:
"""Converts data from a sequence (or numpy array) of page ids to a neo event.
The labels will be taken from the page text and the event times from the first timestamp (timestamp_a)
Expand All @@ -42,9 +66,9 @@ def _pageids_to_event(self, page_ids: Union[Sequence[int], npt.NDArray[np.uint32
"""
times = np.empty(len(page_ids), dtype=np.float64)
comments = np.empty(len(page_ids), dtype=str)
for i in range(len(page_ids)):
times[i] = self.pages[page_ids[i]].timestamp_a
comments[i] = self.pages[page_ids[i]].comment
for i, page in enumerate(self.get_textpage(pid) for pid in page_ids):
times[i] = page.timestamp_a
comments[i] = page.text
return neo.Event(times=times, labels=comments, units=pq.second, name=name, copy=False)

def textstream_to_event(self, stream: Stream, name: Optional[str] = None) -> neo.Event:
Expand Down Expand Up @@ -73,7 +97,7 @@ def _pageids_to_spiketrain(self, page_ids: Union[Sequence[int], npt.NDArray[np.u
:return: A spike train build from the comment pages
"""
return neo.SpikeTrain(
times=np.fromiter((comment.timestamp_a for comment in (self.pages[pid] for pid in page_ids)),
times=np.fromiter((comment.timestamp_a for comment in (self.get_textpage(pid) for pid in page_ids)),
dtype=np.float64, count=len(page_ids)), name=name, units=pq.second, t_stop=t_stop,
copy=False)

Expand All @@ -100,8 +124,8 @@ def _pageids_to_events_by_comment_text(self, page_ids: Union[Sequence[int], npt.
:return: An iterable of neo events
"""
comment_string_to_timestamps: Dict[str, List[float]] = dict()
for comment in (self.pages[pid] for pid in page_ids):
comment_string_to_timestamps.setdefault(comment.comment, list()).append(comment.timestamp_a)
for comment in (self.get_textpage(pid) for pid in page_ids):
comment_string_to_timestamps.setdefault(comment.text, list()).append(comment.timestamp_a)
for comment_string, comment_timestamps in comment_string_to_timestamps.items():
yield neo.Event(times=np.array(comment_timestamps, dtype=np.float64), units=pq.second, name=comment_string,
copy=False)
Expand Down Expand Up @@ -147,7 +171,7 @@ def waveformstream_to_analogsignals(self, stream: Stream, tolerance: float = 1e-
"""
if stream.stream_type != StreamType.Waveform:
raise ValueError(f"StreamType.Waveform required for this operation, not {stream.stream_type.name}")
for segment_group in self._group_recordingsegments((self.pages[pid] for pid in stream.page_ids),
for segment_group in self._group_recordingsegments((self.get_waveformpage(pid) for pid in stream.page_ids),
tolerance=tolerance):
continuous = np.concatenate(list(segment.values for segment in segment_group)).ravel()
yield neo.AnalogSignal(continuous, pq.volt,
Expand Down

0 comments on commit 10c2671

Please sign in to comment.