diff --git a/maskrcnn_benchmark/structures/segmentation_mask.py b/maskrcnn_benchmark/structures/segmentation_mask.py index 89b80aff2..98d9b3cce 100644 --- a/maskrcnn_benchmark/structures/segmentation_mask.py +++ b/maskrcnn_benchmark/structures/segmentation_mask.py @@ -454,7 +454,8 @@ def __getitem__(self, item): else: # advanced indexing on a single dimension selected_polygons = [] - if isinstance(item, torch.Tensor) and item.dtype == torch.uint8: + if isinstance(item, torch.Tensor) and \ + (item.dtype == torch.uint8 or item.dtype == torch.bool): item = item.nonzero() item = item.squeeze(1) if item.numel() > 0 else item item = item.tolist()