Skip to content

Commit cd91fed

Browse files
author
Agustín Castro
committed
Allow for giving a single score for the whole object
1 parent 009a1b1 commit cd91fed

File tree

2 files changed

+10
-3
lines changed

2 files changed

+10
-3
lines changed

norfair/tracker.py

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -746,7 +746,7 @@ class Detection:
746746
Parameters
747747
----------
748748
points : np.ndarray
749-
Points detected. Must be a rank 2 array with shape `(n_points, n_dimensions)` where n_dimensions is 2 or 3.
749+
Points detected. Must be a rank 2 array with shape `(n_points, n_dimensions)`.
750750
scores : np.ndarray, optional
751751
An array of length `n_points` which assigns a score to each of the points defined in `points`.
752752
@@ -770,12 +770,19 @@ class Detection:
770770
def __init__(
771771
self,
772772
points: np.ndarray,
773-
scores: np.ndarray = None,
773+
scores: Union[float, int, np.ndarray] = None,
774774
data: Any = None,
775775
label: Hashable = None,
776776
embedding=None,
777777
):
778778
self.points = validate_points(points)
779+
780+
if isinstance(scores, np.ndarray):
781+
assert len(scores) == len(
782+
self.points
783+
), "scores should be a np.ndarray with it's length being equal to the amount of points."
784+
else:
785+
scores = np.zeros((len(points),)) + scores
779786
self.scores = scores
780787
self.data = data
781788
self.label = label

norfair/utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@ def validate_points(points: np.ndarray) -> np.array:
2020

2121
def raise_detection_error_message(points):
2222
message = "\n[red]INPUT ERROR:[/red]\n"
23-
message += f"Each `Detection` object should have a property `points` of shape (num_of_points_to_track, 2), not {points.shape}. Check your `Detection` list creation code.\n"
23+
message += f"Each `Detection` object should have a property `points` of shape (n_points, n_dimensions), not {points.shape}. Check your `Detection` list creation code.\n"
2424
message += "You can read the documentation for the `Detection` class here:\n"
2525
message += "https://tryolabs.github.io/norfair/reference/tracker/#norfair.tracker.Detection\n"
2626
raise ValueError(message)

0 commit comments

Comments
 (0)