diff --git a/maskrcnn_benchmark/structures/bounding_box.py b/maskrcnn_benchmark/structures/bounding_box.py index 25791d578..bb17cad8b 100644 --- a/maskrcnn_benchmark/structures/bounding_box.py +++ b/maskrcnn_benchmark/structures/bounding_box.py @@ -61,12 +61,12 @@ def convert(self, mode): # self.mode xmin, ymin, xmax, ymax = self._split_into_xyxy() if mode == "xyxy": - bbox = torch.cat((xmin, ymin, xmax, ymax), dim=-1) + bbox = torch.cat((xmin, ymin, xmax, ymax), dim=1) bbox = BoxList(bbox, self.size, mode=mode) else: TO_REMOVE = 1 bbox = torch.cat( - (xmin, ymin, xmax - xmin + TO_REMOVE, ymax - ymin + TO_REMOVE), dim=-1 + (xmin, ymin, xmax - xmin + TO_REMOVE, ymax - ymin + TO_REMOVE), dim=1 ) bbox = BoxList(bbox, self.size, mode=mode) bbox._copy_extra_fields(self) @@ -74,11 +74,11 @@ def convert(self, mode): def _split_into_xyxy(self): if self.mode == "xyxy": - xmin, ymin, xmax, ymax = self.bbox.split(1, dim=-1) + xmin, ymin, xmax, ymax = self.bbox.split(1, dim=1) return xmin, ymin, xmax, ymax elif self.mode == "xywh": TO_REMOVE = 1 - xmin, ymin, w, h = self.bbox.split(1, dim=-1) + xmin, ymin, w, h = self.bbox.split(1, dim=1) return ( xmin, ymin, @@ -115,7 +115,7 @@ def resize(self, size, *args, **kwargs): scaled_ymin = ymin * ratio_height scaled_ymax = ymax * ratio_height scaled_box = torch.cat( - (scaled_xmin, scaled_ymin, scaled_xmax, scaled_ymax), dim=-1 + (scaled_xmin, scaled_ymin, scaled_xmax, scaled_ymax), dim=1 ) bbox = BoxList(scaled_box, size, mode="xyxy") # bbox._copy_extra_fields(self) @@ -154,7 +154,7 @@ def transpose(self, method): transposed_ymax = image_height - ymin transposed_boxes = torch.cat( - (transposed_xmin, transposed_ymin, transposed_xmax, transposed_ymax), dim=-1 + (transposed_xmin, transposed_ymin, transposed_xmax, transposed_ymax), dim=1 ) bbox = BoxList(transposed_boxes, self.size, mode="xyxy") # bbox._copy_extra_fields(self) @@ -182,7 +182,7 @@ def crop(self, box): is_empty = (cropped_xmin == cropped_xmax) | (cropped_ymin == cropped_ymax) cropped_box = torch.cat( - (cropped_xmin, cropped_ymin, cropped_xmax, cropped_ymax), dim=-1 + (cropped_xmin, cropped_ymin, cropped_xmax, cropped_ymax), dim=1 ) bbox = BoxList(cropped_box, (w, h), mode="xyxy") # bbox._copy_extra_fields(self)