From 9b9f0a5d35a7ff00a41ea3cc0288b0b281e39631 Mon Sep 17 00:00:00 2001 From: CoinCheung <867153576@qq.com> Date: Sat, 1 Jun 2019 09:11:31 +0800 Subject: [PATCH] on-line hard example mining --- maskrcnn_benchmark/config/defaults.py | 2 + maskrcnn_benchmark/layers/smooth_l1_loss.py | 12 ++- .../balanced_positive_negative_sampler.py | 50 ++++++++++ .../modeling/roi_heads/box_head/box_head.py | 7 ++ .../modeling/roi_heads/box_head/loss.py | 97 ++++++++++++++++--- maskrcnn_benchmark/modeling/rpn/loss.py | 2 +- .../modeling/rpn/retinanet/loss.py | 2 +- 7 files changed, 153 insertions(+), 19 deletions(-) diff --git a/maskrcnn_benchmark/config/defaults.py b/maskrcnn_benchmark/config/defaults.py index 65fbdaddd..4c6b141c2 100644 --- a/maskrcnn_benchmark/config/defaults.py +++ b/maskrcnn_benchmark/config/defaults.py @@ -191,6 +191,8 @@ _C.MODEL.ROI_HEADS.BATCH_SIZE_PER_IMAGE = 512 # Target fraction of RoI minibatch that is labeled foreground (i.e. class > 0) _C.MODEL.ROI_HEADS.POSITIVE_FRACTION = 0.25 +# whether to use hard-mining +_C.MODEL.ROI_HEADS.OHEM = False # Only used on test mode diff --git a/maskrcnn_benchmark/layers/smooth_l1_loss.py b/maskrcnn_benchmark/layers/smooth_l1_loss.py index 9c4664bb4..fa20a12c0 100644 --- a/maskrcnn_benchmark/layers/smooth_l1_loss.py +++ b/maskrcnn_benchmark/layers/smooth_l1_loss.py @@ -3,7 +3,7 @@ # TODO maybe push this to nn? -def smooth_l1_loss(input, target, beta=1. / 9, size_average=True): +def smooth_l1_loss(input, target, beta=1. / 9, reduction='mean'): """ very similar to the smooth_l1_loss from pytorch, but with the extra beta parameter @@ -11,6 +11,10 @@ def smooth_l1_loss(input, target, beta=1. / 9, size_average=True): n = torch.abs(input - target) cond = n < beta loss = torch.where(cond, 0.5 * n ** 2 / beta, n - 0.5 * beta) - if size_average: - return loss.mean() - return loss.sum() + if reduction == 'mean': + loss = loss.mean() + elif reduction == 'sum': + loss = loss.sum() + elif reduction == 'none': + pass + return loss diff --git a/maskrcnn_benchmark/modeling/balanced_positive_negative_sampler.py b/maskrcnn_benchmark/modeling/balanced_positive_negative_sampler.py index c0bd00444..26386a732 100644 --- a/maskrcnn_benchmark/modeling/balanced_positive_negative_sampler.py +++ b/maskrcnn_benchmark/modeling/balanced_positive_negative_sampler.py @@ -66,3 +66,53 @@ def __call__(self, matched_idxs): neg_idx.append(neg_idx_per_image_mask) return pos_idx, neg_idx + + +class OhemPositiveNegativeSampler(object): + """ + This class samples batches, ensuring that they contain a fixed proportion of positives + """ + + def __init__(self): + """ + Arguments: + batch_size_per_image (int): number of elements to be selected per image + """ + pass + + def __call__(self, matched_idxs): + """ + Arguments: + matched idxs: list of tensors containing -1, 0 or positive values. + Each tensor corresponds to a specific image. + -1 values are ignored, 0 are considered as negatives and > 0 as + positives. + + Returns: + pos_idx (list[tensor]) + neg_idx (list[tensor]) + + Returns two lists of binary masks for each image. + The first list contains the positive elements that were selected, + and the second list the negative example. + """ + pos_idx = [] + neg_idx = [] + for matched_idxs_per_image in matched_idxs: + pos_idx_per_image = torch.nonzero(matched_idxs_per_image >= 1).squeeze(1) + neg_idx_per_image = torch.nonzero(matched_idxs_per_image == 0).squeeze(1) + + # create binary mask from indices + pos_idx_per_image_mask = torch.zeros_like( + matched_idxs_per_image, dtype=torch.uint8 + ) + neg_idx_per_image_mask = torch.zeros_like( + matched_idxs_per_image, dtype=torch.uint8 + ) + pos_idx_per_image_mask[pos_idx_per_image] = 1 + neg_idx_per_image_mask[neg_idx_per_image] = 1 + + pos_idx.append(pos_idx_per_image_mask) + neg_idx.append(neg_idx_per_image_mask) + + return pos_idx, neg_idx diff --git a/maskrcnn_benchmark/modeling/roi_heads/box_head/box_head.py b/maskrcnn_benchmark/modeling/roi_heads/box_head/box_head.py index 482081b8d..63e930742 100644 --- a/maskrcnn_benchmark/modeling/roi_heads/box_head/box_head.py +++ b/maskrcnn_benchmark/modeling/roi_heads/box_head/box_head.py @@ -20,6 +20,7 @@ def __init__(self, cfg, in_channels): cfg, self.feature_extractor.out_channels) self.post_processor = make_roi_box_post_processor(cfg) self.loss_evaluator = make_roi_box_loss_evaluator(cfg) + self.ohem = cfg.MODEL.ROI_HEADS.OHEM def forward(self, features, proposals, targets=None): """ @@ -41,6 +42,12 @@ def forward(self, features, proposals, targets=None): # positive / negative ratio with torch.no_grad(): proposals = self.loss_evaluator.subsample(proposals, targets) + if self.ohem: + x = self.feature_extractor(features, proposals) + class_logits, box_regression = self.predictor(x) + proposals = self.loss_evaluator.mining( + [class_logits], [box_regression] + ) # extract features that will be fed to the final classifier. The # feature_extractor generally corresponds to the pooler + heads diff --git a/maskrcnn_benchmark/modeling/roi_heads/box_head/loss.py b/maskrcnn_benchmark/modeling/roi_heads/box_head/loss.py index 9f2771d02..3c3f4bb96 100644 --- a/maskrcnn_benchmark/modeling/roi_heads/box_head/loss.py +++ b/maskrcnn_benchmark/modeling/roi_heads/box_head/loss.py @@ -7,7 +7,8 @@ from maskrcnn_benchmark.modeling.matcher import Matcher from maskrcnn_benchmark.structures.boxlist_ops import boxlist_iou from maskrcnn_benchmark.modeling.balanced_positive_negative_sampler import ( - BalancedPositiveNegativeSampler + BalancedPositiveNegativeSampler, + OhemPositiveNegativeSampler ) from maskrcnn_benchmark.modeling.utils import cat @@ -19,21 +20,23 @@ class FastRCNNLossComputation(object): """ def __init__( - self, - proposal_matcher, - fg_bg_sampler, - box_coder, + self, + proposal_matcher, + fg_bg_sampler, + box_coder, + batch_size_per_image, cls_agnostic_bbox_reg=False ): """ Arguments: proposal_matcher (Matcher) - fg_bg_sampler (BalancedPositiveNegativeSampler) + fg_bg_sampler (BalancedPositiveNegativeSampler, or OhemPositiveNegativeSampler) box_coder (BoxCoder) """ self.proposal_matcher = proposal_matcher self.fg_bg_sampler = fg_bg_sampler self.box_coder = box_coder + self.batch_size_per_image = batch_size_per_image self.cls_agnostic_bbox_reg = cls_agnostic_bbox_reg def match_targets_to_proposals(self, proposal, target): @@ -105,16 +108,79 @@ def subsample(self, proposals, targets): # distributed sampled proposals, that were obtained on all feature maps # concatenated via the fg_bg_sampler, into individual feature map levels + self.n_proposals_per_img = [] for img_idx, (pos_inds_img, neg_inds_img) in enumerate( zip(sampled_pos_inds, sampled_neg_inds) ): img_sampled_inds = torch.nonzero(pos_inds_img | neg_inds_img).squeeze(1) proposals_per_image = proposals[img_idx][img_sampled_inds] + self.n_proposals_per_img.append(len(proposals_per_image)) proposals[img_idx] = proposals_per_image self._proposals = proposals return proposals + def mining(self, class_logits, box_regression): + """ + Similiar role as sumsample(), but return the rois with top loss. + + Arguments: + class_logits (list[Tensor]) + box_regression (list[Tensor]) + + Returns: + proposals (list[BoxList]) + """ + + class_logits = cat(class_logits, dim=0) + box_regression = cat(box_regression, dim=0) + device = class_logits.device + + if not hasattr(self, "_proposals"): + raise RuntimeError("subsample needs to be called before") + + proposals = self._proposals + + labels = cat([proposal.get_field("labels") for proposal in proposals], dim=0) + regression_targets = cat( + [proposal.get_field("regression_targets") for proposal in proposals], dim=0 + ) + + classification_loss = F.cross_entropy(class_logits, labels, reduction='none') + + # get indices that correspond to the regression targets for + # the corresponding ground truth labels, to be used with + # advanced indexing + sampled_pos_inds_subset = torch.nonzero(labels > 0).squeeze(1) + labels_pos = labels[sampled_pos_inds_subset] + if self.cls_agnostic_bbox_reg: + map_inds = torch.tensor([4, 5, 6, 7], device=device) + else: + map_inds = 4 * labels_pos[:, None] + torch.tensor( + [0, 1, 2, 3], device=device) + + box_loss = smooth_l1_loss( + box_regression[sampled_pos_inds_subset[:, None], map_inds], + regression_targets[sampled_pos_inds_subset], + reduction='none', + beta=1, + ).sum(dim=1, keepdim=True) + ohem_loss = classification_loss.clone() + ohem_loss[sampled_pos_inds_subset[:, None]] = ohem_loss[sampled_pos_inds_subset[:, None]] + box_loss + if ohem_loss.size(0) > self.batch_size_per_image: + ohem_idx = ohem_loss.topk(self.batch_size_per_image)[1].cpu() + lengs = [0,] + self.n_proposals_per_img + milestones = torch.cumsum(torch.tensor(lengs), dim=0) + ms1 = milestones[:-1] + ms2 = milestones[1:] + ohem_idx = ohem_idx.sort()[0] + lengs = [torch.sum((el1 <= ohem_idx)*(ohem_idx < el2)) for el1, el2 in zip(ms1, ms2)] + ohem_idx = ohem_idx.split(lengs) + ohem_idx = [el-ms1[i] for i, el in enumerate(ohem_idx)] + self._proposals = [proposals[i][el] for i, el in enumerate(ohem_idx)] + + return self._proposals + def __call__(self, class_logits, box_regression): """ Computes the loss for Faster R-CNN. @@ -159,7 +225,7 @@ def __call__(self, class_logits, box_regression): box_loss = smooth_l1_loss( box_regression[sampled_pos_inds_subset[:, None], map_inds], regression_targets[sampled_pos_inds_subset], - size_average=False, + reduction='sum', beta=1, ) box_loss = box_loss / labels.numel() @@ -177,16 +243,21 @@ def make_roi_box_loss_evaluator(cfg): bbox_reg_weights = cfg.MODEL.ROI_HEADS.BBOX_REG_WEIGHTS box_coder = BoxCoder(weights=bbox_reg_weights) - fg_bg_sampler = BalancedPositiveNegativeSampler( - cfg.MODEL.ROI_HEADS.BATCH_SIZE_PER_IMAGE, cfg.MODEL.ROI_HEADS.POSITIVE_FRACTION - ) + if cfg.MODEL.ROI_HEADS.OHEM: + fg_bg_sampler = OhemPositiveNegativeSampler() + else: + fg_bg_sampler = BalancedPositiveNegativeSampler( + cfg.MODEL.ROI_HEADS.BATCH_SIZE_PER_IMAGE, + cfg.MODEL.ROI_HEADS.POSITIVE_FRACTION + ) cls_agnostic_bbox_reg = cfg.MODEL.CLS_AGNOSTIC_BBOX_REG loss_evaluator = FastRCNNLossComputation( - matcher, - fg_bg_sampler, - box_coder, + matcher, + fg_bg_sampler, + box_coder, + cfg.MODEL.ROI_HEADS.BATCH_SIZE_PER_IMAGE, cls_agnostic_bbox_reg ) diff --git a/maskrcnn_benchmark/modeling/rpn/loss.py b/maskrcnn_benchmark/modeling/rpn/loss.py index 840e35453..0d70a552e 100644 --- a/maskrcnn_benchmark/modeling/rpn/loss.py +++ b/maskrcnn_benchmark/modeling/rpn/loss.py @@ -121,7 +121,7 @@ def __call__(self, anchors, objectness, box_regression, targets): box_regression[sampled_pos_inds], regression_targets[sampled_pos_inds], beta=1.0 / 9, - size_average=False, + reduction='sum', ) / (sampled_inds.numel()) objectness_loss = F.binary_cross_entropy_with_logits( diff --git a/maskrcnn_benchmark/modeling/rpn/retinanet/loss.py b/maskrcnn_benchmark/modeling/rpn/retinanet/loss.py index 080e2153b..c1978ebeb 100644 --- a/maskrcnn_benchmark/modeling/rpn/retinanet/loss.py +++ b/maskrcnn_benchmark/modeling/rpn/retinanet/loss.py @@ -67,7 +67,7 @@ def __call__(self, anchors, box_cls, box_regression, targets): box_regression[pos_inds], regression_targets[pos_inds], beta=self.bbox_reg_beta, - size_average=False, + reduction='sum', ) / (max(1, pos_inds.numel() * self.regress_norm)) labels = labels.int()