Skip to content
This repository was archived by the owner on Oct 31, 2023. It is now read-only.
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
76 changes: 71 additions & 5 deletions maskrcnn_benchmark/data/build.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,8 @@
from maskrcnn_benchmark.utils.comm import get_world_size
from maskrcnn_benchmark.utils.imports import import_file
from maskrcnn_benchmark.utils.miscellaneous import save_labels
from maskrcnn_benchmark.data.datasets.wrapper import WrapperDataset
from maskrcnn_benchmark.data.datasets.abstract import AbstractDataset

from . import datasets as D
from . import samplers
Expand Down Expand Up @@ -45,6 +47,61 @@ def build_dataset(dataset_list, transforms, dataset_catalog, is_train=True):
dataset = factory(**args)
datasets.append(dataset)

log_info = ""
for dataset_idx, (dataset_name, dataset) in enumerate(zip(dataset_list, datasets)):
log_info += (
f"{dataset_idx:>2}: {dataset_name:>35} [{dataset.__class__.__name__}]\n"
)
logger = logging.getLogger(__name__)
logger.info("Dataset(s) provided in the config:\n" + log_info)

# BOTCS: Wrapping multiple datasets to mimick the first
if len(datasets) > 1:
logger.info(
"Multiple datasets were provided for training. "
"Dataset builder is wrapping datasets[1:] to mimick the first "
"dataset's category indexes. Index matching is based on matching "
"category names (str). Only works with derived Classes of "
"AbstractDataset. Otherwise no wrapping will be carried out. "
)
mimicked_dataset = datasets[0]
if not isinstance(mimicked_dataset, AbstractDataset):
logger.warning(
"ATTENTION! "
f"[{mimicked_dataset.__class__.__name__}] is not a derived Class "
"of AbstractDataset. Further datasets will not be wrapped. "
"Matching class indices could not be assured. "
"Current setting *could* lead to colliding category indices. "
)
else:
wrapped_datasets = []
for d in datasets[1:]:
if isinstance(d, type(mimicked_dataset)):
logger.warning(
f"[{d.__class__.__name__}] will not be wrapped, because "
"it matches the mimicked Dataset's type."
)
wrapped_datasets.append(d)
elif isinstance(d, AbstractDataset):
logger.warning(
f"Wrapping [{d.__class__.__name__}] to mimick "
f"[{mimicked_dataset.__class__.__name__}]"
)
wd = WrapperDataset(mimicked_dataset, d)
logger.info(str(wd))
wrapped_datasets.append(wd)
else:
logger.warning(
"ATTENTION! "
"Matching class indices could not be assured. "
f"Using [{mimicked_dataset.__class__.__name__}] and "
f"[{d.__class__.__name__}] *could* lead to colliding "
" category indices. "
)
wrapped_datasets.append(d)

datasets = [mimicked_dataset] + wrapped_datasets

# for testing, return a list of datasets
if not is_train:
return datasets
Expand Down Expand Up @@ -112,7 +169,8 @@ def make_data_loader(cfg, is_train=True, is_distributed=False, start_iter=0):
assert (
images_per_batch % num_gpus == 0
), "SOLVER.IMS_PER_BATCH ({}) must be divisible by the number of GPUs ({}) used.".format(
images_per_batch, num_gpus)
images_per_batch, num_gpus
)
images_per_gpu = images_per_batch // num_gpus
shuffle = True
num_iters = cfg.SOLVER.MAX_ITER
Expand All @@ -121,7 +179,8 @@ def make_data_loader(cfg, is_train=True, is_distributed=False, start_iter=0):
assert (
images_per_batch % num_gpus == 0
), "TEST.IMS_PER_BATCH ({}) must be divisible by the number of GPUs ({}) used.".format(
images_per_batch, num_gpus)
images_per_batch, num_gpus
)
images_per_gpu = images_per_batch // num_gpus
shuffle = False if not is_distributed else True
num_iters = None
Expand Down Expand Up @@ -152,7 +211,11 @@ def make_data_loader(cfg, is_train=True, is_distributed=False, start_iter=0):
dataset_list = cfg.DATASETS.TRAIN if is_train else cfg.DATASETS.TEST

