diff --git a/sleap/gui/app.py b/sleap/gui/app.py index 41d696f0c..fb1439368 100644 --- a/sleap/gui/app.py +++ b/sleap/gui/app.py @@ -496,16 +496,30 @@ def add_submenu_choices(menu, title, options, key): add_menu_item( export_csv_menu, "export_csv_current", - "Current Video...", + "Current Video (only tracked frames)...", self.commands.exportCSVFile, ) add_menu_item( export_csv_menu, "export_csv_all", - "All Videos...", + "All Videos (only tracked frames)...", lambda: self.commands.exportCSVFile(all_videos=True), ) + export_csv_menu.addSeparator() + add_menu_item( + export_csv_menu, + "export_csv_current_all_frames", + "Current Video (all frames)...", + lambda: self.commands.exportCSVFile(all_frames=True), + ) + add_menu_item( + export_csv_menu, + "export_csv_all_all_frames", + "All Videos (all frames)...", + lambda: self.commands.exportCSVFile(all_videos=True, all_frames=True), + ) + add_menu_item(fileMenu, "export_nwb", "Export NWB...", self.commands.exportNWB) fileMenu.addSeparator() diff --git a/sleap/gui/commands.py b/sleap/gui/commands.py index 90f40397e..925952316 100644 --- a/sleap/gui/commands.py +++ b/sleap/gui/commands.py @@ -326,13 +326,13 @@ def saveProjectAs(self): """Show gui to save project as a new file.""" self.execute(SaveProjectAs) - def exportAnalysisFile(self, all_videos: bool = False): + def exportAnalysisFile(self, all_videos: bool = False, all_frames: bool = False): """Shows gui for exporting analysis h5 file.""" - self.execute(ExportAnalysisFile, all_videos=all_videos, csv=False) + self.execute(ExportAnalysisFile, all_videos=all_videos, all_frames=all_frames, csv=False) - def exportCSVFile(self, all_videos: bool = False): + def exportCSVFile(self, all_videos: bool = False, all_frames: bool = False): """Shows gui for exporting analysis csv file.""" - self.execute(ExportAnalysisFile, all_videos=all_videos, csv=True) + self.execute(ExportAnalysisFile, all_videos=all_videos, all_frames=all_frames, csv=True) def exportNWB(self): """Show gui for exporting nwb file.""" @@ -1142,12 +1142,23 @@ def do_action(cls, context: CommandContext, params: dict): adaptor = NixAdaptor else: adaptor = SleapAnalysisAdaptor - adaptor.write( - filename=output_path, - source_object=context.labels, - source_path=context.state["filename"], - video=video, - ) + + if 'all_frames' in params and params['all_frames']: + adaptor.write( + filename=output_path, + all_frames=True, + source_object=context.labels, + source_path=context.state["filename"], + video=video, + ) + else: + adaptor.write( + filename=output_path, + all_frames=False, + source_object=context.labels, + source_path=context.state["filename"], + video=video, + ) @staticmethod def ask(context: CommandContext, params: dict) -> bool: diff --git a/sleap/info/write_tracking_h5.py b/sleap/info/write_tracking_h5.py index 2b714eeb5..3e960bbe6 100644 --- a/sleap/info/write_tracking_h5.py +++ b/sleap/info/write_tracking_h5.py @@ -287,7 +287,7 @@ def write_occupancy_file( print(f"Saved as {output_path}") -def write_csv_file(output_path, data_dict): +def write_csv_file(output_path, data_dict, all_frames): """Write CSV file with data from given dictionary. @@ -348,14 +348,21 @@ def write_csv_file(output_path, data_dict): tracks.append(detection) tracks = pd.DataFrame(tracks) - tracks.to_csv(output_path, index=False) + + if all_frames: + tracks = tracks.set_index('frame_idx') + tracks = tracks.reindex(range(0, len(data_dict['track_occupancy'])), fill_value=np.nan) + tracks = tracks.reset_index(drop=False) + tracks.to_csv(output_path, index=False) + else: + tracks.to_csv(output_path, index=False) def main( labels: Labels, output_path: str, labels_path: str = None, - all_frames: bool = True, + all_frames: bool = False, video: Video = None, csv: bool = False, ): @@ -435,7 +442,7 @@ def main( ) if csv: - write_csv_file(output_path, data_dict) + write_csv_file(output_path, data_dict, all_frames=all_frames) else: write_occupancy_file(output_path, data_dict, transpose=True) diff --git a/sleap/io/format/csv.py b/sleap/io/format/csv.py index 4640ee117..17666f1e1 100644 --- a/sleap/io/format/csv.py +++ b/sleap/io/format/csv.py @@ -45,6 +45,7 @@ def write( filename: str, source_object: Labels, source_path: str = None, + all_frames: bool = False, video: Video = None, ): """Writes csv file for :py:class:`Labels` `source_object`. @@ -53,18 +54,19 @@ def write( filename: The filename for the output file. source_object: The :py:class:`Labels` from which to get data from. source_path: Path for the labels object - video: The :py:class:`Video` from which toget data from. If no `video` is + all_frames: A boolean flag to determine whether to include all frames or + only those with tracking data in the export. + video: The :py:class:`Video` from which to get data from. If no `video` is specified, then the first video in `source_object` videos list will be used. If there are no :py:class:`Labeled Frame`s in the `video`, then no analysis file will be written. """ from sleap.info.write_tracking_h5 import main as write_analysis - write_analysis( labels=source_object, output_path=filename, labels_path=source_path, - all_frames=True, + all_frames=all_frames, video=video, csv=True, ) diff --git a/sleap/io/format/nix.py b/sleap/io/format/nix.py index 4c39ec8b6..73ab65b6e 100644 --- a/sleap/io/format/nix.py +++ b/sleap/io/format/nix.py @@ -101,6 +101,7 @@ def write( filename: str, source_object: object, source_path: Optional[str] = None, + all_frames: bool = False, video: Optional[Video] = None, ): """Writes the object to a file.""" @@ -460,4 +461,4 @@ def write_data(block, source: Labels, video: Video): print(f"\n\tWriting failed with following error:\n{e}!") finally: if nix_file is not None: - nix_file.close() + nix_file.close() \ No newline at end of file diff --git a/sleap/io/format/sleap_analysis.py b/sleap/io/format/sleap_analysis.py index cc6eedb6f..bef5fd1b9 100644 --- a/sleap/io/format/sleap_analysis.py +++ b/sleap/io/format/sleap_analysis.py @@ -129,6 +129,7 @@ def write( filename: str, source_object: Labels, source_path: str = None, + all_frames: bool = False, video: Video = None, ): """Writes analysis file for :py:class:`Labels` `source_object`. @@ -147,6 +148,6 @@ def write( labels=source_object, output_path=filename, labels_path=source_path, - all_frames=True, + all_frames=all_frames, video=video, ) diff --git a/tests/gui/test_commands.py b/tests/gui/test_commands.py index 899b1f4a0..7c6b2fd36 100644 --- a/tests/gui/test_commands.py +++ b/tests/gui/test_commands.py @@ -279,6 +279,66 @@ def assert_videos_written(num_videos: int, labels_path: str = None): ExportAnalysisFile.do_action(context=context, params=params) assert_videos_written(num_videos=2, labels_path=context.state["filename"]) + # Test with all_videos True and all_frames True + params = {"all_videos": True, "all_frames": True, "csv": csv} + okay = ExportAnalysisFile_ask(context=context, params=params) + assert okay == True + ExportAnalysisFile.do_action(context=context, params=params) + assert_videos_written(num_videos=2, labels_path=context.state["filename"]) + + # Test with all_videos False and all_frames True + params = {"all_videos": False, "all_frames": True, "csv": csv} + okay = ExportAnalysisFile_ask(context=context, params=params) + assert okay == True + ExportAnalysisFile.do_action(context=context, params=params) + assert_videos_written(num_videos=1, labels_path=context.state["filename"]) + + # Test with all_videos False and all_frames False + params = {"all_videos": False, "all_frames": False, "csv": csv} + okay = ExportAnalysisFile_ask(context=context, params=params) + assert okay == True + ExportAnalysisFile.do_action(context=context, params=params) + assert_videos_written(num_videos=1, labels_path=context.state["filename"]) + + # Add labels path and test with all_videos True and all_frames True (single video) + context.state["filename"] = str(tmpdir.with_name("path.to.labels")) + params = {"all_videos": True, "all_frames": True, "csv": csv} + okay = ExportAnalysisFile_ask(context=context, params=params) + assert okay == True + ExportAnalysisFile.do_action(context=context, params=params) + assert_videos_written(num_videos=2, labels_path=context.state["filename"]) + + # Add a video (no labels) and test with all_videos True and all_frames True + labels.add_video(small_robot_mp4_vid) + + params = {"all_videos": True, "all_frames": True, "csv": csv} + okay = ExportAnalysisFile_ask(context=context, params=params) + assert okay == True + ExportAnalysisFile.do_action(context=context, params=params) + assert_videos_written(num_videos=2, labels_path=context.state["filename"]) + + # Test with videos with the same filename + (tmpdir / "session1").mkdir() + (tmpdir / "session2").mkdir() + shutil.copy( + centered_pair_predictions.video.backend.filename, + tmpdir / "session1" / "video.mp4", + ) + shutil.copy(small_robot_mp4_vid.backend.filename, tmpdir / "session2" / "video.mp4") + labels.videos[0].backend.filename = str(tmpdir / "session1" / "video.mp4") + labels.videos[1].backend.filename = str(tmpdir / "session2" / "video.mp4") + params = {"all_videos": True, "csv": csv} + okay = ExportAnalysisFile_ask(context=context, params=params) + assert okay == True + ExportAnalysisFile.do_action(context=context, params=params) + assert_videos_written(num_videos=2, labels_path=context.state["filename"]) + + # Remove all videos and test + all_videos = list(labels.videos) + for video in all_videos: + labels.remove_video(labels.videos[-1]) + + params = {"all_videos": True, "all_frames": True, "csv": csv} # Test with videos with the same filename (tmpdir / "session1").mkdir() (tmpdir / "session2").mkdir()