Skip to content

Commit

Permalink
Merge pull request #1035 from rolson24/redesign-update_with_detection…
Browse files Browse the repository at this point in the history
…s-to-match-Detections-with-tracker_id-v2

Fix issue #754
  • Loading branch information
SkalskiP authored Mar 25, 2024
2 parents 4a6a6db + 1eb877c commit 495f9a9
Showing 1 changed file with 27 additions and 18 deletions.
45 changes: 27 additions & 18 deletions supervision/tracker/byte_tracker/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import numpy as np

from supervision.detection.core import Detections
from supervision.detection.utils import box_iou_batch
from supervision.tracker.byte_tracker import matching
from supervision.tracker.byte_tracker.basetrack import BaseTrack, TrackState
from supervision.tracker.byte_tracker.kalman_filter import KalmanFilter
Expand Down Expand Up @@ -270,27 +271,35 @@ def callback(frame: np.ndarray, index: int) -> np.ndarray:
```
"""

tracks = self.update_with_tensors(
tensors=detections2boxes(detections=detections)
)
detections = Detections.empty()
tensors = detections2boxes(detections=detections)

tracks = self.update_with_tensors(tensors=tensors)

final_detections = Detections.empty()

if len(tracks) > 0:
detections.xyxy = np.array(
[track.tlbr for track in tracks], dtype=np.float32
)
detections.class_id = np.array(
[int(t.class_ids) for t in tracks], dtype=int
)
detections.tracker_id = np.array(
[int(t.track_id) for t in tracks], dtype=int
)
detections.confidence = np.array(
[t.score for t in tracks], dtype=np.float32
)
detection_bounding_boxes = np.asarray([det[:4] for det in tensors])
track_bounding_boxes = np.asarray([track.tlbr for track in tracks])

ious = box_iou_batch(detection_bounding_boxes, track_bounding_boxes)

iou_costs = 1 - ious

matches, _, _ = matching.linear_assignment(iou_costs, 0.5)
for i, idet, itrack in enumerate(matches):
if i == 0:
final_detections = detections[[idet]]
final_detections.tracker_id[0] = int(tracks[itrack].track_id)
else:
current_detection = detections[[idet]]
current_detection.tracker_id[0] = int(tracks[itrack].track_id)
final_detections = Detections.merge(
[final_detections, current_detection]
)
else:
detections.tracker_id = np.array([], dtype=int)
final_detections.tracker_id = np.array([], dtype=int)

return detections
return final_detections

def reset(self):
"""
Expand Down

0 comments on commit 495f9a9

Please sign in to comment.