Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Allow tracking on Instances (by adding Instance.tracking_score) #1302

Open
wants to merge 6 commits into
base: develop
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion sleap/info/write_tracking_h5.py
Original file line number Diff line number Diff line change
Expand Up @@ -176,10 +176,10 @@ def get_occupancy_and_points_matrices(
occupancy_matrix[track_i, frame_i] = 1

locations_matrix[frame_i, ..., track_i] = inst.numpy()
tracking_scores[frame_i, ..., track_i] = inst.tracking_score
if type(inst) == PredictedInstance:
point_scores[frame_i, ..., track_i] = inst.scores
instance_scores[frame_i, ..., track_i] = inst.score
tracking_scores[frame_i, ..., track_i] = inst.tracking_score

return (
occupancy_matrix,
Expand Down
10 changes: 6 additions & 4 deletions sleap/instance.py
Original file line number Diff line number Diff line change
Expand Up @@ -357,6 +357,7 @@ class Instance:
frame: A back reference to the :class:`LabeledFrame` that this
:class:`Instance` belongs to. This field is set when
instances are added to :class:`LabeledFrame` objects.
tracking_score: The instance-level track matching score.
"""

skeleton: Skeleton = attr.ib()
Expand All @@ -369,6 +370,8 @@ class Instance:
# The underlying Point array type that this instances point array should be.
_point_array_type = PointArray

tracking_score: float = attr.ib(default=0.0, converter=float)

@from_predicted.validator
def _validate_from_predicted_(
self, attribute, from_predicted: Optional["PredictedInstance"]
Expand Down Expand Up @@ -662,7 +665,8 @@ def __repr__(self) -> str:
f"video={self.video}, "
f"frame_idx={self.frame_idx}, "
f"points=[{pts}], "
f"track={self.track}"
f"track={self.track}, "
f"tracking_score={self.tracking_score:.2f}"
")"
)

Expand Down Expand Up @@ -998,11 +1002,9 @@ class PredictedInstance(Instance):

Args:
score: The instance-level grouping prediction score.
tracking_score: The instance-level track matching score.
"""

score: float = attr.ib(default=0.0, converter=float)
tracking_score: float = attr.ib(default=0.0, converter=float)

# The underlying Point array type that this instances point array should be.
_point_array_type = PredictedPointArray
Expand Down Expand Up @@ -1309,7 +1311,7 @@ def __len__(self) -> int:
"""Return number of instances associated with frame."""
return len(self.instances)

def __getitem__(self, index) -> Instance:
def __getitem__(self, index) -> Union[Instance, PredictedInstance]:
"""Return instance (retrieved by index)."""
return self.instances.__getitem__(index)

Expand Down
12 changes: 8 additions & 4 deletions sleap/io/format/hdf5.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,11 +28,12 @@


class LabelsV1Adaptor(format.adaptor.Adaptor):
FORMAT_ID = 1.2
FORMAT_ID = 1.3

# 1.0 points with gridline coordinates, top left corner at (0, 0)
# 1.1 points with midpixel coordinates, top left corner at (-0.5, -0.5)
# 1.2 adds track score to read and write functions
# 1.2 adds tracking score for PredictedInstance to read and write functions
# 1.3 adds tracking score for Instance to read and write functions

@property
def handles(self):
Expand Down Expand Up @@ -180,6 +181,9 @@ def read(
skeleton=skeleton,
track=track,
points=points[i["point_id_start"] : i["point_id_end"]],
tracking_score=i["tracking_score"]
if (format_id is not None and format_id >= 1.3)
else 0.0,
)
else: # PredictedInstance
instance = PredictedInstance(
Expand Down Expand Up @@ -438,11 +442,9 @@ def append_unique(old, new):
if instance_type is PredictedInstance:
score = instance.score
pid = pred_point_id + pred_point_id_offset
tracking_score = instance.tracking_score
else:
score = np.nan
pid = point_id + point_id_offset
tracking_score = np.nan

# Keep track of any from_predicted instance links, we will
# insert the correct instance_id in the dataset after we are
Expand All @@ -451,6 +453,8 @@ def append_unique(old, new):
instances_with_from_predicted.append(instance_id)
instances_from_predicted.append(instance.from_predicted)

tracking_score = instance.tracking_score

# Copy all the data
instances[instance_id] = (
instance_id + instance_id_offset,
Expand Down
3 changes: 1 addition & 2 deletions sleap/io/format/nix.py
Original file line number Diff line number Diff line change
Expand Up @@ -239,13 +239,12 @@ def chunked_write(
positions[index, :, node_map[m]] = np.array([np.nan, np.nan])

centroids[index, :] = inst.centroid
trackscore[index] = inst.tracking_score
if hasattr(inst, "score"):
instscore[index] = inst.score
trackscore[index] = inst.tracking_score
pointscore[index, :] = inst.scores
else:
instscore[index] = 0.0
trackscore[index] = 0.0
pointscore[index, :] = dflt_pointscore

frameid_array[start:end] = indices[: end - start]
Expand Down
15 changes: 13 additions & 2 deletions sleap/nn/tracking.py
Original file line number Diff line number Diff line change
Expand Up @@ -1189,16 +1189,27 @@ def run_tracker(frames: List[LabeledFrame], tracker: BaseTracker) -> List[Labele
for inst in lf.instances:
inst.track = None

track_args = dict(untracked_instances=lf.instances)
# Prefer user instances over predicted instances
instances = []
if lf.has_user_instances:
instances_to_track = lf.user_instances
if lf.has_predicted_instances:
instances = lf.predicted_instances
else:
instances_to_track = lf.predicted_instances

track_args = {"untracked_instances": instances_to_track}

if tracker.uses_image:
track_args["img"] = lf.video[lf.frame_idx]
else:
track_args["img"] = None

instances.extend(tracker.track(**track_args))
new_lf = LabeledFrame(
frame_idx=lf.frame_idx,
video=lf.video,
instances=tracker.track(**track_args),
instances=instances,
)
new_lfs.append(new_lf)

Expand Down
31 changes: 30 additions & 1 deletion tests/io/test_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@
from sleap.instance import Instance, Point, LabeledFrame, PredictedInstance, Track
from sleap.io.video import Video, MediaVideo
from sleap.io.dataset import Labels, load_file
from sleap.io.legacy import load_labels_json_old
from sleap.io.format.ndx_pose import NDXPoseAdaptor
from sleap.io.format import filehandle
from sleap.gui.suggestions import VideoFrameSuggestions, SuggestionFrame
Expand Down Expand Up @@ -746,6 +745,36 @@ def test_dont_unify_skeletons():
labels.to_dict()


def test_instance_cattr(centered_pair_predictions: Labels, tmpdir: str):
labels = centered_pair_predictions
lf = labels.labeled_frames[0]
pred_inst: PredictedInstance = lf[0]
skeleton = pred_inst.skeleton
track = pred_inst.track

# Initialize Instance
instance = Instance.from_pointsarray(
points=pred_inst.numpy(), skeleton=skeleton, track=track
)
instance.from_predicted = pred_inst
assert instance.tracking_score == 0.0
labels.add_instance(lf, instance)

instance.tracking_score = 0.5
pred_inst.tracking_score = 0.7

filename = str(PurePath(tmpdir, "labels.slp"))
labels.save(filename)

labels_loaded = sleap.load_file(filename)
lf_loaded = labels_loaded.labeled_frames[0]
pred_inst_loaded = lf_loaded.predicted_instances[0]
instance_loaded = lf_loaded.user_instances[0]

assert round(pred_inst_loaded.tracking_score, 1) == pred_inst.tracking_score
assert round(instance_loaded.tracking_score, 1) == instance.tracking_score


def test_instance_access():
labels = Labels()

Expand Down