Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add option to export to CSV via sleap-convert and API #1730

Merged
merged 15 commits into from
Apr 9, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion docs/guides/cli.md
Original file line number Diff line number Diff line change
Expand Up @@ -322,7 +322,8 @@ optional arguments:
analysis file for the latter video is given a default name.
--format FORMAT Output format. Default ('slp') is SLEAP dataset;
'analysis' results in analysis.h5 file; 'analysis.nix' results
in an analysis nix file; 'h5' or 'json' results in SLEAP dataset
in an analysis nix file; 'analysis.csv' results
in an analysis csv file; 'h5' or 'json' results in SLEAP dataset
with specified file format.
--video VIDEO Path to video (if needed for conversion).
```
Expand Down
22 changes: 21 additions & 1 deletion sleap/io/convert.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,7 @@ def create_parser():
help="Output format. Default ('slp') is SLEAP dataset; "
"'analysis' results in analysis.h5 file; "
"'analysis.nix' results in an analysis nix file;"
"'analysis.csv' results in an analysis csv file;"
"'h5' or 'json' results in SLEAP dataset "
"with specified file format.",
)
Expand Down Expand Up @@ -135,7 +136,12 @@ def main(args: list = None):
outnames = [path for path in args.outputs]
if len(outnames) < len(vids):
# if there are less outnames provided than videos to convert...
out_suffix = "nix" if "nix" in args.format else "h5"
if "nix" in args.format:
out_suffix = "nix"
elif "csv" in args.format:
out_suffix = "csv"
else:
out_suffix = "h5"
Comment on lines +139 to +144
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Refactor the logic for determining the output file suffix into a separate function to improve code readability and maintainability.

+ def determine_output_suffix(format_arg: str) -> str:
+     if "nix" in format_arg:
+         return "nix"
+     elif "csv" in format_arg:
+         return "csv"
+     else:
+         return "h5"
- if "nix" in args.format:
-     out_suffix = "nix"
- elif "csv" in args.format:
-     out_suffix = "csv"
- else:
-     out_suffix = "h5"
+ out_suffix = determine_output_suffix(args.format)

Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation.

Suggested change
if "nix" in args.format:
out_suffix = "nix"
elif "csv" in args.format:
out_suffix = "csv"
else:
out_suffix = "h5"
def determine_output_suffix(format_arg: str) -> str:
if "nix" in format_arg:
return "nix"
elif "csv" in format_arg:
return "csv"
else:
return "h5"
out_suffix = determine_output_suffix(args.format)

fn = args.input_path
fn = re.sub("(\.json(\.zip)?|\.h5|\.slp)$", "", fn)
fn = PurePath(fn)
Expand All @@ -158,6 +164,20 @@ def main(args: list = None):
NixAdaptor.write(outname, labels, args.input_path, video)
except ValueError as e:
print(e.args[0])

elif "csv" in args.format:
from sleap.info.write_tracking_h5 import main as write_analysis

for video, output_path in zip(vids, outnames):
write_analysis(
labels,
output_path=output_path,
labels_path=args.input_path,
all_frames=True,
video=video,
csv=True,
)

else:
from sleap.info.write_tracking_h5 import main as write_analysis

Expand Down
13 changes: 13 additions & 0 deletions sleap/io/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -2055,6 +2055,19 @@ def export(self, filename: str):

SleapAnalysisAdaptor.write(filename, self)

def export_csv(self, filename: str):
"""Export labels to CSV format.

Args:
filename: Output path for the CSV format file.

Notes:
This will write the contents of the labels out as a CSV file.
"""
from sleap.io.format.csv import CSVAdaptor

CSVAdaptor.write(filename, self)
eberrigan marked this conversation as resolved.
Show resolved Hide resolved

def export_nwb(
self,
filename: str,
Expand Down
4 changes: 2 additions & 2 deletions tests/io/test_convert.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
import pytest


@pytest.mark.parametrize("format", ["analysis", "analysis.nix"])
@pytest.mark.parametrize("format", ["analysis", "analysis.nix", "analysis.csv"])
def test_analysis_format(
min_labels_slp: Labels,
min_labels_slp_path: Labels,
Expand All @@ -27,7 +27,7 @@ def generate_filenames(paths, format="analysis"):
labels_path = str(slp_path)
fn = re.sub("(\\.json(\\.zip)?|\\.h5|\\.slp)$", "", labels_path)
fn = PurePath(fn)
out_suffix = "nix" if "nix" in format else "h5"
out_suffix = "nix" if "nix" in format else "csv" if "csv" in format else "h5"
default_names = [
default_analysis_filename(
labels=labels,
Expand Down
44 changes: 44 additions & 0 deletions tests/io/test_dataset.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,11 @@
import os
import pandas as pd
import pytest
import numpy as np
from pathlib import Path, PurePath

import sleap
from sleap.info.write_tracking_h5 import get_nodes_as_np_strings
from sleap.skeleton import Skeleton
from sleap.instance import Instance, Point, LabeledFrame, PredictedInstance, Track
from sleap.io.video import Video, MediaVideo
Expand Down Expand Up @@ -1559,3 +1561,45 @@ def test_export_nwb(centered_pair_predictions: Labels, tmpdir):
# Read from NWB file
read_labels = NDXPoseAdaptor.read(NDXPoseAdaptor, filehandle.FileHandle(filename))
assert_read_labels_match(centered_pair_predictions, read_labels)


@pytest.mark.parametrize(
"labels_fixture_name",
[
"centered_pair_labels",
"centered_pair_predictions",
"min_labels",
"min_labels_slp",
"min_labels_robot",
],
)
def test_export_csv(labels_fixture_name, tmpdir, request):
# Retrieve Labels fixture by name
labels_fixture = request.getfixturevalue(labels_fixture_name)

# Generate the filename for the CSV file
csv_filename = Path(tmpdir) / (labels_fixture_name + "_export.csv")

# Export to CSV file
labels_fixture.export_csv(str(csv_filename))

# Assert that the CSV file was created
assert csv_filename.is_file(), f"CSV file '{csv_filename}' was not created"


def test_exported_csv(tmpdir, min_labels_slp, minimal_instance_predictions_csv_path):
# Construct the filename for the CSV file
filename_csv = Path(tmpdir) / "minimal_instance_predictions_export.csv"
labels = min_labels_slp
# Export to CSV file
labels.export_csv(filename_csv)
# Read the CSV file
labels_csv = pd.read_csv(filename_csv)

# Read the csv file fixture
csv_predictions = pd.read_csv(minimal_instance_predictions_csv_path)

assert labels_csv.equals(csv_predictions)

# check number of cols
assert len(labels_csv.columns) - 3 == len(get_nodes_as_np_strings(labels)) * 3
Loading