# If bbox aug is enabled in testing, simply set transforms to None and we will apply transforms later
transforms = None if not is_train and cfg.TEST.BBOX_AUG.ENABLED else build_transforms(cfg, is_train)
transforms = (
None
if not is_train and cfg.TEST.BBOX_AUG.ENABLED
else build_transforms(cfg, is_train)
)
datasets = build_dataset(dataset_list, transforms, DatasetCatalog, is_train)

if is_train:
Expand All @@ -165,8 +228,11 @@ def make_data_loader(cfg, is_train=True, is_distributed=False, start_iter=0):
batch_sampler = make_batch_data_sampler(
dataset, sampler, aspect_grouping, images_per_gpu, num_iters, start_iter
)
collator = BBoxAugCollator() if not is_train and cfg.TEST.BBOX_AUG.ENABLED else \
BatchCollator(cfg.DATALOADER.SIZE_DIVISIBILITY)
collator = (
BBoxAugCollator()
if not is_train and cfg.TEST.BBOX_AUG.ENABLED
else BatchCollator(cfg.DATALOADER.SIZE_DIVISIBILITY)
)
num_workers = cfg.DATALOADER.NUM_WORKERS
data_loader = torch.utils.data.DataLoader(
dataset,
Expand Down
97 changes: 97 additions & 0 deletions maskrcnn_benchmark/data/datasets/wrapper.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,97 @@
import torch
from .abstract import AbstractDataset


class WrapperDataset(AbstractDataset):
"""
When training on multiple datasets class labels can be misaligned, and our
last hope is to look at human readable class names to find a mapping

This auxiliary dataset helps to find common class names between
`mimicked_dataset` and `wrapped_dataset` and wrap the latter to return
ground truth that aligns with the indices of the former.

A: mimicked_dataset
B: wrapped_dataset

IMPORTANT:
By design this wrapper utilizes fields and methods of datasets
derived from AbstractDataset
"""

def __init__(self, A, B):
# A: mimicked_dataset
# B: wrapped_dataset
self.A = A
self.B = B

common_classes = set(A.CLASSES) & set(B.CLASSES)
self.common_classes = common_classes
assert len(common_classes) > 0

self.idA_to_idB = {
id: B.name_to_id[name] if name in B.name_to_id else None
for id, name in A.id_to_name.items()
}
self.idB_to_idA = {
id: A.name_to_id[name] if name in A.name_to_id else None
for id, name in B.id_to_name.items()
}

# NOTE: By default ids go from 0 to N-1 to address all heads in the
# RCNN RoI heads (contiguous id), and here we assume that the network
# uses the `mimicked_dataset`'s classes, and all ids of the wrapper
# will represent the corresponding class in the `mimicked_dataset`'s
# indexing. Therefore by looking only at the wrapper's used IDs they may
# not appear contiguous, still they are part of a contiguous mapping.

# Resolving contiguous mapping by filling empty spots
self.CLASSES = [
name if name in common_classes else f"__unmatched__({name})"
for name in A.CLASSES
]
assert self.CLASSES[0] == "__background__"

self.name_to_id = {name: id for id, name in enumerate(self.CLASSES)}
self.id_to_name = {id: name for name, id in self.name_to_id.items()}

def __getitem__(self, idx):
img, target, idx = self.B[idx]
labels = target.get_field("labels")

# Remove objects from wrapped GT belonging to classes not present in
# the mimicked dataset
select_idx = torch.tensor(
[self.idB_to_idA[idB] is not None for idB in labels.tolist()],
dtype=torch.uint8,
)

# Fancy indexing using a boolean selection tensor
labels = labels[select_idx]
target.bbox = target.bbox[select_idx]

if "masks" in target.fields():
masks = target.get_field("masks")[select_idx]
target.add_field("masks", masks)

# Convert ids from wrapped to mimicked
for i in range(len(labels)):
labels[i] = self.idB_to_idA[labels[i].item()]
target.add_field("labels", labels)
return img, target, idx

def get_img_info(self, idx):
return self.B.get_img_info(idx)

def __len__(self):
return len(self.B)

def __str__(self):
r = (
f"[WrapperDataset mimicks:{self.A.__class__.__name__} "
f"wraps:{self.B.__class__.__name__}]"
f"\n{'Mimicked index':>15} : {'Mimicked label':<15} {'Wrapped index':>15} : {'Wrapped label':<15}\n"
)
for id, name in self.A.id_to_name.items():
r += f"{id:>15} : {name:<15} -> {str(self.idA_to_idB[id]):>15} : {self.id_to_name[id]:<15}\n"
return r