Skip to content

Commit

Permalink
remove max_tracking argument, just use max_tracks
Browse files Browse the repository at this point in the history
  • Loading branch information
getzze committed Sep 2, 2024
1 parent a22a50e commit c656bb8
Show file tree
Hide file tree
Showing 8 changed files with 18 additions and 50 deletions.
10 changes: 4 additions & 6 deletions docs/guides/cli.md
Original file line number Diff line number Diff line change
Expand Up @@ -124,7 +124,7 @@ usage: sleap-track [-h] [-m MODELS] [--frames FRAMES] [--only-labeled-frames] [-
[--verbosity {none,rich,json}] [--video.dataset VIDEO.DATASET] [--video.input_format VIDEO.INPUT_FORMAT]
[--video.index VIDEO.INDEX] [--cpu | --first-gpu | --last-gpu | --gpu GPU] [--max_edge_length_ratio MAX_EDGE_LENGTH_RATIO]
[--dist_penalty_weight DIST_PENALTY_WEIGHT] [--batch_size BATCH_SIZE] [--open-in-gui] [--peak_threshold PEAK_THRESHOLD]
[-n MAX_INSTANCES] [--tracking.tracker TRACKING.TRACKER] [--tracking.max_tracking TRACKING.MAX_TRACKING]
[-n MAX_INSTANCES] [--tracking.tracker TRACKING.TRACKER]
[--tracking.max_tracks TRACKING.MAX_TRACKS] [--tracking.target_instance_count TRACKING.TARGET_INSTANCE_COUNT]
[--tracking.pre_cull_to_target TRACKING.PRE_CULL_TO_TARGET] [--tracking.pre_cull_iou_threshold TRACKING.PRE_CULL_IOU_THRESHOLD]
[--tracking.post_connect_single_breaks TRACKING.POST_CONNECT_SINGLE_BREAKS]
Expand Down Expand Up @@ -184,10 +184,8 @@ optional arguments:
Limit maximum number of instances in multi-instance models. Not available for ID models. Defaults to None.
--tracking.tracker TRACKING.TRACKER
Options: simple, flow, simplemaxtracks, flowmaxtracks, None (default: None)
--tracking.max_tracking TRACKING.MAX_TRACKING
If true then the tracker will cap the max number of tracks. (default: False)
--tracking.max_tracks TRACKING.MAX_TRACKS
Maximum number of tracks to be tracked by the tracker. (default: None)
Maximum number of tracks to be tracked by the tracker. No limit if None or -1. (default: None)
--tracking.target_instance_count TRACKING.TARGET_INSTANCE_COUNT
Target number of instances to track per frame. (default: 0)
--tracking.pre_cull_to_target TRACKING.PRE_CULL_TO_TARGET
Expand Down Expand Up @@ -261,13 +259,13 @@ sleap-track -m "models/my_model" --tracking.tracker simple -o "output_prediction
**5. Inference with max tracks limit:**

```none
sleap-track -m "models/my_model" --tracking.tracker simple --tracking.max_tracking 1 --tracking.max_tracks 4 -o "output_predictions.slp" "input_video.mp4"
sleap-track -m "models/my_model" --tracking.tracker simple --tracking.max_tracks 4 -o "output_predictions.slp" "input_video.mp4"
```

**6. Re-tracking without pose inference:**

```none
sleap-track --tracking.tracker simple --tracking.max_tracking 1 --tracking.max_tracks 4 -o "retracked.slp" "input_predictions.slp"
sleap-track --tracking.tracker simple --tracking.max_tracks 4 -o "retracked.slp" "input_predictions.slp"
```

**7. Select GPU for pose inference:**
Expand Down
4 changes: 0 additions & 4 deletions sleap/config/pipeline_form.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -521,10 +521,6 @@ inference:
text: '<b>Tracking</b>:<br />
This tracker assigns track identities by matching instances from prior
frames to instances on subsequent frames.'
# - name: tracking.max_tracking
# label: Limit max number of tracks
# type: bool
# default: false
- name: tracking.max_tracks
label: Max number of tracks
type: optional_int
Expand Down
6 changes: 0 additions & 6 deletions sleap/gui/learning/runners.py
Original file line number Diff line number Diff line change
Expand Up @@ -244,11 +244,6 @@ def make_predict_cli_call(
if self.inference_params["tracking.tracker"] in compat_trackers:
tname = self.inference_params["tracking.tracker"][: -len("maxtracks")]
self.inference_params["tracking.tracker"] = tname
self.inference_params["tracking.max_tracking"] = True

# Setting max_tracks to a value means we want to use the max_tracking mode.
if self.inference_params.get("tracking.max_tracks") is not None:
self.inference_params["tracking.max_tracking"] = True

# --tracking.kf_init_frame_count enables the kalman filter tracking
# so if not set, then remove other (unused) args
Expand All @@ -259,7 +254,6 @@ def make_predict_cli_call(
bool_items_as_ints = (
"tracking.pre_cull_to_target",
"tracking.pre_cull_merge_instances",
"tracking.max_tracking",
"tracking.post_connect_single_breaks",
"tracking.save_shifted_instances",
"tracking.oks_score_weighting",
Expand Down
4 changes: 0 additions & 4 deletions sleap/nn/inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -4914,14 +4914,10 @@ def unpack_sleap_model(model_path):
)
predictor.verbosity = progress_reporting
if tracker is not None:
use_max_tracker = (
tracker_max_instances is not None and tracker_max_instances > 0
)
predictor.tracker = Tracker.make_tracker_by_name(
tracker=tracker,
track_window=tracker_window,
post_connect_single_breaks=True,
max_tracking=use_max_tracker,
max_tracks=tracker_max_instances,
# clean_instance_count=tracker_max_instances,
)
Expand Down
16 changes: 6 additions & 10 deletions sleap/nn/tracking.py
Original file line number Diff line number Diff line change
Expand Up @@ -945,7 +945,6 @@ def make_tracker_by_name(
kf_node_indices: Optional[list] = None,
# Max tracking options
max_tracks: Optional[int] = None,
max_tracking: bool = False,
prefer_reassigning_track: bool = False,
allow_reassigning_track: bool = False,
# Object keypoint similarity options
Expand All @@ -956,9 +955,8 @@ def make_tracker_by_name(
report_rate: float = 2.0,
**kwargs,
) -> BaseTracker:
# Parse max_tracking arguments, only True if max_tracks is not None and > 0
max_tracking = max_tracking and max_tracks is not None and max_tracks > 0
max_tracks = max_tracks if max_tracking else -1
# Parse max_tracks, set to -1 if None
max_tracks = max_tracks if max_tracks is not None and max_tracks >= 0 else -1

if tracker.lower() == "none":
candidate_maker = None
Expand Down Expand Up @@ -1056,14 +1054,12 @@ def get_by_name_factory_options(cls):
]
options.append(option)

option = dict(name="max_tracking", default=False)
option["type"] = bool
option["help"] = "If true then the tracker will cap the max number of tracks."
options.append(option)

option = dict(name="max_tracks", default=None)
option["type"] = int
option["help"] = "Maximum number of tracks to be tracked by the tracker."
option["help"] = (
"Maximum number of tracks to be tracked by the tracker. "
"No maximum if set to -1."
)
options.append(option)

option = dict(name="target_instance_count", default=0)
Expand Down
3 changes: 0 additions & 3 deletions tests/nn/test_inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -1377,7 +1377,6 @@ def test_retracking(
if tracker_method == "flow":
cmd += " --tracking.save_shifted_instances 1"
elif tracker_method == "simplemaxtracks" or tracker_method == "flowmaxtracks":
cmd += " --tracking.max_tracking 1"
cmd += " --tracking.max_tracks 2"
if output_path == "not_default":
output_path = Path(tmpdir, "tracked_slp.slp")
Expand Down Expand Up @@ -1790,15 +1789,13 @@ def test_max_tracks_matching_queue(
):
"""Test flow max tracks instance generation."""
labels: Labels = centered_pair_predictions
max_tracking = True
track_window = 5

# Setup flow max tracker
tracker: Tracker = Tracker.make_tracker_by_name(
tracker=trackername,
track_window=track_window,
save_shifted_instances=True,
max_tracking=max_tracking,
max_tracks=max_tracks,
)

Expand Down
18 changes: 6 additions & 12 deletions tests/nn/test_tracker_components.py
Original file line number Diff line number Diff line change
Expand Up @@ -245,7 +245,7 @@ def make_inst(x, y):
return insts


def test_max_tracking_large_gap_single_track():
def test_max_tracks_large_gap_single_track():
# Track 2 instances with gap > window size
preds = make_insts(
[
Expand Down Expand Up @@ -282,8 +282,7 @@ def test_max_tracking_large_gap_single_track():
tracker="simple",
match="hungarian",
track_window=2,
# max_tracks=2,
max_tracking=False,
max_tracks=-1,
)

tracked = []
Expand All @@ -299,7 +298,6 @@ def test_max_tracking_large_gap_single_track():
match="hungarian",
track_window=2,
max_tracks=2,
max_tracking=True,
)

tracked = []
Expand All @@ -311,7 +309,7 @@ def test_max_tracking_large_gap_single_track():
assert len(all_tracks) == 2


def test_max_tracking_small_gap_on_both_tracks():
def test_max_tracks_small_gap_on_both_tracks():
# Test 2 instances with both tracks with gap > window size
preds = make_insts(
[
Expand Down Expand Up @@ -344,8 +342,7 @@ def test_max_tracking_small_gap_on_both_tracks():
tracker="simple",
match="hungarian",
track_window=2,
# max_tracks=2,
max_tracking=False,
max_tracks=-1,
)

tracked = []
Expand All @@ -361,7 +358,6 @@ def test_max_tracking_small_gap_on_both_tracks():
match="hungarian",
track_window=2,
max_tracks=2,
max_tracking=True,
)

tracked = []
Expand All @@ -373,7 +369,7 @@ def test_max_tracking_small_gap_on_both_tracks():
assert len(all_tracks) == 2


def test_max_tracking_extra_detections():
def test_max_tracks_extra_detections():
# Test having more than 2 detected instances in a frame
preds = make_insts(
[
Expand Down Expand Up @@ -411,8 +407,7 @@ def test_max_tracking_extra_detections():
tracker="simple",
match="hungarian",
track_window=2,
# max_tracks=2,
max_tracking=False,
max_tracks=-1,
)

tracked = []
Expand All @@ -428,7 +423,6 @@ def test_max_tracking_extra_detections():
match="hungarian",
track_window=2,
max_tracks=2,
max_tracking=True,
)

tracked = []
Expand Down
7 changes: 2 additions & 5 deletions tests/nn/test_tracking_integration.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ def test_simple_tracker(tmpdir, centered_pair_predictions_slp_path):
def test_simple_max_tracks(tmpdir, centered_pair_predictions_slp_path):
cli = (
"--tracking.tracker simple "
"--tracking.max_tracking 1 --tracking.max_tracks 2 "
"--tracking.max_tracks 2 "
"--frames 200-300 "
f"-o {tmpdir}/simplemaxtracks.slp "
f"{centered_pair_predictions_slp_path}"
Expand Down Expand Up @@ -107,13 +107,12 @@ def main(f, dir):
)

def make_tracker(
tracker_name, matcher_name, sim_name, max_tracks, max_tracking=False, scale=0
tracker_name, matcher_name, sim_name, max_tracks, scale=0
):
tracker = trackers[tracker_name](
matching_function=matchers[matcher_name],
similarity_function=similarities[sim_name],
max_tracks=max_tracks,
max_tracking=max_tracking,
)
if scale:
tracker.candidate_maker.img_scale = scale
Expand Down Expand Up @@ -142,7 +141,6 @@ def make_tracker_and_filename(*args, **kwargs):
tracker_name=tracker_name,
matcher_name=matcher_name,
max_tracks=2,
max_tracking=True,
sim_name=sim_name,
scale=scale,
)
Expand All @@ -152,7 +150,6 @@ def make_tracker_and_filename(*args, **kwargs):
tracker_name=tracker_name,
matcher_name=matcher_name,
max_tracks=2,
max_tracking=True,
sim_name=sim_name,
scale=0,
)
Expand Down

0 comments on commit c656bb8

Please sign in to comment.