Skip to content

Commit

Permalink
Merge pull request roboflow#1603 from roboflow/feat/refactor-byte-track
Browse files Browse the repository at this point in the history
ByteTrack: Remove BaseTrack, refactor, add types, remove dead code, moved out Shared Kalman
  • Loading branch information
LinasKo authored Nov 1, 2024
2 parents 589d4f9 + 22d827f commit 1f610d8
Show file tree
Hide file tree
Showing 5 changed files with 227 additions and 253 deletions.
54 changes: 0 additions & 54 deletions supervision/tracker/byte_tracker/basetrack.py

This file was deleted.

217 changes: 22 additions & 195 deletions supervision/tracker/byte_tracker/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,186 +5,9 @@
from supervision.detection.core import Detections
from supervision.detection.utils import box_iou_batch
from supervision.tracker.byte_tracker import matching
from supervision.tracker.byte_tracker.basetrack import BaseTrack, TrackState
from supervision.tracker.byte_tracker.kalman_filter import KalmanFilter


class IdCounter:
def __init__(self, start_id: int = 0):
self.start_id = start_id
if self.start_id <= self.NO_ID:
raise ValueError("start_id must be greater than -1")
self.reset()

def reset(self) -> None:
self._id = self.start_id

def new_id(self) -> int:
returned_id = self._id
self._id += 1
return returned_id

@property
def NO_ID(self) -> int:
return -1


class STrack(BaseTrack):
shared_kalman = KalmanFilter()

def __init__(
self,
tlwh,
score,
minimum_consecutive_frames,
internal_id_counter: IdCounter,
external_id_counter: IdCounter,
):
super().__init__()
# wait activate
self._tlwh = np.asarray(tlwh, dtype=np.float32)
self.kalman_filter = None
self.mean, self.covariance = None, None
self.is_activated = False

self.score = score
self.tracklet_len = 0

self.minimum_consecutive_frames = minimum_consecutive_frames

self.internal_id_counter = internal_id_counter
self.external_id_counter = external_id_counter
self.internal_track_id = self.internal_id_counter.NO_ID
self.external_track_id = self.external_id_counter.NO_ID

def predict(self):
mean_state = self.mean.copy()
if self.state != TrackState.Tracked:
mean_state[7] = 0
self.mean, self.covariance = self.kalman_filter.predict(
mean_state, self.covariance
)

@staticmethod
def multi_predict(stracks):
if len(stracks) > 0:
multi_mean = []
multi_covariance = []
for i, st in enumerate(stracks):
multi_mean.append(st.mean.copy())
multi_covariance.append(st.covariance)
if st.state != TrackState.Tracked:
multi_mean[i][7] = 0

multi_mean, multi_covariance = STrack.shared_kalman.multi_predict(
np.asarray(multi_mean), np.asarray(multi_covariance)
)
for i, (mean, cov) in enumerate(zip(multi_mean, multi_covariance)):
stracks[i].mean = mean
stracks[i].covariance = cov

def activate(self, kalman_filter, frame_id):
"""Start a new tracklet"""
self.kalman_filter = kalman_filter
self.internal_track_id = self.internal_id_counter.new_id()
self.mean, self.covariance = self.kalman_filter.initiate(
self.tlwh_to_xyah(self._tlwh)
)

self.tracklet_len = 0
self.state = TrackState.Tracked
if frame_id == 1:
self.is_activated = True

if self.minimum_consecutive_frames == 1:
self.external_track_id = self.external_id_counter.new_id()

self.frame_id = frame_id
self.start_frame = frame_id

def re_activate(self, new_track, frame_id):
self.mean, self.covariance = self.kalman_filter.update(
self.mean, self.covariance, self.tlwh_to_xyah(new_track.tlwh)
)
self.tracklet_len = 0
self.state = TrackState.Tracked

self.frame_id = frame_id
self.score = new_track.score

def update(self, new_track, frame_id):
"""
Update a matched track
:type new_track: STrack
:type frame_id: int
:type update_feature: bool
:return:
"""
self.frame_id = frame_id
self.tracklet_len += 1

new_tlwh = new_track.tlwh
self.mean, self.covariance = self.kalman_filter.update(
self.mean, self.covariance, self.tlwh_to_xyah(new_tlwh)
)
self.state = TrackState.Tracked
if self.tracklet_len == self.minimum_consecutive_frames:
self.is_activated = True
if self.external_track_id == self.external_id_counter.NO_ID:
self.external_track_id = self.external_id_counter.new_id()

self.score = new_track.score

@property
def tlwh(self):
"""Get current position in bounding box format `(top left x, top left y,
width, height)`.
"""
if self.mean is None:
return self._tlwh.copy()
ret = self.mean[:4].copy()
ret[2] *= ret[3]
ret[:2] -= ret[2:] / 2
return ret

@property
def tlbr(self):
"""Convert bounding box to format `(min x, min y, max x, max y)`, i.e.,
`(top left, bottom right)`.
"""
ret = self.tlwh.copy()
ret[2:] += ret[:2]
return ret

@staticmethod
def tlwh_to_xyah(tlwh):
"""Convert bounding box to format `(center x, center y, aspect ratio,
height)`, where the aspect ratio is `width / height`.
"""
ret = np.asarray(tlwh).copy()
ret[:2] += ret[2:] / 2
ret[2] /= ret[3]
return ret

def to_xyah(self):
return self.tlwh_to_xyah(self.tlwh)

@staticmethod
def tlbr_to_tlwh(tlbr):
ret = np.asarray(tlbr).copy()
ret[2:] -= ret[:2]
return ret

@staticmethod
def tlwh_to_tlbr(tlwh):
ret = np.asarray(tlwh).copy()
ret[2:] += ret[:2]
return ret

def __repr__(self):
return "OT_{}_({}-{})".format(
self.internal_track_id, self.start_frame, self.end_frame
)
from supervision.tracker.byte_tracker.single_object_track import STrack, TrackState
from supervision.tracker.byte_tracker.utils import IdCounter


class ByteTrack:
Expand Down Expand Up @@ -230,6 +53,7 @@ def __init__(
self.max_time_lost = int(frame_rate / 30.0 * lost_track_buffer)
self.minimum_consecutive_frames = minimum_consecutive_frames
self.kalman_filter = KalmanFilter()
self.shared_kalman = KalmanFilter()

self.tracked_tracks: List[STrack] = []
self.lost_tracks: List[STrack] = []
Expand Down Expand Up @@ -279,7 +103,6 @@ def callback(frame: np.ndarray, index: int) -> np.ndarray:
)
```
"""

tensors = np.hstack(
(
detections.xyxy,
Expand Down Expand Up @@ -311,7 +134,7 @@ def callback(frame: np.ndarray, index: int) -> np.ndarray:

return detections

def reset(self):
def reset(self) -> None:
"""
Resets the internal state of the ByteTrack tracker.
Expand All @@ -323,9 +146,9 @@ def reset(self):
self.frame_id = 0
self.internal_id_counter.reset()
self.external_id_counter.reset()
self.tracked_tracks: List[STrack] = []
self.lost_tracks: List[STrack] = []
self.removed_tracks: List[STrack] = []
self.tracked_tracks = []
self.lost_tracks = []
self.removed_tracks = []

def update_with_tensors(self, tensors: np.ndarray) -> List[STrack]:
"""
Expand Down Expand Up @@ -361,12 +184,13 @@ def update_with_tensors(self, tensors: np.ndarray) -> List[STrack]:
detections = [
STrack(
STrack.tlbr_to_tlwh(tlbr),
s,
score_keep,
self.minimum_consecutive_frames,
self.shared_kalman,
self.internal_id_counter,
self.external_id_counter,
)
for (tlbr, s) in zip(dets, scores_keep)
for (tlbr, score_keep) in zip(dets, scores_keep)
]
else:
detections = []
Expand All @@ -384,7 +208,7 @@ def update_with_tensors(self, tensors: np.ndarray) -> List[STrack]:
""" Step 2: First association, with high score detection boxes"""
strack_pool = joint_tracks(tracked_stracks, self.lost_tracks)
# Predict the current location with KF
STrack.multi_predict(strack_pool)
STrack.multi_predict(strack_pool, self.shared_kalman)
dists = matching.iou_distance(strack_pool, detections)

dists = matching.fuse_score(dists, detections)
Expand All @@ -409,12 +233,13 @@ def update_with_tensors(self, tensors: np.ndarray) -> List[STrack]:
detections_second = [
STrack(
STrack.tlbr_to_tlwh(tlbr),
s,
score_second,
self.minimum_consecutive_frames,
self.shared_kalman,
self.internal_id_counter,
self.external_id_counter,
)
for (tlbr, s) in zip(dets_second, scores_second)
for (tlbr, score_second) in zip(dets_second, scores_second)
]
else:
detections_second = []
Expand All @@ -440,7 +265,7 @@ def update_with_tensors(self, tensors: np.ndarray) -> List[STrack]:
for it in u_track:
track = r_tracked_stracks[it]
if not track.state == TrackState.Lost:
track.mark_lost()
track.state = TrackState.Lost
lost_stracks.append(track)

"""Deal with unconfirmed tracks, usually tracks with only one beginning frame"""
Expand All @@ -456,7 +281,7 @@ def update_with_tensors(self, tensors: np.ndarray) -> List[STrack]:
activated_starcks.append(unconfirmed[itracked])
for it in u_unconfirmed:
track = unconfirmed[it]
track.mark_removed()
track.state = TrackState.Removed
removed_stracks.append(track)

""" Step 4: Init new stracks"""
Expand All @@ -468,8 +293,8 @@ def update_with_tensors(self, tensors: np.ndarray) -> List[STrack]:
activated_starcks.append(track)
""" Step 5: Update state"""
for track in self.lost_tracks:
if self.frame_id - track.end_frame > self.max_time_lost:
track.mark_removed()
if self.frame_id - track.frame_id > self.max_time_lost:
track.state = TrackState.Removed
removed_stracks.append(track)

self.tracked_tracks = [
Expand Down Expand Up @@ -515,7 +340,7 @@ def joint_tracks(
return result


def sub_tracks(track_list_a: List, track_list_b: List) -> List[int]:
def sub_tracks(track_list_a: List[STrack], track_list_b: List[STrack]) -> List[int]:
"""
Returns a list of tracks from track_list_a after removing any tracks
that share the same internal_track_id with tracks in track_list_b.
Expand All @@ -536,7 +361,9 @@ def sub_tracks(track_list_a: List, track_list_b: List) -> List[int]:
return list(tracks.values())


def remove_duplicate_tracks(tracks_a: List, tracks_b: List) -> Tuple[List, List]:
def remove_duplicate_tracks(
tracks_a: List[STrack], tracks_b: List[STrack]
) -> Tuple[List[STrack], List[STrack]]:
pairwise_distance = matching.iou_distance(tracks_a, tracks_b)
matching_pairs = np.where(pairwise_distance < 0.15)

Expand Down
Loading

0 comments on commit 1f610d8

Please sign in to comment.