diff --git a/maskrcnn_benchmark/config/defaults.py b/maskrcnn_benchmark/config/defaults.py index fc750fd4f..4058deb89 100644 --- a/maskrcnn_benchmark/config/defaults.py +++ b/maskrcnn_benchmark/config/defaults.py @@ -395,6 +395,8 @@ _C.SOLVER.WARMUP_METHOD = "linear" _C.SOLVER.CHECKPOINT_PERIOD = 2500 +# Validate every 2500 +_C.SOLVER.TEST_PERIOD = 2500 # Number of images per batch # This is global, so if we have 8 GPUs and IMS_PER_BATCH = 16, each GPU will diff --git a/maskrcnn_benchmark/data/build.py b/maskrcnn_benchmark/data/build.py index d2895fd7e..cf0c8c2c2 100644 --- a/maskrcnn_benchmark/data/build.py +++ b/maskrcnn_benchmark/data/build.py @@ -104,7 +104,9 @@ def make_batch_data_sampler( return batch_sampler -def make_data_loader(cfg, is_train=True, is_distributed=False, start_iter=0): +def make_data_loader( + cfg, is_train=True, is_distributed=False, start_iter=0, is_for_period=False +): num_gpus = get_world_size() if is_train: images_per_batch = cfg.SOLVER.IMS_PER_BATCH @@ -151,7 +153,9 @@ def make_data_loader(cfg, is_train=True, is_distributed=False, start_iter=0): dataset_list = cfg.DATASETS.TRAIN if is_train else cfg.DATASETS.TEST transforms = build_transforms(cfg, is_train) - datasets = build_dataset(dataset_list, transforms, DatasetCatalog, is_train) + datasets = build_dataset( + dataset_list, transforms, DatasetCatalog, is_train or is_for_period + ) data_loaders = [] for dataset in datasets: @@ -168,7 +172,7 @@ def make_data_loader(cfg, is_train=True, is_distributed=False, start_iter=0): collate_fn=collator, ) data_loaders.append(data_loader) - if is_train: + if is_train or is_for_period: # during training, a single (possibly concatenated) data_loader is returned assert len(data_loaders) == 1 return data_loaders[0] diff --git a/maskrcnn_benchmark/engine/trainer.py b/maskrcnn_benchmark/engine/trainer.py index 38a9e524b..66bcd13c5 100644 --- a/maskrcnn_benchmark/engine/trainer.py +++ b/maskrcnn_benchmark/engine/trainer.py @@ -1,13 +1,19 @@ # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved. import datetime import logging +import os import time import torch import torch.distributed as dist +from tqdm import tqdm -from maskrcnn_benchmark.utils.comm import get_world_size +from maskrcnn_benchmark.data import make_data_loader +from maskrcnn_benchmark.utils.comm import get_world_size, synchronize from maskrcnn_benchmark.utils.metric_logger import MetricLogger +from maskrcnn_benchmark.engine.inference import inference + +from apex import amp def reduce_loss_dict(loss_dict): @@ -36,13 +42,16 @@ def reduce_loss_dict(loss_dict): def do_train( + cfg, model, data_loader, + data_loader_val, optimizer, scheduler, checkpointer, device, checkpoint_period, + test_period, arguments, ): logger = logging.getLogger("maskrcnn_benchmark.trainer") @@ -53,6 +62,14 @@ def do_train( model.train() start_training_time = time.time() end = time.time() + + iou_types = ("bbox",) + if cfg.MODEL.MASK_ON: + iou_types = iou_types + ("segm",) + if cfg.MODEL.KEYPOINT_ON: + iou_types = iou_types + ("keypoints",) + dataset_names = cfg.DATASETS.TEST + for iteration, (images, targets, _) in enumerate(data_loader, start_iter): data_time = time.time() - end iteration = iteration + 1 @@ -74,6 +91,10 @@ def do_train( optimizer.zero_grad() losses.backward() + # Note: If mixed precision is not used, this ends up doing nothing + # Otherwise apply loss scaling for mixed-precision recipe + # with amp.scale_loss(losses, optimizer) as scaled_losses: + # scaled_losses.backward() optimizer.step() batch_time = time.time() - end @@ -103,6 +124,66 @@ def do_train( ) if iteration % checkpoint_period == 0: checkpointer.save("model_{:07d}".format(iteration), **arguments) + if ( + data_loader_val is not None + and test_period > 0 + and iteration % test_period == 0 + ): + meters_val = MetricLogger(delimiter=" ") + synchronize() + output_folder = os.path.join(cfg.OUTPUT_DIR, "validation", str(iteration)) + os.makedirs(output_folder, exist_ok=True) + _ = inference( # The result can be used for additional logging, e. g. for TensorBoard + model, + # The method changes the segmentation mask format in a data loader, + # so every time a new data loader is created: + make_data_loader( + cfg, + is_train=False, + is_distributed=(get_world_size() > 1), + is_for_period=True, + ), + dataset_name="[Validation]", + iou_types=iou_types, + box_only=False if cfg.MODEL.RETINANET_ON else cfg.MODEL.RPN_ONLY, + device=cfg.MODEL.DEVICE, + expected_results=cfg.TEST.EXPECTED_RESULTS, + expected_results_sigma_tol=cfg.TEST.EXPECTED_RESULTS_SIGMA_TOL, + output_folder=output_folder, + ) + synchronize() + # model.train() + # with torch.no_grad(): + # # Should be one image for each GPU: + # for iteration_val, (images_val, targets_val, _) in enumerate( + # tqdm(data_loader_val) + # ): + # images_val = images_val.to(device) + # targets_val = [target.to(device) for target in targets_val] + # loss_dict = model(images_val, targets_val) + # losses = sum(loss for loss in loss_dict.values()) + # loss_dict_reduced = reduce_loss_dict(loss_dict) + # losses_reduced = sum(loss for loss in loss_dict_reduced.values()) + # meters_val.update(loss=losses_reduced, **loss_dict_reduced) + # synchronize() + # logger.info( + # meters_val.delimiter.join( + # [ + # "[Validation]: ", + # "eta: {eta}", + # "iter: {iter}", + # "{meters}", + # "lr: {lr:.6f}", + # "max mem: {memory:.0f}", + # ] + # ).format( + # eta=eta_string, + # iter=iteration, + # meters=str(meters_val), + # lr=optimizer.param_groups[0]["lr"], + # memory=torch.cuda.max_memory_allocated() / 1024.0 / 1024.0, + # ) + # ) if iteration == max_iter: checkpointer.save("model_final", **arguments) diff --git a/tools/train_net.py b/tools/train_net.py index e4f95f015..9f951198e 100644 --- a/tools/train_net.py +++ b/tools/train_net.py @@ -36,7 +36,9 @@ def train(cfg, local_rank, distributed): if distributed: model = torch.nn.parallel.DistributedDataParallel( - model, device_ids=[local_rank], output_device=local_rank, + model, + device_ids=[local_rank], + output_device=local_rank, # this should be removed if we update BatchNorm stats broadcast_buffers=False, ) @@ -60,16 +62,30 @@ def train(cfg, local_rank, distributed): start_iter=arguments["iteration"], ) + test_period = cfg.SOLVER.TEST_PERIOD + if test_period > 0: + data_loader_val = make_data_loader( + cfg, is_train=False, is_distributed=distributed, is_for_period=True + ) + else: + data_loader_val = None + checkpoint_period = cfg.SOLVER.CHECKPOINT_PERIOD + print("data_loader_val") + print(data_loader_val) + do_train( + cfg, model, data_loader, + data_loader_val, optimizer, scheduler, checkpointer, device, checkpoint_period, + test_period, arguments, ) @@ -93,7 +109,9 @@ def run_test(cfg, model, distributed): mkdir(output_folder) output_folders[idx] = output_folder data_loaders_val = make_data_loader(cfg, is_train=False, is_distributed=distributed) - for output_folder, dataset_name, data_loader_val in zip(output_folders, dataset_names, data_loaders_val): + for output_folder, dataset_name, data_loader_val in zip( + output_folders, dataset_names, data_loaders_val + ): inference( model, data_loader_val, @@ -138,9 +156,7 @@ def main(): if args.distributed: torch.cuda.set_device(args.local_rank) - torch.distributed.init_process_group( - backend="nccl", init_method="env://" - ) + torch.distributed.init_process_group(backend="nccl", init_method="env://") synchronize() cfg.merge_from_file(args.config_file)