diff --git a/maskrcnn_benchmark/data/build.py b/maskrcnn_benchmark/data/build.py index 26239155d..d712e10f0 100644 --- a/maskrcnn_benchmark/data/build.py +++ b/maskrcnn_benchmark/data/build.py @@ -7,6 +7,8 @@ from maskrcnn_benchmark.utils.comm import get_world_size from maskrcnn_benchmark.utils.imports import import_file from maskrcnn_benchmark.utils.miscellaneous import save_labels +from maskrcnn_benchmark.data.datasets.wrapper import WrapperDataset +from maskrcnn_benchmark.data.datasets.abstract import AbstractDataset from . import datasets as D from . import samplers @@ -45,6 +47,61 @@ def build_dataset(dataset_list, transforms, dataset_catalog, is_train=True): dataset = factory(**args) datasets.append(dataset) + log_info = "" + for dataset_idx, (dataset_name, dataset) in enumerate(zip(dataset_list, datasets)): + log_info += ( + f"{dataset_idx:>2}: {dataset_name:>35} [{dataset.__class__.__name__}]\n" + ) + logger = logging.getLogger(__name__) + logger.info("Dataset(s) provided in the config:\n" + log_info) + + # BOTCS: Wrapping multiple datasets to mimick the first + if len(datasets) > 1: + logger.info( + "Multiple datasets were provided for training. " + "Dataset builder is wrapping datasets[1:] to mimick the first " + "dataset's category indexes. Index matching is based on matching " + "category names (str). Only works with derived Classes of " + "AbstractDataset. Otherwise no wrapping will be carried out. " + ) + mimicked_dataset = datasets[0] + if not isinstance(mimicked_dataset, AbstractDataset): + logger.warning( + "ATTENTION! " + f"[{mimicked_dataset.__class__.__name__}] is not a derived Class " + "of AbstractDataset. Further datasets will not be wrapped. " + "Matching class indices could not be assured. " + "Current setting *could* lead to colliding category indices. " + ) + else: + wrapped_datasets = [] + for d in datasets[1:]: + if isinstance(d, type(mimicked_dataset)): + logger.warning( + f"[{d.__class__.__name__}] will not be wrapped, because " + "it matches the mimicked Dataset's type." + ) + wrapped_datasets.append(d) + elif isinstance(d, AbstractDataset): + logger.warning( + f"Wrapping [{d.__class__.__name__}] to mimick " + f"[{mimicked_dataset.__class__.__name__}]" + ) + wd = WrapperDataset(mimicked_dataset, d) + logger.info(str(wd)) + wrapped_datasets.append(wd) + else: + logger.warning( + "ATTENTION! " + "Matching class indices could not be assured. " + f"Using [{mimicked_dataset.__class__.__name__}] and " + f"[{d.__class__.__name__}] *could* lead to colliding " + " category indices. " + ) + wrapped_datasets.append(d) + + datasets = [mimicked_dataset] + wrapped_datasets + # for testing, return a list of datasets if not is_train: return datasets @@ -112,7 +169,8 @@ def make_data_loader(cfg, is_train=True, is_distributed=False, start_iter=0): assert ( images_per_batch % num_gpus == 0 ), "SOLVER.IMS_PER_BATCH ({}) must be divisible by the number of GPUs ({}) used.".format( - images_per_batch, num_gpus) + images_per_batch, num_gpus + ) images_per_gpu = images_per_batch // num_gpus shuffle = True num_iters = cfg.SOLVER.MAX_ITER @@ -121,7 +179,8 @@ def make_data_loader(cfg, is_train=True, is_distributed=False, start_iter=0): assert ( images_per_batch % num_gpus == 0 ), "TEST.IMS_PER_BATCH ({}) must be divisible by the number of GPUs ({}) used.".format( - images_per_batch, num_gpus) + images_per_batch, num_gpus + ) images_per_gpu = images_per_batch // num_gpus shuffle = False if not is_distributed else True num_iters = None @@ -152,7 +211,11 @@ def make_data_loader(cfg, is_train=True, is_distributed=False, start_iter=0): dataset_list = cfg.DATASETS.TRAIN if is_train else cfg.DATASETS.TEST # If bbox aug is enabled in testing, simply set transforms to None and we will apply transforms later - transforms = None if not is_train and cfg.TEST.BBOX_AUG.ENABLED else build_transforms(cfg, is_train) + transforms = ( + None + if not is_train and cfg.TEST.BBOX_AUG.ENABLED + else build_transforms(cfg, is_train) + ) datasets = build_dataset(dataset_list, transforms, DatasetCatalog, is_train) if is_train: @@ -165,8 +228,11 @@ def make_data_loader(cfg, is_train=True, is_distributed=False, start_iter=0): batch_sampler = make_batch_data_sampler( dataset, sampler, aspect_grouping, images_per_gpu, num_iters, start_iter ) - collator = BBoxAugCollator() if not is_train and cfg.TEST.BBOX_AUG.ENABLED else \ - BatchCollator(cfg.DATALOADER.SIZE_DIVISIBILITY) + collator = ( + BBoxAugCollator() + if not is_train and cfg.TEST.BBOX_AUG.ENABLED + else BatchCollator(cfg.DATALOADER.SIZE_DIVISIBILITY) + ) num_workers = cfg.DATALOADER.NUM_WORKERS data_loader = torch.utils.data.DataLoader( dataset, diff --git a/maskrcnn_benchmark/data/datasets/wrapper.py b/maskrcnn_benchmark/data/datasets/wrapper.py new file mode 100644 index 000000000..0cd54be50 --- /dev/null +++ b/maskrcnn_benchmark/data/datasets/wrapper.py @@ -0,0 +1,97 @@ +import torch +from .abstract import AbstractDataset + + +class WrapperDataset(AbstractDataset): + """ + When training on multiple datasets class labels can be misaligned, and our + last hope is to look at human readable class names to find a mapping + + This auxiliary dataset helps to find common class names between + `mimicked_dataset` and `wrapped_dataset` and wrap the latter to return + ground truth that aligns with the indices of the former. + + A: mimicked_dataset + B: wrapped_dataset + + IMPORTANT: + By design this wrapper utilizes fields and methods of datasets + derived from AbstractDataset + """ + + def __init__(self, A, B): + # A: mimicked_dataset + # B: wrapped_dataset + self.A = A + self.B = B + + common_classes = set(A.CLASSES) & set(B.CLASSES) + self.common_classes = common_classes + assert len(common_classes) > 0 + + self.idA_to_idB = { + id: B.name_to_id[name] if name in B.name_to_id else None + for id, name in A.id_to_name.items() + } + self.idB_to_idA = { + id: A.name_to_id[name] if name in A.name_to_id else None + for id, name in B.id_to_name.items() + } + + # NOTE: By default ids go from 0 to N-1 to address all heads in the + # RCNN RoI heads (contiguous id), and here we assume that the network + # uses the `mimicked_dataset`'s classes, and all ids of the wrapper + # will represent the corresponding class in the `mimicked_dataset`'s + # indexing. Therefore by looking only at the wrapper's used IDs they may + # not appear contiguous, still they are part of a contiguous mapping. + + # Resolving contiguous mapping by filling empty spots + self.CLASSES = [ + name if name in common_classes else f"__unmatched__({name})" + for name in A.CLASSES + ] + assert self.CLASSES[0] == "__background__" + + self.name_to_id = {name: id for id, name in enumerate(self.CLASSES)} + self.id_to_name = {id: name for name, id in self.name_to_id.items()} + + def __getitem__(self, idx): + img, target, idx = self.B[idx] + labels = target.get_field("labels") + + # Remove objects from wrapped GT belonging to classes not present in + # the mimicked dataset + select_idx = torch.tensor( + [self.idB_to_idA[idB] is not None for idB in labels.tolist()], + dtype=torch.uint8, + ) + + # Fancy indexing using a boolean selection tensor + labels = labels[select_idx] + target.bbox = target.bbox[select_idx] + + if "masks" in target.fields(): + masks = target.get_field("masks")[select_idx] + target.add_field("masks", masks) + + # Convert ids from wrapped to mimicked + for i in range(len(labels)): + labels[i] = self.idB_to_idA[labels[i].item()] + target.add_field("labels", labels) + return img, target, idx + + def get_img_info(self, idx): + return self.B.get_img_info(idx) + + def __len__(self): + return len(self.B) + + def __str__(self): + r = ( + f"[WrapperDataset mimicks:{self.A.__class__.__name__} " + f"wraps:{self.B.__class__.__name__}]" + f"\n{'Mimicked index':>15} : {'Mimicked label':<15} {'Wrapped index':>15} : {'Wrapped label':<15}\n" + ) + for id, name in self.A.id_to_name.items(): + r += f"{id:>15} : {name:<15} -> {str(self.idA_to_idB[id]):>15} : {self.id_to_name[id]:<15}\n" + return r