diff --git a/configs/e2e_mask_rcnn_R_50_FPN_1x_periodically_testing.yaml b/configs/e2e_mask_rcnn_R_50_FPN_1x_periodically_testing.yaml new file mode 100644 index 000000000..03ddc5058 --- /dev/null +++ b/configs/e2e_mask_rcnn_R_50_FPN_1x_periodically_testing.yaml @@ -0,0 +1,42 @@ +MODEL: + META_ARCHITECTURE: "GeneralizedRCNN" + WEIGHT: "catalog://ImageNetPretrained/MSRA/R-50" + BACKBONE: + CONV_BODY: "R-50-FPN" + RESNETS: + BACKBONE_OUT_CHANNELS: 256 + RPN: + USE_FPN: True + ANCHOR_STRIDE: (4, 8, 16, 32, 64) + PRE_NMS_TOP_N_TRAIN: 2000 + PRE_NMS_TOP_N_TEST: 1000 + POST_NMS_TOP_N_TEST: 1000 + FPN_POST_NMS_TOP_N_TEST: 1000 + ROI_HEADS: + USE_FPN: True + ROI_BOX_HEAD: + POOLER_RESOLUTION: 7 + POOLER_SCALES: (0.25, 0.125, 0.0625, 0.03125) + POOLER_SAMPLING_RATIO: 2 + FEATURE_EXTRACTOR: "FPN2MLPFeatureExtractor" + PREDICTOR: "FPNPredictor" + ROI_MASK_HEAD: + POOLER_SCALES: (0.25, 0.125, 0.0625, 0.03125) + FEATURE_EXTRACTOR: "MaskRCNNFPNFeatureExtractor" + PREDICTOR: "MaskRCNNC4Predictor" + POOLER_RESOLUTION: 14 + POOLER_SAMPLING_RATIO: 2 + RESOLUTION: 28 + SHARE_BOX_FEATURE_EXTRACTOR: False + MASK_ON: True +DATASETS: + TRAIN: ("coco_2014_train", "coco_2014_valminusminival") + TEST: ("coco_2014_minival",) +DATALOADER: + SIZE_DIVISIBILITY: 32 +SOLVER: + BASE_LR: 0.02 + WEIGHT_DECAY: 0.0001 + STEPS: (60000, 80000) + MAX_ITER: 90000 + TEST_PERIOD: 2500 diff --git a/maskrcnn_benchmark/config/defaults.py b/maskrcnn_benchmark/config/defaults.py index beae4070a..260b48474 100644 --- a/maskrcnn_benchmark/config/defaults.py +++ b/maskrcnn_benchmark/config/defaults.py @@ -406,6 +406,7 @@ _C.SOLVER.WARMUP_METHOD = "linear" _C.SOLVER.CHECKPOINT_PERIOD = 2500 +_C.SOLVER.TEST_PERIOD = 0 # Number of images per batch # This is global, so if we have 8 GPUs and IMS_PER_BATCH = 16, each GPU will diff --git a/maskrcnn_benchmark/data/build.py b/maskrcnn_benchmark/data/build.py index b0ce3c348..07291019f 100644 --- a/maskrcnn_benchmark/data/build.py +++ b/maskrcnn_benchmark/data/build.py @@ -104,7 +104,7 @@ def make_batch_data_sampler( return batch_sampler -def make_data_loader(cfg, is_train=True, is_distributed=False, start_iter=0): +def make_data_loader(cfg, is_train=True, is_distributed=False, start_iter=0, is_for_period=False): num_gpus = get_world_size() if is_train: images_per_batch = cfg.SOLVER.IMS_PER_BATCH @@ -152,7 +152,7 @@ def make_data_loader(cfg, is_train=True, is_distributed=False, start_iter=0): # If bbox aug is enabled in testing, simply set transforms to None and we will apply transforms later transforms = None if not is_train and cfg.TEST.BBOX_AUG.ENABLED else build_transforms(cfg, is_train) - datasets = build_dataset(dataset_list, transforms, DatasetCatalog, is_train) + datasets = build_dataset(dataset_list, transforms, DatasetCatalog, is_train or is_for_period) data_loaders = [] for dataset in datasets: @@ -170,7 +170,7 @@ def make_data_loader(cfg, is_train=True, is_distributed=False, start_iter=0): collate_fn=collator, ) data_loaders.append(data_loader) - if is_train: + if is_train or is_for_period: # during training, a single (possibly concatenated) data_loader is returned assert len(data_loaders) == 1 return data_loaders[0] diff --git a/maskrcnn_benchmark/engine/trainer.py b/maskrcnn_benchmark/engine/trainer.py index 281d91339..7870e1a28 100644 --- a/maskrcnn_benchmark/engine/trainer.py +++ b/maskrcnn_benchmark/engine/trainer.py @@ -1,13 +1,17 @@ # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved. import datetime import logging +import os import time import torch import torch.distributed as dist +from tqdm import tqdm -from maskrcnn_benchmark.utils.comm import get_world_size +from maskrcnn_benchmark.data import make_data_loader +from maskrcnn_benchmark.utils.comm import get_world_size, synchronize from maskrcnn_benchmark.utils.metric_logger import MetricLogger +from maskrcnn_benchmark.engine.inference import inference from apex import amp @@ -37,13 +41,16 @@ def reduce_loss_dict(loss_dict): def do_train( + cfg, model, data_loader, + data_loader_val, optimizer, scheduler, checkpointer, device, checkpoint_period, + test_period, arguments, ): logger = logging.getLogger("maskrcnn_benchmark.trainer") @@ -54,6 +61,14 @@ def do_train( model.train() start_training_time = time.time() end = time.time() + + iou_types = ("bbox",) + if cfg.MODEL.MASK_ON: + iou_types = iou_types + ("segm",) + if cfg.MODEL.KEYPOINT_ON: + iou_types = iou_types + ("keypoints",) + dataset_names = cfg.DATASETS.TEST + for iteration, (images, targets, _) in enumerate(data_loader, start_iter): data_time = time.time() - end iteration = iteration + 1 @@ -107,6 +122,53 @@ def do_train( ) if iteration % checkpoint_period == 0: checkpointer.save("model_{:07d}".format(iteration), **arguments) + if data_loader_val is not None and test_period > 0 and iteration % test_period == 0: + meters_val = MetricLogger(delimiter=" ") + synchronize() + _ = inference( # The result can be used for additional logging, e. g. for TensorBoard + model, + # The method changes the segmentation mask format in a data loader, + # so every time a new data loader is created: + make_data_loader(cfg, is_train=False, is_distributed=(get_world_size() > 1), is_for_period=True), + dataset_name="[Validation]", + iou_types=iou_types, + box_only=False if cfg.MODEL.RETINANET_ON else cfg.MODEL.RPN_ONLY, + device=cfg.MODEL.DEVICE, + expected_results=cfg.TEST.EXPECTED_RESULTS, + expected_results_sigma_tol=cfg.TEST.EXPECTED_RESULTS_SIGMA_TOL, + output_folder=None, + ) + synchronize() + model.train() + with torch.no_grad(): + # Should be one image for each GPU: + for iteration_val, (images_val, targets_val, _) in enumerate(tqdm(data_loader_val)): + images_val = images_val.to(device) + targets_val = [target.to(device) for target in targets_val] + loss_dict = model(images_val, targets_val) + losses = sum(loss for loss in loss_dict.values()) + loss_dict_reduced = reduce_loss_dict(loss_dict) + losses_reduced = sum(loss for loss in loss_dict_reduced.values()) + meters_val.update(loss=losses_reduced, **loss_dict_reduced) + synchronize() + logger.info( + meters_val.delimiter.join( + [ + "[Validation]: ", + "eta: {eta}", + "iter: {iter}", + "{meters}", + "lr: {lr:.6f}", + "max mem: {memory:.0f}", + ] + ).format( + eta=eta_string, + iter=iteration, + meters=str(meters_val), + lr=optimizer.param_groups[0]["lr"], + memory=torch.cuda.max_memory_allocated() / 1024.0 / 1024.0, + ) + ) if iteration == max_iter: checkpointer.save("model_final", **arguments) diff --git a/tools/train_net.py b/tools/train_net.py index 9f4761b3f..6b7f6222b 100644 --- a/tools/train_net.py +++ b/tools/train_net.py @@ -72,16 +72,25 @@ def train(cfg, local_rank, distributed): start_iter=arguments["iteration"], ) + test_period = cfg.SOLVER.TEST_PERIOD + if test_period > 0: + data_loader_val = make_data_loader(cfg, is_train=False, is_distributed=distributed, is_for_period=True) + else: + data_loader_val = None + checkpoint_period = cfg.SOLVER.CHECKPOINT_PERIOD do_train( + cfg, model, data_loader, + data_loader_val, optimizer, scheduler, checkpointer, device, checkpoint_period, + test_period, arguments, )