Skip to content
2 changes: 2 additions & 0 deletions maskrcnn_benchmark/config/defaults.py
Original file line number Diff line number Diff line change
Expand Up @@ -395,6 +395,8 @@
_C.SOLVER.WARMUP_METHOD = "linear"

_C.SOLVER.CHECKPOINT_PERIOD = 2500
# Validate every 2500
_C.SOLVER.TEST_PERIOD = 2500

# Number of images per batch
# This is global, so if we have 8 GPUs and IMS_PER_BATCH = 16, each GPU will
Expand Down
10 changes: 7 additions & 3 deletions maskrcnn_benchmark/data/build.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,7 +104,9 @@ def make_batch_data_sampler(
return batch_sampler


def make_data_loader(cfg, is_train=True, is_distributed=False, start_iter=0):
def make_data_loader(
cfg, is_train=True, is_distributed=False, start_iter=0, is_for_period=False
):
num_gpus = get_world_size()
if is_train:
images_per_batch = cfg.SOLVER.IMS_PER_BATCH
Expand Down Expand Up @@ -151,7 +153,9 @@ def make_data_loader(cfg, is_train=True, is_distributed=False, start_iter=0):
dataset_list = cfg.DATASETS.TRAIN if is_train else cfg.DATASETS.TEST

transforms = build_transforms(cfg, is_train)
datasets = build_dataset(dataset_list, transforms, DatasetCatalog, is_train)
datasets = build_dataset(
dataset_list, transforms, DatasetCatalog, is_train or is_for_period
)

data_loaders = []
for dataset in datasets:
Expand All @@ -168,7 +172,7 @@ def make_data_loader(cfg, is_train=True, is_distributed=False, start_iter=0):
collate_fn=collator,
)
data_loaders.append(data_loader)
if is_train:
if is_train or is_for_period:
# during training, a single (possibly concatenated) data_loader is returned
assert len(data_loaders) == 1
return data_loaders[0]
Expand Down
83 changes: 82 additions & 1 deletion maskrcnn_benchmark/engine/trainer.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,19 @@
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved.
import datetime
import logging
import os
import time

import torch
import torch.distributed as dist
from tqdm import tqdm

from maskrcnn_benchmark.utils.comm import get_world_size
from maskrcnn_benchmark.data import make_data_loader
from maskrcnn_benchmark.utils.comm import get_world_size, synchronize
from maskrcnn_benchmark.utils.metric_logger import MetricLogger
from maskrcnn_benchmark.engine.inference import inference

from apex import amp


def reduce_loss_dict(loss_dict):
Expand Down Expand Up @@ -36,13 +42,16 @@ def reduce_loss_dict(loss_dict):


