diff --git a/tests/nn/test_tracking_integration.py b/tests/nn/test_tracking_integration.py index 625302fd0..98e4e3d35 100644 --- a/tests/nn/test_tracking_integration.py +++ b/tests/nn/test_tracking_integration.py @@ -2,13 +2,35 @@ import operator import os import time - +import pytest import sleap from sleap.nn.inference import main as inference_cli import sleap.nn.tracker.components from sleap.io.dataset import Labels, LabeledFrame +@pytest.mark.parametrize( + "tracker_name", ["simple", "simplemaxtracks", "flow", "flowmaxtracks"] +) +def test_kalman_tracker(tmpdir, centered_pair_predictions_slp_path, tracker_name): + cli = ( + f"--tracking.tracker {tracker_name} " + "--tracking.max_tracking 1 --tracking.max_tracks 2 " + "--frames 200-300 " + "--tracking.similarity instance " + "--tracking.match hungarian " + "--tracking.track_window 5 " + "--tracking.kf_init_frame_count 10 " + "--tracking.kf_node_indices 0,1 " + f"-o {tmpdir}/{tracker_name}.slp " + f"{centered_pair_predictions_slp_path}" + ) + inference_cli(cli.split(" ")) + + labels = sleap.load_file(f"{tmpdir}/{tracker_name}.slp") + assert len(labels.tracks) == 2 + + def test_simple_tracker(tmpdir, centered_pair_predictions_slp_path): cli = ( "--tracking.tracker simple "