From 2a481bfae9e33aee7631a4d5d3f4758a520e3d8f Mon Sep 17 00:00:00 2001 From: lrh Date: Fri, 10 Nov 2023 11:01:24 +0800 Subject: [PATCH 1/4] support sparseinst batch inference --- .../mmdet/deploy/object_detection_model.py | 2 +- .../mmdet/models/dense_heads/__init__.py | 1 + .../models/dense_heads/sparseinst_head.py | 63 +++++++++++++++++++ 3 files changed, 65 insertions(+), 1 deletion(-) create mode 100644 mmdeploy/codebase/mmdet/models/dense_heads/sparseinst_head.py diff --git a/mmdeploy/codebase/mmdet/deploy/object_detection_model.py b/mmdeploy/codebase/mmdet/deploy/object_detection_model.py index c6a958e5eb..dab5b074b6 100644 --- a/mmdeploy/codebase/mmdet/deploy/object_detection_model.py +++ b/mmdeploy/codebase/mmdet/deploy/object_detection_model.py @@ -241,7 +241,7 @@ def postprocessing_results(self, masks = batch_masks[i] img_h, img_w = img_metas[i]['img_shape'][:2] ori_h, ori_w = img_metas[i]['ori_shape'][:2] - if model_type in ['RTMDet', 'CondInst']: + if model_type in ['RTMDet', 'CondInst', 'SparseInst']: export_postprocess_mask = True else: export_postprocess_mask = False diff --git a/mmdeploy/codebase/mmdet/models/dense_heads/__init__.py b/mmdeploy/codebase/mmdet/models/dense_heads/__init__.py index 062bc7de52..3bee17f449 100644 --- a/mmdeploy/codebase/mmdet/models/dense_heads/__init__.py +++ b/mmdeploy/codebase/mmdet/models/dense_heads/__init__.py @@ -11,5 +11,6 @@ from . import rtmdet_ins_head # noqa: F401,F403 from . import solo_head # noqa: F401,F403 from . import solov2_head # noqa: F401,F403 +from . import sparseinst_head # noqa: F401,F403 from . import yolo_head # noqa: F401,F403 from . import yolox_head # noqa: F401,F403 diff --git a/mmdeploy/codebase/mmdet/models/dense_heads/sparseinst_head.py b/mmdeploy/codebase/mmdet/models/dense_heads/sparseinst_head.py new file mode 100644 index 0000000000..6ddbb7744c --- /dev/null +++ b/mmdeploy/codebase/mmdet/models/dense_heads/sparseinst_head.py @@ -0,0 +1,63 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from typing import Dict, List, Tuple + +import torch +import torch.nn.functional as F +from mmdet.models.utils import aligned_bilinear +from mmdet.structures import OptSampleList, SampleList +from mmengine.config import ConfigDict +from torch import Tensor + +from mmdeploy.core import FUNCTION_REWRITER + + +@torch.jit.script +def rescoring_mask(scores, mask_pred, masks): + mask_pred_ = mask_pred.float() + return scores * ((masks * mask_pred_).sum([2, 3]) / + (mask_pred_.sum([2, 3]) + 1e-6)) + + +@FUNCTION_REWRITER.register_rewriter( + 'projects.SparseInst.sparseinst.SparseInst.predict') +def sparseinst__predict( + self, + batch_inputs: Tensor, + batch_data_samples: List[dict], + rescale: bool = False, +): + """Rewrite `predict` of `SparseInst` for default backend.""" + max_shape = batch_inputs.shape[-2:] + x = self.extract_feat(batch_inputs) + output = self.decoder(x) + + pred_scores = output['pred_logits'].sigmoid() + pred_masks = output['pred_masks'].sigmoid() + pred_objectness = output['pred_scores'].sigmoid() + pred_scores = torch.sqrt(pred_scores * pred_objectness) + + # max/argmax + scores, labels = pred_scores.max(dim=-1) + # cls threshold + keep = scores > self.cls_threshold + scores = scores.where(keep, scores.new_zeros(1)) + labels = labels.where(keep, labels.new_zeros(1)) + keep = keep.unsqueeze(-1).unsqueeze(-1).expand_as(pred_masks) + pred_masks = pred_masks.where(keep, pred_masks.new_zeros(1)) + + img_meta = batch_data_samples[0].metainfo + # rescoring mask using maskness + scores = rescoring_mask(scores, + pred_masks > self.mask_threshold, + pred_masks) + h, w = img_meta['img_shape'][:2] + pred_masks = F.interpolate(pred_masks, + size=max_shape, + mode='bilinear', + align_corners=False)[:, :, :h, :w] + + bboxes = torch.zeros(scores.shape[0], scores.shape[1], 4) + dets = torch.cat([bboxes, scores.unsqueeze(-1)], dim=-1) + masks = (pred_masks > self.mask_threshold).float() + + return dets, labels, masks From 4c463ca94253371081d8c43fd8840708305b793a Mon Sep 17 00:00:00 2001 From: lrh Date: Fri, 10 Nov 2023 11:16:49 +0800 Subject: [PATCH 2/4] fix lint error --- .../models/dense_heads/sparseinst_head.py | 19 +++++++------------ 1 file changed, 7 insertions(+), 12 deletions(-) diff --git a/mmdeploy/codebase/mmdet/models/dense_heads/sparseinst_head.py b/mmdeploy/codebase/mmdet/models/dense_heads/sparseinst_head.py index 6ddbb7744c..8e4c3e487d 100644 --- a/mmdeploy/codebase/mmdet/models/dense_heads/sparseinst_head.py +++ b/mmdeploy/codebase/mmdet/models/dense_heads/sparseinst_head.py @@ -1,11 +1,8 @@ # Copyright (c) OpenMMLab. All rights reserved. -from typing import Dict, List, Tuple +from typing import List import torch import torch.nn.functional as F -from mmdet.models.utils import aligned_bilinear -from mmdet.structures import OptSampleList, SampleList -from mmengine.config import ConfigDict from torch import Tensor from mmdeploy.core import FUNCTION_REWRITER @@ -30,7 +27,7 @@ def sparseinst__predict( max_shape = batch_inputs.shape[-2:] x = self.extract_feat(batch_inputs) output = self.decoder(x) - + pred_scores = output['pred_logits'].sigmoid() pred_masks = output['pred_masks'].sigmoid() pred_objectness = output['pred_scores'].sigmoid() @@ -47,15 +44,13 @@ def sparseinst__predict( img_meta = batch_data_samples[0].metainfo # rescoring mask using maskness - scores = rescoring_mask(scores, - pred_masks > self.mask_threshold, + scores = rescoring_mask(scores, pred_masks > self.mask_threshold, pred_masks) h, w = img_meta['img_shape'][:2] - pred_masks = F.interpolate(pred_masks, - size=max_shape, - mode='bilinear', - align_corners=False)[:, :, :h, :w] - + pred_masks = F.interpolate( + pred_masks, size=max_shape, mode='bilinear', + align_corners=False)[:, :, :h, :w] + bboxes = torch.zeros(scores.shape[0], scores.shape[1], 4) dets = torch.cat([bboxes, scores.unsqueeze(-1)], dim=-1) masks = (pred_masks > self.mask_threshold).float() From 795d4cd5b9a35248fde6db6c76bb19d8d1459e34 Mon Sep 17 00:00:00 2001 From: lrh Date: Fri, 10 Nov 2023 22:21:44 +0800 Subject: [PATCH 3/4] add ut --- .../test_mmdet/test_mmdet_models.py | 90 +++++++++++++++++++ 1 file changed, 90 insertions(+) diff --git a/tests/test_codebase/test_mmdet/test_mmdet_models.py b/tests/test_codebase/test_mmdet/test_mmdet_models.py index 232e50ac5e..530f403a7b 100644 --- a/tests/test_codebase/test_mmdet/test_mmdet_models.py +++ b/tests/test_codebase/test_mmdet/test_mmdet_models.py @@ -2582,3 +2582,93 @@ def forward(self, x, param_preds, points, strides): deploy_cfg=deploy_cfg) assert rewrite_outputs is not None + + +def get_sparseinst(): + """SparseInst Config.""" + test_cfg = Config(dict(score_thr=0.4, mask_thr_binary=0.45)) + data_preprocessor = dict( + type='DetDataPreprocessor', + mean=[123.675, 116.28, 103.53], + std=[58.395, 57.12, 57.375], + bgr_to_rgb=True, + pad_mask=True, + pad_size_divisor=32) + backbone = Config( + dict( + type='ResNet', + depth=50, + out_indices=(1, 2, 3), + frozen_stages=0, + norm_cfg=dict(type='BN', requires_grad=False), + init_cfg=dict( + type='Pretrained', checkpoint='torchvision://resnet50'))) + + from projects.SparseInst.sparseinst import SparseInst + model = SparseInst( + data_preprocessor=data_preprocessor, + backbone=backbone, + encoder=dict( + type='InstanceContextEncoder', in_channels=[512, 1024, 2048]), + decoder=dict( + type='BaseIAMDecoder', in_channels=256 + 2, num_classes=80), + criterion=dict( + type='SparseInstCriterion', + num_classes=80, + assigner=dict(type='SparseInstMatcher', alpha=0.8, beta=0.2)), + test_cfg=test_cfg, + init_cfg=dict( + type='Normal', + layer='Conv2d', + std=0.01, + override=dict( + type='Normal', name='conv_cls', std=0.01, bias_prob=0.01))) + + model.requires_grad_(False) + return model + + +@pytest.mark.parametrize('backend_type', [Backend.ONNXRUNTIME]) +def test_sparseinst_predict(backend_type): + """Test predict rewrite of sparseinst.""" + check_backend(backend_type) + sparseinst = get_sparseinst() + sparseinst.cpu().eval() + + output_names = ['dets', 'labels', 'masks'] + deploy_cfg = Config( + dict( + backend_config=dict(type=backend_type.value), + onnx_config=dict(output_names=output_names, input_shape=None), + codebase_config=dict( + type='mmdet', + task='ObjectDetection', + post_processing=dict( + score_threshold=0.05, + confidence_threshold=0.005, + iou_threshold=0.5, + max_output_boxes_per_class=200, + pre_top_k=5000, + keep_top_k=100, + background_label_id=-1, + export_postprocess_mask=False)))) + + img = torch.randn(1, 3, 320, 320) + from mmdet.structures import DetDataSample + data_sample = DetDataSample(metainfo=dict(img_shape=(320, 320, 3))) + + # to get outputs of onnx model after rewrite + wrapped_model = WrapModel( + sparseinst, 'predict', batch_data_samples=[data_sample]) + rewrite_inputs = {'batch_inputs': img} + rewrite_outputs, is_backend_output = get_rewrite_outputs( + wrapped_model=wrapped_model, + model_inputs=rewrite_inputs, + deploy_cfg=deploy_cfg) + + if is_backend_output: + assert rewrite_outputs[0].shape[-1] == 5 + assert rewrite_outputs[1] is not None + assert rewrite_outputs[2] is not None + else: + assert rewrite_outputs is not None From 53d423857d8ab5c5808a784ae80db72cf7804f18 Mon Sep 17 00:00:00 2001 From: lrh Date: Fri, 10 Nov 2023 22:26:29 +0800 Subject: [PATCH 4/4] update docs --- docs/en/04-supported-codebases/mmdet.md | 1 + docs/zh_cn/04-supported-codebases/mmdet.md | 1 + 2 files changed, 2 insertions(+) diff --git a/docs/en/04-supported-codebases/mmdet.md b/docs/en/04-supported-codebases/mmdet.md index 16bbacb299..3e7815a413 100644 --- a/docs/en/04-supported-codebases/mmdet.md +++ b/docs/en/04-supported-codebases/mmdet.md @@ -220,6 +220,7 @@ Besides python API, mmdeploy SDK also provides other FFI (Foreign Function Inter | [SOLO](https://github.com/open-mmlab/mmdetection/tree/main/configs/solo) | Instance Segmentation | Y | N | N | N | Y | | [SOLOv2](https://github.com/open-mmlab/mmdetection/tree/main/configs/solov2) | Instance Segmentation | Y | N | N | N | Y | | [CondInst](https://github.com/open-mmlab/mmdetection/tree/main/configs/condinst) | Instance Segmentation | Y | Y | N | N | N | +| [SparseInst](https://github.com/open-mmlab/mmdetection/blob/main/projects/SparseInst) | Instance Segmentation | Y | Y | N | N | N | | [Panoptic FPN](https://github.com/open-mmlab/mmdetection/tree/main/configs/panoptic_fpn) | Panoptic Segmentation | Y | Y | N | N | N | | [MaskFormer](https://github.com/open-mmlab/mmdetection/tree/main/configs/maskformer) | Panoptic Segmentation | Y | Y | N | N | N | | [Mask2Former](https://github.com/open-mmlab/mmdetection/tree/main/configs/mask2former)[\*](#mask2former) | Panoptic Segmentation | Y | Y | N | N | N | diff --git a/docs/zh_cn/04-supported-codebases/mmdet.md b/docs/zh_cn/04-supported-codebases/mmdet.md index c131f76698..7b1f431158 100644 --- a/docs/zh_cn/04-supported-codebases/mmdet.md +++ b/docs/zh_cn/04-supported-codebases/mmdet.md @@ -223,6 +223,7 @@ cv2.imwrite('output_detection.png', img) | [SOLO](https://github.com/open-mmlab/mmdetection/tree/main/configs/solo) | Instance Segmentation | Y | N | N | N | Y | | [SOLOv2](https://github.com/open-mmlab/mmdetection/tree/main/configs/solov2) | Instance Segmentation | Y | N | N | N | Y | | [CondInst](https://github.com/open-mmlab/mmdetection/tree/main/configs/condinst) | Instance Segmentation | Y | Y | N | N | N | +| [SparseInst](https://github.com/open-mmlab/mmdetection/blob/main/projects/SparseInst) | Instance Segmentation | Y | Y | N | N | N | | [Panoptic FPN](https://github.com/open-mmlab/mmdetection/tree/main/configs/panoptic_fpn) | Panoptic Segmentation | Y | Y | N | N | N | | [MaskFormer](https://github.com/open-mmlab/mmdetection/tree/main/configs/maskformer) | Panoptic Segmentation | Y | Y | N | N | N | | [Mask2Former](https://github.com/open-mmlab/mmdetection/tree/main/configs/mask2former)[\*](#mask2former) | Panoptic Segmentation | Y | Y | N | N | N |