Skip to content

Commit

Permalink
Merge pull request #14 from kaanakan/dev
Browse files Browse the repository at this point in the history
empty detections bugfix
  • Loading branch information
kaanakan authored Dec 5, 2021
2 parents 6af4502 + 2007809 commit 5df948f
Showing 1 changed file with 12 additions and 3 deletions.
15 changes: 12 additions & 3 deletions confusion_matrix.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,13 +31,13 @@ def box_area(box):


class ConfusionMatrix:
def __init__(self, num_classes, CONF_THRESHOLD=0.3, IOU_THRESHOLD=0.5):
def __init__(self, num_classes: int, CONF_THRESHOLD=0.3, IOU_THRESHOLD=0.5):
self.matrix = np.zeros((num_classes + 1, num_classes + 1))
self.num_classes = num_classes
self.CONF_THRESHOLD = CONF_THRESHOLD
self.IOU_THRESHOLD = IOU_THRESHOLD

def process_batch(self, detections, labels):
def process_batch(self, detections, labels: np.ndarray):
"""
Return intersection-over-union (Jaccard index) of boxes.
Both sets of boxes are expected to be in (x1, y1, x2, y2) format.
Expand All @@ -47,8 +47,17 @@ def process_batch(self, detections, labels):
Returns:
None, updates confusion matrix accordingly
"""
detections = detections[detections[:, 4] > self.CONF_THRESHOLD]
gt_classes = labels[:, 0].astype(np.int16)

try:
detections = detections[detections[:, 4] > self.CONF_THRESHOLD]
except IndexError or TypeError:
# detections are empty, end of process
for i, label in enumerate(labels):
gt_class = gt_classes[i]
self.matrix[self.num_classes, gt_class] += 1
return

detection_classes = detections[:, 5].astype(np.int16)

all_ious = box_iou_calc(labels[:, 1:], detections[:, :4])
Expand Down

0 comments on commit 5df948f

Please sign in to comment.