Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix track only PredictedInstance #2028

Open
wants to merge 4 commits into
base: develop
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions sleap/gui/learning/runners.py
Original file line number Diff line number Diff line change
Expand Up @@ -261,6 +261,7 @@ def make_predict_cli_call(
"tracking.post_connect_single_breaks",
"tracking.save_shifted_instances",
"tracking.oks_score_weighting",
"tracking.only_predicted_instances",
)

for key in bool_items_as_ints:
Expand Down
19 changes: 19 additions & 0 deletions sleap/instance.py
Original file line number Diff line number Diff line change
Expand Up @@ -1177,6 +1177,25 @@ def from_numpy(
)


def convert_to_predicted_instance(
inst: Union[Instance, PredictedInstance],
*,
score: float = 1.0,
tracking_score: float = 0.0,
) -> PredictedInstance:
"""Convert an Instance to a PredictedInstance, if it's not one already.

Score is by default 1.0, like a user-defined instance.
"""
if isinstance(inst, PredictedInstance):
return inst

kwargs = attr.asdict(inst)
kwargs["score"] = score
kwargs["tracking_score"] = tracking_score
return PredictedInstance(**kwargs)


def make_instance_cattr() -> cattr.Converter:
"""Create a cattr converter for Lists of Instances/PredictedInstances.

Expand Down
72 changes: 41 additions & 31 deletions sleap/nn/tracking.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
"""Tracking tools for linking grouped instances over time."""

from collections import deque, defaultdict
from collections import deque
import abc
import attr
import numpy as np
Expand All @@ -9,6 +9,7 @@
from typing import Callable, Deque, Dict, Iterable, List, Optional, Tuple

from sleap import Track, LabeledFrame, Skeleton
from sleap.instance import convert_to_predicted_instance

from sleap.nn.tracker.components import (
factory_object_keypoint_similarity,
Expand Down Expand Up @@ -580,6 +581,7 @@ class Tracker(BaseTracker):
robust_best_instance: float = 1.0

min_new_track_points: int = 0
only_predicted_instances: bool = True

track_matching_queue: Deque[MatchedFrameInstances] = attr.ib()

Expand Down Expand Up @@ -639,6 +641,28 @@ def unique_tracks_in_queue(self) -> List[Track]:
def uses_image(self):
return getattr(self.candidate_maker, "uses_image", False)

def infer_next_timestep(self, t: Optional[int] = None) -> int:
"""Infer timestep if not provided."""
# Timestep was provided
if t is not None:
return t

if self.has_max_tracking and len(self.track_matching_queue_dict) > 0:
# Default to last timestep + 1 if available.
# Here we find the track that has the most instances.
track_with_max_instances = max(
self.track_matching_queue_dict,
key=lambda track: len(self.track_matching_queue_dict[track]),
)
return 1 + self.track_matching_queue_dict[track_with_max_instances][-1].t

# Default to last timestep + 1 if available.
if not self.has_max_tracking and len(self.track_matching_queue) > 0:
return self.track_matching_queue[-1].t + 1

# Default to 0
return 0

def track(
self,
untracked_instances: List[InstanceType],
Expand Down Expand Up @@ -667,31 +691,7 @@ def track(
return untracked_instances

# Infer timestep if not provided.
if t is None:
if self.has_max_tracking:
if len(self.track_matching_queue_dict) > 0:

# Default to last timestep + 1 if available.
# Here we find the track that has the most instances.
track_with_max_instances = max(
self.track_matching_queue_dict,
key=lambda track: len(self.track_matching_queue_dict[track]),
)
t = (
self.track_matching_queue_dict[track_with_max_instances][-1].t
+ 1
)

else:
t = 0
else:
if len(self.track_matching_queue) > 0:

# Default to last timestep + 1 if available.
t = self.track_matching_queue[-1].t + 1

else:
t = 0
t = self.infer_next_timestep(t)

# Initialize containers for tracked instances at the current timestep.
tracked_instances = []
Expand Down Expand Up @@ -844,6 +844,7 @@ def make_tracker_by_name(
robust: float = 1.0,
min_new_track_points: int = 0,
min_match_points: int = 0,
only_predicted_instances: bool = True,
# Optical flow options
img_scale: float = 1.0,
of_window_size: int = 21,
Expand Down Expand Up @@ -942,6 +943,7 @@ def pre_cull_function(inst_list):
max_tracks=max_tracks,
target_instance_count=target_instance_count,
post_connect_single_breaks=post_connect_single_breaks,
only_predicted_instances=only_predicted_instances,
)

if target_instance_count and kf_init_frame_count:
Expand Down Expand Up @@ -1058,6 +1060,11 @@ def get_by_name_factory_options(cls):
option["help"] = "Minimum points for match candidates"
options.append(option)

option = dict(name="only_predicted_instances", default=1)
option["type"] = int
option["help"] = "Track only predicted instances, not user-defined instances."
options.append(option)

option = dict(name="img_scale", default=1.0)
option["type"] = float
option["help"] = "For optical-flow: Image scale"
Expand Down Expand Up @@ -1166,9 +1173,7 @@ class FlowTracker(Tracker):
candidate_maker: object = attr.ib(factory=FlowCandidateMaker)


attr.s(auto_attribs=True)


@attr.s(auto_attribs=True)
class FlowMaxTracker(Tracker):
"""Pre-configured tracker to use optical flow shifted candidates with max tracks."""

Expand Down Expand Up @@ -1520,12 +1525,17 @@ def run_tracker(frames: List[LabeledFrame], tracker: BaseTracker) -> List[Labele

# Run tracking on every frame
for lf in frames:
# Use only the predicted instances
if tracker.only_predicted_instances:
instances = lf.predicted_instances
else:
instances = [convert_to_predicted_instance(inst) for inst in lf.instances]

# Clear the tracks
for inst in lf.instances:
for inst in instances:
inst.track = None

track_args = dict(untracked_instances=lf.instances)
track_args = dict(untracked_instances=instances)
if tracker.uses_image:
track_args["img"] = lf.video[lf.frame_idx]
else:
Expand Down
Loading