diff --git a/sleap/gui/learning/runners.py b/sleap/gui/learning/runners.py index d0bb1f3ba..5da9bad78 100644 --- a/sleap/gui/learning/runners.py +++ b/sleap/gui/learning/runners.py @@ -261,6 +261,7 @@ def make_predict_cli_call( "tracking.post_connect_single_breaks", "tracking.save_shifted_instances", "tracking.oks_score_weighting", + "tracking.only_predicted_instances", ) for key in bool_items_as_ints: diff --git a/sleap/instance.py b/sleap/instance.py index 08a5c6ae6..9ffdd0cc7 100644 --- a/sleap/instance.py +++ b/sleap/instance.py @@ -1177,6 +1177,25 @@ def from_numpy( ) +def convert_to_predicted_instance( + inst: Union[Instance, PredictedInstance], + *, + score: float = 1.0, + tracking_score: float = 0.0, +) -> PredictedInstance: + """Convert an Instance to a PredictedInstance, if it's not one already. + + Score is by default 1.0, like a user-defined instance. + """ + if isinstance(inst, PredictedInstance): + return inst + + kwargs = attr.asdict(inst) + kwargs["score"] = score + kwargs["tracking_score"] = tracking_score + return PredictedInstance(**kwargs) + + def make_instance_cattr() -> cattr.Converter: """Create a cattr converter for Lists of Instances/PredictedInstances. diff --git a/sleap/io/format/coco.py b/sleap/io/format/coco.py index 89b3a450c..44e7fb84a 100644 --- a/sleap/io/format/coco.py +++ b/sleap/io/format/coco.py @@ -180,7 +180,7 @@ def read( if flag == 0: # node not labeled for this instance - if (x, y) != (0, 0): + if (x, y) != (0, 0): # If labeled but invisible, place the node at the coord points[node] = Point(x, y, False) continue diff --git a/sleap/nn/tracking.py b/sleap/nn/tracking.py index 558aa9309..fa11cbcae 100644 --- a/sleap/nn/tracking.py +++ b/sleap/nn/tracking.py @@ -1,6 +1,6 @@ """Tracking tools for linking grouped instances over time.""" -from collections import deque, defaultdict +from collections import deque import abc import attr import numpy as np @@ -9,6 +9,7 @@ from typing import Callable, Deque, Dict, Iterable, List, Optional, Tuple from sleap import Track, LabeledFrame, Skeleton +from sleap.instance import convert_to_predicted_instance from sleap.nn.tracker.components import ( factory_object_keypoint_similarity, @@ -580,6 +581,7 @@ class Tracker(BaseTracker): robust_best_instance: float = 1.0 min_new_track_points: int = 0 + only_predicted_instances: bool = True track_matching_queue: Deque[MatchedFrameInstances] = attr.ib() @@ -639,6 +641,28 @@ def unique_tracks_in_queue(self) -> List[Track]: def uses_image(self): return getattr(self.candidate_maker, "uses_image", False) + def infer_next_timestep(self, t: Optional[int] = None) -> int: + """Infer timestep if not provided.""" + # Timestep was provided + if t is not None: + return t + + if self.has_max_tracking and len(self.track_matching_queue_dict) > 0: + # Default to last timestep + 1 if available. + # Here we find the track that has the most instances. + track_with_max_instances = max( + self.track_matching_queue_dict, + key=lambda track: len(self.track_matching_queue_dict[track]), + ) + return 1 + self.track_matching_queue_dict[track_with_max_instances][-1].t + + # Default to last timestep + 1 if available. + if not self.has_max_tracking and len(self.track_matching_queue) > 0: + return self.track_matching_queue[-1].t + 1 + + # Default to 0 + return 0 + def track( self, untracked_instances: List[InstanceType], @@ -667,31 +691,7 @@ def track( return untracked_instances # Infer timestep if not provided. - if t is None: - if self.has_max_tracking: - if len(self.track_matching_queue_dict) > 0: - - # Default to last timestep + 1 if available. - # Here we find the track that has the most instances. - track_with_max_instances = max( - self.track_matching_queue_dict, - key=lambda track: len(self.track_matching_queue_dict[track]), - ) - t = ( - self.track_matching_queue_dict[track_with_max_instances][-1].t - + 1 - ) - - else: - t = 0 - else: - if len(self.track_matching_queue) > 0: - - # Default to last timestep + 1 if available. - t = self.track_matching_queue[-1].t + 1 - - else: - t = 0 + t = self.infer_next_timestep(t) # Initialize containers for tracked instances at the current timestep. tracked_instances = [] @@ -844,6 +844,7 @@ def make_tracker_by_name( robust: float = 1.0, min_new_track_points: int = 0, min_match_points: int = 0, + only_predicted_instances: bool = True, # Optical flow options img_scale: float = 1.0, of_window_size: int = 21, @@ -942,6 +943,7 @@ def pre_cull_function(inst_list): max_tracks=max_tracks, target_instance_count=target_instance_count, post_connect_single_breaks=post_connect_single_breaks, + only_predicted_instances=only_predicted_instances, ) if target_instance_count and kf_init_frame_count: @@ -1058,6 +1060,11 @@ def get_by_name_factory_options(cls): option["help"] = "Minimum points for match candidates" options.append(option) + option = dict(name="only_predicted_instances", default=1) + option["type"] = int + option["help"] = "Track only predicted instances, not user-defined instances." + options.append(option) + option = dict(name="img_scale", default=1.0) option["type"] = float option["help"] = "For optical-flow: Image scale" @@ -1166,9 +1173,7 @@ class FlowTracker(Tracker): candidate_maker: object = attr.ib(factory=FlowCandidateMaker) -attr.s(auto_attribs=True) - - +@attr.s(auto_attribs=True) class FlowMaxTracker(Tracker): """Pre-configured tracker to use optical flow shifted candidates with max tracks.""" @@ -1520,12 +1525,17 @@ def run_tracker(frames: List[LabeledFrame], tracker: BaseTracker) -> List[Labele # Run tracking on every frame for lf in frames: + # Use only the predicted instances + if tracker.only_predicted_instances: + instances = lf.predicted_instances + else: + instances = [convert_to_predicted_instance(inst) for inst in lf.instances] # Clear the tracks - for inst in lf.instances: + for inst in instances: inst.track = None - track_args = dict(untracked_instances=lf.instances) + track_args = dict(untracked_instances=instances) if tracker.uses_image: track_args["img"] = lf.video[lf.frame_idx] else: