Skip to content
This repository was archived by the owner on Oct 31, 2023. It is now read-only.
Open
Show file tree
Hide file tree
Changes from 9 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -29,3 +29,8 @@ dist/
/datasets
/models
/output

/inference
last_checkpoint
log.txt

3 changes: 2 additions & 1 deletion configs/e2e_mask_rcnn_R_50_FPN_1x.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,8 @@ MODEL:
SHARE_BOX_FEATURE_EXTRACTOR: False
MASK_ON: True
DATASETS:
TRAIN: ("coco_2014_train", "coco_2014_valminusminival")
TRAIN: ("coco_2014_train",)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you revert this change? All the models have been trained using the new coco_2017train dataset, which corresponds to coco_2014_train + coco_2014_valminusminival. If you want to evaluate at every N iterations, you could do it on the coco_2014_minival?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I've reverted it and created another config file where the number of iterations for validation specified:
https://github.com/facebookresearch/maskrcnn-benchmark/pull/828/files#diff-4dd26a63ac00a49aeb10985800d7f21c

VAL: ("coco_2014_valminusminival",)
TEST: ("coco_2014_minival",)
DATALOADER:
SIZE_DIVISIBILITY: 32
Expand Down
2 changes: 2 additions & 0 deletions maskrcnn_benchmark/config/defaults.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,8 @@
_C.DATASETS = CN()
# List of the dataset names for training, as present in paths_catalog.py
_C.DATASETS.TRAIN = ()
# List of the dataset names for validating, as present in paths_catalog.py
_C.DATASETS.VAL = ()
# List of the dataset names for testing, as present in paths_catalog.py
_C.DATASETS.TEST = ()

Expand Down
26 changes: 14 additions & 12 deletions maskrcnn_benchmark/data/build.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,20 +9,21 @@

from . import datasets as D
from . import samplers
from .dataset_mode import DatasetMode

from .collate_batch import BatchCollator
from .transforms import build_transforms


