Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add option to analysis export for exporting predictions for all frames including those with no predictions #1624

Open
wants to merge 7 commits into
base: develop
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
31 changes: 29 additions & 2 deletions sleap/gui/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -496,15 +496,42 @@ def add_submenu_choices(menu, title, options, key):
add_menu_item(
export_csv_menu,
"export_csv_current",
"Current Video...",
"Current Video (only tracked frames)...",
self.commands.exportCSVFile,
)
add_menu_item(
export_csv_menu,
"export_csv_all",
"All Videos...",
"All Videos (only tracked frames)...",
lambda: self.commands.exportCSVFile(all_videos=True),
)

fileMenu.addSeparator()
add_menu_item(
export_csv_menu,
"export_csv_current_all_frames",
"Current Video (all frames)...",
lambda: self.commands.exportCSVFile(all_frames=True),
)
add_menu_item(
export_csv_menu,
"export_csv_all_frames",
"All Videos (all frames)...",
lambda: self.commands.exportCSVFile(all_frames=True, all_videos=True),

export_csv_menu.addSeparator()
add_menu_item(
export_csv_menu,
"export_csv_current_all_frames",
"Current Video (all frames)...",
self.commands.exportCSVFile(all_frames=True),
)
add_menu_item(
export_csv_menu,
"export_csv_all_all_frames",
"All Videos (all frames)...",
lambda: self.commands.exportCSVFile(all_frames=True),
)
vtsai881 marked this conversation as resolved.
Show resolved Hide resolved

add_menu_item(fileMenu, "export_nwb", "Export NWB...", self.commands.exportNWB)

Expand Down
23 changes: 17 additions & 6 deletions sleap/gui/commands.py
Original file line number Diff line number Diff line change
Expand Up @@ -326,13 +326,13 @@ def saveProjectAs(self):
"""Show gui to save project as a new file."""
self.execute(SaveProjectAs)

def exportAnalysisFile(self, all_videos: bool = False):
def exportAnalysisFile(self, all_videos: bool = False, all_frames: bool = False):
"""Shows gui for exporting analysis h5 file."""
self.execute(ExportAnalysisFile, all_videos=all_videos, csv=False)
self.execute(ExportAnalysisFile, all_videos=all_videos, all_frames=all_frames, csv=False)

def exportCSVFile(self, all_videos: bool = False):
def exportCSVFile(self, all_videos: bool = False, all_frames: bool = False):
"""Shows gui for exporting analysis csv file."""
self.execute(ExportAnalysisFile, all_videos=all_videos, csv=True)
self.execute(ExportAnalysisFile, all_videos=all_videos, all_frames=all_frames, csv=True)

def exportNWB(self):
"""Show gui for exporting nwb file."""
Expand Down Expand Up @@ -1142,13 +1142,24 @@ def do_action(cls, context: CommandContext, params: dict):
adaptor = NixAdaptor
else:
adaptor = SleapAnalysisAdaptor

if params['all_frames']:
adaptor.write(
filename=output_path,
all_frames=params.get('all_frames', False),
source_object=context.labels,
source_path=context.state["filename"],
video=video,
)

)
else:
adaptor.write(
filename=output_path,
all_frames=False,
source_object=context.labels,
source_path=context.state["filename"],
video=video,
)

@staticmethod
def ask(context: CommandContext, params: dict) -> bool:
def ask_for_filename(default_name: str, csv: bool) -> str:
Expand Down
12 changes: 10 additions & 2 deletions sleap/info/write_tracking_h5.py
Original file line number Diff line number Diff line change
Expand Up @@ -348,14 +348,22 @@ def write_csv_file(output_path, data_dict):
tracks.append(detection)

tracks = pd.DataFrame(tracks)
tracks.to_csv(output_path, index=False)
all_frames = globals().get('all_frames', False)
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The use of globals().get('all_frames', False) to retrieve the all_frames variable is not a standard practice and can lead to unexpected behavior. It would be safer and more maintainable to pass all_frames as a parameter to the write_csv_file function.


if all_frames:
tracks = tracks.set_index('frame_idx')
tracks = tracks.reindex(range(0, len(data_dict['track_occupancy'])), fill_value=np.nan)
tracks = tracks.reset_index(drop=False)
tracks.to_csv(output_path, index=False)
else:
tracks.to_csv(output_path, index=False)


def main(
labels: Labels,
output_path: str,
labels_path: str = None,
all_frames: bool = True,
all_frames: bool = False,
video: Video = None,
csv: bool = False,
):
Expand Down
3 changes: 3 additions & 0 deletions sleap/io/format/csv.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@ def write(
filename: str,
source_object: Labels,
source_path: str = None,
all_frames: bool: False,
vtsai881 marked this conversation as resolved.
Show resolved Hide resolved
video: Video = None,
):
"""Writes csv file for :py:class:`Labels` `source_object`.
Expand All @@ -53,6 +54,8 @@ def write(
filename: The filename for the output file.
source_object: The :py:class:`Labels` from which to get data from.
source_path: Path for the labels object
all_frames: A boolean flag to determine whether to include all frames
or only those with tracking data in the export.
video: The :py:class:`Video` from which toget data from. If no `video` is
specified, then the first video in `source_object` videos list will be
used. If there are no :py:class:`Labeled Frame`s in the `video`, then no
vtsai881 marked this conversation as resolved.
Show resolved Hide resolved
Expand Down
Loading