From d83c5b3f7ee20ddf461afb89caaf54eaea8d85d2 Mon Sep 17 00:00:00 2001 From: Evan Mahony <58450417+evanmahony@users.noreply.github.com> Date: Tue, 7 Dec 2021 03:27:29 +0000 Subject: [PATCH] Changed indexing of the label matrix to match how many classes exist --- ML/Pytorch/object_detection/YOLO/dataset.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/ML/Pytorch/object_detection/YOLO/dataset.py b/ML/Pytorch/object_detection/YOLO/dataset.py index 2958da79..b8743eb2 100755 --- a/ML/Pytorch/object_detection/YOLO/dataset.py +++ b/ML/Pytorch/object_detection/YOLO/dataset.py @@ -73,16 +73,16 @@ def __getitem__(self, index): # If no object already found for specific cell i,j # Note: This means we restrict to ONE object # per cell! - if label_matrix[i, j, 20] == 0: + if label_matrix[i, j, self.C] == 0: # Set that there exists an object - label_matrix[i, j, 20] = 1 + label_matrix[i, j, self.C] = 1 # Box coordinates box_coordinates = torch.tensor( [x_cell, y_cell, width_cell, height_cell] ) - label_matrix[i, j, 21:25] = box_coordinates + label_matrix[i, j, (self.C + 1):(self.C + 5)] = box_coordinates # Set one hot encoding for class_label label_matrix[i, j, class_label] = 1