Skip to content

Commit

Permalink
fixes from coderabbitai review
Browse files Browse the repository at this point in the history
  • Loading branch information
getzze committed Sep 2, 2024
1 parent ab24b50 commit a22a50e
Show file tree
Hide file tree
Showing 4 changed files with 10 additions and 10 deletions.
13 changes: 6 additions & 7 deletions sleap/instance.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,6 @@
"""

import math
import operator
from functools import reduce
from itertools import chain, combinations

Expand Down Expand Up @@ -1214,16 +1213,16 @@ def create_merged_instances(

# Ensure non-intersecting visible nodes
merged_instances = []
instance_subsets = chain(
*(combinations(instances, n) for n in range(2, len(instances) + 1))
instance_subsets = (
combinations(instances, n) for n in range(2, len(instances) + 1)
)
instance_subsets = chain.from_iterable(instance_subsets)
for subset in instance_subsets:
if not all_disjoint([s.nodes for s in subset]):
nodes = [s.nodes for s in subset]
if not all_disjoint(nodes):
continue

nodes_points = []
for instance in subset:
nodes_points.extend(list(instance.nodes_points))
nodes_points = [point for instance in subset for point in instance.nodes_points]
predicted_points = {node: point for node, point in nodes_points}

instance_score = reduce(lambda x, y: x * y, [s.score for s in subset])
Expand Down
1 change: 1 addition & 0 deletions sleap/nn/tracker/components.py
Original file line number Diff line number Diff line change
Expand Up @@ -376,6 +376,7 @@ def cull_frame_instances(

# Merge instances
if merge_instances:
logger.info("Merging instances with penalty: %f", merging_penalty)
merged_instances = create_merged_instances(
instances_list, penalty=merging_penalty
)
Expand Down
4 changes: 2 additions & 2 deletions sleap/util.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
"""A miscellaneous set of utility functions.
"""A miscellaneous set of utility functions.
Try not to put things in here unless they really have no other place.
"""
Expand Down Expand Up @@ -35,7 +35,7 @@
class RateColumn(rich.progress.ProgressColumn):
"""Renders the progress rate."""

def render(self, task: "Task") -> rich.progress.Text:
def render(self, task: rich.progress.Task) -> rich.progress.Text:
"""Show progress rate."""
speed = task.speed
if speed is None:
Expand Down
2 changes: 1 addition & 1 deletion tests/nn/test_inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -1817,7 +1817,7 @@ def test_max_tracks_matching_queue(

if trackername == "flow":
# Check that saved instances are pruned to track window
for key in tracker.candidate_maker.shifted_instances.keys():
for key in tracker.candidate_maker.shifted_instances:
assert lf.frame_idx - key[0] <= track_window # Keys are pruned
assert abs(key[0] - key[1]) <= track_window

Expand Down

0 comments on commit a22a50e

Please sign in to comment.