From 10c2671dbb36d9111f9e351cb1c328d8e5792672 Mon Sep 17 00:00:00 2001 From: Peter Konradi Date: Fri, 10 Mar 2023 13:46:58 +0100 Subject: [PATCH] introduce some helper functions for typed get of pages --- pydapsys/neo_convert/abstract_converter.py | 40 +++++++++++++++++----- 1 file changed, 32 insertions(+), 8 deletions(-) diff --git a/pydapsys/neo_convert/abstract_converter.py b/pydapsys/neo_convert/abstract_converter.py index fd1f269..f7e5a06 100644 --- a/pydapsys/neo_convert/abstract_converter.py +++ b/pydapsys/neo_convert/abstract_converter.py @@ -6,7 +6,7 @@ import numpy.typing as npt import quantities as pq -from pydapsys.page import DataPage, WaveformPage +from pydapsys.page import DataPage, WaveformPage, TextPage, PageType from pydapsys.toc.entry import Root, Stream, StreamType from pydapsys.util.floats import float_comp @@ -33,6 +33,30 @@ def to_neo(self) -> neo.Block: """ ... + def _get_datapage_typechecked(self, pid: int, ptype: PageType) -> DataPage: + """ + gets the page with the given id and checks if it is of the requested type. + Throws an exception if the type doesn't match + """ + page = self.pages[pid] + if page.type != ptype: + raise Exception(f"page {pid} is not of type {ptype.Text}, but {page.type.Text}") + return page + + def get_textpage(self, pid: int) -> TextPage: + """ + Gets the page with the given id and checks if it is a text page. If it isn't, an exception is thrown. + Primarily serves as a convenience function for type hinting. + """ + return self._get_datapage_typechecked(pid, PageType.Text) + + def get_waveformpage(self, pid: int) -> WaveformPage: + """ + Gets the page with the given id and checks if it is a waveform page. If it isn't, an exception is thrown. + Primarily serves as a convenience function for type hinting. + """ + return self._get_datapage_typechecked(pid, PageType.Waveform) + def _pageids_to_event(self, page_ids: Union[Sequence[int], npt.NDArray[np.uint32]], name: str = "") -> neo.Event: """Converts data from a sequence (or numpy array) of page ids to a neo event. The labels will be taken from the page text and the event times from the first timestamp (timestamp_a) @@ -42,9 +66,9 @@ def _pageids_to_event(self, page_ids: Union[Sequence[int], npt.NDArray[np.uint32 """ times = np.empty(len(page_ids), dtype=np.float64) comments = np.empty(len(page_ids), dtype=str) - for i in range(len(page_ids)): - times[i] = self.pages[page_ids[i]].timestamp_a - comments[i] = self.pages[page_ids[i]].comment + for i, page in enumerate(self.get_textpage(pid) for pid in page_ids): + times[i] = page.timestamp_a + comments[i] = page.text return neo.Event(times=times, labels=comments, units=pq.second, name=name, copy=False) def textstream_to_event(self, stream: Stream, name: Optional[str] = None) -> neo.Event: @@ -73,7 +97,7 @@ def _pageids_to_spiketrain(self, page_ids: Union[Sequence[int], npt.NDArray[np.u :return: A spike train build from the comment pages """ return neo.SpikeTrain( - times=np.fromiter((comment.timestamp_a for comment in (self.pages[pid] for pid in page_ids)), + times=np.fromiter((comment.timestamp_a for comment in (self.get_textpage(pid) for pid in page_ids)), dtype=np.float64, count=len(page_ids)), name=name, units=pq.second, t_stop=t_stop, copy=False) @@ -100,8 +124,8 @@ def _pageids_to_events_by_comment_text(self, page_ids: Union[Sequence[int], npt. :return: An iterable of neo events """ comment_string_to_timestamps: Dict[str, List[float]] = dict() - for comment in (self.pages[pid] for pid in page_ids): - comment_string_to_timestamps.setdefault(comment.comment, list()).append(comment.timestamp_a) + for comment in (self.get_textpage(pid) for pid in page_ids): + comment_string_to_timestamps.setdefault(comment.text, list()).append(comment.timestamp_a) for comment_string, comment_timestamps in comment_string_to_timestamps.items(): yield neo.Event(times=np.array(comment_timestamps, dtype=np.float64), units=pq.second, name=comment_string, copy=False) @@ -147,7 +171,7 @@ def waveformstream_to_analogsignals(self, stream: Stream, tolerance: float = 1e- """ if stream.stream_type != StreamType.Waveform: raise ValueError(f"StreamType.Waveform required for this operation, not {stream.stream_type.name}") - for segment_group in self._group_recordingsegments((self.pages[pid] for pid in stream.page_ids), + for segment_group in self._group_recordingsegments((self.get_waveformpage(pid) for pid in stream.page_ids), tolerance=tolerance): continuous = np.concatenate(list(segment.values for segment in segment_group)).ravel() yield neo.AnalogSignal(continuous, pq.volt,