From bedb8797b5179e1757edc3d225c20cd724112989 Mon Sep 17 00:00:00 2001 From: Valerie Tsai <87097162+vtsai881@users.noreply.github.com> Date: Sun, 3 Dec 2023 15:55:59 -0500 Subject: [PATCH 1/6] Update commands.py added all_frames variable to exportAnalysisFile() and exportCSVFile() command definitions (ln 329, 331). added all_frames conditional to ExportAnalysisFile(command) definition --- sleap/gui/commands.py | 23 +++++++++++++++++------ 1 file changed, 17 insertions(+), 6 deletions(-) diff --git a/sleap/gui/commands.py b/sleap/gui/commands.py index 90f40397e..4e8f0670d 100644 --- a/sleap/gui/commands.py +++ b/sleap/gui/commands.py @@ -326,13 +326,13 @@ def saveProjectAs(self): """Show gui to save project as a new file.""" self.execute(SaveProjectAs) - def exportAnalysisFile(self, all_videos: bool = False): + def exportAnalysisFile(self, all_videos: bool = False, all_frames: bool = False): """Shows gui for exporting analysis h5 file.""" - self.execute(ExportAnalysisFile, all_videos=all_videos, csv=False) + self.execute(ExportAnalysisFile, all_videos=all_videos, all_frames=all_frames, csv=False) - def exportCSVFile(self, all_videos: bool = False): + def exportCSVFile(self, all_videos: bool = False, all_frames: bool = False): """Shows gui for exporting analysis csv file.""" - self.execute(ExportAnalysisFile, all_videos=all_videos, csv=True) + self.execute(ExportAnalysisFile, all_videos=all_videos, all_frames=all_frames, csv=True) def exportNWB(self): """Show gui for exporting nwb file.""" @@ -1142,13 +1142,24 @@ def do_action(cls, context: CommandContext, params: dict): adaptor = NixAdaptor else: adaptor = SleapAnalysisAdaptor + + if params['all_frames']: adaptor.write( filename=output_path, + all_frames=True, source_object=context.labels, source_path=context.state["filename"], video=video, - ) - + ) + else: + adaptor.write( + filename=output_path, + all_frames=False, + source_object=context.labels, + source_path=context.state["filename"], + video=video, + ) + @staticmethod def ask(context: CommandContext, params: dict) -> bool: def ask_for_filename(default_name: str, csv: bool) -> str: From 8635ec06f87972e990ced78f620accd55c6f2511 Mon Sep 17 00:00:00 2001 From: Valerie Tsai <87097162+vtsai881@users.noreply.github.com> Date: Sun, 3 Dec 2023 15:59:17 -0500 Subject: [PATCH 2/6] Update app.py added dropdown menu button options for exporting csvs with all frames (ln 495-521) --- sleap/gui/app.py | 17 +++++++++++++++-- 1 file changed, 15 insertions(+), 2 deletions(-) diff --git a/sleap/gui/app.py b/sleap/gui/app.py index 41d696f0c..039cd931e 100644 --- a/sleap/gui/app.py +++ b/sleap/gui/app.py @@ -496,15 +496,28 @@ def add_submenu_choices(menu, title, options, key): add_menu_item( export_csv_menu, "export_csv_current", - "Current Video...", + "Current Video (only frames with tracking)...", self.commands.exportCSVFile, ) add_menu_item( export_csv_menu, "export_csv_all", - "All Videos...", + "All Videos (only frames with tracking)...", lambda: self.commands.exportCSVFile(all_videos=True), ) + + fileMenu.addSeparator() + add_menu_item( + export_csv_menu, + "export_csv_current", + "Current Video (all frames)...", + lambda: self.commands.exportCSVFile(all_frames=True), + ) + add_menu_item( + export_csv_menu, + "export_csv_all", + "All Videos (all frames)...", + lambda: self.commands.exportCSVFile(all_frames=True, all_videos=True), add_menu_item(fileMenu, "export_nwb", "Export NWB...", self.commands.exportNWB) From 51f1d2196f4bac650ec58282e2473b0e5eb5a960 Mon Sep 17 00:00:00 2001 From: Valerie Tsai <87097162+vtsai881@users.noreply.github.com> Date: Sun, 10 Dec 2023 18:24:24 -0800 Subject: [PATCH 3/6] Update app.py --- sleap/gui/app.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/sleap/gui/app.py b/sleap/gui/app.py index 039cd931e..6cfe242b3 100644 --- a/sleap/gui/app.py +++ b/sleap/gui/app.py @@ -509,13 +509,13 @@ def add_submenu_choices(menu, title, options, key): fileMenu.addSeparator() add_menu_item( export_csv_menu, - "export_csv_current", + "export_csv_current_all_frames", "Current Video (all frames)...", lambda: self.commands.exportCSVFile(all_frames=True), ) add_menu_item( export_csv_menu, - "export_csv_all", + "export_csv_all_frames", "All Videos (all frames)...", lambda: self.commands.exportCSVFile(all_frames=True, all_videos=True), From 1f8cb4702b5585d70f9a61826d350fdc661e4179 Mon Sep 17 00:00:00 2001 From: vtsai881 Date: Mon, 11 Dec 2023 12:43:10 -0500 Subject: [PATCH 4/6] reformat --- sleap/gui/app.py | 18 +- sleap/gui/commands.py | 2 +- sleap/info/write_tracking_h5.py | 12 +- sleap/io/format/csv.py | 3 + sleap/io/format/nix.py | 484 +++--------------------------- sleap/io/format/sleap_analysis.py | 1 + 6 files changed, 78 insertions(+), 442 deletions(-) diff --git a/sleap/gui/app.py b/sleap/gui/app.py index 41d696f0c..47c3890c1 100644 --- a/sleap/gui/app.py +++ b/sleap/gui/app.py @@ -496,16 +496,30 @@ def add_submenu_choices(menu, title, options, key): add_menu_item( export_csv_menu, "export_csv_current", - "Current Video...", + "Current Video (only tracked frames)...", self.commands.exportCSVFile, ) add_menu_item( export_csv_menu, "export_csv_all", - "All Videos...", + "All Videos (only tracked frames)...", lambda: self.commands.exportCSVFile(all_videos=True), ) + export_csv_menu.addSeparator() + add_menu_item( + export_csv_menu, + "export_csv_current_all_frames", + "Current Video (all frames)...", + self.commands.exportCSVFile(all_frames=True), + ) + add_menu_item( + export_csv_menu, + "export_csv_all_all_frames", + "All Videos (all frames)...", + lambda: self.commands.exportCSVFile(all_frames=True), + ) + add_menu_item(fileMenu, "export_nwb", "Export NWB...", self.commands.exportNWB) fileMenu.addSeparator() diff --git a/sleap/gui/commands.py b/sleap/gui/commands.py index 4e8f0670d..3924627bb 100644 --- a/sleap/gui/commands.py +++ b/sleap/gui/commands.py @@ -1146,7 +1146,7 @@ def do_action(cls, context: CommandContext, params: dict): if params['all_frames']: adaptor.write( filename=output_path, - all_frames=True, + all_frames=params.get('all_frames', False), source_object=context.labels, source_path=context.state["filename"], video=video, diff --git a/sleap/info/write_tracking_h5.py b/sleap/info/write_tracking_h5.py index 2b714eeb5..a31f1f510 100644 --- a/sleap/info/write_tracking_h5.py +++ b/sleap/info/write_tracking_h5.py @@ -348,14 +348,22 @@ def write_csv_file(output_path, data_dict): tracks.append(detection) tracks = pd.DataFrame(tracks) - tracks.to_csv(output_path, index=False) + all_frames = globals().get('all_frames', False) + + if all_frames: + tracks = tracks.set_index('frame_idx') + tracks = tracks.reindex(range(0, len(data_dict['track_occupancy'])), fill_value=np.nan) + tracks = tracks.reset_index(drop=False) + tracks.to_csv(output_path, index=False) + else: + tracks.to_csv(output_path, index=False) def main( labels: Labels, output_path: str, labels_path: str = None, - all_frames: bool = True, + all_frames: bool = False, video: Video = None, csv: bool = False, ): diff --git a/sleap/io/format/csv.py b/sleap/io/format/csv.py index 4640ee117..915370f03 100644 --- a/sleap/io/format/csv.py +++ b/sleap/io/format/csv.py @@ -45,6 +45,7 @@ def write( filename: str, source_object: Labels, source_path: str = None, + all_frames: bool: False, video: Video = None, ): """Writes csv file for :py:class:`Labels` `source_object`. @@ -53,6 +54,8 @@ def write( filename: The filename for the output file. source_object: The :py:class:`Labels` from which to get data from. source_path: Path for the labels object + all_frames: A boolean flag to determine whether to include all frames + or only those with tracking data in the export. video: The :py:class:`Video` from which toget data from. If no `video` is specified, then the first video in `source_object` videos list will be used. If there are no :py:class:`Labeled Frame`s in the `video`, then no diff --git a/sleap/io/format/nix.py b/sleap/io/format/nix.py index 4c39ec8b6..915370f03 100644 --- a/sleap/io/format/nix.py +++ b/sleap/io/format/nix.py @@ -1,463 +1,73 @@ -import numpy as np -import nixio as nix +"""Adaptor for writing SLEAP analysis as csv.""" -from pathlib import Path -from typing import Dict, List, Optional, cast -from sleap.instance import Track +from sleap.io import format -from sleap.io.format.adaptor import Adaptor, SleapObjectType -from sleap.io.format.filehandle import FileHandle -from sleap.io.dataset import Labels -from sleap.io.video import Video -from sleap.skeleton import Node, Skeleton +from sleap import Labels, Video +from typing import Optional, Callable, List, Text, Union -class NixAdaptor(Adaptor): - """Adaptor class for export of tracking analysis results to the generic - [NIX](https://github.com/g-node/nix) format. - NIX defines a generic data model for scientific data that combines data and data - annotations within the same container. The written files are hdf5 files that can - be read with any hdf5 library but follow the entity definitions of the NIX data - model. For reading nix-files with python install the nixio low-level library - ```pip install nixio``` or use the high-level api - [nixtrack](https://github.com/bendalab/nixtrack). +class CSVAdaptor(format.adaptor.Adaptor): + FORMAT_ID = 1.0 - So far the adaptor exports the tracked positions for each node of each instance, - the track and skeleton information along with the respective scores and the - centroid. Additionally, the video information is exported as metadata. - For more information on the mapping from sleap to nix see the docs on - [nixtrack](https://github.com/bendalab/nixtrack) (work in progress). - The adaptor uses a chunked writing approach which avoids numpy out of memory - exceptions when exporting large datasets. - - author: Jan Grewe (jan.grewe@g-node.org) - """ + # 1.0 initial implementation @property - def default_ext(self): - return "nix" + def handles(self): + return format.adaptor.SleapObjectType.labels @property - def all_exts(self) -> List[str]: - return [self.default_ext] + def default_ext(self): + return "csv" @property - def handles(self): - return SleapObjectType.misc + def all_exts(self): + return ["csv", "xlsx"] @property - def name(self) -> str: - """Human-reading name of the file format""" - return ( - "NIX file flavoured for animal tracking data https://github.com/g-node/nix" - ) + def name(self): + return "CSV" - @classmethod - def can_read_file(cls, file: FileHandle) -> bool: - """Returns whether this adaptor can read this file.""" + def can_read_file(self, file: format.filehandle.FileHandle): return False - def can_write_filename(self, filename: str) -> bool: - """Returns whether this adaptor can write format of this filename.""" - return filename.endswith(tuple(self.all_exts)) + def can_write_filename(self, filename: str): + return self.does_match_ext(filename) - @classmethod - def does_read(cls) -> bool: - """Returns whether this adaptor supports reading.""" + def does_read(self) -> bool: return False - @classmethod - def does_write(cls) -> bool: - """Returns whether this adaptor supports writing.""" + def does_write(self) -> bool: return True - @classmethod - def read(cls, file: FileHandle) -> object: - """Reads the file and returns the appropriate deserialized object.""" - raise NotImplementedError("NixAdaptor does not support reading.") - - @classmethod - def __check_video(cls, labels: Labels, video: Optional[Video]): - if (video is None) and (len(labels.videos) == 0): - raise ValueError( - f"There are no videos in this project. " - "No analysis file will be be written." - ) - if video is not None: - if video not in labels.videos: - raise ValueError( - f"Specified video {video} is not part of this project. " - "Skipping the analysis file for this video." - ) - if len(labels.get(video)) == 0: - raise ValueError( - f"No labeled frames in {video.backend.filename}. " - "Skipping the analysis file for this video." - ) - @classmethod def write( cls, filename: str, - source_object: object, - source_path: Optional[str] = None, - video: Optional[Video] = None, + source_object: Labels, + source_path: str = None, + all_frames: bool: False, + video: Video = None, ): - """Writes the object to a file.""" - source_object = cast(Labels, source_object) - - cls.__check_video(source_object, video) - - def create_file(filename: str, project: Optional[str], video: Video): - print(f"Creating nix file...", end="\t") - nf = nix.File.open(filename, nix.FileMode.Overwrite) - try: - s = nf.create_section("TrackingAnalysis", "nix.tracking.metadata") - s["version"] = "0.1.0" - s["format"] = "nix.tracking" - s["definitions"] = "https://github.com/bendalab/nixtrack" - s["writer"] = str(cls)[8:-2] - if project is not None: - s["project"] = project - - name = Path(video.backend.filename).name - b = nf.create_block(name, "nix.tracking_results") - - # add video metadata, if exists - src = b.create_source(name, "nix.tracking.source.video") - sec = src.file.create_section( - name, "nix.tracking.source.video.metadata" - ) - sec["filename"] = video.backend.filename - sec["fps"] = getattr(video.backend, "fps", 0.0) - sec.props["fps"].unit = "Hz" - sec["frames"] = video.num_frames - sec["grayscale"] = getattr(video.backend, "grayscale", None) - sec["height"] = video.backend.height - sec["width"] = video.backend.width - src.metadata = sec - except Exception as e: - nf.close() - raise e - - print("done") - return nf - - def track_map(source: Labels) -> Dict[Track, int]: - track_map: Dict[Track, int] = {} - for track in source.tracks: - if track in track_map: - continue - track_map[track] = len(track_map) - return track_map - - def skeleton_map(source: Labels) -> Dict[Skeleton, int]: - skel_map: Dict[Skeleton, int] = {} - for skeleton in source.skeletons: - if skeleton in skel_map: - continue - skel_map[skeleton] = len(skel_map) - return skel_map - - def node_map(source: Labels) -> Dict[Node, int]: - n_map: Dict[Node, int] = {} - for node in source.nodes: - if node in n_map: - continue - n_map[node] = len(n_map) - return n_map - - def create_feature_array(name, type, block, frame_index_array, shape, dtype): - array = block.create_data_array(name, type, dtype=dtype, shape=shape) - rd = array.append_range_dimension() - rd.link_data_array(frame_index_array, [-1]) - return array - - def create_positions_array( - name, type, block, frame_index_array, node_names, shape, dtype - ): - array = block.create_data_array( - name, type, dtype=dtype, shape=shape, label="pixel" - ) - rd = array.append_range_dimension() - rd.link_data_array(frame_index_array, [-1]) - array.append_set_dimension(["x", "y"]) - array.append_set_dimension(node_names) - return array - - def chunked_write( - instances, - frameid_array, - positions_array, - track_array, - skeleton_array, - pointscore_array, - instancescore_array, - trackingscore_array, - centroid_array, - track_map, - node_map, - skeleton_map, - chunksize=10000, - ): - data_written = 0 - indices = np.zeros(chunksize, dtype=int) - track = np.zeros_like(indices) - skeleton = np.zeros_like(indices) - instscore = np.zeros_like(indices, dtype=float) - positions = np.zeros((chunksize, 2, len(node_map.keys())), dtype=float) - centroids = np.zeros((chunksize, 2), dtype=float) - trackscore = np.zeros_like(instscore) - pointscore = np.zeros((chunksize, len(node_map.keys())), dtype=float) - dflt_pointscore = [0.0 for n in range(len(node_map.keys()))] - - while data_written < len(instances): - print(".", end="") - start = data_written - end = ( - len(instances) - if start + chunksize >= len(instances) - else start + chunksize - ) - for i in range(start, end): - inst = instances[i] - index = i - start - indices[index] = inst.frame_idx - if inst.track is not None: - track[index] = track_map[inst.track] - else: - track[index] = -1 - - skeleton[index] = skeleton_map[inst.skeleton] - - all_nodes = set([n.name for n in inst.nodes]) - used_nodes = set([n.name for n in node_map.keys()]) - missing_nodes = all_nodes.difference(used_nodes) - for n, p in zip(inst.nodes, inst.points): - positions[index, :, node_map[n]] = np.array([p.x, p.y]) - for m in missing_nodes: - positions[index, :, node_map[m]] = np.array([np.nan, np.nan]) - - centroids[index, :] = inst.centroid - if hasattr(inst, "score"): - instscore[index] = inst.score - trackscore[index] = inst.tracking_score - pointscore[index, :] = inst.scores - else: - instscore[index] = 0.0 - trackscore[index] = 0.0 - pointscore[index, :] = dflt_pointscore - - frameid_array[start:end] = indices[: end - start] - track_array[start:end] = track[: end - start] - positions_array[start:end, :, :] = positions[: end - start, :, :] - centroid_array[start:end, :] = centroids[: end - start, :] - skeleton_array[start:end] = skeleton[: end - start] - pointscore_array[start:end] = pointscore[: end - start] - instancescore_array[start:end] = instscore[: end - start] - trackingscore_array[start:end] = trackscore[: end - start] - data_written += end - start - - def write_data(block, source: Labels, video: Video): - instances = [ - instance - for instance in source.instances(video=video) - if instance.frame_idx is not None - ] - instances = sorted(instances, key=lambda i: i.frame_idx) - nodes = node_map(source) - tracks = track_map(source) - skeletons = skeleton_map(source) - positions_shape = (len(instances), 2, len(nodes)) - - frameid_array = block.create_data_array( - "frame", - "nix.tracking.instance_frameidx", - label="frame index", - shape=(len(instances),), - dtype=nix.DataType.Int64, - ) - frameid_array.append_range_dimension_using_self() - - positions_array = create_positions_array( - "position", - "nix.tracking.instance_position", - block, - frameid_array, - [node.name for node in nodes.keys()], - positions_shape, - nix.DataType.Float, - ) - - track_array = create_feature_array( - "track", - "nix.tracking.instance_track", - block, - frameid_array, - shape=(len(instances),), - dtype=nix.DataType.Int64, - ) - - skeleton_array = create_feature_array( - "skeleton", - "nix.tracking.instance_skeleton", - block, - frameid_array, - (len(instances),), - nix.DataType.Int64, - ) - - point_score = create_feature_array( - "node score", - "nix.tracking.nodes_score", - block, - frameid_array, - (len(instances), len(nodes)), - nix.DataType.Float, - ) - point_score.append_set_dimension([node.name for node in nodes.keys()]) - - centroid_array = create_feature_array( - "centroid", - "nix.tracking.centroid_position", - block, - frameid_array, - (len(instances), 2), - nix.DataType.Float, - ) - - centroid_array.append_set_dimension(["x", "y"]) - instance_score = create_feature_array( - "instance score", - "nix.tracking.instance_score", - block, - frameid_array, - (len(instances),), - nix.DataType.Float, - ) - - tracking_score = create_feature_array( - "tracking score", - "nix.tracking.tack_score", - block, - frameid_array, - (len(instances),), - nix.DataType.Float, - ) - - # bind all together using a nix.MultiTag - mtag = block.create_multi_tag( - "tracking results", "nix.tracking.results", positions=frameid_array - ) - mtag.references.append(positions_array) - mtag.create_feature(track_array, nix.LinkType.Indexed) - mtag.create_feature(skeleton_array, nix.LinkType.Indexed) - mtag.create_feature(point_score, nix.LinkType.Indexed) - mtag.create_feature(instance_score, nix.LinkType.Indexed) - mtag.create_feature(tracking_score, nix.LinkType.Indexed) - mtag.create_feature(centroid_array, nix.LinkType.Indexed) - - sm = block.create_data_frame( - "skeleton map", - "nix.tracking.skeleton_map", - col_names=["name", "index"], - col_dtypes=[nix.DataType.String, nix.DataType.Int8], - ) - table_data = [] - for track in skeletons.keys(): - table_data.append((track.name, skeletons[track])) - sm.append_rows(table_data) - - nm = block.create_data_frame( - "node map", - "nix.tracking.node_map", - col_names=["name", "weight", "index", "skeleton"], - col_dtypes=[ - nix.DataType.String, - nix.DataType.Float, - nix.DataType.Int8, - nix.DataType.Int8, - ], - ) - table_data = [] - for node in nodes.keys(): - skel_index = -1 # if node is not assigned to a skeleton - for track in skeletons: - if node in track.nodes: - skel_index = skeletons[track] - break - table_data.append((node.name, node.weight, nodes[node], skel_index)) - nm.append_rows(table_data) - - tm = block.create_data_frame( - "track map", - "nix.tracking.track_map", - col_names=["name", "spawned_on", "index"], - col_dtypes=[nix.DataType.String, nix.DataType.Int64, nix.DataType.Int8], - ) - table_data = [("none", -1, -1)] # default for user-labeled instances - for track in tracks.keys(): - table_data.append((track.name, track.spawned_on, tracks[track])) - tm.append_rows(table_data) - - # Print shape info - data_dict = { - "instances": instances, - "frameid_array": frameid_array, - "positions_array": positions_array, - "track_array": track_array, - "skeleton_array": skeleton_array, - "point_score": point_score, - "instance_score": instance_score, - "tracking_score": tracking_score, - "centroid_array": centroid_array, - "tracks": tracks, - "nodes": nodes, - "skeletons": skeletons, - } - for key, val in data_dict.items(): - print(f"\t{key}:", end=" ") - if hasattr(val, "shape"): - print(f"{val.shape}") - else: - print(f"{len(val)}") - - # Print labels/video info - print( - f"\tlabels path: {source_path}\n" - f"\tvideo path: {video.backend.filename}\n" - f"\tvideo index = {source_object.videos.index(video)}" - ) - - print(f"Writing to NIX file...") - chunked_write( - instances, - frameid_array, - positions_array, - track_array, - skeleton_array, - point_score, - instance_score, - tracking_score, - centroid_array, - tracks, - nodes, - skeletons, - ) - print(f"done") - - print(f"\nExporting to NIX analysis file...") - if video is None: - video = source_object.videos[0] - print(f"No video specified, exporting the first one...") - - nix_file = None - try: - nix_file = create_file(filename, source_path, video) - write_data(nix_file.blocks[0], source_object, video) - print(f"Saved as {filename}") - except Exception as e: - print(f"\n\tWriting failed with following error:\n{e}!") - finally: - if nix_file is not None: - nix_file.close() + """Writes csv file for :py:class:`Labels` `source_object`. + + Args: + filename: The filename for the output file. + source_object: The :py:class:`Labels` from which to get data from. + source_path: Path for the labels object + all_frames: A boolean flag to determine whether to include all frames + or only those with tracking data in the export. + video: The :py:class:`Video` from which toget data from. If no `video` is + specified, then the first video in `source_object` videos list will be + used. If there are no :py:class:`Labeled Frame`s in the `video`, then no + analysis file will be written. + """ + from sleap.info.write_tracking_h5 import main as write_analysis + + write_analysis( + labels=source_object, + output_path=filename, + labels_path=source_path, + all_frames=True, + video=video, + csv=True, + ) diff --git a/sleap/io/format/sleap_analysis.py b/sleap/io/format/sleap_analysis.py index cc6eedb6f..41e2f2397 100644 --- a/sleap/io/format/sleap_analysis.py +++ b/sleap/io/format/sleap_analysis.py @@ -129,6 +129,7 @@ def write( filename: str, source_object: Labels, source_path: str = None, + all_frames: bool = False, video: Video = None, ): """Writes analysis file for :py:class:`Labels` `source_object`. From a9cb4960165f8e3a48b6b8240d1dc0e0e6735027 Mon Sep 17 00:00:00 2001 From: vtsai881 Date: Mon, 11 Dec 2023 14:03:21 -0500 Subject: [PATCH 5/6] Minor edits --- sleap/gui/app.py | 13 - sleap/io/format/csv.py | 4 +- sleap/io/format/nix.py | 485 +++++++++++++++++++++++++++--- sleap/io/format/sleap_analysis.py | 2 +- 4 files changed, 441 insertions(+), 63 deletions(-) diff --git a/sleap/gui/app.py b/sleap/gui/app.py index 9ee918004..47c3890c1 100644 --- a/sleap/gui/app.py +++ b/sleap/gui/app.py @@ -505,19 +505,6 @@ def add_submenu_choices(menu, title, options, key): "All Videos (only tracked frames)...", lambda: self.commands.exportCSVFile(all_videos=True), ) - - fileMenu.addSeparator() - add_menu_item( - export_csv_menu, - "export_csv_current_all_frames", - "Current Video (all frames)...", - lambda: self.commands.exportCSVFile(all_frames=True), - ) - add_menu_item( - export_csv_menu, - "export_csv_all_frames", - "All Videos (all frames)...", - lambda: self.commands.exportCSVFile(all_frames=True, all_videos=True), export_csv_menu.addSeparator() add_menu_item( diff --git a/sleap/io/format/csv.py b/sleap/io/format/csv.py index 915370f03..cf8e78145 100644 --- a/sleap/io/format/csv.py +++ b/sleap/io/format/csv.py @@ -45,7 +45,7 @@ def write( filename: str, source_object: Labels, source_path: str = None, - all_frames: bool: False, + all_frames: bool= False, video: Video = None, ): """Writes csv file for :py:class:`Labels` `source_object`. @@ -67,7 +67,7 @@ def write( labels=source_object, output_path=filename, labels_path=source_path, - all_frames=True, + all_frames=all_frames, video=video, csv=True, ) diff --git a/sleap/io/format/nix.py b/sleap/io/format/nix.py index 915370f03..73ab65b6e 100644 --- a/sleap/io/format/nix.py +++ b/sleap/io/format/nix.py @@ -1,73 +1,464 @@ -"""Adaptor for writing SLEAP analysis as csv.""" +import numpy as np +import nixio as nix -from sleap.io import format +from pathlib import Path +from typing import Dict, List, Optional, cast +from sleap.instance import Track -from sleap import Labels, Video -from typing import Optional, Callable, List, Text, Union +from sleap.io.format.adaptor import Adaptor, SleapObjectType +from sleap.io.format.filehandle import FileHandle +from sleap.io.dataset import Labels +from sleap.io.video import Video +from sleap.skeleton import Node, Skeleton -class CSVAdaptor(format.adaptor.Adaptor): - FORMAT_ID = 1.0 +class NixAdaptor(Adaptor): + """Adaptor class for export of tracking analysis results to the generic + [NIX](https://github.com/g-node/nix) format. + NIX defines a generic data model for scientific data that combines data and data + annotations within the same container. The written files are hdf5 files that can + be read with any hdf5 library but follow the entity definitions of the NIX data + model. For reading nix-files with python install the nixio low-level library + ```pip install nixio``` or use the high-level api + [nixtrack](https://github.com/bendalab/nixtrack). - # 1.0 initial implementation + So far the adaptor exports the tracked positions for each node of each instance, + the track and skeleton information along with the respective scores and the + centroid. Additionally, the video information is exported as metadata. + For more information on the mapping from sleap to nix see the docs on + [nixtrack](https://github.com/bendalab/nixtrack) (work in progress). + The adaptor uses a chunked writing approach which avoids numpy out of memory + exceptions when exporting large datasets. - @property - def handles(self): - return format.adaptor.SleapObjectType.labels + author: Jan Grewe (jan.grewe@g-node.org) + """ @property def default_ext(self): - return "csv" + return "nix" + + @property + def all_exts(self) -> List[str]: + return [self.default_ext] @property - def all_exts(self): - return ["csv", "xlsx"] + def handles(self): + return SleapObjectType.misc @property - def name(self): - return "CSV" + def name(self) -> str: + """Human-reading name of the file format""" + return ( + "NIX file flavoured for animal tracking data https://github.com/g-node/nix" + ) - def can_read_file(self, file: format.filehandle.FileHandle): + @classmethod + def can_read_file(cls, file: FileHandle) -> bool: + """Returns whether this adaptor can read this file.""" return False - def can_write_filename(self, filename: str): - return self.does_match_ext(filename) + def can_write_filename(self, filename: str) -> bool: + """Returns whether this adaptor can write format of this filename.""" + return filename.endswith(tuple(self.all_exts)) - def does_read(self) -> bool: + @classmethod + def does_read(cls) -> bool: + """Returns whether this adaptor supports reading.""" return False - def does_write(self) -> bool: + @classmethod + def does_write(cls) -> bool: + """Returns whether this adaptor supports writing.""" return True + @classmethod + def read(cls, file: FileHandle) -> object: + """Reads the file and returns the appropriate deserialized object.""" + raise NotImplementedError("NixAdaptor does not support reading.") + + @classmethod + def __check_video(cls, labels: Labels, video: Optional[Video]): + if (video is None) and (len(labels.videos) == 0): + raise ValueError( + f"There are no videos in this project. " + "No analysis file will be be written." + ) + if video is not None: + if video not in labels.videos: + raise ValueError( + f"Specified video {video} is not part of this project. " + "Skipping the analysis file for this video." + ) + if len(labels.get(video)) == 0: + raise ValueError( + f"No labeled frames in {video.backend.filename}. " + "Skipping the analysis file for this video." + ) + @classmethod def write( cls, filename: str, - source_object: Labels, - source_path: str = None, - all_frames: bool: False, - video: Video = None, + source_object: object, + source_path: Optional[str] = None, + all_frames: bool = False, + video: Optional[Video] = None, ): - """Writes csv file for :py:class:`Labels` `source_object`. - - Args: - filename: The filename for the output file. - source_object: The :py:class:`Labels` from which to get data from. - source_path: Path for the labels object - all_frames: A boolean flag to determine whether to include all frames - or only those with tracking data in the export. - video: The :py:class:`Video` from which toget data from. If no `video` is - specified, then the first video in `source_object` videos list will be - used. If there are no :py:class:`Labeled Frame`s in the `video`, then no - analysis file will be written. - """ - from sleap.info.write_tracking_h5 import main as write_analysis - - write_analysis( - labels=source_object, - output_path=filename, - labels_path=source_path, - all_frames=True, - video=video, - csv=True, - ) + """Writes the object to a file.""" + source_object = cast(Labels, source_object) + + cls.__check_video(source_object, video) + + def create_file(filename: str, project: Optional[str], video: Video): + print(f"Creating nix file...", end="\t") + nf = nix.File.open(filename, nix.FileMode.Overwrite) + try: + s = nf.create_section("TrackingAnalysis", "nix.tracking.metadata") + s["version"] = "0.1.0" + s["format"] = "nix.tracking" + s["definitions"] = "https://github.com/bendalab/nixtrack" + s["writer"] = str(cls)[8:-2] + if project is not None: + s["project"] = project + + name = Path(video.backend.filename).name + b = nf.create_block(name, "nix.tracking_results") + + # add video metadata, if exists + src = b.create_source(name, "nix.tracking.source.video") + sec = src.file.create_section( + name, "nix.tracking.source.video.metadata" + ) + sec["filename"] = video.backend.filename + sec["fps"] = getattr(video.backend, "fps", 0.0) + sec.props["fps"].unit = "Hz" + sec["frames"] = video.num_frames + sec["grayscale"] = getattr(video.backend, "grayscale", None) + sec["height"] = video.backend.height + sec["width"] = video.backend.width + src.metadata = sec + except Exception as e: + nf.close() + raise e + + print("done") + return nf + + def track_map(source: Labels) -> Dict[Track, int]: + track_map: Dict[Track, int] = {} + for track in source.tracks: + if track in track_map: + continue + track_map[track] = len(track_map) + return track_map + + def skeleton_map(source: Labels) -> Dict[Skeleton, int]: + skel_map: Dict[Skeleton, int] = {} + for skeleton in source.skeletons: + if skeleton in skel_map: + continue + skel_map[skeleton] = len(skel_map) + return skel_map + + def node_map(source: Labels) -> Dict[Node, int]: + n_map: Dict[Node, int] = {} + for node in source.nodes: + if node in n_map: + continue + n_map[node] = len(n_map) + return n_map + + def create_feature_array(name, type, block, frame_index_array, shape, dtype): + array = block.create_data_array(name, type, dtype=dtype, shape=shape) + rd = array.append_range_dimension() + rd.link_data_array(frame_index_array, [-1]) + return array + + def create_positions_array( + name, type, block, frame_index_array, node_names, shape, dtype + ): + array = block.create_data_array( + name, type, dtype=dtype, shape=shape, label="pixel" + ) + rd = array.append_range_dimension() + rd.link_data_array(frame_index_array, [-1]) + array.append_set_dimension(["x", "y"]) + array.append_set_dimension(node_names) + return array + + def chunked_write( + instances, + frameid_array, + positions_array, + track_array, + skeleton_array, + pointscore_array, + instancescore_array, + trackingscore_array, + centroid_array, + track_map, + node_map, + skeleton_map, + chunksize=10000, + ): + data_written = 0 + indices = np.zeros(chunksize, dtype=int) + track = np.zeros_like(indices) + skeleton = np.zeros_like(indices) + instscore = np.zeros_like(indices, dtype=float) + positions = np.zeros((chunksize, 2, len(node_map.keys())), dtype=float) + centroids = np.zeros((chunksize, 2), dtype=float) + trackscore = np.zeros_like(instscore) + pointscore = np.zeros((chunksize, len(node_map.keys())), dtype=float) + dflt_pointscore = [0.0 for n in range(len(node_map.keys()))] + + while data_written < len(instances): + print(".", end="") + start = data_written + end = ( + len(instances) + if start + chunksize >= len(instances) + else start + chunksize + ) + for i in range(start, end): + inst = instances[i] + index = i - start + indices[index] = inst.frame_idx + if inst.track is not None: + track[index] = track_map[inst.track] + else: + track[index] = -1 + + skeleton[index] = skeleton_map[inst.skeleton] + + all_nodes = set([n.name for n in inst.nodes]) + used_nodes = set([n.name for n in node_map.keys()]) + missing_nodes = all_nodes.difference(used_nodes) + for n, p in zip(inst.nodes, inst.points): + positions[index, :, node_map[n]] = np.array([p.x, p.y]) + for m in missing_nodes: + positions[index, :, node_map[m]] = np.array([np.nan, np.nan]) + + centroids[index, :] = inst.centroid + if hasattr(inst, "score"): + instscore[index] = inst.score + trackscore[index] = inst.tracking_score + pointscore[index, :] = inst.scores + else: + instscore[index] = 0.0 + trackscore[index] = 0.0 + pointscore[index, :] = dflt_pointscore + + frameid_array[start:end] = indices[: end - start] + track_array[start:end] = track[: end - start] + positions_array[start:end, :, :] = positions[: end - start, :, :] + centroid_array[start:end, :] = centroids[: end - start, :] + skeleton_array[start:end] = skeleton[: end - start] + pointscore_array[start:end] = pointscore[: end - start] + instancescore_array[start:end] = instscore[: end - start] + trackingscore_array[start:end] = trackscore[: end - start] + data_written += end - start + + def write_data(block, source: Labels, video: Video): + instances = [ + instance + for instance in source.instances(video=video) + if instance.frame_idx is not None + ] + instances = sorted(instances, key=lambda i: i.frame_idx) + nodes = node_map(source) + tracks = track_map(source) + skeletons = skeleton_map(source) + positions_shape = (len(instances), 2, len(nodes)) + + frameid_array = block.create_data_array( + "frame", + "nix.tracking.instance_frameidx", + label="frame index", + shape=(len(instances),), + dtype=nix.DataType.Int64, + ) + frameid_array.append_range_dimension_using_self() + + positions_array = create_positions_array( + "position", + "nix.tracking.instance_position", + block, + frameid_array, + [node.name for node in nodes.keys()], + positions_shape, + nix.DataType.Float, + ) + + track_array = create_feature_array( + "track", + "nix.tracking.instance_track", + block, + frameid_array, + shape=(len(instances),), + dtype=nix.DataType.Int64, + ) + + skeleton_array = create_feature_array( + "skeleton", + "nix.tracking.instance_skeleton", + block, + frameid_array, + (len(instances),), + nix.DataType.Int64, + ) + + point_score = create_feature_array( + "node score", + "nix.tracking.nodes_score", + block, + frameid_array, + (len(instances), len(nodes)), + nix.DataType.Float, + ) + point_score.append_set_dimension([node.name for node in nodes.keys()]) + + centroid_array = create_feature_array( + "centroid", + "nix.tracking.centroid_position", + block, + frameid_array, + (len(instances), 2), + nix.DataType.Float, + ) + + centroid_array.append_set_dimension(["x", "y"]) + instance_score = create_feature_array( + "instance score", + "nix.tracking.instance_score", + block, + frameid_array, + (len(instances),), + nix.DataType.Float, + ) + + tracking_score = create_feature_array( + "tracking score", + "nix.tracking.tack_score", + block, + frameid_array, + (len(instances),), + nix.DataType.Float, + ) + + # bind all together using a nix.MultiTag + mtag = block.create_multi_tag( + "tracking results", "nix.tracking.results", positions=frameid_array + ) + mtag.references.append(positions_array) + mtag.create_feature(track_array, nix.LinkType.Indexed) + mtag.create_feature(skeleton_array, nix.LinkType.Indexed) + mtag.create_feature(point_score, nix.LinkType.Indexed) + mtag.create_feature(instance_score, nix.LinkType.Indexed) + mtag.create_feature(tracking_score, nix.LinkType.Indexed) + mtag.create_feature(centroid_array, nix.LinkType.Indexed) + + sm = block.create_data_frame( + "skeleton map", + "nix.tracking.skeleton_map", + col_names=["name", "index"], + col_dtypes=[nix.DataType.String, nix.DataType.Int8], + ) + table_data = [] + for track in skeletons.keys(): + table_data.append((track.name, skeletons[track])) + sm.append_rows(table_data) + + nm = block.create_data_frame( + "node map", + "nix.tracking.node_map", + col_names=["name", "weight", "index", "skeleton"], + col_dtypes=[ + nix.DataType.String, + nix.DataType.Float, + nix.DataType.Int8, + nix.DataType.Int8, + ], + ) + table_data = [] + for node in nodes.keys(): + skel_index = -1 # if node is not assigned to a skeleton + for track in skeletons: + if node in track.nodes: + skel_index = skeletons[track] + break + table_data.append((node.name, node.weight, nodes[node], skel_index)) + nm.append_rows(table_data) + + tm = block.create_data_frame( + "track map", + "nix.tracking.track_map", + col_names=["name", "spawned_on", "index"], + col_dtypes=[nix.DataType.String, nix.DataType.Int64, nix.DataType.Int8], + ) + table_data = [("none", -1, -1)] # default for user-labeled instances + for track in tracks.keys(): + table_data.append((track.name, track.spawned_on, tracks[track])) + tm.append_rows(table_data) + + # Print shape info + data_dict = { + "instances": instances, + "frameid_array": frameid_array, + "positions_array": positions_array, + "track_array": track_array, + "skeleton_array": skeleton_array, + "point_score": point_score, + "instance_score": instance_score, + "tracking_score": tracking_score, + "centroid_array": centroid_array, + "tracks": tracks, + "nodes": nodes, + "skeletons": skeletons, + } + for key, val in data_dict.items(): + print(f"\t{key}:", end=" ") + if hasattr(val, "shape"): + print(f"{val.shape}") + else: + print(f"{len(val)}") + + # Print labels/video info + print( + f"\tlabels path: {source_path}\n" + f"\tvideo path: {video.backend.filename}\n" + f"\tvideo index = {source_object.videos.index(video)}" + ) + + print(f"Writing to NIX file...") + chunked_write( + instances, + frameid_array, + positions_array, + track_array, + skeleton_array, + point_score, + instance_score, + tracking_score, + centroid_array, + tracks, + nodes, + skeletons, + ) + print(f"done") + + print(f"\nExporting to NIX analysis file...") + if video is None: + video = source_object.videos[0] + print(f"No video specified, exporting the first one...") + + nix_file = None + try: + nix_file = create_file(filename, source_path, video) + write_data(nix_file.blocks[0], source_object, video) + print(f"Saved as {filename}") + except Exception as e: + print(f"\n\tWriting failed with following error:\n{e}!") + finally: + if nix_file is not None: + nix_file.close() \ No newline at end of file diff --git a/sleap/io/format/sleap_analysis.py b/sleap/io/format/sleap_analysis.py index 41e2f2397..bef5fd1b9 100644 --- a/sleap/io/format/sleap_analysis.py +++ b/sleap/io/format/sleap_analysis.py @@ -148,6 +148,6 @@ def write( labels=source_object, output_path=filename, labels_path=source_path, - all_frames=True, + all_frames=all_frames, video=video, ) From 6d93f171c802c99aa3f1a9695b1b927f656a983c Mon Sep 17 00:00:00 2001 From: vtsai881 Date: Thu, 14 Dec 2023 14:24:19 -0500 Subject: [PATCH 6/6] revisions revisions --- sleap/gui/app.py | 4 +-- sleap/gui/commands.py | 34 +++++++++---------- sleap/info/write_tracking_h5.py | 5 ++- sleap/io/format/csv.py | 9 +++-- tests/gui/test_commands.py | 60 +++++++++++++++++++++++++++++++++ 5 files changed, 85 insertions(+), 27 deletions(-) diff --git a/sleap/gui/app.py b/sleap/gui/app.py index 47c3890c1..fb1439368 100644 --- a/sleap/gui/app.py +++ b/sleap/gui/app.py @@ -511,13 +511,13 @@ def add_submenu_choices(menu, title, options, key): export_csv_menu, "export_csv_current_all_frames", "Current Video (all frames)...", - self.commands.exportCSVFile(all_frames=True), + lambda: self.commands.exportCSVFile(all_frames=True), ) add_menu_item( export_csv_menu, "export_csv_all_all_frames", "All Videos (all frames)...", - lambda: self.commands.exportCSVFile(all_frames=True), + lambda: self.commands.exportCSVFile(all_videos=True, all_frames=True), ) add_menu_item(fileMenu, "export_nwb", "Export NWB...", self.commands.exportNWB) diff --git a/sleap/gui/commands.py b/sleap/gui/commands.py index 3924627bb..925952316 100644 --- a/sleap/gui/commands.py +++ b/sleap/gui/commands.py @@ -1143,23 +1143,23 @@ def do_action(cls, context: CommandContext, params: dict): else: adaptor = SleapAnalysisAdaptor - if params['all_frames']: - adaptor.write( - filename=output_path, - all_frames=params.get('all_frames', False), - source_object=context.labels, - source_path=context.state["filename"], - video=video, - ) - else: - adaptor.write( - filename=output_path, - all_frames=False, - source_object=context.labels, - source_path=context.state["filename"], - video=video, - ) - + if 'all_frames' in params and params['all_frames']: + adaptor.write( + filename=output_path, + all_frames=True, + source_object=context.labels, + source_path=context.state["filename"], + video=video, + ) + else: + adaptor.write( + filename=output_path, + all_frames=False, + source_object=context.labels, + source_path=context.state["filename"], + video=video, + ) + @staticmethod def ask(context: CommandContext, params: dict) -> bool: def ask_for_filename(default_name: str, csv: bool) -> str: diff --git a/sleap/info/write_tracking_h5.py b/sleap/info/write_tracking_h5.py index a31f1f510..3e960bbe6 100644 --- a/sleap/info/write_tracking_h5.py +++ b/sleap/info/write_tracking_h5.py @@ -287,7 +287,7 @@ def write_occupancy_file( print(f"Saved as {output_path}") -def write_csv_file(output_path, data_dict): +def write_csv_file(output_path, data_dict, all_frames): """Write CSV file with data from given dictionary. @@ -348,7 +348,6 @@ def write_csv_file(output_path, data_dict): tracks.append(detection) tracks = pd.DataFrame(tracks) - all_frames = globals().get('all_frames', False) if all_frames: tracks = tracks.set_index('frame_idx') @@ -443,7 +442,7 @@ def main( ) if csv: - write_csv_file(output_path, data_dict) + write_csv_file(output_path, data_dict, all_frames=all_frames) else: write_occupancy_file(output_path, data_dict, transpose=True) diff --git a/sleap/io/format/csv.py b/sleap/io/format/csv.py index cf8e78145..17666f1e1 100644 --- a/sleap/io/format/csv.py +++ b/sleap/io/format/csv.py @@ -45,7 +45,7 @@ def write( filename: str, source_object: Labels, source_path: str = None, - all_frames: bool= False, + all_frames: bool = False, video: Video = None, ): """Writes csv file for :py:class:`Labels` `source_object`. @@ -54,15 +54,14 @@ def write( filename: The filename for the output file. source_object: The :py:class:`Labels` from which to get data from. source_path: Path for the labels object - all_frames: A boolean flag to determine whether to include all frames - or only those with tracking data in the export. - video: The :py:class:`Video` from which toget data from. If no `video` is + all_frames: A boolean flag to determine whether to include all frames or + only those with tracking data in the export. + video: The :py:class:`Video` from which to get data from. If no `video` is specified, then the first video in `source_object` videos list will be used. If there are no :py:class:`Labeled Frame`s in the `video`, then no analysis file will be written. """ from sleap.info.write_tracking_h5 import main as write_analysis - write_analysis( labels=source_object, output_path=filename, diff --git a/tests/gui/test_commands.py b/tests/gui/test_commands.py index 899b1f4a0..7c6b2fd36 100644 --- a/tests/gui/test_commands.py +++ b/tests/gui/test_commands.py @@ -279,6 +279,66 @@ def assert_videos_written(num_videos: int, labels_path: str = None): ExportAnalysisFile.do_action(context=context, params=params) assert_videos_written(num_videos=2, labels_path=context.state["filename"]) + # Test with all_videos True and all_frames True + params = {"all_videos": True, "all_frames": True, "csv": csv} + okay = ExportAnalysisFile_ask(context=context, params=params) + assert okay == True + ExportAnalysisFile.do_action(context=context, params=params) + assert_videos_written(num_videos=2, labels_path=context.state["filename"]) + + # Test with all_videos False and all_frames True + params = {"all_videos": False, "all_frames": True, "csv": csv} + okay = ExportAnalysisFile_ask(context=context, params=params) + assert okay == True + ExportAnalysisFile.do_action(context=context, params=params) + assert_videos_written(num_videos=1, labels_path=context.state["filename"]) + + # Test with all_videos False and all_frames False + params = {"all_videos": False, "all_frames": False, "csv": csv} + okay = ExportAnalysisFile_ask(context=context, params=params) + assert okay == True + ExportAnalysisFile.do_action(context=context, params=params) + assert_videos_written(num_videos=1, labels_path=context.state["filename"]) + + # Add labels path and test with all_videos True and all_frames True (single video) + context.state["filename"] = str(tmpdir.with_name("path.to.labels")) + params = {"all_videos": True, "all_frames": True, "csv": csv} + okay = ExportAnalysisFile_ask(context=context, params=params) + assert okay == True + ExportAnalysisFile.do_action(context=context, params=params) + assert_videos_written(num_videos=2, labels_path=context.state["filename"]) + + # Add a video (no labels) and test with all_videos True and all_frames True + labels.add_video(small_robot_mp4_vid) + + params = {"all_videos": True, "all_frames": True, "csv": csv} + okay = ExportAnalysisFile_ask(context=context, params=params) + assert okay == True + ExportAnalysisFile.do_action(context=context, params=params) + assert_videos_written(num_videos=2, labels_path=context.state["filename"]) + + # Test with videos with the same filename + (tmpdir / "session1").mkdir() + (tmpdir / "session2").mkdir() + shutil.copy( + centered_pair_predictions.video.backend.filename, + tmpdir / "session1" / "video.mp4", + ) + shutil.copy(small_robot_mp4_vid.backend.filename, tmpdir / "session2" / "video.mp4") + labels.videos[0].backend.filename = str(tmpdir / "session1" / "video.mp4") + labels.videos[1].backend.filename = str(tmpdir / "session2" / "video.mp4") + params = {"all_videos": True, "csv": csv} + okay = ExportAnalysisFile_ask(context=context, params=params) + assert okay == True + ExportAnalysisFile.do_action(context=context, params=params) + assert_videos_written(num_videos=2, labels_path=context.state["filename"]) + + # Remove all videos and test + all_videos = list(labels.videos) + for video in all_videos: + labels.remove_video(labels.videos[-1]) + + params = {"all_videos": True, "all_frames": True, "csv": csv} # Test with videos with the same filename (tmpdir / "session1").mkdir() (tmpdir / "session2").mkdir()