def do_train(
cfg,
model,
data_loader,
data_loader_val,
optimizer,
scheduler,
checkpointer,
device,
checkpoint_period,
test_period,
arguments,
):
logger = logging.getLogger("maskrcnn_benchmark.trainer")
Expand All @@ -53,6 +62,14 @@ def do_train(
model.train()
start_training_time = time.time()
end = time.time()

iou_types = ("bbox",)
if cfg.MODEL.MASK_ON:
iou_types = iou_types + ("segm",)
if cfg.MODEL.KEYPOINT_ON:
iou_types = iou_types + ("keypoints",)
dataset_names = cfg.DATASETS.TEST

for iteration, (images, targets, _) in enumerate(data_loader, start_iter):
data_time = time.time() - end
iteration = iteration + 1
Expand All @@ -74,6 +91,10 @@ def do_train(

optimizer.zero_grad()
losses.backward()
# Note: If mixed precision is not used, this ends up doing nothing
# Otherwise apply loss scaling for mixed-precision recipe
# with amp.scale_loss(losses, optimizer) as scaled_losses:
# scaled_losses.backward()
optimizer.step()

batch_time = time.time() - end
Expand Down Expand Up @@ -103,6 +124,66 @@ def do_train(
)
if iteration % checkpoint_period == 0:
checkpointer.save("model_{:07d}".format(iteration), **arguments)
if (
data_loader_val is not None
and test_period > 0
and iteration % test_period == 0
):
meters_val = MetricLogger(delimiter=" ")
synchronize()
output_folder = os.path.join(cfg.OUTPUT_DIR, "validation", str(iteration))
os.makedirs(output_folder, exist_ok=True)
_ = inference( # The result can be used for additional logging, e. g. for TensorBoard
model,
# The method changes the segmentation mask format in a data loader,
# so every time a new data loader is created:
make_data_loader(
cfg,
is_train=False,
is_distributed=(get_world_size() > 1),
is_for_period=True,
),
dataset_name="[Validation]",
iou_types=iou_types,
box_only=False if cfg.MODEL.RETINANET_ON else cfg.MODEL.RPN_ONLY,
device=cfg.MODEL.DEVICE,
expected_results=cfg.TEST.EXPECTED_RESULTS,
expected_results_sigma_tol=cfg.TEST.EXPECTED_RESULTS_SIGMA_TOL,
output_folder=output_folder,
)
synchronize()
# model.train()
# with torch.no_grad():
# # Should be one image for each GPU:
# for iteration_val, (images_val, targets_val, _) in enumerate(
# tqdm(data_loader_val)
# ):
# images_val = images_val.to(device)
# targets_val = [target.to(device) for target in targets_val]
# loss_dict = model(images_val, targets_val)
# losses = sum(loss for loss in loss_dict.values())
# loss_dict_reduced = reduce_loss_dict(loss_dict)
# losses_reduced = sum(loss for loss in loss_dict_reduced.values())
# meters_val.update(loss=losses_reduced, **loss_dict_reduced)
# synchronize()
# logger.info(
# meters_val.delimiter.join(
# [
# "[Validation]: ",
# "eta: {eta}",
# "iter: {iter}",
# "{meters}",
# "lr: {lr:.6f}",
# "max mem: {memory:.0f}",
# ]
# ).format(
# eta=eta_string,
# iter=iteration,
# meters=str(meters_val),
# lr=optimizer.param_groups[0]["lr"],
# memory=torch.cuda.max_memory_allocated() / 1024.0 / 1024.0,
# )
# )
if iteration == max_iter:
checkpointer.save("model_final", **arguments)

Expand Down
26 changes: 21 additions & 5 deletions tools/train_net.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,9 @@ def train(cfg, local_rank, distributed):

if distributed:
model = torch.nn.parallel.DistributedDataParallel(
model, device_ids=[local_rank], output_device=local_rank,
model,
device_ids=[local_rank],
output_device=local_rank,
# this should be removed if we update BatchNorm stats
broadcast_buffers=False,
)
Expand All @@ -60,16 +62,30 @@ def train(cfg, local_rank, distributed):
start_iter=arguments["iteration"],
)

test_period = cfg.SOLVER.TEST_PERIOD
if test_period > 0:
data_loader_val = make_data_loader(
cfg, is_train=False, is_distributed=distributed, is_for_period=True
)
else:
data_loader_val = None

checkpoint_period = cfg.SOLVER.CHECKPOINT_PERIOD

print("data_loader_val")
print(data_loader_val)

do_train(
cfg,
model,
data_loader,
data_loader_val,
optimizer,
scheduler,
checkpointer,
device,
checkpoint_period,
test_period,
arguments,
)

Expand All @@ -93,7 +109,9 @@ def run_test(cfg, model, distributed):
mkdir(output_folder)
output_folders[idx] = output_folder
data_loaders_val = make_data_loader(cfg, is_train=False, is_distributed=distributed)
for output_folder, dataset_name, data_loader_val in zip(output_folders, dataset_names, data_loaders_val):
for output_folder, dataset_name, data_loader_val in zip(
output_folders, dataset_names, data_loaders_val
):
inference(
model,
data_loader_val,
Expand Down Expand Up @@ -138,9 +156,7 @@ def main():

if args.distributed:
torch.cuda.set_device(args.local_rank)
torch.distributed.init_process_group(
backend="nccl", init_method="env://"
)
torch.distributed.init_process_group(backend="nccl", init_method="env://")
synchronize()

cfg.merge_from_file(args.config_file)
Expand Down