From 5c52c7be56aaeb7d090c047bc6d58dfe8e263d7c Mon Sep 17 00:00:00 2001 From: Raif Olson Date: Thu, 21 Mar 2024 21:51:10 -0500 Subject: [PATCH 1/5] Add confidence tracker --- .../tracker/confidence_tracker/basetrack.py | 61 ++ .../tracker/confidence_tracker/core.py | 574 ++++++++++++++++++ .../confidence_tracker/kalman_filter.py | 267 ++++++++ .../tracker/confidence_tracker/matching.py | 84 +++ 4 files changed, 986 insertions(+) create mode 100644 supervision/tracker/confidence_tracker/basetrack.py create mode 100644 supervision/tracker/confidence_tracker/core.py create mode 100644 supervision/tracker/confidence_tracker/kalman_filter.py create mode 100644 supervision/tracker/confidence_tracker/matching.py diff --git a/supervision/tracker/confidence_tracker/basetrack.py b/supervision/tracker/confidence_tracker/basetrack.py new file mode 100644 index 000000000..36e28e6fa --- /dev/null +++ b/supervision/tracker/confidence_tracker/basetrack.py @@ -0,0 +1,61 @@ +from collections import OrderedDict + +import numpy as np + + +class TrackState(object): + New = 0 + Tracked = 1 + Lost = 2 + LongLost = 3 + Removed = 4 + + +class BaseTrack(object): + _count = 0 + + track_id = 0 + is_activated = False + state = TrackState.New + + history = OrderedDict() + features = [] + curr_feature = None + score = 0 + start_frame = 0 + frame_id = 0 + time_since_update = 0 + + # multi-camera + location = (np.inf, np.inf) + + @property + def end_frame(self): + return self.frame_id + + @staticmethod + def next_id(): + BaseTrack._count += 1 + return BaseTrack._count + + def activate(self, *args): + raise NotImplementedError + + def predict(self): + raise NotImplementedError + + def update(self, *args, **kwargs): + raise NotImplementedError + + def mark_lost(self): + self.state = TrackState.Lost + + def mark_long_lost(self): + self.state = TrackState.LongLost + + def mark_removed(self): + self.state = TrackState.Removed + + @staticmethod + def clear_count(): + BaseTrack._count = 0 diff --git a/supervision/tracker/confidence_tracker/core.py b/supervision/tracker/confidence_tracker/core.py new file mode 100644 index 000000000..c401af1e1 --- /dev/null +++ b/supervision/tracker/confidence_tracker/core.py @@ -0,0 +1,574 @@ +from collections import deque +from typing import List + +import numpy as np + +from supervision.detection.core import Detections +from supervision.tracker.confidence_tracker import matching +from supervision.tracker.confidence_tracker.basetrack import BaseTrack, TrackState +from supervision.tracker.confidence_tracker.kalman_filter import KalmanFilter + + +class STrack(BaseTrack): + shared_kalman = KalmanFilter(0.6, 10) + + def __init__(self, tlwh, score, class_id, feat=None, feat_history=50): + # wait activate + self._tlwh = np.asarray(tlwh, dtype=float) + self.kalman_filter = None + self.mean, self.covariance = None, None + self.is_activated = False + + self.score = score + self.tracklet_len = 0 + + self.cls = -1 + self.cls_hist = [] # (cls id, freq) + self.update_cls(class_id, score) + + self.smooth_feat = None + self.curr_feat = None + if feat is not None: + self.update_features(feat) + self.features = deque([], maxlen=feat_history) + self.alpha = 0.9 + + def update_features(self, feat): + feat /= np.linalg.norm(feat) + self.curr_feat = feat + if self.smooth_feat is None: + self.smooth_feat = feat + else: + self.smooth_feat = self.alpha * self.smooth_feat + (1 - self.alpha) * feat + self.features.append(feat) + self.smooth_feat /= np.linalg.norm(self.smooth_feat) + + def update_cls(self, cls, score): + if len(self.cls_hist) > 0: + max_freq = 0 + found = False + for c in self.cls_hist: + if cls == c[0]: + c[1] += score + found = True + + if c[1] > max_freq: + max_freq = c[1] + self.cls = c[0] + if not found: + self.cls_hist.append([cls, score]) + self.cls = cls + else: + self.cls_hist.append([cls, score]) + self.cls = cls + + def predict(self): + mean_state = self.mean.copy() + if self.state != TrackState.Tracked: + mean_state[6] = 0 + mean_state[7] = 0 + + self.mean, self.covariance = self.kalman_filter.predict( + mean_state, self.covariance + ) + + @staticmethod + def multi_predict(stracks): + if len(stracks) > 0: + multi_mean = np.asarray([st.mean.copy() for st in stracks]) + multi_covariance = np.asarray([st.covariance for st in stracks]) + for i, st in enumerate(stracks): + if st.state != TrackState.Tracked: + multi_mean[i][6] = 0 + multi_mean[i][7] = 0 + multi_mean, multi_covariance = STrack.shared_kalman.multi_predict( + multi_mean, multi_covariance + ) + for i, (mean, cov) in enumerate(zip(multi_mean, multi_covariance)): + stracks[i].mean = mean + stracks[i].covariance = cov + + def activate(self, kalman_filter, frame_id): + """Start a new tracklet""" + self.kalman_filter = kalman_filter + self.track_id = self.next_id() + + self.mean, self.covariance = self.kalman_filter.initiate( + self.tlwh_to_xywh(self._tlwh) + ) + + self.tracklet_len = 0 + self.state = TrackState.Tracked + if frame_id == 1: + self.is_activated = True + self.frame_id = frame_id + self.start_frame = frame_id + + def re_activate(self, new_track, frame_id, new_id=False): + self.mean, self.covariance = self.kalman_filter.update( + self.mean, self.covariance, self.tlwh_to_xywh(new_track.tlwh), self.score + ) + if new_track.curr_feat is not None: + self.update_features(new_track.curr_feat) + self.tracklet_len = 0 + self.state = TrackState.Tracked + self.is_activated = True + self.frame_id = frame_id + if new_id: + self.track_id = self.next_id() + self.score = new_track.score + + self.update_cls(new_track.cls, new_track.score) + + def update(self, new_track, frame_id): + """ + Update a matched track + :type new_track: STrack + :type frame_id: int + :type update_feature: bool + :return: + """ + self.frame_id = frame_id + self.tracklet_len += 1 + + new_tlwh = new_track.tlwh + + self.score = new_track.score + + self.mean, self.covariance = self.kalman_filter.update( + self.mean, self.covariance, self.tlwh_to_xywh(new_tlwh), self.score + ) + + if new_track.curr_feat is not None: + self.update_features(new_track.curr_feat) + + self.state = TrackState.Tracked + self.is_activated = True + + self.update_cls(new_track.cls, new_track.score) + + @property + def tlwh(self): + """Get current position in bounding box format `(top left x, top left y, + width, height)`. + """ + if self.mean is None: + return self._tlwh.copy() + ret = self.mean[:4].copy() + ret[:2] -= ret[2:] / 2 + return ret + + @property + def tlbr(self): + """Convert bounding box to format `(min x, min y, max x, max y)`, i.e., + `(top left, bottom right)`. + """ + ret = self.tlwh.copy() + ret[2:] += ret[:2] + return ret + + @property + def xywh(self): + """Convert bounding box to format `(min x, min y, max x, max y)`, i.e., + `(top left, bottom right)`. + """ + ret = self.tlwh.copy() + ret[:2] += ret[2:] / 2.0 + return ret + + @staticmethod + def tlwh_to_xyah(tlwh): + """Convert bounding box to format `(center x, center y, aspect ratio, + height)`, where the aspect ratio is `width / height`. + """ + ret = np.asarray(tlwh).copy() + ret[:2] += ret[2:] / 2 + ret[2] /= ret[3] + return ret + + @staticmethod + def tlwh_to_xywh(tlwh): + """Convert bounding box to format `(center x, center y, width, + height)`. + """ + ret = np.asarray(tlwh).copy() + ret[:2] += ret[2:] / 2 + return ret + + def to_xywh(self): + return self.tlwh_to_xywh(self.tlwh) + + @staticmethod + def tlbr_to_tlwh(tlbr): + ret = np.asarray(tlbr).copy() + ret[2:] -= ret[:2] + return ret + + @staticmethod + def tlwh_to_tlbr(tlwh): + ret = np.asarray(tlwh).copy() + ret[2:] += ret[:2] + return ret + + def __repr__(self): + return "OT_{}_({}-{})".format(self.track_id, self.start_frame, self.end_frame) + + +def detections2boxes(detections: Detections) -> np.ndarray: + """ + Convert Supervision Detections to numpy tensors for further computation. + Args: + detections (Detections): Detections/Targets in the format of sv.Detections. + features (ndarray): The corresponding image features of each detection. + Has shape [N, num_features] + Returns: + (np.ndarray): Detections as numpy tensors as in + `(x_min, y_min, x_max, y_max, confidence, class_id, feature_vect)` order. + """ + return np.hstack( + ( + detections.xyxy, + detections.confidence[:, np.newaxis], + detections.class_id[:, np.newaxis], + ) + ) + + +class ConfTrack: + def __init__( + self, + track_high_thresh: float = 0.6, + track_low_thresh: float = 0.2, + new_track_thresh: float = 0.1, + tent_conf_thresh: float = 0.7, + minimum_matching_threshold: float = 0.8, + lost_track_buffer: int = 30, + proximity_thresh: float = 0.6, + frame_rate: int = 30, + ): + BaseTrack.clear_count() + + self.frame_id = 0 + + self.track_high_thresh = track_high_thresh + self.track_low_thresh = track_low_thresh + self.new_track_thresh = new_track_thresh + + self.tent_conf_thresh = tent_conf_thresh + + self.minimum_matching_threshold = minimum_matching_threshold + + # self.buffer_size = int(frame_rate / 30.0 * args.lost_track_buffer) + self.max_time_lost = int(frame_rate / 30.0 * lost_track_buffer) + self.kalman_filter = KalmanFilter(0.6, 10) + + self.tracked_tracks: List[STrack] = [] + self.lost_tracks: List[STrack] = [] + self.removed_tracks: List[STrack] = [] + + def update_with_detections(self, detections: Detections) -> Detections: + """ + Updates the tracker with the provided detections and + returns the updated detection results. + + Parameters: + detections: The new detections to update with. + Returns: + Detection: The updated detection results that now include tracking IDs. + Example: + ```python + >>> import supervision as sv + >>> from ultralytics import YOLO + + >>> model = YOLO(...) + >>> byte_tracker = sv.ByteTrack() + >>> annotator = sv.BoxAnnotator() + + >>> def callback(frame: np.ndarray, index: int) -> np.ndarray: + ... results = model(frame)[0] + ... detections = sv.Detections.from_ultralytics(results) + ... detections = byte_tracker.update_with_detections(detections) + ... labels = [ + ... f"#{tracker_id} {model.model.names[class_id]} {confidence:0.2f}" + ... for _, _, confidence, class_id, tracker_id + ... in detections + ... ] + ... return annotator.annotate(scene=frame.copy(), + ... detections=detections, labels=labels) + + >>> sv.process_video( + ... source_path='...', + ... target_path='...', + ... callback=callback + ... ) + ``` + """ + tensors = detections2boxes(detections) + # print(f"tensors: {tensors}") + tracks = self.update_with_tensors( + # maybe extract features here + tensors + ) + detections = Detections.empty() + if len(tracks) > 0: + detections.xyxy = np.array([track.tlbr for track in tracks], dtype=float) + detections.class_id = np.array([int(t.cls) for t in tracks], dtype=int) + detections.tracker_id = np.array( + [int(t.track_id) for t in tracks], dtype=int + ) + detections.confidence = np.array([t.score for t in tracks], dtype=float) + else: + detections.tracker_id = np.array([], dtype=int) + + return detections + + def reset(self): + """ + Resets the internal state of the ByteTrack tracker. + + This method clears the tracking data, including tracked, lost, + and removed tracks, as well as resetting the frame counter. It's + particularly useful when processing multiple videos sequentially, + ensuring the tracker starts with a clean state for each new video. + """ + self.frame_id = 0 + self.tracked_tracks: List[STrack] = [] + self.lost_tracks: List[STrack] = [] + self.removed_tracks: List[STrack] = [] + BaseTrack.reset_counter() + + def update_with_tensors(self, tensors: np.ndarray) -> List[STrack]: + """ + Updates the tracker with the provided tensors and returns the updated tracks. + + Parameters: + tensors: The new tensors to update with. + + Returns: + List[STrack]: Updated tracks. + """ + self.frame_id += 1 + activated_starcks = [] + refind_stracks = [] + lost_tracks = [] + removed_stracks = [] + + class_ids = tensors[:, 5] + scores = tensors[:, 4] + bboxes = tensors[:, :4] + + # Remove bad detections + inds_low = scores > self.track_low_thresh + bboxes = bboxes[inds_low] + scores = scores[inds_low] + class_ids = class_ids[inds_low] + + # Find high threshold detections + inds_high = scores > self.track_high_thresh + dets = bboxes[inds_high] + scores_keep = scores[inds_high] + class_ids_keep = class_ids[inds_high] + + # Find low threshold detections + inds_second = np.logical_and(inds_low, inds_high) + dets_second = bboxes[inds_second] + scores_second = scores[inds_second] + class_ids_second = class_ids[inds_second] + + if len(dets) > 0: + """Detections""" + detections = [ + STrack(STrack.tlbr_to_tlwh(tlbr), s, cl) + for (tlbr, s, cl) in zip(dets, scores_keep, class_ids_keep) + ] + else: + detections = [] + + """ Add newly detected tracklets to tracked_tracks""" + low_tent = [] # type: list[STrack] + high_tent = [] # type: list[STrack] + tracked_tracks = [] # type: list[STrack] + for track in self.tracked_tracks: + if not track.is_activated: + # implement LM from ConfTrack paper + if track.score < self.tent_conf_thresh: + low_tent.append(track) + else: + high_tent.append(track) + else: + tracked_tracks.append(track) + + """ Step 2: First association, with high score detection boxes""" + strack_pool = joint_stracks(tracked_tracks, self.lost_tracks) + + # Predict the current location with KF + STrack.multi_predict(strack_pool) + + # LM algorithm + strack_pool = joint_stracks(strack_pool, high_tent) + + # Associate with high score detection boxes + ious_dists = matching.iou_distance(strack_pool, detections) + + # Fuse the iou's with the scores + ious_dists = matching.fuse_score(ious_dists, detections) + + matches, track_conf_remain, det_high_remain = matching.linear_assignment( + ious_dists, thresh=self.minimum_matching_threshold + ) + + for itracked, idet in matches: + track = strack_pool[itracked] + det = detections[idet] + if track.state == TrackState.Tracked: + track.update(detections[idet], self.frame_id) + activated_starcks.append(track) + else: + track.re_activate(det, self.frame_id, new_id=False) + refind_stracks.append(track) + + """ Step 3: Second association, with low score detection boxes""" + # association the untrack to the low score detections + if len(dets_second) > 0: + """Detections""" + detections_second = [ + STrack(STrack.tlbr_to_tlwh(tlbr), s, cl) + for (tlbr, s, cl) in zip(dets_second, scores_second, class_ids_second) + ] + else: + detections_second = [] + + r_tracked_tracks = [ + strack_pool[i] + for i in track_conf_remain + if strack_pool[i].state == TrackState.Tracked + ] + dists = matching.iou_distance(r_tracked_tracks, detections_second) + matches, track_conf_remain, det_low_remain = matching.linear_assignment( + dists, thresh=0.5 + ) + for itracked, idet in matches: + track = r_tracked_tracks[itracked] + det = detections_second[idet] + if track.state == TrackState.Tracked: + track.update(det, self.frame_id) + activated_starcks.append(track) + else: + track.re_activate(det, self.frame_id, new_id=False) + refind_stracks.append(track) + + # implement LM from ConfTrack paper + """Step 4: low-confidence track matching with high-conf dets""" + # Associate with high score detection boxes + stracks_conf_remain = [r_tracked_tracks[i] for i in track_conf_remain] + ious_dists = matching.iou_distance(low_tent, stracks_conf_remain) + _, low_tent_valid, _ = matching.linear_assignment( + ious_dists, thresh=1 - 0.7 + ) # want to get rid of tracks with low iou costs + stracks_low_tent_valid = [low_tent[i] for i in low_tent_valid] + stracks_det_high_remain = [detections[i] for i in det_high_remain] + C_low_ious = matching.iou_distance( + stracks_low_tent_valid, stracks_det_high_remain + ) + + matches, track_tent_remain, det_high_remain = matching.linear_assignment( + C_low_ious, thresh=0.3 + ) # need to find this val in ConfTrack paper + + for itracked, idet in matches: + low_tent[itracked].update(stracks_det_high_remain[idet], self.frame_id) + activated_starcks.append(low_tent[itracked]) + + """Deal with unconfirmed tracks, usually tracks with only one beginning frame""" + for it in track_tent_remain: + track = stracks_low_tent_valid[it] + track.mark_removed() + removed_stracks.append(track) + # left over confirmed tracks get lost + for it in track_conf_remain: + track = r_tracked_tracks[it] + if not track.state == TrackState.Lost: + track.mark_lost() + lost_tracks.append(track) + + """ Step 5: Init new stracks""" + for inew in det_high_remain: + track = stracks_det_high_remain[inew] + if track.score < self.new_track_thresh: + continue + + track.activate(self.kalman_filter, self.frame_id) + activated_starcks.append(track) + + for inew in det_low_remain: + track = detections_second[inew] + if track.score < self.new_track_thresh: + continue + + track.activate(self.kalman_filter, self.frame_id) + activated_starcks.append(track) + + """ Step 6: Update state""" + for track in self.lost_tracks: + if self.frame_id - track.end_frame > self.max_time_lost: + track.mark_removed() + removed_stracks.append(track) + + """ Merge """ + self.tracked_tracks = [ + t for t in self.tracked_tracks if t.state == TrackState.Tracked + ] + self.tracked_tracks = joint_stracks(self.tracked_tracks, activated_starcks) + self.tracked_tracks = joint_stracks(self.tracked_tracks, refind_stracks) + self.lost_tracks = sub_stracks(self.lost_tracks, self.tracked_tracks) + self.lost_tracks.extend(lost_tracks) + self.lost_tracks = sub_stracks(self.lost_tracks, self.removed_stracks) + self.removed_tracks.extend(removed_stracks) + self.tracked_tracks, self.lost_tracks = remove_duplicate_stracks( + self.tracked_tracks, self.lost_tracks + ) + + output_stracks = [track for track in self.tracked_tracks] + + return output_stracks + + +def joint_stracks(tlista, tlistb): + exists = {} + res = [] + for t in tlista: + exists[t.track_id] = 1 + res.append(t) + for t in tlistb: + tid = t.track_id + if not exists.get(tid, 0): + exists[tid] = 1 + res.append(t) + return res + + +def sub_stracks(tlista, tlistb): + stracks = {} + for t in tlista: + stracks[t.track_id] = t + for t in tlistb: + tid = t.track_id + if stracks.get(tid, 0): + del stracks[tid] + return list(stracks.values()) + + +def remove_duplicate_stracks(stracksa, stracksb): + pdist = matching.iou_distance(stracksa, stracksb) + pairs = np.where(pdist < 0.15) + dupa, dupb = list(), list() + for p, q in zip(*pairs): + timep = stracksa[p].frame_id - stracksa[p].start_frame + timeq = stracksb[q].frame_id - stracksb[q].start_frame + if timep > timeq: + dupb.append(q) + else: + dupa.append(p) + resa = [t for i, t in enumerate(stracksa) if i not in dupa] + resb = [t for i, t in enumerate(stracksb) if i not in dupb] + return resa, resb diff --git a/supervision/tracker/confidence_tracker/kalman_filter.py b/supervision/tracker/confidence_tracker/kalman_filter.py new file mode 100644 index 000000000..740c454da --- /dev/null +++ b/supervision/tracker/confidence_tracker/kalman_filter.py @@ -0,0 +1,267 @@ +from typing import Tuple + +import numpy as np +import scipy.linalg + +""" +Table for the 0.95 quantile of the chi-square distribution with N degrees of +freedom (contains values for N=1, ..., 9). Taken from MATLAB/Octave's chi2inv +function and used as Mahalanobis gating threshold. +""" +chi2inv95 = { + 1: 3.8415, + 2: 5.9915, + 3: 7.8147, + 4: 9.4877, + 5: 11.070, + 6: 12.592, + 7: 14.067, + 8: 15.507, + 9: 16.919, +} + + +class KalmanFilter(object): + """ + A simple Kalman filter for tracking bounding boxes in image space. + + The 8-dimensional state space + + x, y, w, h, vx, vy, vw, vh + + contains the bounding box center position (x, y), width w, height h, + and their respective velocities. + + Object motion follows a constant velocity model. The bounding box location + (x, y, w, h) is taken as direct observation of the state space (linear + observation model). + + Parameters: + conf_thresh (float, optional): For detection boxes whose confidence score + are lower than conf_threshold, change the location of the "measurement" + (detected box) to be closer to the location of the predicted box, + weighted by the confidence of the box. + + Equation: + measurement = measurement + (projected_mean - measurement) * conf_cost + + cov_alpha (int, optional): Amplifying coefficient for Noise Scale Adaptive + Kalman Filter. If the confidence is low, then the noise in the + covariance projected into the measurement space is greater. + + Equation: + innovation_cov = innovation_cov * (1 - conf) * self.cov_alpha + + + """ + + def __init__(self, conf_thresh: float = 0.6, cov_alpha: int = 10): + ndim, dt = 4, 1.0 + + # Create Kalman filter model matrices. + self._motion_mat = np.eye(2 * ndim, 2 * ndim) + for i in range(ndim): + self._motion_mat[i, ndim + i] = dt + self._update_mat = np.eye(ndim, 2 * ndim) + + # Motion and observation uncertainty are chosen relative to the current + # state estimate. These weights control the amount of uncertainty in + # the model. This is a bit hacky. + self._std_weight_position = 1.0 / 20 + self._std_weight_velocity = 1.0 / 160 + + self.conf_thresh = conf_thresh + self.cov_alpha = cov_alpha + + def initiate(self, measurement): + """Create track from unassociated measurement. + + Parameters: + measurement (ndarray) + Bounding box coordinates (x, y, w, h) with center position (x, y), + width w, and height h. + + Returns: + Tuple[ndarray, ndarray]: Returns the mean vector (8 dimensional) and + covariance matrix (8x8 dimensional) of the new track. + Unobserved velocities are initialized to 0 mean. + + """ + mean_pos = measurement + mean_vel = np.zeros_like(mean_pos) + mean = np.r_[mean_pos, mean_vel] + + std = [ + 2 * self._std_weight_position * measurement[2], + 2 * self._std_weight_position * measurement[3], + 2 * self._std_weight_position * measurement[2], + 2 * self._std_weight_position * measurement[3], + 10 * self._std_weight_velocity * measurement[2], + 10 * self._std_weight_velocity * measurement[3], + 10 * self._std_weight_velocity * measurement[2], + 10 * self._std_weight_velocity * measurement[3], + ] + covariance = np.diag(np.square(std)) + return mean, covariance + + def predict( + self, mean: np.ndarray, covariance: np.ndarray + ) -> Tuple[np.ndarray, np.ndarray]: + """ + Run Kalman filter prediction step. + + Args: + mean (ndarray): The 8 dimensional mean vector of the object + state at the previous time step. + covariance (ndarray): The 8x8 dimensional covariance matrix of + the object state at the previous time step. + + Returns: + Tuple[ndarray, ndarray]: Returns the mean vector and + covariance matrix of the predicted state. + Unobserved velocities are initialized to 0 mean. + """ + std_pos = [ + self._std_weight_position * mean[2], + self._std_weight_position * mean[3], + self._std_weight_position * mean[2], + self._std_weight_position * mean[3], + ] + std_vel = [ + self._std_weight_velocity * mean[2], + self._std_weight_velocity * mean[3], + self._std_weight_velocity * mean[2], + self._std_weight_velocity * mean[3], + ] + motion_cov = np.diag(np.square(np.r_[std_pos, std_vel])) + + mean = np.dot(mean, self._motion_mat.T) + covariance = ( + np.linalg.multi_dot((self._motion_mat, covariance, self._motion_mat.T)) + + motion_cov + ) + + return mean, covariance + + def project( + self, mean: np.ndarray, covariance: np.ndarray, confidence: float + ) -> Tuple[np.ndarray, np.ndarray]: + """ + Project state distribution to measurement space. + + Args: + mean (ndarray): The state's mean vector (8 dimensional array). + covariance (ndarray): The state's covariance matrix (8x8 dimensional). + confidence (float): The confidence of the current measurement. + + Returns: + Tuple[ndarray, ndarray]: Returns the projected mean and + covariance matrix of the given state estimate. + """ + std = [ + self._std_weight_position * mean[2], + self._std_weight_position * mean[3], + self._std_weight_position * mean[2], + self._std_weight_position * mean[3], + ] + innovation_cov = np.diag(np.square(std)) + + # implement NK from ConfTrack paper + innovation_cov = innovation_cov * (1 - confidence) * self.cov_alpha + + mean = np.dot(self._update_mat, mean) + + covariance = np.linalg.multi_dot( + (self._update_mat, covariance, self._update_mat.T) + ) + return mean, covariance + innovation_cov + + def multi_predict( + self, mean: np.ndarray, covariance: np.ndarray + ) -> Tuple[np.ndarray, np.ndarray]: + """ + Run Kalman filter prediction step (Vectorized version). + + Args: + mean (ndarray): The Nx8 dimensional mean matrix + of the object states at the previous time step. + covariance (ndarray): The Nx8x8 dimensional covariance matrices + of the object states at the previous time step. + + Returns: + Tuple[ndarray, ndarray]: Returns the mean vector and + covariance matrix of the predicted state. + Unobserved velocities are initialized to 0 mean. + """ + std_pos = [ + self._std_weight_position * mean[:, 2], + self._std_weight_position * mean[:, 3], + self._std_weight_position * mean[:, 2], + self._std_weight_position * mean[:, 3], + ] + std_vel = [ + self._std_weight_velocity * mean[:, 2], + self._std_weight_velocity * mean[:, 3], + self._std_weight_velocity * mean[:, 2], + self._std_weight_velocity * mean[:, 3], + ] + sqr = np.square(np.r_[std_pos, std_vel]).T + + motion_cov = [] + for i in range(len(mean)): + motion_cov.append(np.diag(sqr[i])) + motion_cov = np.asarray(motion_cov) + + mean = np.dot(mean, self._motion_mat.T) + left = np.dot(self._motion_mat, covariance).transpose((1, 0, 2)) + covariance = np.dot(left, self._motion_mat.T) + motion_cov + + return mean, covariance + + def update( + self, + mean: np.ndarray, + covariance: np.ndarray, + measurement: np.ndarray, + confidence: float, + ) -> Tuple[np.ndarray, np.ndarray]: + """ + Run Kalman filter correction step. + + Args: + mean (ndarray): The predicted state's mean vector (8 dimensional). + covariance (ndarray): The state's covariance matrix (8x8 dimensional). + measurement (ndarray): The 4-dimensional measurement vector (x, y, a, h), + where (x, y) is the center position, a the aspect ratio, + and h the height of the bounding box. + confidence (float): The confidence of the measurement. + + Returns: + Tuple[ndarray, ndarray]: Returns the measurement-corrected + state distribution. + """ + projected_mean, projected_cov = self.project(mean, covariance, confidence) + + chol_factor, lower = scipy.linalg.cho_factor( + projected_cov, lower=True, check_finite=False + ) + kalman_gain = scipy.linalg.cho_solve( + (chol_factor, lower), + np.dot(covariance, self._update_mat.T).T, + check_finite=False, + ).T + + # implement CW from ConfTrack paper + if confidence < self.conf_thresh: + conf_cost = 1 - confidence + measurement = measurement + (projected_mean - measurement) * conf_cost + + innovation = measurement - projected_mean # ~yk in GIAO paper + + new_mean = mean + np.dot(innovation, kalman_gain.T) + + new_covariance = covariance - np.linalg.multi_dot( + (kalman_gain, projected_cov, kalman_gain.T) + ) + + return new_mean, new_covariance diff --git a/supervision/tracker/confidence_tracker/matching.py b/supervision/tracker/confidence_tracker/matching.py new file mode 100644 index 000000000..22b75c1c9 --- /dev/null +++ b/supervision/tracker/confidence_tracker/matching.py @@ -0,0 +1,84 @@ +from typing import List, Tuple + +import numpy as np +from scipy.optimize import linear_sum_assignment + +from supervision.detection.utils import box_iou_batch + + +def indices_to_matches( + cost_matrix: np.ndarray, indices: np.ndarray, thresh: float +) -> Tuple[np.ndarray, tuple, tuple]: + matched_cost = cost_matrix[tuple(zip(*indices))] + matched_mask = matched_cost <= thresh + + matches = indices[matched_mask] + unmatched_a = tuple(set(range(cost_matrix.shape[0])) - set(matches[:, 0])) + unmatched_b = tuple(set(range(cost_matrix.shape[1])) - set(matches[:, 1])) + return matches, unmatched_a, unmatched_b + + +def linear_assignment( + cost_matrix: np.ndarray, thresh: float +) -> [np.ndarray, Tuple[int], Tuple[int, int]]: + if cost_matrix.size == 0: + return ( + np.empty((0, 2), dtype=int), + tuple(range(cost_matrix.shape[0])), + tuple(range(cost_matrix.shape[1])), + ) + + cost_matrix[cost_matrix > thresh] = thresh + 1e-4 + row_ind, col_ind = linear_sum_assignment(cost_matrix) + indices = np.column_stack((row_ind, col_ind)) + + return indices_to_matches(cost_matrix, indices, thresh) + + +def ious(atlbrs, btlbrs): + """ + Compute cost based on IoU + :type atlbrs: list[tlbr] | np.ndarray + :type atlbrs: list[tlbr] | np.ndarray + + :rtype ious np.ndarray + """ + ious = np.zeros((len(atlbrs), len(btlbrs)), dtype=float) + if ious.size == 0: + return ious + + ious = box_iou_batch( + np.ascontiguousarray(atlbrs, dtype=float), + np.ascontiguousarray(btlbrs, dtype=float), + ) + + return ious + + +def iou_distance(atracks: List, btracks: List) -> np.ndarray: + if (len(atracks) > 0 and isinstance(atracks[0], np.ndarray)) or ( + len(btracks) > 0 and isinstance(btracks[0], np.ndarray) + ): + atlbrs = atracks + btlbrs = btracks + else: + atlbrs = [track.tlbr for track in atracks] + btlbrs = [track.tlbr for track in btracks] + + _ious = np.zeros((len(atlbrs), len(btlbrs)), dtype=np.float32) + if _ious.size != 0: + _ious = box_iou_batch(np.asarray(atlbrs), np.asarray(btlbrs)) + cost_matrix = 1 - _ious + + return cost_matrix + + +def fuse_score(cost_matrix, detections): + if cost_matrix.size == 0: + return cost_matrix + iou_sim = 1 - cost_matrix + det_scores = np.array([det.score for det in detections]) + det_scores = np.expand_dims(det_scores, axis=0).repeat(cost_matrix.shape[0], axis=0) + fuse_sim = iou_sim * det_scores + fuse_cost = 1 - fuse_sim + return fuse_cost From 875e8263d8f7b0a5f6785b1e021e71559280b7db Mon Sep 17 00:00:00 2001 From: Raif Olson Date: Fri, 22 Mar 2024 10:06:43 -0500 Subject: [PATCH 2/5] Update docstrings and remove unnecessary ConfTrack initialization args. --- .../tracker/confidence_tracker/core.py | 67 +++++++++++++------ 1 file changed, 45 insertions(+), 22 deletions(-) diff --git a/supervision/tracker/confidence_tracker/core.py b/supervision/tracker/confidence_tracker/core.py index c401af1e1..f440eaee6 100644 --- a/supervision/tracker/confidence_tracker/core.py +++ b/supervision/tracker/confidence_tracker/core.py @@ -235,26 +235,48 @@ def detections2boxes(detections: Detections) -> np.ndarray: class ConfTrack: + """ + Initialize the ByteTrack object. + + Parameters: + detection_high_threshold (float, optional): Detection confidence threshold + for first matching step. Increasing detection_high_threshold increases + track initialization accuracy at the risk of missing low-confidence + tracks altogether. + Decreasing it increases the number of tracks initialized but risks + initializing short, likely in-valid tracks. + tentative_track_high_threshold (float, optional): Track confidence threshold + for matching tentative tracks. Similar to detection_high_threshold, + increasing it increases track initialization accuracy at the risk of + missing low-confidence tracks. + lost_track_buffer (int, optional): Number of frames to buffer when a track is + lost. + Increasing lost_track_buffer enhances occlusion handling, significantly + reducing the likelihood of track fragmentation or disappearance caused + by brief detection gaps. + minimum_matching_threshold (float, optional): Threshold for matching tracks + with detections. + Increasing minimum_matching_threshold improves accuracy but risks + fragmentation. + Decreasing it improves completeness but risks false positives and drift. + frame_rate (int, optional): The frame rate of the video. + """ + def __init__( self, - track_high_thresh: float = 0.6, - track_low_thresh: float = 0.2, - new_track_thresh: float = 0.1, - tent_conf_thresh: float = 0.7, + detection_high_threshold: float = 0.6, + tentative_track_high_threshold: float = 0.7, minimum_matching_threshold: float = 0.8, lost_track_buffer: int = 30, - proximity_thresh: float = 0.6, frame_rate: int = 30, ): BaseTrack.clear_count() self.frame_id = 0 - self.track_high_thresh = track_high_thresh - self.track_low_thresh = track_low_thresh - self.new_track_thresh = new_track_thresh + self.detection_high_threshold = detection_high_threshold - self.tent_conf_thresh = tent_conf_thresh + self.tentative_track_high_threshold = tentative_track_high_threshold self.minimum_matching_threshold = minimum_matching_threshold @@ -281,13 +303,13 @@ def update_with_detections(self, detections: Detections) -> Detections: >>> from ultralytics import YOLO >>> model = YOLO(...) - >>> byte_tracker = sv.ByteTrack() + >>> conf_tracker = sv.ConfTrack() >>> annotator = sv.BoxAnnotator() >>> def callback(frame: np.ndarray, index: int) -> np.ndarray: ... results = model(frame)[0] ... detections = sv.Detections.from_ultralytics(results) - ... detections = byte_tracker.update_with_detections(detections) + ... detections = conf_tracker.update_with_detections(detections) ... labels = [ ... f"#{tracker_id} {model.model.names[class_id]} {confidence:0.2f}" ... for _, _, confidence, class_id, tracker_id @@ -324,7 +346,7 @@ def update_with_detections(self, detections: Detections) -> Detections: def reset(self): """ - Resets the internal state of the ByteTrack tracker. + Resets the internal state of the ConfTrack tracker. This method clears the tracking data, including tracked, lost, and removed tracks, as well as resetting the frame counter. It's @@ -358,13 +380,13 @@ def update_with_tensors(self, tensors: np.ndarray) -> List[STrack]: bboxes = tensors[:, :4] # Remove bad detections - inds_low = scores > self.track_low_thresh + inds_low = scores > 0.2 bboxes = bboxes[inds_low] scores = scores[inds_low] class_ids = class_ids[inds_low] # Find high threshold detections - inds_high = scores > self.track_high_thresh + inds_high = scores > self.detection_high_threshold dets = bboxes[inds_high] scores_keep = scores[inds_high] class_ids_keep = class_ids[inds_high] @@ -391,7 +413,7 @@ def update_with_tensors(self, tensors: np.ndarray) -> List[STrack]: for track in self.tracked_tracks: if not track.is_activated: # implement LM from ConfTrack paper - if track.score < self.tent_conf_thresh: + if track.score < self.tentative_track_high_threshold: low_tent.append(track) else: high_tent.append(track) @@ -428,7 +450,7 @@ def update_with_tensors(self, tensors: np.ndarray) -> List[STrack]: refind_stracks.append(track) """ Step 3: Second association, with low score detection boxes""" - # association the untrack to the low score detections + # association the untracked to the low score detections if len(dets_second) > 0: """Detections""" detections_second = [ @@ -458,12 +480,12 @@ def update_with_tensors(self, tensors: np.ndarray) -> List[STrack]: refind_stracks.append(track) # implement LM from ConfTrack paper - """Step 4: low-confidence track matching with high-conf dets""" + """Step 4: low-confidence track matching with high-confidence detections""" # Associate with high score detection boxes stracks_conf_remain = [r_tracked_tracks[i] for i in track_conf_remain] ious_dists = matching.iou_distance(low_tent, stracks_conf_remain) _, low_tent_valid, _ = matching.linear_assignment( - ious_dists, thresh=1 - 0.7 + ious_dists, thresh=(1 - 0.7) ) # want to get rid of tracks with low iou costs stracks_low_tent_valid = [low_tent[i] for i in low_tent_valid] stracks_det_high_remain = [detections[i] for i in det_high_remain] @@ -473,13 +495,14 @@ def update_with_tensors(self, tensors: np.ndarray) -> List[STrack]: matches, track_tent_remain, det_high_remain = matching.linear_assignment( C_low_ious, thresh=0.3 - ) # need to find this val in ConfTrack paper + ) # thresh is from ConfTrack paper for itracked, idet in matches: low_tent[itracked].update(stracks_det_high_remain[idet], self.frame_id) activated_starcks.append(low_tent[itracked]) - """Deal with unconfirmed tracks, usually tracks with only one beginning frame""" + """Deal with left over tentative tracks, + usually tracks with only one beginning frame""" for it in track_tent_remain: track = stracks_low_tent_valid[it] track.mark_removed() @@ -494,7 +517,7 @@ def update_with_tensors(self, tensors: np.ndarray) -> List[STrack]: """ Step 5: Init new stracks""" for inew in det_high_remain: track = stracks_det_high_remain[inew] - if track.score < self.new_track_thresh: + if track.score < 0.1: continue track.activate(self.kalman_filter, self.frame_id) @@ -502,7 +525,7 @@ def update_with_tensors(self, tensors: np.ndarray) -> List[STrack]: for inew in det_low_remain: track = detections_second[inew] - if track.score < self.new_track_thresh: + if track.score < 0.1: continue track.activate(self.kalman_filter, self.frame_id) From 6d77f1d404eaf37ae0eacfea2282c0d43c237791 Mon Sep 17 00:00:00 2001 From: Raif Olson Date: Fri, 22 Mar 2024 10:28:47 -0500 Subject: [PATCH 3/5] Add tracker to init import --- supervision/__init__.py | 1 + 1 file changed, 1 insertion(+) diff --git a/supervision/__init__.py b/supervision/__init__.py index 35f525b04..34def4802 100644 --- a/supervision/__init__.py +++ b/supervision/__init__.py @@ -71,6 +71,7 @@ from supervision.geometry.utils import get_polygon_center from supervision.metrics.detection import ConfusionMatrix, MeanAveragePrecision from supervision.tracker.byte_tracker.core import ByteTrack +from supervision.tracker.confidence_tracker.core import ConfTrack from supervision.utils.file import list_files_with_extensions from supervision.utils.image import ImageSink, crop_image, place_image, resize_image from supervision.utils.notebook import plot_image, plot_images_grid From 9718e3377222a39542f281326a5f77d88341b0a8 Mon Sep 17 00:00:00 2001 From: Raif Olson Date: Fri, 22 Mar 2024 10:34:45 -0500 Subject: [PATCH 4/5] Fix naming of self.removed_tracks --- supervision/tracker/confidence_tracker/core.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/supervision/tracker/confidence_tracker/core.py b/supervision/tracker/confidence_tracker/core.py index f440eaee6..c7290eaaf 100644 --- a/supervision/tracker/confidence_tracker/core.py +++ b/supervision/tracker/confidence_tracker/core.py @@ -545,7 +545,7 @@ def update_with_tensors(self, tensors: np.ndarray) -> List[STrack]: self.tracked_tracks = joint_stracks(self.tracked_tracks, refind_stracks) self.lost_tracks = sub_stracks(self.lost_tracks, self.tracked_tracks) self.lost_tracks.extend(lost_tracks) - self.lost_tracks = sub_stracks(self.lost_tracks, self.removed_stracks) + self.lost_tracks = sub_stracks(self.lost_tracks, self.removed_tracks) self.removed_tracks.extend(removed_stracks) self.tracked_tracks, self.lost_tracks = remove_duplicate_stracks( self.tracked_tracks, self.lost_tracks From c6339690b2b8712b0b45894874ce56398af6cdff Mon Sep 17 00:00:00 2001 From: Raif Olson Date: Fri, 22 Mar 2024 19:23:51 -0500 Subject: [PATCH 5/5] remove unused functions and change track list operation functions to be the same as ByteTrack --- .../tracker/confidence_tracker/core.py | 128 +++++++++++------- .../tracker/confidence_tracker/matching.py | 20 --- 2 files changed, 76 insertions(+), 72 deletions(-) diff --git a/supervision/tracker/confidence_tracker/core.py b/supervision/tracker/confidence_tracker/core.py index c7290eaaf..4c2c653d5 100644 --- a/supervision/tracker/confidence_tracker/core.py +++ b/supervision/tracker/confidence_tracker/core.py @@ -1,5 +1,5 @@ from collections import deque -from typing import List +from typing import List, Tuple import numpy as np @@ -280,7 +280,6 @@ def __init__( self.minimum_matching_threshold = minimum_matching_threshold - # self.buffer_size = int(frame_rate / 30.0 * args.lost_track_buffer) self.max_time_lost = int(frame_rate / 30.0 * lost_track_buffer) self.kalman_filter = KalmanFilter(0.6, 10) @@ -326,11 +325,7 @@ def update_with_detections(self, detections: Detections) -> Detections: ``` """ tensors = detections2boxes(detections) - # print(f"tensors: {tensors}") - tracks = self.update_with_tensors( - # maybe extract features here - tensors - ) + tracks = self.update_with_tensors(tensors) detections = Detections.empty() if len(tracks) > 0: detections.xyxy = np.array([track.tlbr for track in tracks], dtype=float) @@ -421,13 +416,13 @@ def update_with_tensors(self, tensors: np.ndarray) -> List[STrack]: tracked_tracks.append(track) """ Step 2: First association, with high score detection boxes""" - strack_pool = joint_stracks(tracked_tracks, self.lost_tracks) + strack_pool = joint_tracks(tracked_tracks, self.lost_tracks) # Predict the current location with KF STrack.multi_predict(strack_pool) # LM algorithm - strack_pool = joint_stracks(strack_pool, high_tent) + strack_pool = joint_tracks(strack_pool, high_tent) # Associate with high score detection boxes ious_dists = matching.iou_distance(strack_pool, detections) @@ -541,13 +536,13 @@ def update_with_tensors(self, tensors: np.ndarray) -> List[STrack]: self.tracked_tracks = [ t for t in self.tracked_tracks if t.state == TrackState.Tracked ] - self.tracked_tracks = joint_stracks(self.tracked_tracks, activated_starcks) - self.tracked_tracks = joint_stracks(self.tracked_tracks, refind_stracks) - self.lost_tracks = sub_stracks(self.lost_tracks, self.tracked_tracks) + self.tracked_tracks = joint_tracks(self.tracked_tracks, activated_starcks) + self.tracked_tracks = joint_tracks(self.tracked_tracks, refind_stracks) + self.lost_tracks = sub_tracks(self.lost_tracks, self.tracked_tracks) self.lost_tracks.extend(lost_tracks) - self.lost_tracks = sub_stracks(self.lost_tracks, self.removed_tracks) + self.lost_tracks = sub_tracks(self.lost_tracks, self.removed_tracks) self.removed_tracks.extend(removed_stracks) - self.tracked_tracks, self.lost_tracks = remove_duplicate_stracks( + self.tracked_tracks, self.lost_tracks = remove_duplicate_tracks( self.tracked_tracks, self.lost_tracks ) @@ -556,42 +551,71 @@ def update_with_tensors(self, tensors: np.ndarray) -> List[STrack]: return output_stracks -def joint_stracks(tlista, tlistb): - exists = {} - res = [] - for t in tlista: - exists[t.track_id] = 1 - res.append(t) - for t in tlistb: - tid = t.track_id - if not exists.get(tid, 0): - exists[tid] = 1 - res.append(t) - return res - - -def sub_stracks(tlista, tlistb): - stracks = {} - for t in tlista: - stracks[t.track_id] = t - for t in tlistb: - tid = t.track_id - if stracks.get(tid, 0): - del stracks[tid] - return list(stracks.values()) - - -def remove_duplicate_stracks(stracksa, stracksb): - pdist = matching.iou_distance(stracksa, stracksb) - pairs = np.where(pdist < 0.15) - dupa, dupb = list(), list() - for p, q in zip(*pairs): - timep = stracksa[p].frame_id - stracksa[p].start_frame - timeq = stracksb[q].frame_id - stracksb[q].start_frame - if timep > timeq: - dupb.append(q) +def joint_tracks( + track_list_a: List[STrack], track_list_b: List[STrack] +) -> List[STrack]: + """ + Joins two lists of tracks, ensuring that the resulting list does not + contain tracks with duplicate track_id values. + + Parameters: + track_list_a: First list of tracks (with track_id attribute). + track_list_b: Second list of tracks (with track_id attribute). + + Returns: + Combined list of tracks from track_list_a and track_list_b + without duplicate track_id values. + """ + seen_track_ids = set() + result = [] + + for track in track_list_a + track_list_b: + if track.track_id not in seen_track_ids: + seen_track_ids.add(track.track_id) + result.append(track) + + return result + + +def sub_tracks(track_list_a: List, track_list_b: List) -> List[int]: + """ + Returns a list of tracks from track_list_a after removing any tracks + that share the same track_id with tracks in track_list_b. + + Parameters: + track_list_a: List of tracks (with track_id attribute). + track_list_b: List of tracks (with track_id attribute) to + be subtracted from track_list_a. + Returns: + List of remaining tracks from track_list_a after subtraction. + """ + tracks = {track.track_id: track for track in track_list_a} + track_ids_b = {track.track_id for track in track_list_b} + + for track_id in track_ids_b: + tracks.pop(track_id, None) + + return list(tracks.values()) + + +def remove_duplicate_tracks(tracks_a: List, tracks_b: List) -> Tuple[List, List]: + pairwise_distance = matching.iou_distance(tracks_a, tracks_b) + matching_pairs = np.where(pairwise_distance < 0.15) + + duplicates_a, duplicates_b = set(), set() + for track_index_a, track_index_b in zip(*matching_pairs): + time_a = tracks_a[track_index_a].frame_id - tracks_a[track_index_a].start_frame + time_b = tracks_b[track_index_b].frame_id - tracks_b[track_index_b].start_frame + if time_a > time_b: + duplicates_b.add(track_index_b) else: - dupa.append(p) - resa = [t for i, t in enumerate(stracksa) if i not in dupa] - resb = [t for i, t in enumerate(stracksb) if i not in dupb] - return resa, resb + duplicates_a.add(track_index_a) + + result_a = [ + track for index, track in enumerate(tracks_a) if index not in duplicates_a + ] + result_b = [ + track for index, track in enumerate(tracks_b) if index not in duplicates_b + ] + + return result_a, result_b diff --git a/supervision/tracker/confidence_tracker/matching.py b/supervision/tracker/confidence_tracker/matching.py index 22b75c1c9..9afe8e251 100644 --- a/supervision/tracker/confidence_tracker/matching.py +++ b/supervision/tracker/confidence_tracker/matching.py @@ -35,26 +35,6 @@ def linear_assignment( return indices_to_matches(cost_matrix, indices, thresh) -def ious(atlbrs, btlbrs): - """ - Compute cost based on IoU - :type atlbrs: list[tlbr] | np.ndarray - :type atlbrs: list[tlbr] | np.ndarray - - :rtype ious np.ndarray - """ - ious = np.zeros((len(atlbrs), len(btlbrs)), dtype=float) - if ious.size == 0: - return ious - - ious = box_iou_batch( - np.ascontiguousarray(atlbrs, dtype=float), - np.ascontiguousarray(btlbrs, dtype=float), - ) - - return ious - - def iou_distance(atracks: List, btracks: List) -> np.ndarray: if (len(atracks) > 0 and isinstance(atracks[0], np.ndarray)) or ( len(btracks) > 0 and isinstance(btracks[0], np.ndarray)