diff --git a/maskrcnn_benchmark/config/defaults.py b/maskrcnn_benchmark/config/defaults.py index 367337f0d..b7bfffdea 100644 --- a/maskrcnn_benchmark/config/defaults.py +++ b/maskrcnn_benchmark/config/defaults.py @@ -62,7 +62,16 @@ # Flips _C.INPUT.HORIZONTAL_FLIP_PROB_TRAIN = 0.5 -_C.INPUT.VERTICAL_FLIP_PROB_TRAIN = 0.0 +_C.INPUT.VERTICAL_FLIP_PROB_TRAIN = 0.5 + +# Rotate +_C.INPUT.ANGLE_TRAIN = 10 + +# Linear Light Shade +_C.INPUT.VERTICAL_LIGHT_PROB_TRAIN = 0.5 +_C.INPUT.VERTICAL_LIGHT_SCALE_TRAIN = 20 +_C.INPUT.HORIZONTAL_LIGHT_PROB_TRAIN = 0.5 +_C.INPUT.HORIZONTAL_LIGHT_SCALE_TRAIN = 20 # ----------------------------------------------------------------------------- # Dataset diff --git a/maskrcnn_benchmark/data/transforms/build.py b/maskrcnn_benchmark/data/transforms/build.py index 098da11d4..e297ca501 100644 --- a/maskrcnn_benchmark/data/transforms/build.py +++ b/maskrcnn_benchmark/data/transforms/build.py @@ -12,6 +12,11 @@ def build_transforms(cfg, is_train=True): contrast = cfg.INPUT.CONTRAST saturation = cfg.INPUT.SATURATION hue = cfg.INPUT.HUE + angle = cfg.INPUT.ANGLE_TRAIN + light_vertical_prob = cfg.INPUT.VERTICAL_LIGHT_PROB_TRAIN + light_vertical_scale = cfg.INPUT.VERTICAL_LIGHT_SCALE_TRAIN + light_horizontal_prob = cfg.INPUT.HORIZONTAL_LIGHT_PROB_TRAIN + light_horizontal_scale = cfg.INPUT.HORIZONTAL_LIGHT_SCALE_TRAIN else: min_size = cfg.INPUT.MIN_SIZE_TEST max_size = cfg.INPUT.MAX_SIZE_TEST @@ -21,6 +26,11 @@ def build_transforms(cfg, is_train=True): contrast = 0.0 saturation = 0.0 hue = 0.0 + angle = 0.0 + light_vertical_prob = 0.0 + light_vertical_scale = 0.0 + light_horizontal_prob = 0.0 + light_horizontal_scale = 0.0 to_bgr255 = cfg.INPUT.TO_BGR255 normalize_transform = T.Normalize( @@ -35,7 +45,10 @@ def build_transforms(cfg, is_train=True): transform = T.Compose( [ + T.SmallAngleRotate(angle), color_jitter, + T.HorizontalLinearLight(light_horizontal_prob, light_horizontal_scale), + T.VerticalLinearLight(light_vertical_prob, light_vertical_scale), T.Resize(min_size, max_size), T.RandomHorizontalFlip(flip_horizontal_prob), T.RandomVerticalFlip(flip_vertical_prob), diff --git a/maskrcnn_benchmark/data/transforms/transforms.py b/maskrcnn_benchmark/data/transforms/transforms.py index 2d37dc72f..2aefa7bde 100644 --- a/maskrcnn_benchmark/data/transforms/transforms.py +++ b/maskrcnn_benchmark/data/transforms/transforms.py @@ -1,6 +1,9 @@ # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved. import random +from PIL import Image +import numpy as np +import cv2 import torch import torchvision from torchvision.transforms import functional as F @@ -83,6 +86,59 @@ def __call__(self, image, target): target = target.transpose(1) return image, target +class HorizontalLinearLight(object): + def __init__(self, prob=0.5, lightsacle=50): + self.prob = prob + self.lightsacle = lightsacle + + def __call__(self, image, target): + if random.random() < self.prob: + image = np.asarray(image.copy()) + h, w, c = image.shape + x = np.linspace(-1*self.lightsacle, self.lightsacle, w) + weight = np.expand_dims(x, axis=1) + weight = weight.repeat(c, axis=1) + image = image + weight + image[image < 0] = 0 + image[image > 255] = 255 + image = Image.fromarray(image.astype(np.uint8)) + return image, target + +class VerticalLinearLight(object): + def __init__(self, prob=0.5, lightsacle=50): + self.prob = prob + self.lightsacle = lightsacle + + def __call__(self, image, target): + if random.random() < self.prob: + image = np.asarray(image.copy()) + h, w, c = image.shape + x = np.linspace(-1*self.lightsacle, self.lightsacle, h) + weight = np.expand_dims(x, axis=1) + weight = weight.repeat(c, axis=1) + weight = np.expand_dims(weight, axis=1) + image = image + weight + image[image < 0] = 0 + image[image > 255] = 255 + image = Image.fromarray(image.astype(np.uint8)) + return image, target + +class SmallAngleRotate(object): + def __init__(self,angle=10): + self.angle_range = angle + + def __call__(self, image, target): + self.angle = random.randint(-1*self.angle_range, self.angle_range) + image = np.asarray(image.copy()) + h, w, _ = image.shape + cx = w / 2 + cy = h / 2 + M = cv2.getRotationMatrix2D((cx, cy), self.angle, 1.0) + image = cv2.warpAffine(image, M, (w, h)) + target = target.rotate(M) + image = Image.fromarray(image) + return image, target + class ColorJitter(object): def __init__(self, brightness=None, @@ -119,3 +175,4 @@ def __call__(self, image, target=None): if target is None: return image return image, target + diff --git a/maskrcnn_benchmark/structures/bounding_box.py b/maskrcnn_benchmark/structures/bounding_box.py index 25791d578..688571e1a 100644 --- a/maskrcnn_benchmark/structures/bounding_box.py +++ b/maskrcnn_benchmark/structures/bounding_box.py @@ -1,5 +1,6 @@ # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved. import torch +import numpy as np # transpose FLIP_LEFT_RIGHT = 0 @@ -25,7 +26,7 @@ def __init__(self, bbox, image_size, mode="xyxy"): ) if bbox.size(-1) != 4: raise ValueError( - "last dimension of bbox should have a " + "last dimenion of bbox should have a " "size of 4, got {}".format(bbox.size(-1)) ) if mode not in ("xyxy", "xywh"): @@ -164,9 +165,31 @@ def transpose(self, method): bbox.add_field(k, v) return bbox.convert(self.mode) + def rotate(self, matrix): + """ + Returns a rotated copy of this bounding box + + :param matrix: Rotation matrix caculated + according to certain degree by Opencv CV2::getRotationMatrix2D + """ + bbox = BoxList(self.bbox, self.size, mode="xyxy") + # bbox._copy_extra_fields(self) + for k, v in self.extra_fields.items(): + if not isinstance(v, torch.Tensor): + v = v.rotate(matrix) + newbox = [] + for polygon in v.polygons: + coods = polygon.polygons[0].numpy() + allx = coods[[0, 2, 4, 6]] + ally = coods[[1, 3, 5, 7]] + newbox.append([allx.min(), ally.min(), allx.max(), ally.max()]) + bbox.bbox = torch.from_numpy(np.asarray(newbox)) + bbox.add_field(k, v) + return bbox.convert(self.mode) + def crop(self, box): """ - Crops a rectangular region from this bounding box. The box is a + Cropss a rectangular region from this bounding box. The box is a 4-tuple defining the left, upper, right, and lower pixel coordinate. """ @@ -232,18 +255,15 @@ def area(self): area = box[:, 2] * box[:, 3] else: raise RuntimeError("Should not be here") - + return area - def copy_with_fields(self, fields, skip_missing=False): + def copy_with_fields(self, fields): bbox = BoxList(self.bbox, self.size, self.mode) if not isinstance(fields, (list, tuple)): fields = [fields] for field in fields: - if self.has_field(field): - bbox.add_field(field, self.get_field(field)) - elif not skip_missing: - raise KeyError("Field '{}' not found in {}".format(field, self)) + bbox.add_field(field, self.get_field(field)) return bbox def __repr__(self): diff --git a/maskrcnn_benchmark/structures/segmentation_mask.py b/maskrcnn_benchmark/structures/segmentation_mask.py index 84ef9dbc9..9901fbdf6 100644 --- a/maskrcnn_benchmark/structures/segmentation_mask.py +++ b/maskrcnn_benchmark/structures/segmentation_mask.py @@ -1,9 +1,5 @@ -import cv2 -import copy +# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved. import torch -import numpy as np -from maskrcnn_benchmark.layers.misc import interpolate -from maskrcnn_benchmark.utils import cv2_util import pycocotools.mask as mask_utils # transpose @@ -11,241 +7,63 @@ FLIP_TOP_BOTTOM = 1 -""" ABSTRACT -Segmentations come in either: -1) Binary masks -2) Polygons - -Binary masks can be represented in a contiguous array -and operations can be carried out more efficiently, -therefore BinaryMaskList handles them together. - -Polygons are handled separately for each instance, -by PolygonInstance and instances are handled by -PolygonList. - -SegmentationList is supposed to represent both, -therefore it wraps the functions of BinaryMaskList -and PolygonList to make it transparent. -""" - - -class BinaryMaskList(object): +class Mask(object): """ - This class handles binary masks for all objects in the image + This class is unfinished and not meant for use yet + It is supposed to contain the mask for an object as + a 2d tensor """ - def __init__(self, masks, size): - """ - Arguments: - masks: Either torch.tensor of [num_instances, H, W] - or list of torch.tensors of [H, W] with num_instances elems, - or RLE (Run Length Encoding) - interpreted as list of dicts, - or BinaryMaskList. - size: absolute image size, width first - - After initialization, a hard copy will be made, to leave the - initializing source data intact. - """ - - assert isinstance(size, (list, tuple)) - assert len(size) == 2 - - if isinstance(masks, torch.Tensor): - # The raw data representation is passed as argument - masks = masks.clone() - elif isinstance(masks, (list, tuple)): - if len(masks) == 0: - masks = torch.empty([0, size[1], size[0]]) # num_instances = 0! - elif isinstance(masks[0], torch.Tensor): - masks = torch.stack(masks, dim=0).clone() - elif isinstance(masks[0], dict) and "counts" in masks[0]: - if(isinstance(masks[0]["counts"], (list, tuple))): - masks = mask_utils.frPyObjects(masks, size[1], size[0]) - # RLE interpretation - rle_sizes = [tuple(inst["size"]) for inst in masks] - - masks = mask_utils.decode(masks) # [h, w, n] - masks = torch.tensor(masks).permute(2, 0, 1) # [n, h, w] - - assert rle_sizes.count(rle_sizes[0]) == len(rle_sizes), ( - "All the sizes must be the same size: %s" % rle_sizes - ) - - # in RLE, height come first in "size" - rle_height, rle_width = rle_sizes[0] - assert masks.shape[1] == rle_height - assert masks.shape[2] == rle_width - - width, height = size - if width != rle_width or height != rle_height: - masks = interpolate( - input=masks[None].float(), - size=(height, width), - mode="bilinear", - align_corners=False, - )[0].type_as(masks) - else: - RuntimeError( - "Type of `masks[0]` could not be interpreted: %s" - % type(masks) - ) - elif isinstance(masks, BinaryMaskList): - # just hard copy the BinaryMaskList instance's underlying data - masks = masks.masks.clone() - else: - RuntimeError( - "Type of `masks` argument could not be interpreted:%s" - % type(masks) - ) - - if len(masks.shape) == 2: - # if only a single instance mask is passed - masks = masks[None] - - assert len(masks.shape) == 3 - assert masks.shape[1] == size[1], "%s != %s" % (masks.shape[1], size[1]) - assert masks.shape[2] == size[0], "%s != %s" % (masks.shape[2], size[0]) - + def __init__(self, masks, size, mode): self.masks = masks - self.size = tuple(size) + self.size = size + self.mode = mode def transpose(self, method): - dim = 1 if method == FLIP_TOP_BOTTOM else 2 - flipped_masks = self.masks.flip(dim) - return BinaryMaskList(flipped_masks, self.size) - - def crop(self, box): - assert isinstance(box, (list, tuple, torch.Tensor)), str(type(box)) - # box is assumed to be xyxy - current_width, current_height = self.size - xmin, ymin, xmax, ymax = [round(float(b)) for b in box] - - assert xmin <= xmax and ymin <= ymax, str(box) - xmin = min(max(xmin, 0), current_width - 1) - ymin = min(max(ymin, 0), current_height - 1) - - xmax = min(max(xmax, 0), current_width) - ymax = min(max(ymax, 0), current_height) - - xmax = max(xmax, xmin + 1) - ymax = max(ymax, ymin + 1) - - width, height = xmax - xmin, ymax - ymin - cropped_masks = self.masks[:, ymin:ymax, xmin:xmax] - cropped_size = width, height - return BinaryMaskList(cropped_masks, cropped_size) - - def resize(self, size): - try: - iter(size) - except TypeError: - assert isinstance(size, (int, float)) - size = size, size - width, height = map(int, size) - - assert width > 0 - assert height > 0 - - # Height comes first here! - resized_masks = interpolate( - input=self.masks[None].float(), - size=(height, width), - mode="bilinear", - align_corners=False, - )[0].type_as(self.masks) - resized_size = width, height - return BinaryMaskList(resized_masks, resized_size) - - def convert_to_polygon(self): - if self.masks.numel() == 0: - return PolygonList([], self.size) - - contours = self._findContours() - return PolygonList(contours, self.size) - - def to(self, *args, **kwargs): - return self - - def _findContours(self): - contours = [] - masks = self.masks.detach().numpy() - for mask in masks: - mask = cv2.UMat(mask) - contour, hierarchy = cv2_util.findContours( - mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_TC89_L1 + if method not in (FLIP_LEFT_RIGHT, FLIP_TOP_BOTTOM): + raise NotImplementedError( + "Only FLIP_LEFT_RIGHT and FLIP_TOP_BOTTOM implemented" ) - reshaped_contour = [] - for entity in contour: - assert len(entity.shape) == 3 - assert ( - entity.shape[1] == 1 - ), "Hierarchical contours are not allowed" - reshaped_contour.append(entity.reshape(-1).tolist()) - contours.append(reshaped_contour) - return contours + width, height = self.size + if method == FLIP_LEFT_RIGHT: + dim = width + idx = 2 + elif method == FLIP_TOP_BOTTOM: + dim = height + idx = 1 - def __len__(self): - return len(self.masks) + flip_idx = list(range(dim)[::-1]) + flipped_masks = self.masks.index_select(dim, flip_idx) + return Mask(flipped_masks, self.size, self.mode) - def __getitem__(self, index): - if self.masks.numel() == 0: - raise RuntimeError("Indexing empty BinaryMaskList") - return BinaryMaskList(self.masks[index], self.size) + def crop(self, box): + w, h = box[2] - box[0], box[3] - box[1] - def __iter__(self): - return iter(self.masks) + cropped_masks = self.masks[:, box[1] : box[3], box[0] : box[2]] + return Mask(cropped_masks, size=(w, h), mode=self.mode) - def __repr__(self): - s = self.__class__.__name__ + "(" - s += "num_instances={}, ".format(len(self.masks)) - s += "image_width={}, ".format(self.size[0]) - s += "image_height={})".format(self.size[1]) - return s + def resize(self, size, *args, **kwargs): + pass -class PolygonInstance(object): +class Polygons(object): """ This class holds a set of polygons that represents a single instance of an object mask. The object can be represented as a set of polygons """ - def __init__(self, polygons, size): - """ - Arguments: - a list of lists of numbers. - The first level refers to all the polygons that compose the - object, and the second level to the polygon coordinates. - """ - if isinstance(polygons, (list, tuple)): - valid_polygons = [] - for p in polygons: - p = torch.as_tensor(p, dtype=torch.float32) - if len(p) >= 6: # 3 * 2 coordinates - valid_polygons.append(p) - polygons = valid_polygons - - elif isinstance(polygons, PolygonInstance): - polygons = copy.copy(polygons.polygons) - - else: - RuntimeError( - "Type of argument `polygons` is not allowed:%s" - % (type(polygons)) - ) - - """ This crashes the training way too many times... - for p in polygons: - assert p[::2].min() >= 0 - assert p[::2].max() < size[0] - assert p[1::2].min() >= 0 - assert p[1::2].max() , size[1] - """ + def __init__(self, polygons, size, mode): + # assert isinstance(polygons, list), '{}'.format(polygons) + if isinstance(polygons, list): + polygons = [torch.as_tensor(p, dtype=torch.float32) for p in polygons] + elif isinstance(polygons, Polygons): + polygons = polygons.polygons self.polygons = polygons - self.size = tuple(size) + self.size = size + self.mode = mode def transpose(self, method): if method not in (FLIP_LEFT_RIGHT, FLIP_TOP_BOTTOM): @@ -268,51 +86,47 @@ def transpose(self, method): p[idx::2] = dim - poly[idx::2] - TO_REMOVE flipped_polygons.append(p) - return PolygonInstance(flipped_polygons, size=self.size) - - def crop(self, box): - assert isinstance(box, (list, tuple, torch.Tensor)), str(type(box)) - - # box is assumed to be xyxy - current_width, current_height = self.size - xmin, ymin, xmax, ymax = map(float, box) + return Polygons(flipped_polygons, size=self.size, mode=self.mode) - assert xmin <= xmax and ymin <= ymax, str(box) - xmin = min(max(xmin, 0), current_width - 1) - ymin = min(max(ymin, 0), current_height - 1) + def rotate(self, matrix): + a = torch.from_numpy(matrix[:, :2]).float() + b = torch.from_numpy(matrix[:, 2:]).float() + b = torch.reshape(b, shape=(2,)) + a = a.transpose(1,0) + rotated_polygons = [] + for poly in self.polygons: + p = poly.clone() + p = p.reshape([-1, 2]) + p = torch.mm(p, a) + b + p = p.reshape([-1]).clone().float() + p[0::2] = p[0::2].clamp(min=0, max=self.size[0]-1) + p[1::2] = p[1::2].clamp(min=0, max=self.size[1]-1) + rotated_polygons.append(p) + return Polygons(rotated_polygons, size=self.size, mode=self.mode) - xmax = min(max(xmax, 0), current_width) - ymax = min(max(ymax, 0), current_height) - xmax = max(xmax, xmin + 1) - ymax = max(ymax, ymin + 1) + def crop(self, box): + w, h = box[2] - box[0], box[3] - box[1] - w, h = xmax - xmin, ymax - ymin + # TODO chck if necessary + w = max(w, 1) + h = max(h, 1) cropped_polygons = [] for poly in self.polygons: p = poly.clone() - p[0::2] = p[0::2] - xmin # .clamp(min=0, max=w) - p[1::2] = p[1::2] - ymin # .clamp(min=0, max=h) + p[0::2] = p[0::2] - box[0] # .clamp(min=0, max=w) + p[1::2] = p[1::2] - box[1] # .clamp(min=0, max=h) cropped_polygons.append(p) - return PolygonInstance(cropped_polygons, size=(w, h)) - - def resize(self, size): - try: - iter(size) - except TypeError: - assert isinstance(size, (int, float)) - size = size, size - - ratios = tuple( - float(s) / float(s_orig) for s, s_orig in zip(size, self.size) - ) + return Polygons(cropped_polygons, size=(w, h), mode=self.mode) + def resize(self, size, *args, **kwargs): + ratios = tuple(float(s) / float(s_orig) for s, s_orig in zip(size, self.size)) if ratios[0] == ratios[1]: ratio = ratios[0] scaled_polys = [p * ratio for p in self.polygons] - return PolygonInstance(scaled_polys, size) + return Polygons(scaled_polys, size, mode=self.mode) ratio_w, ratio_h = ratios scaled_polygons = [] @@ -322,85 +136,48 @@ def resize(self, size): p[1::2] *= ratio_h scaled_polygons.append(p) - return PolygonInstance(scaled_polygons, size=size) + return Polygons(scaled_polygons, size=size, mode=self.mode) - def convert_to_binarymask(self): - width, height = self.size - # formatting for COCO PythonAPI - polygons = [p.numpy() for p in self.polygons] - rles = mask_utils.frPyObjects(polygons, height, width) - rle = mask_utils.merge(rles) - mask = mask_utils.decode(rle) - mask = torch.from_numpy(mask) - return mask - def __len__(self): - return len(self.polygons) + def convert(self, mode): + width, height = self.size + if mode == "mask": + rles = mask_utils.frPyObjects( + [p.numpy() for p in self.polygons], height, width + ) + rle = mask_utils.merge(rles) + mask = mask_utils.decode(rle) + mask = torch.from_numpy(mask) + # TODO add squeeze? + return mask def __repr__(self): s = self.__class__.__name__ + "(" - s += "num_groups={}, ".format(len(self.polygons)) + s += "num_polygons={}, ".format(len(self.polygons)) s += "image_width={}, ".format(self.size[0]) - s += "image_height={})".format(self.size[1]) + s += "image_height={}, ".format(self.size[1]) + s += "mode={})".format(self.mode) return s -class PolygonList(object): +class SegmentationMask(object): """ - This class handles PolygonInstances for all objects in the image + This class stores the segmentations for all objects in the image """ - def __init__(self, polygons, size): + def __init__(self, polygons, size, mode=None): """ Arguments: - polygons: - a list of list of lists of numbers. The first + polygons: a list of list of lists of numbers. The first level of the list correspond to individual instances, the second level to all the polygons that compose the object, and the third level to the polygon coordinates. - - OR - - a list of PolygonInstances. - - OR - - a PolygonList - - size: absolute image size - """ - if isinstance(polygons, (list, tuple)): - if len(polygons) == 0: - polygons = [[[]]] - if isinstance(polygons[0], (list, tuple)): - assert isinstance(polygons[0][0], (list, tuple)), str( - type(polygons[0][0]) - ) - else: - assert isinstance(polygons[0], PolygonInstance), str( - type(polygons[0]) - ) - - elif isinstance(polygons, PolygonList): - size = polygons.size - polygons = polygons.polygons - - else: - RuntimeError( - "Type of argument `polygons` is not allowed:%s" - % (type(polygons)) - ) - - assert isinstance(size, (list, tuple)), str(type(size)) + assert isinstance(polygons, list) - self.polygons = [] - for p in polygons: - p = PolygonInstance(p, size) - if len(p) > 0: - self.polygons.append(p) - - self.size = tuple(size) + self.polygons = [Polygons(p, size, mode) for p in polygons] + self.size = size + self.mode = mode def transpose(self, method): if method not in (FLIP_LEFT_RIGHT, FLIP_TOP_BOTTOM): @@ -408,61 +185,48 @@ def transpose(self, method): "Only FLIP_LEFT_RIGHT and FLIP_TOP_BOTTOM implemented" ) - flipped_polygons = [] + flipped = [] for polygon in self.polygons: - flipped_polygons.append(polygon.transpose(method)) + flipped.append(polygon.transpose(method)) + return SegmentationMask(flipped, size=self.size, mode=self.mode) - return PolygonList(flipped_polygons, size=self.size) + def rotate(self, matrix): + rotated = [] + for polygon in self.polygons: + r = polygon.rotate(matrix) + rotated.append(r) + return SegmentationMask(rotated, size=self.size, mode=self.mode) def crop(self, box): w, h = box[2] - box[0], box[3] - box[1] - cropped_polygons = [] + cropped = [] for polygon in self.polygons: - cropped_polygons.append(polygon.crop(box)) - - cropped_size = w, h - return PolygonList(cropped_polygons, cropped_size) + cropped.append(polygon.crop(box)) + return SegmentationMask(cropped, size=(w, h), mode=self.mode) - def resize(self, size): - resized_polygons = [] + def resize(self, size, *args, **kwargs): + scaled = [] for polygon in self.polygons: - resized_polygons.append(polygon.resize(size)) - - resized_size = size - return PolygonList(resized_polygons, resized_size) + scaled.append(polygon.resize(size, *args, **kwargs)) + return SegmentationMask(scaled, size=size, mode=self.mode) def to(self, *args, **kwargs): return self - def convert_to_binarymask(self): - if len(self) > 0: - masks = torch.stack( - [p.convert_to_binarymask() for p in self.polygons] - ) - else: - size = self.size - masks = torch.empty([0, size[1], size[0]], dtype=torch.bool) - - return BinaryMaskList(masks, size=self.size) - - def __len__(self): - return len(self.polygons) - def __getitem__(self, item): - if isinstance(item, int): + if isinstance(item, (int, slice)): selected_polygons = [self.polygons[item]] - elif isinstance(item, slice): - selected_polygons = self.polygons[item] else: # advanced indexing on a single dimension selected_polygons = [] - if isinstance(item, torch.Tensor) and item.dtype == torch.bool: + if isinstance(item, torch.Tensor) and \ + (item.dtype == torch.uint8 or item.dtype == torch.bool): item = item.nonzero() item = item.squeeze(1) if item.numel() > 0 else item item = item.tolist() for i in item: selected_polygons.append(self.polygons[i]) - return PolygonList(selected_polygons, size=self.size) + return SegmentationMask(selected_polygons, size=self.size, mode=self.mode) def __iter__(self): return iter(self.polygons) @@ -473,105 +237,3 @@ def __repr__(self): s += "image_width={}, ".format(self.size[0]) s += "image_height={})".format(self.size[1]) return s - - -class SegmentationMask(object): - - """ - This class stores the segmentations for all objects in the image. - It wraps BinaryMaskList and PolygonList conveniently. - """ - - def __init__(self, instances, size, mode="poly"): - """ - Arguments: - instances: two types - (1) polygon - (2) binary mask - size: (width, height) - mode: 'poly', 'mask'. if mode is 'mask', convert mask of any format to binary mask - """ - - assert isinstance(size, (list, tuple)) - assert len(size) == 2 - if isinstance(size[0], torch.Tensor): - assert isinstance(size[1], torch.Tensor) - size = size[0].item(), size[1].item() - - assert isinstance(size[0], (int, float)) - assert isinstance(size[1], (int, float)) - - if mode == "poly": - self.instances = PolygonList(instances, size) - elif mode == "mask": - self.instances = BinaryMaskList(instances, size) - else: - raise NotImplementedError("Unknown mode: %s" % str(mode)) - - self.mode = mode - self.size = tuple(size) - - def transpose(self, method): - flipped_instances = self.instances.transpose(method) - return SegmentationMask(flipped_instances, self.size, self.mode) - - def crop(self, box): - cropped_instances = self.instances.crop(box) - cropped_size = cropped_instances.size - return SegmentationMask(cropped_instances, cropped_size, self.mode) - - def resize(self, size, *args, **kwargs): - resized_instances = self.instances.resize(size) - resized_size = size - return SegmentationMask(resized_instances, resized_size, self.mode) - - def to(self, *args, **kwargs): - return self - - def convert(self, mode): - if mode == self.mode: - return self - - if mode == "poly": - converted_instances = self.instances.convert_to_polygon() - elif mode == "mask": - converted_instances = self.instances.convert_to_binarymask() - else: - raise NotImplementedError("Unknown mode: %s" % str(mode)) - - return SegmentationMask(converted_instances, self.size, mode) - - def get_mask_tensor(self): - instances = self.instances - if self.mode == "poly": - instances = instances.convert_to_binarymask() - # If there is only 1 instance - return instances.masks.squeeze(0) - - def __len__(self): - return len(self.instances) - - def __getitem__(self, item): - selected_instances = self.instances.__getitem__(item) - return SegmentationMask(selected_instances, self.size, self.mode) - - def __iter__(self): - self.iter_idx = 0 - return self - - def __next__(self): - if self.iter_idx < self.__len__(): - next_segmentation = self.__getitem__(self.iter_idx) - self.iter_idx += 1 - return next_segmentation - raise StopIteration() - - next = __next__ # Python 2 compatibility - - def __repr__(self): - s = self.__class__.__name__ + "(" - s += "num_instances={}, ".format(len(self.instances)) - s += "image_width={}, ".format(self.size[0]) - s += "image_height={}, ".format(self.size[1]) - s += "mode={})".format(self.mode) - return s