def build_dataset(dataset_list, transforms, dataset_catalog, is_train=True):
def build_dataset(dataset_list, transforms, dataset_catalog, mode=DatasetMode.TRAIN):
"""
Arguments:
dataset_list (list[str]): Contains the names of the datasets, i.e.,
coco_2014_trian, coco_2014_val, etc
transforms (callable): transforms to apply to each (image, target) sample
dataset_catalog (DatasetCatalog): contains the information on how to
construct a dataset.
is_train (bool): whether to setup the dataset for training or testing
mode (DatasetMode): whether to setup the dataset for training, validation, or testing
"""
if not isinstance(dataset_list, (list, tuple)):
raise RuntimeError(
Expand All @@ -36,16 +37,16 @@ def build_dataset(dataset_list, transforms, dataset_catalog, is_train=True):
# for COCODataset, we want to remove images without annotations
# during training
if data["factory"] == "COCODataset":
args["remove_images_without_annotations"] = is_train
args["remove_images_without_annotations"] = mode != DatasetMode.TEST
if data["factory"] == "PascalVOCDataset":
args["use_difficult"] = not is_train
args["use_difficult"] = mode == DatasetMode.TEST
args["transforms"] = transforms
# make dataset from factory
dataset = factory(**args)
datasets.append(dataset)

# for testing, return a list of datasets
if not is_train:
if mode != DatasetMode.TEST:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Even though not really the best thing to do, I believe in most cases we simply evaluate on the test dataset after N iterations, so I think that we can remove the VAL part altogether.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I added another boolean flag instead for controlling the way of data loader creating:
https://github.com/facebookresearch/maskrcnn-benchmark/pull/828/files#diff-48c338613bdbf422235cdb2ef17201f7R77

return datasets

# for training, concatenate all datasets into a single one
Expand Down Expand Up @@ -104,9 +105,9 @@ def make_batch_data_sampler(
return batch_sampler


def make_data_loader(cfg, is_train=True, is_distributed=False, start_iter=0):
def make_data_loader(cfg, mode=DatasetMode.TRAIN, is_distributed=False, start_iter=0):
num_gpus = get_world_size()
if is_train:
if mode == DatasetMode.TRAIN:
images_per_batch = cfg.SOLVER.IMS_PER_BATCH
assert (
images_per_batch % num_gpus == 0
Expand All @@ -115,6 +116,7 @@ def make_data_loader(cfg, is_train=True, is_distributed=False, start_iter=0):
images_per_gpu = images_per_batch // num_gpus
shuffle = True
num_iters = cfg.SOLVER.MAX_ITER
dataset_list = cfg.DATASETS.TRAIN
else:
images_per_batch = cfg.TEST.IMS_PER_BATCH
assert (
Expand All @@ -125,6 +127,7 @@ def make_data_loader(cfg, is_train=True, is_distributed=False, start_iter=0):
shuffle = False if not is_distributed else True
num_iters = None
start_iter = 0
dataset_list = cfg.DATASETS.TEST if mode == DatasetMode.TEST else cfg.DATASETS.VAL

if images_per_gpu > 1:
logger = logging.getLogger(__name__)
Expand All @@ -148,10 +151,9 @@ def make_data_loader(cfg, is_train=True, is_distributed=False, start_iter=0):
"maskrcnn_benchmark.config.paths_catalog", cfg.PATHS_CATALOG, True
)
DatasetCatalog = paths_catalog.DatasetCatalog
dataset_list = cfg.DATASETS.TRAIN if is_train else cfg.DATASETS.TEST

transforms = build_transforms(cfg, is_train)
datasets = build_dataset(dataset_list, transforms, DatasetCatalog, is_train)
transforms = build_transforms(cfg, mode)
datasets = build_dataset(dataset_list, transforms, DatasetCatalog, mode)

data_loaders = []
for dataset in datasets:
Expand All @@ -168,8 +170,8 @@ def make_data_loader(cfg, is_train=True, is_distributed=False, start_iter=0):
collate_fn=collator,
)
data_loaders.append(data_loader)
if is_train:
# during training, a single (possibly concatenated) data_loader is returned
if mode != DatasetMode.TEST:
# during training and validation, a single (possibly concatenated) data_loader is returned
assert len(data_loaders) == 1
return data_loaders[0]
return data_loaders
8 changes: 8 additions & 0 deletions maskrcnn_benchmark/data/dataset_mode.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
# Author: Petr Vytovtov <p.vytovtov@partner.samsung.com>
from enum import Enum


class DatasetMode(Enum):
TRAIN = 1
VALID = 2
TEST = 3
5 changes: 3 additions & 2 deletions maskrcnn_benchmark/data/transforms/build.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved.
from . import transforms as T
from maskrcnn_benchmark.data.dataset_mode import DatasetMode


def build_transforms(cfg, is_train=True):
if is_train:
def build_transforms(cfg, mode=DatasetMode.TRAIN):
if mode == DatasetMode.TRAIN:
min_size = cfg.INPUT.MIN_SIZE_TRAIN
max_size = cfg.INPUT.MAX_SIZE_TRAIN
flip_prob = 0.5 # cfg.INPUT.FLIP_PROB_TRAIN
Expand Down
34 changes: 33 additions & 1 deletion maskrcnn_benchmark/engine/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
import torch
import torch.distributed as dist

from maskrcnn_benchmark.utils.comm import get_world_size
from maskrcnn_benchmark.utils.comm import get_world_size, synchronize
from maskrcnn_benchmark.utils.metric_logger import MetricLogger

from apex import amp
Expand Down Expand Up @@ -45,6 +45,7 @@ def do_train(
device,
checkpoint_period,
arguments,
data_loader_val=None,
):
logger = logging.getLogger("maskrcnn_benchmark.trainer")
logger.info("Start training")
Expand Down Expand Up @@ -107,6 +108,37 @@ def do_train(
)
if iteration % checkpoint_period == 0:
checkpointer.save("model_{:07d}".format(iteration), **arguments)
if data_loader_val is not None:
meters_val = MetricLogger(delimiter=" ")
synchronize()
with torch.no_grad():
for idx_val, (images_val, targets_val, _) in enumerate(data_loader_val):
images_val = images_val.to(device)
targets_val = [target.to(device) for target in targets_val]
loss_dict = model(images_val, targets_val)
losses = sum(loss for loss in loss_dict.values())
loss_dict_reduced = reduce_loss_dict(loss_dict)
losses_reduced = sum(loss for loss in loss_dict_reduced.values())
meters_val.update(loss=losses_reduced, **loss_dict_reduced)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If I understand it correctly, you only evaluate the loss here, while a metric which is generally more useful is to report the mAP as we do for testing.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copy link

@qihao-huang qihao-huang May 28, 2019

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This line records batch's loss of val-set using current train iteration model, right?
So, if the purpose is to check our model is over fitting or not, we need to calculate the average loss of val-set using current train iteration model. And use this average loss to decide early stop.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

synchronize()
logger.info(
meters_val.delimiter.join(
[
"[Validation]: ",
"eta: {eta}",
"iter: {iter}",
"{meters}",
"lr: {lr:.6f}",
"max mem: {memory:.0f}",
]
).format(
eta=eta_string,
iter=iteration,
meters=str(meters),
Copy link

@droseger droseger May 22, 2019

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

str(meters) needs to be str(meters_val) here, otherwise the training metrics are displayed

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh... Yes. Fixed.

lr=optimizer.param_groups[0]["lr"],
memory=torch.cuda.max_memory_allocated() / 1024.0 / 1024.0,
)
)
if iteration == max_iter:
checkpointer.save("model_final", **arguments)

Expand Down
3 changes: 2 additions & 1 deletion tools/test_net.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
import torch
from maskrcnn_benchmark.config import cfg
from maskrcnn_benchmark.data import make_data_loader
from maskrcnn_benchmark.data.dataset_mode import DatasetMode
from maskrcnn_benchmark.engine.inference import inference
from maskrcnn_benchmark.modeling.detector import build_detection_model
from maskrcnn_benchmark.utils.checkpoint import DetectronCheckpointer
Expand Down Expand Up @@ -87,7 +88,7 @@ def main():
output_folder = os.path.join(cfg.OUTPUT_DIR, "inference", dataset_name)
mkdir(output_folder)
output_folders[idx] = output_folder
data_loaders_val = make_data_loader(cfg, is_train=False, is_distributed=distributed)
data_loaders_val = make_data_loader(cfg, mode=DatasetMode.TEST, is_distributed=distributed)
for output_folder, dataset_name, data_loader_val in zip(output_folders, dataset_names, data_loaders_val):
inference(
model,
Expand Down
11 changes: 9 additions & 2 deletions tools/train_net.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
import torch
from maskrcnn_benchmark.config import cfg
from maskrcnn_benchmark.data import make_data_loader
from maskrcnn_benchmark.data.dataset_mode import DatasetMode
from maskrcnn_benchmark.solver import make_lr_scheduler
from maskrcnn_benchmark.solver import make_optimizer
from maskrcnn_benchmark.engine.inference import inference
Expand Down Expand Up @@ -67,10 +68,15 @@ def train(cfg, local_rank, distributed):

data_loader = make_data_loader(
cfg,
is_train=True,
mode=DatasetMode.TRAIN,
is_distributed=distributed,
start_iter=arguments["iteration"],
)
data_loader_val = make_data_loader(
cfg,
mode=DatasetMode.VALID,
is_distributed=distributed,
)

checkpoint_period = cfg.SOLVER.CHECKPOINT_PERIOD

Expand All @@ -83,6 +89,7 @@ def train(cfg, local_rank, distributed):
device,
checkpoint_period,
arguments,
data_loader_val,
)

return model
Expand All @@ -104,7 +111,7 @@ def run_test(cfg, model, distributed):
output_folder = os.path.join(cfg.OUTPUT_DIR, "inference", dataset_name)
mkdir(output_folder)
output_folders[idx] = output_folder
data_loaders_val = make_data_loader(cfg, is_train=False, is_distributed=distributed)
data_loaders_val = make_data_loader(cfg, mode=DatasetMode.TEST, is_distributed=distributed)
for output_folder, dataset_name, data_loader_val in zip(output_folders, dataset_names, data_loaders_val):
inference(
model,
Expand Down