diff --git a/README.md b/README.md
index 2f3738c..643d01f 100644
--- a/README.md
+++ b/README.md
@@ -118,6 +118,9 @@ python run_net.py --config-file=configs/base.py --task=test
| RSDet-R50-FPN | DOTA1.0|1024/200|Flip|-| SGD | 1x | 68.41 | [arxiv](https://arxiv.org/abs/1911.08299) | [config](configs/rotated_retinanet/rsdet_obb_r50_fpn_1x_dota_lmr5p.py) | [model](https://cloud.tsinghua.edu.cn/f/642e200f5a8a420eb726/?dl=1) |
| ATSS-R50-FPN|DOTA1.0|1024/200| flip|-| SGD | 1x | 72.44 | [arxiv](https://arxiv.org/abs/1912.02424) | [config](configs/rotated_retinanet/rotated_retinanet_obb_r50_fpn_1x_dota_atss.py) | [model](https://cloud.tsinghua.edu.cn/f/5168189dcd364eaebce5/?dl=1) |
| Reppoints-R50-FPN|DOTA1.0|1024/200| flip|-| SGD | 1x | 56.34 | [arxiv](https://arxiv.org/abs/1904.11490) | [config](configs/rotated_retinanet/rotated_retinanet_obb_r50_fpn_1x_dota_atss.py) | [model](https://cloud.tsinghua.edu.cn/f/be359ac932c84f9c839e/?dl=1) |
+| CFA-R50-FPN|DOTA1.0|1024/200| flip|-| SGD | 1x | | [arxiv]() | [config]() | [model]() |
+| Oriented-Reppoints-R50-FPN|DOTA1.0|1024/200| flip|-| SGD | 1x | | [arxiv]() | [config]() | [model]() |
+| SASM-R50-FPN|DOTA1.0|1024/200| flip|-| SGD | 1x | | [arxiv]() | [config]() | [model]() |
**Notice**:
@@ -153,9 +156,11 @@ python run_net.py --config-file=configs/base.py --task=test
- :heavy_check_mark: Reppoints
- :heavy_check_mark: RSDet
- :heavy_check_mark: ATSS
+- :heavy_check_mark: CFA
+- :heavy_check_mark: Oriented Reppoints
+- :heavy_check_mark: SASM
- :clock3: R3Det
- :clock3: Cascade R-CNN
-- :clock3: Oriented Reppoints
- :heavy_plus_sign: DCL
- :heavy_plus_sign: Double Head OBB
- :heavy_plus_sign: Guided Anchoring
diff --git a/configs/cfa_r50_fpn_1x_dota.py b/configs/cfa_r50_fpn_1x_dota.py
new file mode 100644
index 0000000..4f341b4
--- /dev/null
+++ b/configs/cfa_r50_fpn_1x_dota.py
@@ -0,0 +1,161 @@
+# model settings
+model = dict(
+ type='RotatedRetinaNet',
+ backbone=dict(
+ type='Resnet50',
+ frozen_stages=1,
+ return_stages=["layer1","layer2","layer3","layer4"],
+ pretrained= True),
+ neck=dict(
+ type='FPN',
+ in_channels=[256, 512, 1024, 2048],
+ out_channels=256,
+ start_level=1,
+ add_extra_convs="on_input",
+ norm_cfg = dict(type='GN', num_groups=32),
+ num_outs=5),
+ bbox_head=dict(
+ type='RotatedRepPointsHead',
+ num_classes=15,
+ in_channels=256,
+ feat_channels=256,
+ point_feat_channels=256,
+ stacked_convs=3,
+ num_points=9,
+ gradient_mul=0.3,
+ point_strides=[8, 16, 32, 64, 128],
+ point_base_scale=2,
+ norm_cfg = dict(type='GN', num_groups=32),
+ loss_cls=dict(
+ type='FocalLoss',
+ use_sigmoid=True,
+ gamma=2.0,
+ alpha=0.25,
+ loss_weight=1.0),
+ loss_bbox_init=dict(type='ConvexGIoULoss', loss_weight=0.375),
+ loss_bbox_refine=dict(type='ConvexGIoULoss', loss_weight=1.0),
+ transform_method='rotrect',
+ use_reassign=True,
+ topk=6,
+ anti_factor=0.75,
+ train_cfg=dict(
+ init=dict(
+ assigner=dict(type='ConvexAssigner', scale=4, pos_num=1, assigned_labels_filled=-1),
+ allowed_border=-1,
+ pos_weight=-1,
+ debug=False),
+ refine=dict(
+ assigner=dict(
+ type='MaxConvexIoUAssigner',
+ pos_iou_thr=0.1,
+ neg_iou_thr=0.1,
+ min_pos_iou=0,
+ ignore_iof_thr=-1,
+ assigned_labels_filled=-1,
+ iou_calculator=dict(type='ConvexOverlaps')),
+ allowed_border=-1,
+ pos_weight=-1,
+ debug=False)),
+ test_cfg=dict(
+ nms_pre=2000,
+ min_bbox_size=0,
+ score_thr=0.05,
+ nms=dict(iou_thr=0.4),
+ max_per_img=2000))
+ )
+dataset = dict(
+ train=dict(
+ type="DOTADataset",
+ dataset_dir='/home/zytx121/mmrotate/data/processed_DOTA/trainval_1024_200_1.0',
+ transforms=[
+ dict(
+ type="RotatedResize",
+ min_size=1024,
+ max_size=1024
+ ),
+ dict(type='RotatedRandomFlip', prob=0.5, direction="horizontal"),
+ dict(type='RotatedRandomFlip', prob=0.5, direction="vertical"),
+ dict(
+ type = "Pad",
+ size_divisor=32),
+ dict(
+ type = "Normalize",
+ mean = [123.675, 116.28, 103.53],
+ std = [58.395, 57.12, 57.375],
+ to_bgr=False,)
+
+ ],
+ batch_size=1,
+ num_workers=4,
+ shuffle=True,
+ filter_empty_gt=False
+ ),
+ val=dict(
+ type="DOTADataset",
+ dataset_dir='/home/zytx121/mmrotate/data/processed_DOTA/trainval_1024_200_1.0',
+ transforms=[
+ dict(
+ type="RotatedResize",
+ min_size=1024,
+ max_size=1024
+ ),
+ dict(
+ type = "Pad",
+ size_divisor=32),
+ dict(
+ type = "Normalize",
+ mean = [123.675, 116.28, 103.53],
+ std = [58.395, 57.12, 57.375],
+ to_bgr=False),
+ ],
+ batch_size=2,
+ num_workers=4,
+ shuffle=False
+ ),
+ test=dict(
+ type="ImageDataset",
+ images_dir='/home/zytx121/mmrotate/data/processed_DOTA/test_1024_200_1.0/images',
+ transforms=[
+ dict(
+ type="RotatedResize",
+ min_size=1024,
+ max_size=1024
+ ),
+ dict(
+ type = "Pad",
+ size_divisor=32),
+ dict(
+ type = "Normalize",
+ mean = [123.675, 116.28, 103.53],
+ std = [58.395, 57.12, 57.375],
+ to_bgr=False,),
+ ],
+ num_workers=4,
+ batch_size=1,
+ )
+)
+
+optimizer = dict(
+ type='SGD',
+ lr=0.008,
+ momentum=0.9,
+ weight_decay=0.0001,
+ grad_clip=dict(
+ max_norm=35,
+ norm_type=2))
+
+scheduler = dict(
+ type='StepLR',
+ warmup='linear',
+ warmup_iters=500,
+ warmup_ratio=1.0 / 3,
+ milestones=[7, 10])
+
+
+logger = dict(
+ type="RunLogger")
+
+max_epoch = 12
+eval_interval = 1
+checkpoint_interval = 1
+log_interval = 50
diff --git a/configs/sasm_obb_r50_fpn_1x_dota.py b/configs/sasm_obb_r50_fpn_1x_dota.py
new file mode 100644
index 0000000..5397876
--- /dev/null
+++ b/configs/sasm_obb_r50_fpn_1x_dota.py
@@ -0,0 +1,155 @@
+# model settings
+model = dict(
+ type='RotatedRetinaNet',
+ backbone=dict(
+ type='Resnet50',
+ frozen_stages=1,
+ return_stages=["layer1","layer2","layer3","layer4"],
+ pretrained= True),
+ neck=dict(
+ type='FPN',
+ in_channels=[256, 512, 1024, 2048],
+ out_channels=256,
+ start_level=1,
+ add_extra_convs="on_input",
+ norm_cfg = dict(type='GN', num_groups=32),
+ num_outs=5),
+ bbox_head=dict(
+ type='SAMRepPointsHead',
+ num_classes=15,
+ in_channels=256,
+ feat_channels=256,
+ point_feat_channels=256,
+ stacked_convs=3,
+ num_points=9,
+ gradient_mul=0.3,
+ point_strides=[8, 16, 32, 64, 128],
+ point_base_scale=2,
+ norm_cfg = dict(type='GN', num_groups=32),
+ loss_cls=dict(
+ type='FocalLoss',
+ use_sigmoid=True,
+ gamma=2.0,
+ alpha=0.25,
+ loss_weight=1.0),
+ loss_bbox_init=dict(type='ConvexGIoULoss', loss_weight=0.375),
+ loss_bbox_refine=dict(type='BCConvexGIoULoss', loss_weight=1.0),
+ transform_method='rotrect',
+ use_reassign=False,
+ topk=6,
+ anti_factor=0.75,
+ train_cfg=dict(
+ init=dict(
+ assigner=dict(type='ConvexAssigner', scale=4, pos_num=1, assigned_labels_filled=-1),
+ allowed_border=-1,
+ pos_weight=-1,
+ debug=False),
+ refine=dict(
+ assigner=dict(type='SASAssigner', topk=9),
+ allowed_border=-1,
+ pos_weight=-1,
+ debug=False)),
+ test_cfg=dict(
+ nms_pre=2000,
+ min_bbox_size=0,
+ score_thr=0.05,
+ nms=dict(iou_thr=0.4),
+ max_per_img=2000))
+ )
+dataset = dict(
+ train=dict(
+ type="DOTADataset",
+ dataset_dir='/home/zytx121/mmrotate/data/processed_DOTA/trainval_1024_200_1.0',
+ transforms=[
+ dict(
+ type="RotatedResize",
+ min_size=1024,
+ max_size=1024
+ ),
+ dict(type='RotatedRandomFlip', prob=0.5, direction="horizontal"),
+ dict(type='RotatedRandomFlip', prob=0.5, direction="vertical"),
+ dict(
+ type = "Pad",
+ size_divisor=32),
+ dict(
+ type = "Normalize",
+ mean = [123.675, 116.28, 103.53],
+ std = [58.395, 57.12, 57.375],
+ to_bgr=False,)
+
+ ],
+ batch_size=2,
+ num_workers=4,
+ shuffle=True,
+ filter_empty_gt=False
+ ),
+ val=dict(
+ type="DOTADataset",
+ dataset_dir='/home/zytx121/mmrotate/data/processed_DOTA/trainval_1024_200_1.0',
+ transforms=[
+ dict(
+ type="RotatedResize",
+ min_size=1024,
+ max_size=1024
+ ),
+ dict(
+ type = "Pad",
+ size_divisor=32),
+ dict(
+ type = "Normalize",
+ mean = [123.675, 116.28, 103.53],
+ std = [58.395, 57.12, 57.375],
+ to_bgr=False),
+ ],
+ batch_size=2,
+ num_workers=4,
+ shuffle=False
+ ),
+ test=dict(
+ type="ImageDataset",
+ images_dir='/home/zytx121/mmrotate/data/processed_DOTA/test_1024_200_1.0/images',
+ transforms=[
+ dict(
+ type="RotatedResize",
+ min_size=1024,
+ max_size=1024
+ ),
+ dict(
+ type = "Pad",
+ size_divisor=32),
+ dict(
+ type = "Normalize",
+ mean = [123.675, 116.28, 103.53],
+ std = [58.395, 57.12, 57.375],
+ to_bgr=False,),
+ ],
+ num_workers=4,
+ batch_size=1,
+ )
+)
+
+optimizer = dict(
+ type='SGD',
+ # lr=0.01/4., #0.0,#0.01*(1/8.),
+ lr=0.008,
+ momentum=0.9,
+ weight_decay=0.0001,
+ grad_clip=dict(
+ max_norm=35,
+ norm_type=2))
+
+scheduler = dict(
+ type='StepLR',
+ warmup='linear',
+ warmup_iters=500,
+ warmup_ratio=1.0 / 3,
+ milestones=[7, 10])
+
+
+logger = dict(
+ type="RunLogger")
+
+max_epoch = 12
+eval_interval = 1
+checkpoint_interval = 1
+log_interval = 50
diff --git a/projects/oriented_reppoints/README.md b/projects/oriented_reppoints/README.md
new file mode 100644
index 0000000..22e4b01
--- /dev/null
+++ b/projects/oriented_reppoints/README.md
@@ -0,0 +1,45 @@
+## Oriented RepPoints
+> [Oriented RepPoints for Aerial Object Detection](https://openaccess.thecvf.com/content/CVPR2022/papers/Li_Oriented_RepPoints_for_Aerial_Object_Detection_CVPR_2022_paper.pdf)
+
+
+### Abstract
+
+
+

+
+
+In contrast to the generic object, aerial targets are often non-axis aligned with arbitrary orientations having
+the cluttered surroundings. Unlike the mainstreamed approaches regressing the bounding box orientations, this paper
+proposes an effective adaptive points learning approach to aerial object detection by taking advantage of the adaptive
+points representation, which is able to capture the geometric information of the arbitrary-oriented instances.
+To this end, three oriented conversion functions are presented to facilitate the classification and localization
+with accurate orientation. Moreover, we propose an effective quality assessment and sample assignment scheme for
+adaptive points learning toward choosing the representative oriented reppoints samples during training, which is
+able to capture the non-axis aligned features from adjacent objects or background noises. A spatial constraint is
+introduced to penalize the outlier points for roust adaptive learning. Experimental results on four challenging
+aerial datasets including DOTA, HRSC2016, UCAS-AOD and DIOR-R, demonstrate the efficacy of our proposed approach.
+
+### Training
+```sh
+python run_net.py --config-file=configs/oriented_reppoints_r50_fpn_1x_dota.py --task=train
+```
+
+### Testing
+```sh
+python run_net.py --config-file=configs/oriented_reppoints_r50_fpn_1x_dota.py --task=test
+```
+
+### Performance
+| Models | Dataset| Sub_Image_Size/Overlap |Train Aug | Test Aug | Optim | Lr schd | mAP | Paper | Config | Download |
+|:-----------:| :-----: |:-----:|:-----:| :-----: | :-----:| :-----:| :----: |:--------:|:--------------------------------------------------------------:| :--------: |
+| Oriented-Reppoints-R50-FPN | DOTA1.0|1024/200| flip|-| SGD | 1x | 66.99 | [paper](https://openaccess.thecvf.com/content/CVPR2022/papers/Li_Oriented_RepPoints_for_Aerial_Object_Detection_CVPR_2022_paper.pdf)| [config](configs/oriented_reppoints_r50_fpn_1x_dota.py) | [model]() |
+
+### Citation
+```
+@inproceedings{li2022ori,
+ title={Oriented RepPoints for Aerial Object Detection},
+ author={Wentong Li, Yijie Chen, Kaixuan Hu, Jianke Zhu},
+ booktitle={Proceedings of IEEE/CVF Conference on Computer Vision and Pattern Recognition (CVPR)},
+ year={2022}
+}
+```
\ No newline at end of file
diff --git a/projects/oriented_reppoints/configs/oriented_reppoints_r50_fpn_1x_dota.py b/projects/oriented_reppoints/configs/oriented_reppoints_r50_fpn_1x_dota.py
new file mode 100644
index 0000000..8a620c1
--- /dev/null
+++ b/projects/oriented_reppoints/configs/oriented_reppoints_r50_fpn_1x_dota.py
@@ -0,0 +1,165 @@
+# model settings
+model = dict(
+ type='RotatedRetinaNet',
+ backbone=dict(
+ type='Resnet50',
+ frozen_stages=1,
+ return_stages=["layer1","layer2","layer3","layer4"],
+ pretrained= True),
+ neck=dict(
+ type='FPN',
+ in_channels=[256, 512, 1024, 2048],
+ out_channels=256,
+ start_level=1,
+ add_extra_convs="on_input",
+ norm_cfg = dict(type='GN', num_groups=32),
+ num_outs=5),
+ bbox_head=dict(
+ type='OrientedRepPointsHead',
+ num_classes=15,
+ in_channels=256,
+ feat_channels=256,
+ point_feat_channels=256,
+ stacked_convs=3,
+ num_points=9,
+ gradient_mul=0.3,
+ point_strides=[8, 16, 32, 64, 128],
+ point_base_scale=2,
+ norm_cfg = dict(type='GN', num_groups=32),
+ loss_cls=dict(
+ type='FocalLoss',
+ use_sigmoid=True,
+ gamma=2.0,
+ alpha=0.25,
+ loss_weight=1.0),
+ loss_bbox_init=dict(type='ConvexGIoULoss', loss_weight=0.375),
+ loss_bbox_refine=dict(type='ConvexGIoULoss', loss_weight=1.0),
+ loss_spatial_init=dict(type='SpatialBorderLoss', loss_weight=0.05),
+ loss_spatial_refine=dict(type='SpatialBorderLoss', loss_weight=0.1),
+ transform_method='rotrect',
+ use_reassign=False,
+ init_qua_weight=0.2,
+ top_ratio=0.4,
+ # topk=6,
+ # anti_factor=0.75,
+ train_cfg=dict(
+ init=dict(
+ assigner=dict(type='ConvexAssigner', scale=4, pos_num=1, assigned_labels_filled=-1),
+ allowed_border=-1,
+ pos_weight=-1,
+ debug=False),
+ refine=dict(
+ assigner=dict(
+ type='MaxConvexIoUAssigner',
+ pos_iou_thr=0.1,
+ neg_iou_thr=0.1,
+ min_pos_iou=0,
+ ignore_iof_thr=-1,
+ assigned_labels_filled=-1,
+ iou_calculator=dict(type='ConvexOverlaps')),
+ allowed_border=-1,
+ pos_weight=-1,
+ debug=False)),
+ test_cfg=dict(
+ nms_pre=2000,
+ min_bbox_size=0,
+ score_thr=0.05,
+ nms=dict(iou_thr=0.4),
+ max_per_img=2000))
+ )
+dataset = dict(
+ train=dict(
+ type="DOTADataset",
+ dataset_dir='/home/cxjyxx_me/workspace/JAD/datasets/processed_DOTA/trainval_1024_200_1.0',
+ transforms=[
+ dict(
+ type="RotatedResize",
+ min_size=1024,
+ max_size=1024
+ ),
+ dict(type='RotatedRandomFlip', prob=0.5, direction="horizontal"),
+ dict(type='RotatedRandomFlip', prob=0.5, direction="vertical"),
+ dict(
+ type = "Pad",
+ size_divisor=32),
+ dict(
+ type = "Normalize",
+ mean = [123.675, 116.28, 103.53],
+ std = [58.395, 57.12, 57.375],
+ to_bgr=False,)
+
+ ],
+ batch_size=2,
+ num_workers=4,
+ shuffle=True,
+ filter_empty_gt=False
+ ),
+ val=dict(
+ type="DOTADataset",
+ dataset_dir='/home/cxjyxx_me/workspace/JAD/datasets/processed_DOTA/trainval_1024_200_1.0',
+ transforms=[
+ dict(
+ type="RotatedResize",
+ min_size=1024,
+ max_size=1024
+ ),
+ dict(
+ type = "Pad",
+ size_divisor=32),
+ dict(
+ type = "Normalize",
+ mean = [123.675, 116.28, 103.53],
+ std = [58.395, 57.12, 57.375],
+ to_bgr=False),
+ ],
+ batch_size=2,
+ num_workers=4,
+ shuffle=False
+ ),
+ test=dict(
+ type="ImageDataset",
+ images_dir='/home/cxjyxx_me/workspace/JAD/datasets/processed_DOTA/test_1024_200_1.0/images',
+ transforms=[
+ dict(
+ type="RotatedResize",
+ min_size=1024,
+ max_size=1024
+ ),
+ dict(
+ type = "Pad",
+ size_divisor=32),
+ dict(
+ type = "Normalize",
+ mean = [123.675, 116.28, 103.53],
+ std = [58.395, 57.12, 57.375],
+ to_bgr=False,),
+ ],
+ num_workers=4,
+ batch_size=1,
+ )
+)
+
+optimizer = dict(
+ type='SGD',
+ lr=0.008,
+ momentum=0.9,
+ weight_decay=0.0001,
+ grad_clip=dict(
+ max_norm=35,
+ norm_type=2))
+
+scheduler = dict(
+ type='StepLR',
+ warmup='linear',
+ warmup_iters=500,
+ warmup_ratio=1.0 / 3,
+ milestones=[7, 10])
+
+
+logger = dict(
+ type="RunLogger")
+
+max_epoch = 12
+eval_interval = 1
+checkpoint_interval = 1
+log_interval = 50
diff --git a/projects/oriented_reppoints/configs/oriented_reppoints_test.py b/projects/oriented_reppoints/configs/oriented_reppoints_test.py
new file mode 100644
index 0000000..b8094e8
--- /dev/null
+++ b/projects/oriented_reppoints/configs/oriented_reppoints_test.py
@@ -0,0 +1,162 @@
+# model settings
+model = dict(
+ type='RotatedRetinaNet',
+ backbone=dict(
+ type='Resnet50',
+ frozen_stages=1,
+ return_stages=["layer1","layer2","layer3","layer4"],
+ pretrained= True),
+ neck=dict(
+ type='FPN',
+ in_channels=[256, 512, 1024, 2048],
+ out_channels=256,
+ start_level=1,
+ add_extra_convs="on_input",
+ norm_cfg = dict(type='GN', num_groups=32),
+ num_outs=5),
+ bbox_head=dict(
+ type='RotatedRepPointsHead',
+ num_classes=15,
+ in_channels=256,
+ feat_channels=256,
+ point_feat_channels=256,
+ stacked_convs=3,
+ num_points=9,
+ gradient_mul=0.3,
+ point_strides=[8, 16, 32, 64, 128],
+ point_base_scale=2,
+ norm_cfg = dict(type='GN', num_groups=32),
+ loss_cls=dict(
+ type='FocalLoss',
+ use_sigmoid=True,
+ gamma=2.0,
+ alpha=0.25,
+ loss_weight=1.0),
+ loss_bbox_init=dict(type='ConvexGIoULoss', loss_weight=0.375),
+ loss_bbox_refine=dict(type='ConvexGIoULoss', loss_weight=1.0),
+ transform_method='rotrect',
+ use_reassign=False,
+ topk=6,
+ anti_factor=0.75,
+ train_cfg=dict(
+ init=dict(
+ assigner=dict(type='ConvexAssigner', scale=4, pos_num=1, assigned_labels_filled=-1),
+ allowed_border=-1,
+ pos_weight=-1,
+ debug=False),
+ refine=dict(
+ assigner=dict(
+ type='MaxConvexIoUAssigner',
+ pos_iou_thr=0.4,
+ neg_iou_thr=0.3,
+ min_pos_iou=0,
+ ignore_iof_thr=-1,
+ assigned_labels_filled=-1,
+ iou_calculator=dict(type='ConvexOverlaps')),
+ allowed_border=-1,
+ pos_weight=-1,
+ debug=False)),
+ test_cfg=dict(
+ nms_pre=2000,
+ min_bbox_size=0,
+ score_thr=0.05,
+ nms=dict(iou_thr=0.4),
+ max_per_img=2000))
+ )
+dataset = dict(
+ train=dict(
+ type="DOTADataset",
+ dataset_dir='/home/cxjyxx_me/workspace/JAD/datasets/processed_DOTA/trainval_1024_200_1.0',
+ transforms=[
+ dict(
+ type="RotatedResize",
+ min_size=1024,
+ max_size=1024
+ ),
+ # dict(type='RotatedRandomFlip', prob=0.5, direction="horizontal"),
+ # dict(type='RotatedRandomFlip', prob=0.5, direction="vertical"),
+ dict(
+ type = "Pad",
+ size_divisor=32),
+ dict(
+ type = "Normalize",
+ mean = [123.675, 116.28, 103.53],
+ std = [58.395, 57.12, 57.375],
+ to_bgr=False,)
+
+ ],
+ batch_size=2,
+ num_workers=4,
+ shuffle=True,
+ filter_empty_gt=False
+ ),
+ val=dict(
+ type="DOTADataset",
+ dataset_dir='/home/cxjyxx_me/workspace/JAD/datasets/processed_DOTA/trainval_1024_200_1.0',
+ transforms=[
+ dict(
+ type="RotatedResize",
+ min_size=1024,
+ max_size=1024
+ ),
+ dict(
+ type = "Pad",
+ size_divisor=32),
+ dict(
+ type = "Normalize",
+ mean = [123.675, 116.28, 103.53],
+ std = [58.395, 57.12, 57.375],
+ to_bgr=False),
+ ],
+ batch_size=2,
+ num_workers=4,
+ shuffle=False
+ ),
+ test=dict(
+ type="ImageDataset",
+ images_dir='/home/cxjyxx_me/workspace/JAD/datasets/processed_DOTA/test_1024_200_1.0/images',
+ transforms=[
+ dict(
+ type="RotatedResize",
+ min_size=1024,
+ max_size=1024
+ ),
+ dict(
+ type = "Pad",
+ size_divisor=32),
+ dict(
+ type = "Normalize",
+ mean = [123.675, 116.28, 103.53],
+ std = [58.395, 57.12, 57.375],
+ to_bgr=False,),
+ ],
+ num_workers=4,
+ batch_size=1,
+ )
+)
+
+optimizer = dict(
+ type='SGD',
+ # lr=0.01/4., #0.0,#0.01*(1/8.),
+ lr=0.008,
+ momentum=0.9,
+ weight_decay=0.0001,
+ grad_clip=dict(
+ max_norm=35,
+ norm_type=2))
+
+scheduler = dict(
+ type='StepLR',
+ warmup='linear',
+ warmup_iters=500,
+ warmup_ratio=1.0 / 3,
+ milestones=[7, 10])
+
+
+logger = dict(
+ type="RunLogger")
+
+max_epoch = 12
+eval_interval = 1
+checkpoint_interval = 1
+log_interval = 50
diff --git a/projects/oriented_reppoints/run_net.py b/projects/oriented_reppoints/run_net.py
new file mode 100644
index 0000000..25f1d65
--- /dev/null
+++ b/projects/oriented_reppoints/run_net.py
@@ -0,0 +1,56 @@
+import argparse
+import jittor as jt
+from jdet.runner import Runner
+from jdet.config import init_cfg
+
+
+def main():
+ parser = argparse.ArgumentParser(description="Jittor Object Detection Training")
+ parser.add_argument(
+ "--config-file",
+ default="",
+ metavar="FILE",
+ help="path to config file",
+ type=str,
+ )
+ parser.add_argument(
+ "--task",
+ default="train",
+ help="train,val,test",
+ type=str,
+ )
+
+ parser.add_argument(
+ "--no_cuda",
+ action='store_true'
+ )
+
+ parser.add_argument(
+ "--save_dir",
+ default=".",
+ type=str,
+ )
+
+ args = parser.parse_args()
+
+ if not args.no_cuda:
+ jt.flags.use_cuda=1
+
+ assert args.task in ["train","val","test","vis_test"],f"{args.task} not support, please choose [train,val,test,vis_test]"
+
+ if args.config_file:
+ init_cfg(args.config_file)
+
+ runner = Runner()
+
+ if args.task == "train":
+ runner.run()
+ elif args.task == "val":
+ runner.val()
+ elif args.task == "test":
+ runner.test()
+ elif args.task == "vis_test":
+ runner.run_on_images(args.save_dir)
+
+if __name__ == "__main__":
+ main()
\ No newline at end of file
diff --git a/projects/oriented_reppoints/test_oriented_reppoints.py b/projects/oriented_reppoints/test_oriented_reppoints.py
new file mode 100644
index 0000000..b2b34b8
--- /dev/null
+++ b/projects/oriented_reppoints/test_oriented_reppoints.py
@@ -0,0 +1,72 @@
+import jittor as jt
+from jdet.config import init_cfg, get_cfg
+from jdet.utils.general import parse_losses
+from jdet.utils.registry import build_from_cfg,MODELS,DATASETS,OPTIMS
+import argparse
+import os
+import pickle as pk
+import jdet
+
+def main():
+ parser = argparse.ArgumentParser(description="Jittor Object Detection Training")
+ parser.add_argument(
+ "--set_data",
+ action='store_true'
+ )
+ args = parser.parse_args()
+
+ jt.flags.use_cuda=1
+ jt.set_global_seed(666)
+ init_cfg("projects/oriented_reppoints/configs/oriented_reppoints_test.py")
+ cfg = get_cfg()
+
+ model = build_from_cfg(cfg.model,MODELS)
+ optimizer = build_from_cfg(cfg.optimizer,OPTIMS,params= model.parameters())
+
+ model.train()
+ if (args.set_data):
+ imagess = []
+ targetss = []
+ correct_loss = []
+ train_dataset = build_from_cfg(cfg.dataset.train,DATASETS)
+ for batch_idx,(images,targets) in enumerate(train_dataset):
+ if (batch_idx > 10):
+ break
+ print(batch_idx)
+ imagess.append(jdet.utils.general.sync(images))
+ targetss.append(jdet.utils.general.sync(targets))
+ losses = model(images,targets)
+ all_loss,losses = parse_losses(losses)
+ optimizer.step(all_loss)
+ correct_loss.append(all_loss.item())
+ data = {
+ "imagess": imagess,
+ "targetss": targetss,
+ "correct_loss": correct_loss,
+ }
+ if (not os.path.exists("test_datas_oriented_reppoints")):
+ os.makedirs("test_datas_oriented_reppoints")
+ pk.dump(data, open("test_datas_oriented_reppoints/test_data.pk", "wb"))
+ print(correct_loss)
+ else:
+ data = pk.load(open("test_datas_oriented_reppoints/test_data.pk", "rb"))
+ imagess = jdet.utils.general.to_jt_var(data["imagess"])
+ targetss = jdet.utils.general.to_jt_var(data["targetss"])
+ correct_loss = data["correct_loss"]
+
+ for batch_idx in range(len(imagess)):
+ images = imagess[batch_idx]
+ targets = targetss[batch_idx]
+ losses = model(images,targets)
+ all_loss,losses = parse_losses(losses)
+ optimizer.step(all_loss)
+ l = all_loss.item()
+ c_l = correct_loss[batch_idx]
+ err_rate = abs(c_l-l)/min(c_l,l)
+ print(f"correct loss is {c_l:.4f}, runtime loss is {l:.4f}, err rate is {err_rate*100:.2f}%")
+ assert err_rate<1e-3,"LOSS is not correct, please check it"
+ print(f"Loss is correct with err_rate<{1e-3}")
+ print("success!")
+
+if __name__ == "__main__":
+ main()
\ No newline at end of file
diff --git a/python/jdet/models/boxes/assigner.py b/python/jdet/models/boxes/assigner.py
index 253aeac..c777534 100644
--- a/python/jdet/models/boxes/assigner.py
+++ b/python/jdet/models/boxes/assigner.py
@@ -1,6 +1,7 @@
import jittor as jt
from jdet.utils.registry import BOXES,build_from_cfg
from jdet.models.boxes.box_ops import points_in_rotated_boxes
+from jdet.ops.reppoints_convex_iou import reppoints_convex_iou
import numpy as np
def deleteme(a, b, size = 10):
@@ -414,10 +415,10 @@ def get_horizontal_bboxes(self, gt_rbboxes):
"""get_horizontal_bboxes from polygons.
Args:
- gt_rbboxes (torch.Tensor): Groundtruth polygons, shape (k, 8).
+ gt_rbboxes (jt.Tensor): Groundtruth polygons, shape (k, 8).
Returns:
- gt_rect_bboxes (torch.Tensor): The horizontal bboxes, shape (k, 4).
+ gt_rect_bboxes (jt.Tensor): The horizontal bboxes, shape (k, 4).
"""
gt_xs, gt_ys = gt_rbboxes[:, 0::2], gt_rbboxes[:, 1::2]
gt_xmin = gt_xs.min(1)
@@ -451,8 +452,8 @@ def assign(self,
6. limit the positive sample's center in gt
Args:
- points (torch.Tensor): Points to be assigned, shape(n, 18).
- gt_rbboxes (torch.Tensor): Groundtruth polygons, shape (k, 8).
+ points (jt.Tensor): Points to be assigned, shape(n, 18).
+ gt_rbboxes (jt.Tensor): Groundtruth polygons, shape (k, 8).
gt_rbboxes_ignore (Tensor, optional): Ground truth polygons that
are labelled as `ignored`, e.g., crowd boxes in COCO.
gt_labels (Tensor, optional): Label of gt_bboxes, shape (k, ).
@@ -465,15 +466,17 @@ def assign(self,
if num_gts == 0 or num_points == 0:
# If no truth assign everything to the background
- assigned_gt_inds = points.new_full((num_points, ),
- 0,
- dtype=jt.int32)
+ assigned_gt_inds = jt.full((num_points, ), 0, dtype=jt.int32)
+ # assigned_gt_inds = points.new_full((num_points, ),
+ # 0,
+ # dtype=jt.int32)
if gt_labels is None:
assigned_labels = None
else:
- assigned_labels = points.new_full((num_points, ),
- -1,
- dtype=jt.int32)
+ assigned_labels = jt.full((num_bboxes, ), -1, dtype=jt.int32)
+ # assigned_labels = points.new_full((num_points, ),
+ # -1,
+ # dtype=jt.int32)
return AssignResult(
num_gts, assigned_gt_inds, None, labels=assigned_labels)
@@ -588,9 +591,9 @@ def assign(self, bboxes, gt_bboxes, gt_bboxes_ignore=None, gt_labels=None):
6. limit the positive sample's center in gt
Args:
- points (torch.Tensor): Points to be assigned, shape(n, 18).
- gt_rbboxes (torch.Tensor): Groundtruth polygons, shape (k, 8).
- overlaps (torch.Tensor): Overlaps between k gt_bboxes and n bboxes,
+ points (jt.Tensor): Points to be assigned, shape(n, 18).
+ gt_rbboxes (jt.Tensor): Groundtruth polygons, shape (k, 8).
+ overlaps (jt.Tensor): Overlaps between k gt_bboxes and n bboxes,
shape(k, n).
gt_rbboxes_ignore (Tensor, optional): Ground truth polygons that
are labelled as `ignored`, e.g., crowd boxes in COCO.
@@ -599,6 +602,8 @@ def assign(self, bboxes, gt_bboxes, gt_bboxes_ignore=None, gt_labels=None):
Returns:
:obj:`AssignResult`: The assign result.
"""
+
+
if bboxes.shape[0] == 0 or gt_bboxes.shape[0] == 0:
raise ValueError('No gt or bboxes')
@@ -609,3 +614,217 @@ def assign(self, bboxes, gt_bboxes, gt_bboxes_ignore=None, gt_labels=None):
assert NotImplementedError
assign_result = self.assign_wrt_overlaps(overlaps, gt_labels)
return assign_result
+
+@BOXES.register_module()
+class SASAssigner:
+ """Assign a corresponding gt bbox or background to each bbox. Each
+ proposals will be assigned with `0` or a positive integer indicating the
+ ground truth index.
+
+ - 0: negative sample, no assigned gt
+ - positive integer: positive sample, index (1-based) of assigned gt
+
+ Args:
+ scale (float): IoU threshold for positive bboxes.
+ pos_num (float): find the nearest pos_num points to gt center in this
+ level.
+ """
+
+ def __init__(self, topk):
+ self.topk = topk
+
+ def assign(self,
+ bboxes,
+ num_level_bboxes,
+ gt_bboxes,
+ gt_bboxes_ignore=None,
+ gt_labels=None):
+ """Assign gt to bboxes.
+
+ The assignment is done in following steps
+
+ 1. compute iou between all bbox (bbox of all pyramid levels) and gt
+ 2. compute center distance between all bbox and gt
+ 3. on each pyramid level, for each gt, select k bbox whose center
+ are closest to the gt center, so we total select k*l bbox as
+ candidates for each gt
+ 4. get corresponding iou for the these candidates, and compute the
+ mean and std, set mean + std as the iou threshold
+ 5. select these candidates whose iou are greater than or equal to
+ the threshold as positive
+ 6. limit the positive sample's center in gt
+
+ Args:
+ bboxes (jt.Tensor): Bounding boxes to be assigned, shape(n, 4).
+ num_level_bboxes (List): num of bboxes in each level
+ gt_bboxes (jt.Tensor): Groundtruth boxes, shape (k, 4).
+ gt_bboxes_ignore (Tensor, optional): Ground truth bboxes that are
+ labelled as `ignored`, e.g., crowd boxes in COCO.
+ gt_labels (Tensor, optional): Label of gt_bboxes, shape (k, ).
+
+ Returns:
+ :obj:`AssignResult`: The assign result.
+ """
+ INF = 100000000
+
+ num_gt, num_bboxes = gt_bboxes.size(0), bboxes.size(0)
+ overlaps = convex_overlaps(gt_bboxes, bboxes)
+ # assigned_gt_inds = overlaps.new_full((num_bboxes, ),
+ # 0,
+ # dtype=jt.long)
+ assigned_gt_inds = jt.full((num_bboxes, ), 0, dtype=jt.int32)
+
+ if num_gt == 0 or num_bboxes == 0:
+ max_overlaps = overlaps.new_zeros((num_bboxes, ))
+ if num_gt == 0:
+ assigned_gt_inds[:] = 0
+ if gt_labels is None:
+ assigned_labels = None
+ else:
+ assigned_labels = jt.full((num_bboxes, ), -1, dtype=jt.int32)
+ # assigned_labels = overlaps.new_full((num_bboxes, ),
+ # -1,
+ # dtype=jt.long)
+ return AssignResult(
+ num_gt, assigned_gt_inds, max_overlaps, labels=assigned_labels)
+
+ # compute center distance between all bbox and gt
+ # the center of poly
+ gt_bboxes_hbb = get_horizontal_bboxes(gt_bboxes)
+
+ gt_cx = (gt_bboxes_hbb[:, 0] + gt_bboxes_hbb[:, 2]) / 2.0
+ gt_cy = (gt_bboxes_hbb[:, 1] + gt_bboxes_hbb[:, 3]) / 2.0
+ gt_points = jt.stack((gt_cx, gt_cy), dim=1)
+
+ bboxes = bboxes.reshape(-1, 9, 2)
+ pts_x = bboxes[:, :, 0::2]
+ pts_y = bboxes[:, :, 1::2]
+
+ pts_x_mean = pts_x.mean(dim=1).squeeze()
+ pts_y_mean = pts_y.mean(dim=1).squeeze()
+
+ bboxes_points = jt.stack((pts_x_mean, pts_y_mean), dim=1)
+
+ distances = (bboxes_points[:, None, :] -
+ gt_points[None, :, :]).pow(2).sum(-1).sqrt()
+
+ # Selecting candidates based on the center distance
+ candidate_idxs = []
+ start_idx = 0
+ for level, bboxes_per_level in enumerate(num_level_bboxes):
+ end_idx = start_idx + bboxes_per_level
+ distances_per_level = distances[start_idx:end_idx, :]
+ _, topk_idxs_per_level = distances_per_level.topk(
+ self.topk, dim=0, largest=False)
+ candidate_idxs.append(topk_idxs_per_level + start_idx)
+ start_idx = end_idx
+ candidate_idxs = jt.contrib.concat(candidate_idxs, dim=0)
+
+ gt_bboxes_ratios = AspectRatio(gt_bboxes)
+ gt_bboxes_ratios_per_gt = gt_bboxes_ratios.mean(0)
+ candidate_overlaps = overlaps[candidate_idxs, jt.arange(num_gt)]
+ overlaps_mean_per_gt = candidate_overlaps.mean(0)
+ overlaps_std_per_gt = candidate_overlaps.std()
+ overlaps_thr_per_gt = overlaps_mean_per_gt + overlaps_std_per_gt
+
+ # new assign
+ iou_thr_weight = jt.exp((-1 / 4) * gt_bboxes_ratios_per_gt)
+ overlaps_thr_per_gt = overlaps_thr_per_gt * iou_thr_weight
+ is_pos = candidate_overlaps >= overlaps_thr_per_gt[None, :]
+
+ # limit the positive sample's center in gt
+ # inside_flag = jt.full([num_bboxes, num_gt],
+ # 0.).to(gt_bboxes.device).float()
+ # inside_flag = points_in_polygons(bboxes_points, gt_bboxes)
+ inside_flag = points_in_rotated_boxes(bboxes_points, gt_bboxes)
+ # pointsJf(bboxes_points, gt_bboxes, inside_flag)
+ is_in_gts = inside_flag[candidate_idxs,
+ jt.arange(num_gt)].to(is_pos.dtype)
+
+ is_pos = is_pos & is_in_gts
+ for gt_idx in range(num_gt):
+ candidate_idxs[:, gt_idx] += gt_idx * num_bboxes
+ candidate_idxs = candidate_idxs.view(-1)
+
+ # if an anchor box is assigned to multiple gts,
+ # the one with the highest IoU will be selected.
+ overlaps_inf = jt.full_like(overlaps,
+ -INF).t().contiguous().view(-1)
+ index = candidate_idxs.view(-1)[is_pos.view(-1)]
+
+ overlaps_inf[index] = overlaps.t().contiguous().view(-1)[index]
+ overlaps_inf = overlaps_inf.view(num_gt, -1).t()
+
+ argmax_overlaps, max_overlaps = overlaps_inf.argmax(dim=1)
+ assigned_gt_inds[
+ max_overlaps != -INF] = argmax_overlaps[max_overlaps != -INF] + 1
+ if gt_labels is not None and bool(assigned_gt_inds.sum()):
+ # assigned_labels = assigned_gt_inds.new_full((num_bboxes, ), -1)
+ assigned_labels = jt.full((num_bboxes, ), -1, dtype=jt.int32)
+ pos_inds = jt.nonzero(
+ assigned_gt_inds > 0).squeeze()
+ if pos_inds.numel() > 0:
+ assigned_labels[pos_inds] = gt_labels[
+ assigned_gt_inds[pos_inds] - 1]
+ else:
+ assigned_labels = None
+ return AssignResult(
+ num_gt, assigned_gt_inds, max_overlaps, labels=assigned_labels)
+
+def AspectRatio(gt_rbboxes):
+ """Compute the aspect ratio of all gts.
+
+ Args:
+ gt_rbboxes (jt.Tensor): Groundtruth polygons, shape (k, 8).
+
+ Returns:
+ ratios (jt.Tensor): The aspect ratio of gt_rbboxes, shape (k, 1).
+ """
+ pt1, pt2, pt3, pt4 = gt_rbboxes[..., :8].chunk(4, 1)
+ edge1 = jt.sqrt(
+ jt.pow(pt1[..., 0] - pt2[..., 0], 2) +
+ jt.pow(pt1[..., 1] - pt2[..., 1], 2))
+ edge2 = jt.sqrt(
+ jt.pow(pt2[..., 0] - pt3[..., 0], 2) +
+ jt.pow(pt2[..., 1] - pt3[..., 1], 2))
+ edges = jt.stack([edge1, edge2], dim=1)
+ width = jt.max(edges, 1)
+ height = jt.min(edges, 1)
+ ratios = (width / height)
+ return ratios
+
+def get_horizontal_bboxes(gt_rbboxes):
+ """Get horizontal bboxes from polygons.
+
+ Args:
+ gt_rbboxes (jt.Tensor): Groundtruth polygons, shape (k, 8).
+
+ Returns:
+ gt_rect_bboxes (jt.Tensor): The horizontal bboxes, shape (k, 4).
+ """
+ gt_xs, gt_ys = gt_rbboxes[:, 0::2], gt_rbboxes[:, 1::2]
+ gt_xmin = gt_xs.min(1)
+ gt_ymin = gt_ys.min(1)
+ gt_xmax = gt_xs.max(1)
+ gt_ymax = gt_ys.max(1)
+ gt_rect_bboxes = jt.contrib.concat([
+ gt_xmin[:, None], gt_ymin[:, None], gt_xmax[:, None], gt_ymax[:, None]
+ ],
+ dim=1)
+ return gt_rect_bboxes
+
+def convex_overlaps(gt_rbboxes, points):
+ """Compute overlaps between polygons and points.
+
+ Args:
+ gt_rbboxes (jt.Tensor): Groundtruth polygons, shape (k, 8).
+ points (jt.Tensor): Points to be assigned, shape(n, 18).
+
+ Returns:
+ overlaps (jt.Tensor): Overlaps between k gt_bboxes and n bboxes,
+ shape(k, n).
+ """
+ if gt_rbboxes.shape[0] == 0:
+ return gt_rbboxes.new_zeros((0, points.shape[0]))
+ overlaps = reppoints_convex_iou(points, gt_rbboxes)
+ return overlaps
diff --git a/python/jdet/models/losses/__init__.py b/python/jdet/models/losses/__init__.py
index fb153d1..b0d3fe6 100644
--- a/python/jdet/models/losses/__init__.py
+++ b/python/jdet/models/losses/__init__.py
@@ -13,4 +13,5 @@
from .kd_loss import IMLoss
from .rsdet_loss import RSDetLoss
from .ridet_loss import RIDetLoss
-from .convex_giou_loss import ConvexGIoULoss
+from .convex_giou_loss import ConvexGIoULoss, BCConvexGIoULoss
+from .spatial_border_loss import SpatialBorderLoss
\ No newline at end of file
diff --git a/python/jdet/models/losses/convex_giou_loss.py b/python/jdet/models/losses/convex_giou_loss.py
index 35218b5..0d9dcc6 100644
--- a/python/jdet/models/losses/convex_giou_loss.py
+++ b/python/jdet/models/losses/convex_giou_loss.py
@@ -18,9 +18,9 @@ def execute(self,
Args:
ctx: {save_for_backward, convex_points_grad}
- pred (torch.Tensor): Predicted convexes.
- target (torch.Tensor): Corresponding gt convexes.
- weight (torch.Tensor, optional): The weight of loss for each
+ pred (jt.Tensor): Predicted convexes.
+ target (jt.Tensor): Corresponding gt convexes.
+ weight (jt.Tensor, optional): The weight of loss for each
prediction. Defaults to None.
reduction (str, optional): The reduction method of the
loss. Defaults to None.
@@ -72,7 +72,7 @@ class ConvexGIoULoss(nn.Module):
loss_weight (float, optional): The weight of loss. Defaults to 1.0.
Return:
- torch.Tensor: Loss tensor.
+ jt.Tensor: Loss tensor.
"""
def __init__(self, reduction='mean', loss_weight=1.0):
@@ -90,9 +90,9 @@ def execute(self,
"""Forward function.
Args:
- pred (torch.Tensor): Predicted convexes.
- target (torch.Tensor): Corresponding gt convexes.
- weight (torch.Tensor, optional): The weight of loss for each
+ pred (jt.Tensor): Predicted convexes.
+ target (jt.Tensor): Corresponding gt convexes.
+ weight (jt.Tensor, optional): The weight of loss for each
prediction. Defaults to None.
avg_factor (int, optional): Average factor that is used to average
the loss. Defaults to None.
@@ -109,3 +109,220 @@ def execute(self,
pred, target, weight, reduction, avg_factor, self.loss_weight)
return loss
+
+@LOSSES.register_module()
+class BCConvexGIoULossFuction(jt.Function):
+ """The function of BCConvex GIoU loss."""
+
+
+ def execute(self,
+ pred,
+ target,
+ weight=None,
+ reduction=None,
+ avg_factor=None,
+ loss_weight=1.0):
+ """Forward function.
+
+ Args:
+ ctx: {save_for_backward, convex_points_grad}
+ pred (jt.Tensor): Predicted convexes.
+ target (jt.Tensor): Corresponding gt convexes.
+ weight (jt.Tensor, optional): The weight of loss for each
+ prediction. Defaults to None.
+ reduction (str, optional): The reduction method of the
+ loss. Defaults to None.
+ avg_factor (int, optional): Average factor that is used to average
+ the loss. Defaults to None.
+ loss_weight (float, optional): The weight of loss. Defaults to 1.0.
+ """
+ convex_gious, grad = reppoints_convex_giou(pred, target)
+
+ pts_pred_all_dx = pred[:, 0::2]
+ pts_pred_all_dy = pred[:, 1::2]
+ pred_left_x_inds = pts_pred_all_dx.argmin(dim=1)[0].unsqueeze(0)
+ pred_right_x_inds = pts_pred_all_dx.argmax(dim=1)[0].unsqueeze(0)
+ pred_up_y_inds = pts_pred_all_dy.argmin(dim=1)[0].unsqueeze(0)
+ pred_bottom_y_inds = pts_pred_all_dy.argmax(dim=1)[0].unsqueeze(0)
+
+ pred_right_x = pts_pred_all_dx.gather(dim=1, index=pred_right_x_inds)
+ pred_right_y = pts_pred_all_dy.gather(dim=1, index=pred_right_x_inds)
+
+ pred_left_x = pts_pred_all_dx.gather(dim=1, index=pred_left_x_inds)
+ pred_left_y = pts_pred_all_dy.gather(dim=1, index=pred_left_x_inds)
+
+ pred_up_x = pts_pred_all_dx.gather(dim=1, index=pred_up_y_inds)
+ pred_up_y = pts_pred_all_dy.gather(dim=1, index=pred_up_y_inds)
+
+ pred_bottom_x = pts_pred_all_dx.gather(dim=1, index=pred_bottom_y_inds)
+ pred_bottom_y = pts_pred_all_dy.gather(dim=1, index=pred_bottom_y_inds)
+ pred_corners = jt.concat([
+ pred_left_x, pred_left_y, pred_up_x, pred_up_y, pred_right_x,
+ pred_right_y, pred_bottom_x, pred_bottom_y
+ ],
+ dim=-1)
+
+ pts_target_all_dx = target[:, 0::2]
+ pts_target_all_dy = target[:, 1::2]
+
+ target_left_x_inds = pts_target_all_dx.argmin(dim=1)[0].unsqueeze(0)
+ target_right_x_inds = pts_target_all_dx.argmax(dim=1)[0].unsqueeze(0)
+ target_up_y_inds = pts_target_all_dy.argmin(dim=1)[0].unsqueeze(0)
+ target_bottom_y_inds = pts_target_all_dy.argmax(dim=1)[0].unsqueeze(0)
+
+ target_right_x = pts_target_all_dx.gather(
+ dim=1, index=target_right_x_inds)
+ target_right_y = pts_target_all_dy.gather(
+ dim=1, index=target_right_x_inds)
+
+ target_left_x = pts_target_all_dx.gather(
+ dim=1, index=target_left_x_inds)
+ target_left_y = pts_target_all_dy.gather(
+ dim=1, index=target_left_x_inds)
+
+ target_up_x = pts_target_all_dx.gather(dim=1, index=target_up_y_inds)
+ target_up_y = pts_target_all_dy.gather(dim=1, index=target_up_y_inds)
+
+ target_bottom_x = pts_target_all_dx.gather(
+ dim=1, index=target_bottom_y_inds)
+ target_bottom_y = pts_target_all_dy.gather(
+ dim=1, index=target_bottom_y_inds)
+
+ target_corners = jt.concat([
+ target_left_x, target_left_y, target_up_x, target_up_y,
+ target_right_x, target_right_y, target_bottom_x, target_bottom_y
+ ],
+ dim=-1)
+
+ pts_pred_dx_mean = pts_pred_all_dx.mean(
+ dim=1).reshape(-1, 1)
+ pts_pred_dy_mean = pts_pred_all_dy.mean(
+ dim=1).reshape(-1, 1)
+ pts_pred_mean = jt.concat([pts_pred_dx_mean, pts_pred_dy_mean], dim=-1)
+
+ pts_target_dx_mean = pts_target_all_dx.mean(
+ dim=1).reshape(-1, 1)
+ pts_target_dy_mean = pts_target_all_dy.mean(
+ dim=1).reshape(-1, 1)
+ pts_target_mean = jt.concat([pts_target_dx_mean, pts_target_dy_mean],
+ dim=-1)
+
+ beta = 1.0
+
+ diff_mean = jt.abs(pts_pred_mean - pts_target_mean)
+ diff_mean_loss = jt.where(diff_mean < beta,
+ 0.5 * diff_mean * diff_mean / beta,
+ diff_mean - 0.5 * beta)
+ diff_mean_loss = diff_mean_loss.sum() / len(diff_mean_loss)
+
+ diff_corners = jt.abs(pred_corners - target_corners)
+ diff_corners_loss = jt.where(
+ diff_corners < beta, 0.5 * diff_corners * diff_corners / beta,
+ diff_corners - 0.5 * beta)
+ diff_corners_loss = diff_corners_loss.sum() / len(diff_corners_loss)
+
+ target_aspect = AspectRatio(target)
+ smooth_loss_weight = jt.exp((-1 / 4) * target_aspect)
+ loss = \
+ smooth_loss_weight * (diff_mean_loss.reshape(-1, 1) +
+ diff_corners_loss.reshape(-1, 1)) + \
+ 1 - (1 - 2 * smooth_loss_weight) * convex_gious
+
+ if weight is not None:
+ loss = loss * weight
+ grad = grad * weight.reshape(-1, 1)
+ if reduction == 'sum':
+ loss = loss.sum()
+ elif reduction == 'mean':
+ loss = loss.mean()
+
+ unvaild_inds = jt.nonzero((grad > 1).sum(1))[:, 0]
+ grad[unvaild_inds] = 1e-6
+
+ reduce_grad = -grad / grad.size(0) * loss_weight
+ self.convex_points_grad = reduce_grad
+ return loss
+
+ def grad(self, input=None):
+ """Backward function."""
+ convex_points_grad = self.convex_points_grad
+ return convex_points_grad, None, None, None, None, None
+
+
+bc_convex_giou_loss = BCConvexGIoULossFuction.apply
+
+
+@LOSSES.register_module()
+class BCConvexGIoULoss(nn.Module):
+ """BCConvex GIoU loss.
+
+ Computing the BCConvex GIoU loss between a set of predicted convexes and
+ target convexes.
+
+ Args:
+ reduction (str, optional): The reduction method of the loss. Defaults
+ to 'mean'.
+ loss_weight (float, optional): The weight of loss. Defaults to 1.0.
+
+ Return:
+ jt.Tensor: Loss tensor.
+ """
+
+ def __init__(self, reduction='mean', loss_weight=1.0):
+ super(BCConvexGIoULoss, self).__init__()
+ self.reduction = reduction
+ self.loss_weight = loss_weight
+
+ def execute(self,
+ pred,
+ target,
+ weight=None,
+ avg_factor=None,
+ reduction_override=None,
+ **kwargs):
+ """Forward function.
+
+ Args:
+ pred (jt.Tensor): Predicted convexes.
+ target (jt.Tensor): Corresponding gt convexes.
+ weight (jt.Tensor, optional): The weight of loss for each
+ prediction. Defaults to None.
+ avg_factor (int, optional): Average factor that is used to average
+ the loss. Defaults to None.
+ reduction_override (str, optional): The reduction method used to
+ override the original reduction method of the loss.
+ Defaults to None.
+ """
+ if weight is not None and not jt.any(weight > 0):
+ return (pred * weight.unsqueeze(-1)).sum()
+ assert reduction_override in (None, 'none', 'mean', 'sum')
+ reduction = (
+ reduction_override if reduction_override else self.reduction)
+ loss = self.loss_weight * bc_convex_giou_loss(
+ pred, target, weight, reduction, avg_factor, self.loss_weight)
+ return loss
+
+
+def AspectRatio(gt_rbboxes):
+ """Compute the aspect ratio of all gts.
+
+ Args:
+ gt_rbboxes (jt.Tensor): Groundtruth polygons, shape (k, 8).
+
+ Returns:
+ ratios (jt.Tensor): The aspect ratio of gt_rbboxes, shape (k, 1).
+ """
+ pt1, pt2, pt3, pt4 = gt_rbboxes[..., :8].chunk(4, 1)
+ edge1 = jt.sqrt(
+ jt.pow(pt1[..., 0] - pt2[..., 0], 2) +
+ jt.pow(pt1[..., 1] - pt2[..., 1], 2))
+ edge2 = jt.sqrt(
+ jt.pow(pt2[..., 0] - pt3[..., 0], 2) +
+ jt.pow(pt2[..., 1] - pt3[..., 1], 2))
+
+ edges = jt.stack([edge1, edge2], dim=1)
+
+ width= jt.max(edges, 1)
+ height = jt.min(edges, 1)
+ ratios = (width / height)
+ return ratios
diff --git a/python/jdet/models/losses/spatial_border_loss.py b/python/jdet/models/losses/spatial_border_loss.py
new file mode 100644
index 0000000..863120b
--- /dev/null
+++ b/python/jdet/models/losses/spatial_border_loss.py
@@ -0,0 +1,99 @@
+import jittor as jt
+from jittor import nn
+from jdet.utils.registry import LOSSES
+
+from jdet.models.boxes.box_ops import points_in_rotated_boxes
+from jdet.ops.bbox_transforms import poly2obb
+import numpy as np
+
+@LOSSES.register_module()
+class SpatialBorderLoss(nn.Module):
+ """Spatial Border loss for learning points in Oriented RepPoints.
+ Args:
+ pts (jt.Tensor): point sets with shape (N, 9*2).
+ Default points number in each point set is 9.
+ gt_bboxes (jt.Tensor): gt_bboxes with polygon form with shape(N, 8)
+ Returns:
+ loss (jt.Tensor)
+ """
+
+ def __init__(self, loss_weight=1.0):
+ super(SpatialBorderLoss, self).__init__()
+ self.loss_weight = loss_weight
+
+ def execute(self, pts, gt_bboxes, weight, *args, **kwargs):
+ loss = self.loss_weight * weighted_spatial_border_loss(
+ pts, gt_bboxes, weight, *args, **kwargs)
+ return loss
+
+def to_le135(boxes):
+ # swap edge and angle if h >= w
+ x, y, w, h, t = boxes.unbind(dim=-1)
+ start_angle = -45 / 180 * np.pi
+ w_ = jt.where(w > h, w, h)
+ h_ = jt.where(w > h, h, w)
+ t = jt.where(w > h, t, t + np.pi / 2)
+ t = ((t - start_angle) % np.pi) + start_angle
+ return jt.stack([x, y, w_, h_, t], dim=-1)
+
+def spatial_border_loss(pts, gt_bboxes):
+ """The loss is used to penalize the learning points out of the assigned
+ ground truth boxes (polygon by default).
+ Args:
+ pts (jt.Tensor): point sets with shape (N, 9*2).
+ gt_bboxes (jt.Tensor): gt_bboxes with polygon form with shape(N, 8)
+ Returns:
+ loss (jt.Tensor)
+ """
+
+ num_gts, num_pointsets = gt_bboxes.size(0), pts.size(0)
+ num_point = int(pts.size(1) / 2.0)
+ loss = jt.zeros([0])
+ gt_bboxes_ = to_le135(poly2obb(gt_bboxes))
+
+ if num_gts > 0:
+ inside_flag_list = []
+ for i in range(num_point):
+ pt = pts[:, (2 * i):(2 * i + 2)].reshape(num_pointsets,
+ 2).contiguous()
+ # inside_pt_flag = points_in_polygons(pt, gt_bboxes)
+ inside_pt_flag = points_in_rotated_boxes(pt, gt_bboxes_)
+ inside_pt_flag = jt.diag(inside_pt_flag)
+ inside_flag_list.append(inside_pt_flag)
+
+ inside_flag = jt.stack(inside_flag_list, dim=1)
+ pts = pts.reshape(-1, num_point, 2)
+ out_border_pts = pts[jt.where(inside_flag == 0)]
+
+ if out_border_pts.size(0) > 0:
+ corr_gt_boxes = gt_bboxes[jt.where(inside_flag == 0)[0]]
+ corr_gt_boxes_center_x = (corr_gt_boxes[:, 0] +
+ corr_gt_boxes[:, 4]) / 2.0
+ corr_gt_boxes_center_y = (corr_gt_boxes[:, 1] +
+ corr_gt_boxes[:, 5]) / 2.0
+ corr_gt_boxes_center = jt.stack(
+ [corr_gt_boxes_center_x, corr_gt_boxes_center_y], dim=1)
+ distance_out_pts = 0.2 * ((
+ (out_border_pts - corr_gt_boxes_center)**2).sum(dim=1).sqrt())
+ loss = distance_out_pts.sum() / out_border_pts.size(0)
+
+ return loss
+
+
+def weighted_spatial_border_loss(pts, gt_bboxes, weight, avg_factor=None):
+ """Weghted spatial border loss.
+ Args:
+ pts (jt.Tensor): point sets with shape (N, 9*2).
+ gt_bboxes (jt.Tensor): gt_bboxes with polygon form with shape(N, 8)
+ weight (jt.Tensor): weights for point sets with shape (N)
+ Returns:
+ loss (jt.Tensor)
+ """
+
+ weight = weight.unsqueeze(dim=1).repeat(1, 4)
+ assert len(weight.shape) == 2
+ if avg_factor is None:
+ avg_factor = jt.sum(weight > 0).float().item() / 4 + 1e-6
+ loss = spatial_border_loss(pts, gt_bboxes)
+
+ return jt.sum(loss)[None] / avg_factor
diff --git a/python/jdet/models/roi_heads/__init__.py b/python/jdet/models/roi_heads/__init__.py
index c82b747..19e79df 100644
--- a/python/jdet/models/roi_heads/__init__.py
+++ b/python/jdet/models/roi_heads/__init__.py
@@ -19,4 +19,6 @@
from . import rsdet_head
from . import rotated_atss_head
from . import rotated_reppoints_head
+from . import oriented_reppoints_head
+from . import sam_reppoints_head
__all__ = []
diff --git a/python/jdet/models/roi_heads/oriented_reppoints_head.py b/python/jdet/models/roi_heads/oriented_reppoints_head.py
new file mode 100644
index 0000000..628ba8f
--- /dev/null
+++ b/python/jdet/models/roi_heads/oriented_reppoints_head.py
@@ -0,0 +1,1527 @@
+from jdet.utils.registry import HEADS, LOSSES, BOXES, build_from_cfg
+from jdet.models.utils.modules import ConvModule
+from jdet.ops.dcn_v1 import DeformConv
+from jdet.ops.bbox_transforms import obb2poly, poly2obb
+from jdet.utils.general import multi_apply, unmap
+from jdet.ops.nms_rotated import multiclass_nms_rotated
+from jdet.models.boxes.anchor_target import images_to_levels
+from jdet.ops.reppoints_convex_iou import reppoints_convex_iou
+from jdet.ops.reppoints_min_area_bbox import reppoints_min_area_bbox
+from jdet.ops.chamfer_distance import chamfer_distance
+from jdet.models.boxes.box_ops import rotated_box_to_poly
+
+import math
+import jittor as jt
+from jittor import nn
+import numpy as np
+from jittor.nn import _pair, grid_sample
+
+import numpy as np
+def deleteme(a, b, size = 10):
+ if a is None and b is None:
+ return
+ if isinstance(a, dict) and isinstance(b, dict):
+ print('-' * size)
+ for a1, b1 in zip(a.values(), b.values()):
+ deleteme(a1, b1, size + 10)
+ print('-' * size)
+ elif isinstance(a, (list, tuple)) and isinstance(b, (list, tuple)):
+ print('-' * size)
+ for a1, b1 in zip(a, b):
+ deleteme(a1, b1, size + 10)
+ print('-' * size)
+ elif isinstance(a, jt.Var) and isinstance(b, np.ndarray):
+ print((a - b).abs().max().item())
+ elif isinstance(a, np.ndarray) and isinstance(b, np.ndarray):
+ print(np.max(np.abs(a - b)))
+ elif isinstance(a, (int, float)) and isinstance(b, (int, float)):
+ print("number diff:", a - b)
+ else:
+ print(type(a))
+ print(type(b))
+ raise NotImplementedError
+def transpose_to(a, b):
+ if a is None:
+ return None
+ if isinstance(a, list) and isinstance(b, list):
+ rlist = []
+ for a1, b1 in zip(a, b):
+ rlist.append(transpose_to(a1, b1))
+ return rlist
+ elif isinstance(a, dict) and isinstance(b, dict):
+ rdict = []
+ for k in b.keys():
+ rdict[k] = transpose_to(a[k], b[k])
+ return rdict
+ elif isinstance(a, np.ndarray) and isinstance(b, jt.Var):
+ return jt.array(a)
+ elif isinstance(a, np.ndarray) and isinstance(b, np.ndarray):
+ return a
+ elif isinstance(a, tuple) and isinstance(b, tuple):
+ rlist = [transpose_to(a1, b1) for a1, b1 in zip(a, b)]
+ return tuple(rlist)
+ elif isinstance(a, (int, float, str)) and isinstance(b, (int, float, str)):
+ assert(type(a) == type(b))
+ return a
+ else:
+ print(type(a))
+ print(type(b))
+ raise NotImplementedError
+
+def fake_argsort2(x, dim=0, descending=False):
+ x_ = x.data
+ if (descending):
+ x__ = -x_
+ else:
+ x__ = x_
+ index_ = np.argsort(x__, axis=dim, kind="stable")
+ y_ = x_[index_]
+ index = jt.array(index_)
+ y = jt.array(y_)
+ return index, y
+
+def ChamferDistance2D(point_set_1,
+ point_set_2,
+ distance_weight=0.05,
+ eps=1e-12):
+ """Compute the Chamfer distance between two point sets.
+
+ Args:
+ point_set_1 (jt.tensor): point set 1 with shape (N_pointsets,
+ N_points, 2)
+ point_set_2 (jt.tensor): point set 2 with shape (N_pointsets,
+ N_points, 2)
+
+ Returns:
+ dist (jt.tensor): chamfer distance between two point sets
+ with shape (N_pointsets,)
+ """
+ assert point_set_1.shape == point_set_2.shape
+ assert point_set_1.shape[-1] == point_set_2.shape[-1]
+ # assert point_set_1.dim() <= 3
+ dist1, dist2, _, _ = chamfer_distance(point_set_1, point_set_2)
+ dist1 = jt.sqrt(jt.clamp(dist1, eps))
+ dist2 = jt.sqrt(jt.clamp(dist2, eps))
+ dist = distance_weight * (dist1.mean(-1) + dist2.mean(-1)) / 2.0
+
+ return dist
+
+
+@HEADS.register_module()
+class OrientedRepPointsHead(nn.Module):
+ """Rotated RepPoints head.
+
+ Args:
+ num_classes (int): Number of classes.
+ in_channels (int): Number of input channels.
+ feat_channels (int): Number of feature channels.
+ point_feat_channels (int, optional): Number of channels of points
+ features.
+ stacked_convs (int, optional): Number of stacked convolutions.
+ num_points (int, optional): Number of points in points set.
+ gradient_mul (float, optional): The multiplier to gradients from
+ points refinement and recognition.
+ point_strides (Iterable, optional): points strides.
+ point_base_scale (int, optional): Bbox scale for assigning labels.
+ conv_bias (str, optional): The bias of convolution.
+ loss_cls (dict, optional): Config of classification loss.
+ loss_bbox_init (dict, optional): Config of initial points loss.
+ loss_bbox_refine (dict, optional): Config of points loss in refinement.
+ conv_cfg (dict, optional): The config of convolution.
+ norm_cfg (dict, optional): The config of normlization.
+ train_cfg (dict, optional): The config of train.
+ test_cfg (dict, optional): The config of test.
+ center_init (bool, optional): Whether to use center point assignment.
+ top_ratio (float, optional): Ratio of top high-quality point sets.
+ Defaults to 0.4.
+ init_qua_weight (float, optional): Quality weight of initial
+ stage.
+ ori_qua_weight (float, optional): Orientation quality weight.
+ poc_qua_weight (float, optional): Point-wise correlation
+ quality weight.
+ """
+
+ def __init__(self,
+ num_classes,
+ in_channels,
+ feat_channels,
+ point_feat_channels=256,
+ stacked_convs=3,
+ num_points=9,
+ gradient_mul=0.1,
+ point_strides=[8, 16, 32, 64, 128],
+ point_base_scale=4,
+ conv_bias='auto',
+ loss_cls=dict(
+ type='FocalLoss',
+ use_sigmoid=True,
+ gamma=2.0,
+ alpha=0.25,
+ loss_weight=1.0),
+ loss_bbox_init=dict(
+ type='SmoothL1Loss', beta=1.0 / 9.0, loss_weight=0.5),
+ loss_bbox_refine=dict(
+ type='SmoothL1Loss', beta=1.0 / 9.0, loss_weight=1.0),
+ loss_spatial_init=dict(
+ type='SpatialBorderLoss', loss_weight=0.05),
+ loss_spatial_refine=dict(
+ type='SpatialBorderLoss', loss_weight=0.1),
+ conv_cfg=None,
+ norm_cfg=None,
+ train_cfg=None,
+ test_cfg=None,
+ center_init=True,
+ top_ratio=0.4,
+ init_qua_weight=0.2,
+ ori_qua_weight=0.3,
+ poc_qua_weight=0.1,
+ **kwargs):
+ self.num_points = num_points
+ self.point_feat_channels = point_feat_channels
+ self.center_init = center_init
+
+ # we use deform conv to extract points features
+ self.dcn_kernel = int(np.sqrt(num_points))
+ self.dcn_pad = int((self.dcn_kernel - 1) / 2)
+ assert self.dcn_kernel * self.dcn_kernel == num_points, \
+ 'The points number should be a square number.'
+ assert self.dcn_kernel % 2 == 1, \
+ 'The points number should be an odd square number.'
+ dcn_base = np.arange(-self.dcn_pad, self.dcn_pad + 1).astype(np.float64)
+ dcn_base_y = np.repeat(dcn_base, self.dcn_kernel)
+ dcn_base_x = np.tile(dcn_base, self.dcn_kernel)
+ dcn_base_offset = np.stack([dcn_base_y, dcn_base_x], axis=1).reshape((-1))
+ self.dcn_base_offset = jt.array(dcn_base_offset).view(1, -1, 1, 1)
+ self.num_classes = num_classes
+ self.in_channels = in_channels
+ self.feat_channels = feat_channels
+ self.stacked_convs = stacked_convs
+ assert conv_bias == 'auto' or isinstance(conv_bias, bool)
+ self.conv_bias = conv_bias
+ self.loss_cls = build_from_cfg(loss_cls, LOSSES)
+ self.train_cfg = train_cfg
+ self.test_cfg = test_cfg
+ self.conv_cfg = conv_cfg
+ self.norm_cfg = norm_cfg
+ self.gradient_mul = gradient_mul
+ self.point_base_scale = point_base_scale
+ self.point_strides = point_strides
+ self.prior_generator = MlvlPointGenerator(self.point_strides, offset=0.)
+ self.num_base_priors = self.prior_generator.num_base_priors[0]
+ self.sampling = loss_cls['type'] not in ['FocalLoss']
+ if self.train_cfg:
+ self.init_assigner = build_from_cfg(self.train_cfg.init.assigner, BOXES)
+ self.refine_assigner = build_from_cfg(self.train_cfg.refine.assigner, BOXES)
+ # use PseudoSampler when sampling is False
+ if self.sampling and hasattr(self.train_cfg, 'sampler'):
+ sampler_cfg = self.train_cfg.sampler
+ else:
+ sampler_cfg = dict(type='PseudoSampler')
+ self.sampler = build_from_cfg(sampler_cfg, BOXES)
+ self.use_sigmoid_cls = loss_cls.get('use_sigmoid', False)
+ if self.use_sigmoid_cls:
+ self.cls_out_channels = self.num_classes
+ else:
+ self.cls_out_channels = self.num_classes + 1
+ self.loss_bbox_init = build_from_cfg(loss_bbox_init, LOSSES)
+ self.loss_bbox_refine = build_from_cfg(loss_bbox_refine, LOSSES)
+ self.loss_spatial_init = build_from_cfg(loss_spatial_init, LOSSES)
+ self.loss_spatial_refine = build_from_cfg(loss_spatial_refine, LOSSES)
+ self.init_qua_weight = init_qua_weight
+ self.ori_qua_weight = ori_qua_weight
+ self.poc_qua_weight = poc_qua_weight
+ self.top_ratio = top_ratio
+ self._init_layers()
+
+ def _init_layers(self):
+ """Initialize layers of the head."""
+ self.relu = nn.ReLU()
+ self.cls_convs = nn.ModuleList()
+ self.reg_convs = nn.ModuleList()
+ for i in range(self.stacked_convs):
+ chn = self.in_channels if i == 0 else self.feat_channels
+ self.cls_convs.append(
+ ConvModule(
+ chn,
+ self.feat_channels,
+ 3,
+ stride=1,
+ padding=1,
+ conv_cfg=self.conv_cfg,
+ norm_cfg=self.norm_cfg,
+ bias=self.conv_bias))
+ self.reg_convs.append(
+ ConvModule(
+ chn,
+ self.feat_channels,
+ 3,
+ stride=1,
+ padding=1,
+ conv_cfg=self.conv_cfg,
+ norm_cfg=self.norm_cfg,
+ bias=self.conv_bias))
+ pts_out_dim = 2 * self.num_points
+ self.reppoints_cls_conv = DeformConv(self.feat_channels,
+ self.point_feat_channels,
+ self.dcn_kernel, 1,
+ self.dcn_pad)
+ self.reppoints_cls_out = nn.Conv2d(self.point_feat_channels,
+ self.cls_out_channels, 1, 1, 0)
+ self.reppoints_pts_init_conv = nn.Conv2d(self.feat_channels,
+ self.point_feat_channels, 3,
+ 1, 1)
+ self.reppoints_pts_init_out = nn.Conv2d(self.point_feat_channels,
+ pts_out_dim, 1, 1, 0)
+ self.reppoints_pts_refine_conv = DeformConv(self.feat_channels,
+ self.point_feat_channels,
+ self.dcn_kernel, 1,
+ self.dcn_pad)
+ self.reppoints_pts_refine_out = nn.Conv2d(self.point_feat_channels,
+ pts_out_dim, 1, 1, 0)
+
+ def points2rotrect(self, pts, y_first=True):
+ """Convert points to oriented bboxes."""
+ if y_first:
+ pts = pts.reshape(-1, self.num_points, 2)
+ pts_dy = pts[:, :, 0::2]
+ pts_dx = pts[:, :, 1::2]
+ pts = jt.concat([pts_dx, pts_dy],
+ dim=2).reshape(-1, 2 * self.num_points)
+ return reppoints_min_area_bbox(pts)
+
+
+ def forward_single(self, x):
+ """Forward feature map of a single FPN level."""
+ dcn_base_offset = self.dcn_base_offset.type_as(x)
+ points_init = 0
+ cls_feat = x
+ pts_feat = x
+ base_feat = x
+ for cls_conv in self.cls_convs:
+ cls_feat = cls_conv(cls_feat)
+ for reg_conv in self.reg_convs:
+ pts_feat = reg_conv(pts_feat)
+ # initialize reppoints
+ pts_out_init = self.reppoints_pts_init_out(
+ self.relu(self.reppoints_pts_init_conv(pts_feat)))
+ pts_out_init = pts_out_init + points_init
+ # refine and classify reppoints
+ pts_out_init_grad_mul = (1 - self.gradient_mul) * pts_out_init.detach() + self.gradient_mul * pts_out_init
+ dcn_offset = pts_out_init_grad_mul - dcn_base_offset
+ cls_out = self.reppoints_cls_out(
+ self.relu(self.reppoints_cls_conv(cls_feat, dcn_offset)))
+ pts_out_refine = self.reppoints_pts_refine_out(
+ self.relu(self.reppoints_pts_refine_conv(pts_feat, dcn_offset)))
+ pts_out_refine = pts_out_refine + pts_out_init.detach()
+
+ return cls_out, pts_out_init, pts_out_refine, base_feat
+
+ def get_points(self, featmap_sizes, img_metas):
+ """Get points according to feature map sizes.
+
+ Args:
+ featmap_sizes (list[tuple]): Multi-level feature map sizes.
+ img_metas (list[dict]): Image meta info.
+
+ Returns:
+ tuple: points of each image, valid flags of each image
+ """
+ num_imgs = len(img_metas)
+
+ multi_level_points = self.prior_generator.grid_priors(
+ featmap_sizes, with_stride=True)
+ points_list = [[point.clone() for point in multi_level_points]
+ for _ in range(num_imgs)]
+
+ valid_flag_list = []
+ for img_id, img_meta in enumerate(img_metas):
+ multi_level_flags = self.prior_generator.valid_flags(
+ featmap_sizes, img_meta['pad_shape'])
+ valid_flag_list.append(multi_level_flags)
+
+ return points_list, valid_flag_list
+
+ def offset_to_pts(self, center_list, pred_list):
+ """Change from point offset to point coordinate."""
+ pts_list = []
+ for i_lvl, _ in enumerate(self.point_strides):
+ pts_lvl = []
+ for i_img, _ in enumerate(center_list):
+ pts_center = center_list[i_img][i_lvl][:, :2].repeat(1, self.num_points)
+ pts_shift = pred_list[i_lvl][i_img]
+ yx_pts_shift = pts_shift.permute(1, 2, 0).view(-1, 2 * self.num_points)
+ y_pts_shift = yx_pts_shift[..., 0::2]
+ x_pts_shift = yx_pts_shift[..., 1::2]
+ xy_pts_shift = jt.stack([x_pts_shift, y_pts_shift], -1)
+ xy_pts_shift = xy_pts_shift.view(*yx_pts_shift.shape[:-1], -1)
+ pts = xy_pts_shift * self.point_strides[i_lvl] + pts_center
+ pts_lvl.append(pts)
+ pts_lvl = jt.stack(pts_lvl, 0)
+ pts_list.append(pts_lvl)
+ return pts_list
+
+ def get_adaptive_points_feature(self, features, pt_locations, stride):
+ """Get the points features from the locations of predicted points.
+
+ Args:
+ features (jt.tensor): base feature with shape (B,C,W,H)
+ pt_locations (jt.tensor): locations of points in each point set
+ with shape (B, N_points_set(number of point set),
+ N_points(number of points in each point set) *2)
+ Returns:
+ tensor: sampling features with (B, C, N_points_set, N_points)
+ """
+
+ h = features.shape[2] * stride
+ w = features.shape[3] * stride
+
+ pt_locations = pt_locations.view(pt_locations.shape[0],
+ pt_locations.shape[1], -1, 2).clone()
+ pt_locations[..., 0] = pt_locations[..., 0] / (w / 2.) - 1
+ pt_locations[..., 1] = pt_locations[..., 1] / (h / 2.) - 1
+
+ batch_size = features.size(0)
+ sampled_features = jt.zeros([
+ pt_locations.shape[0],
+ features.size(1),
+ pt_locations.size(1),
+ pt_locations.size(2)
+ ])
+
+ for i in range(batch_size):
+ feature = grid_sample(features[i:i + 1],
+ pt_locations[i:i + 1])[0]
+ sampled_features[i] = feature
+
+ return sampled_features,
+
+ def feature_cosine_similarity(self, points_features):
+ """Compute the points features similarity for points-wise correlation.
+
+ Args:
+ points_features (jt.tensor): sampling point feature with
+ shape (N_pointsets, N_points, C)
+ Returns:
+ max_correlation: max feature similarity in each point set with
+ shape (N_points_set, N_points, C)
+ """
+
+ mean_points_feats = jt.mean(points_features, dim=1, keepdims=True)
+ norm_pts_feats = jt.norm(
+ points_features, p=2, dim=2).unsqueeze(dim=2).clamp(min_v=1e-2)
+ norm_mean_pts_feats = jt.norm(
+ mean_points_feats, p=2, dim=2).unsqueeze(dim=2).clamp(min_v=1e-2)
+
+ unity_points_features = points_features / norm_pts_feats
+ unity_mean_points_feats = mean_points_feats / norm_mean_pts_feats
+
+ # cos_similarity = nn.CosineSimilarity(dim=2, eps=1e-6)
+ unity_points_features = 1. * unity_points_features / (jt.norm(unity_points_features, 2, 2, keepdims=True).expand_as(unity_mean_points_feats) + 1e-6)
+ unity_mean_points_feats = 1. * unity_mean_points_feats / (jt.norm(unity_mean_points_feats, 2, 2, keepdims=True).expand_as(unity_mean_points_feats) + 1e-6)
+ feats_similarity = 1.0 - (unity_points_features * unity_mean_points_feats).sum(dim=-1)
+ # feats_similarity = 1.0 - cos_similarity(unity_points_features,
+ # unity_mean_points_feats)
+
+ max_correlation = jt.max(feats_similarity, dim=1)
+
+ return max_correlation
+
+ def sampling_points(self, polygons, points_num):
+ """Sample edge points for polygon.
+
+ Args:
+ polygons (jt.tensor): polygons with shape (N, 8)
+ points_num (int): number of sampling points for each polygon edge.
+ 10 by default.
+
+ Returns:
+ sampling_points (jt.tensor): sampling points with shape (N,
+ points_num*4, 2)
+ """
+ polygons_xs, polygons_ys = polygons[:, 0::2], polygons[:, 1::2]
+ ratio = jt.linspace(0, 1, points_num).repeat(
+ polygons.shape[0], 1)
+
+ edge_pts_x = []
+ edge_pts_y = []
+ for i in range(4):
+ if i < 3:
+ points_x = ratio * polygons_xs[:, i + 1:i + 2] + (
+ 1 - ratio) * polygons_xs[:, i:i + 1]
+ points_y = ratio * polygons_ys[:, i + 1:i + 2] + (
+ 1 - ratio) * polygons_ys[:, i:i + 1]
+ else:
+ points_x = ratio * polygons_xs[:, 0].unsqueeze(1) + (
+ 1 - ratio) * polygons_xs[:, i].unsqueeze(1)
+ points_y = ratio * polygons_ys[:, 0].unsqueeze(1) + (
+ 1 - ratio) * polygons_ys[:, i].unsqueeze(1)
+
+ edge_pts_x.append(points_x)
+ edge_pts_y.append(points_y)
+
+ sampling_points_x = jt.concat(edge_pts_x, dim=1).unsqueeze(dim=2)
+ sampling_points_y = jt.concat(edge_pts_y, dim=1).unsqueeze(dim=2)
+ sampling_points = jt.concat([sampling_points_x, sampling_points_y],
+ dim=2)
+
+ return sampling_points
+
+ def pointsets_quality_assessment(self, pts_features, cls_score,
+ pts_pred_init, pts_pred_refine, label,
+ bbox_gt, label_weight, bbox_weight,
+ pos_inds):
+ """Assess the quality of each point set from the classification,
+ localization, orientation, and point-wise correlation based on
+ the assigned point sets samples.
+ Args:
+ pts_features (jt.tensor): points features with shape (N, 9, C)
+ cls_score (jt.tensor): classification scores with
+ shape (N, class_num)
+ pts_pred_init (jt.tensor): initial point sets prediction with
+ shape (N, 9*2)
+ pts_pred_refine (jt.tensor): refined point sets prediction with
+ shape (N, 9*2)
+ label (jt.tensor): gt label with shape (N)
+ bbox_gt(jt.tensor): gt bbox of polygon with shape (N, 8)
+ label_weight (jt.tensor): label weight with shape (N)
+ bbox_weight (jt.tensor): box weight with shape (N)
+ pos_inds (jt.tensor): the inds of positive point set samples
+
+ Returns:
+ qua (jt.tensor) : weighted quality values for positive
+ point set samples.
+ """
+ pos_scores = cls_score[pos_inds]
+ pos_pts_pred_init = pts_pred_init[pos_inds]
+ pos_pts_pred_refine = pts_pred_refine[pos_inds]
+ pos_pts_refine_features = pts_features[pos_inds]
+ pos_bbox_gt = bbox_gt[pos_inds]
+ pos_label = label[pos_inds]
+ pos_label_weight = label_weight[pos_inds]
+ pos_bbox_weight = bbox_weight[pos_inds]
+
+ # quality of point-wise correlation
+ qua_poc = self.poc_qua_weight * self.feature_cosine_similarity(
+ pos_pts_refine_features)
+
+ qua_cls = self.loss_cls(
+ pos_scores,
+ pos_label,
+ pos_label_weight,
+ avg_factor=self.loss_cls.loss_weight,
+ reduction_override='none')
+
+ polygons_pred_init = reppoints_min_area_bbox(pos_pts_pred_init)
+ polygons_pred_refine = reppoints_min_area_bbox(pos_pts_pred_refine)
+ sampling_pts_pred_init = self.sampling_points(
+ polygons_pred_init, 10)
+ sampling_pts_pred_refine = self.sampling_points(
+ polygons_pred_refine, 10)
+ sampling_pts_gt = self.sampling_points(pos_bbox_gt, 10)
+
+ # quality of orientation
+ qua_ori_init = self.ori_qua_weight * ChamferDistance2D(
+ sampling_pts_gt, sampling_pts_pred_init)
+ qua_ori_refine = self.ori_qua_weight * ChamferDistance2D(
+ sampling_pts_gt, sampling_pts_pred_refine)
+
+ # quality of localization
+ qua_loc_init = self.loss_bbox_refine(
+ pos_pts_pred_init,
+ pos_bbox_gt,
+ pos_bbox_weight,
+ avg_factor=self.loss_cls.loss_weight,
+ reduction_override='none')
+ qua_loc_refine = self.loss_bbox_refine(
+ pos_pts_pred_refine,
+ pos_bbox_gt,
+ pos_bbox_weight,
+ avg_factor=self.loss_cls.loss_weight,
+ reduction_override='none')
+
+ # quality of classification
+ qua_cls = qua_cls.sum(-1)
+
+ # weighted inti-stage and refine-stage
+ qua = qua_cls + self.init_qua_weight * (
+ qua_loc_init + qua_ori_init) + (1.0 - self.init_qua_weight) * (
+ qua_loc_refine + qua_ori_refine) + qua_poc
+
+ return qua,
+
+ def dynamic_pointset_samples_selection(self,
+ quality,
+ label,
+ label_weight,
+ bbox_weight,
+ pos_inds,
+ pos_gt_inds,
+ num_proposals_each_level=None,
+ num_level=None):
+ """The dynamic top k selection of point set samples based on the
+ quality assessment values.
+
+ Args:
+ quality (jt.tensor): the quality values of positive
+ point set samples
+ label (jt.tensor): gt label with shape (N)
+ bbox_gt(jt.tensor): gt bbox of polygon with shape (N, 8)
+ label_weight (jt.tensor): label weight with shape (N)
+ bbox_weight (jt.tensor): box weight with shape (N)
+ pos_inds (jt.tensor): the inds of positive point set samples
+ num_proposals_each_level (list[int]): proposals number of
+ each level
+ num_level (int): the level number
+ Returns:
+ label: gt label with shape (N)
+ label_weight: label weight with shape (N)
+ bbox_weight: box weight with shape (N)
+ num_pos (int): the number of selected positive point samples
+ with high-qualty
+ pos_normalize_term (jt.tensor): the corresponding positive
+ normalize term
+ """
+
+ if len(pos_inds) == 0:
+ return label, label_weight, bbox_weight, 0, jt.array(
+ []).type_as(bbox_weight)
+
+ num_gt = pos_gt_inds.max()
+ num_proposals_each_level_ = num_proposals_each_level.copy()
+ num_proposals_each_level_.insert(0, 0)
+ inds_level_interval = np.cumsum(num_proposals_each_level_)
+ pos_level_mask = []
+ for i in range(num_level):
+ mask = (pos_inds >= inds_level_interval[i]) & (
+ pos_inds < inds_level_interval[i + 1])
+ pos_level_mask.append(mask)
+
+ pos_inds_after_select = []
+ ignore_inds_after_select = []
+
+ if int(num_gt) == 0:
+ return label, label_weight, bbox_weight, 0, jt.array(
+ []).type_as(bbox_weight)
+
+ for gt_ind in range(int(num_gt)):
+ pos_inds_select = []
+ pos_loss_select = []
+ gt_mask = pos_gt_inds == (gt_ind + 1)
+ for level in range(num_level):
+ level_mask = pos_level_mask[level]
+ level_gt_mask = level_mask & gt_mask
+ mask_quality = quality[level_gt_mask]
+ if level_gt_mask.sum() <= 6:
+ topk_inds, value = jt.argsort(mask_quality,dim=mask_quality.ndim-1,descending=False)
+ else:
+ value, topk_inds = mask_quality.topk(
+ 6, largest=False)
+ pos_inds_select.append(pos_inds[level_gt_mask][topk_inds])
+ pos_loss_select.append(value)
+ pos_inds_select = jt.concat(pos_inds_select)
+ pos_loss_select = jt.concat(pos_loss_select)
+
+ if len(pos_inds_select) < 2:
+ pos_inds_after_select.append(pos_inds_select)
+ ignore_inds_after_select.append(jt.empty([]))
+
+ else:
+ # pos_loss_select, sort_inds = pos_loss_select.sort() # small to large
+ # pos_loss_select, sort_inds = fake_argsort2(pos_loss_select)
+ sort_inds, pos_loss_select = fake_argsort2(pos_loss_select)
+ pos_inds_select = pos_inds_select[sort_inds]
+ # dynamic top k
+ topk = math.ceil(pos_loss_select.shape[0] * self.top_ratio)
+ pos_inds_select_topk = pos_inds_select[:topk]
+ pos_inds_after_select.append(pos_inds_select_topk)
+ ignore_inds_after_select.append(
+ jt.empty([]))
+
+ pos_inds_after_select = jt.concat(pos_inds_after_select)
+ ignore_inds_after_select = jt.concat(ignore_inds_after_select)
+
+ reassign_mask = (pos_inds.unsqueeze(1) != pos_inds_after_select).all(1)
+ reassign_ids = pos_inds[reassign_mask]
+ label[reassign_ids] = 0
+ label_weight[ignore_inds_after_select] = 0
+ bbox_weight[reassign_ids] = 0
+ num_pos = len(pos_inds_after_select)
+
+ pos_level_mask_after_select = []
+ for i in range(num_level):
+ mask = (pos_inds_after_select >= inds_level_interval[i]) & (
+ pos_inds_after_select < inds_level_interval[i + 1])
+ pos_level_mask_after_select.append(mask)
+ pos_level_mask_after_select = jt.stack(pos_level_mask_after_select,
+ 0).type_as(label)
+ pos_normalize_term = pos_level_mask_after_select * (
+ self.point_base_scale *
+ jt.array(self.point_strides).type_as(label)).reshape(-1, 1)
+ pos_normalize_term = pos_normalize_term[
+ pos_normalize_term > 0].type_as(bbox_weight)
+ assert len(pos_normalize_term) == len(pos_inds_after_select)
+
+ return label, label_weight, bbox_weight, num_pos, pos_normalize_term
+
+ def _point_target_single(self,
+ flat_proposals,
+ valid_flags,
+ gt_bboxes,
+ gt_bboxes_ignore,
+ gt_labels,
+ overlaps,
+ stage='init',
+ unmap_outputs=True):
+ """Single point target function."""
+ inside_flags = valid_flags
+ if not inside_flags.any():
+ return (None, ) * 8
+ # assign gt and sample proposals
+ proposals = flat_proposals[inside_flags, :]
+
+ if stage == 'init':
+ assigner = self.init_assigner
+ pos_weight = self.train_cfg.init.pos_weight
+ else:
+ assigner = self.refine_assigner
+ pos_weight = self.train_cfg.refine.pos_weight
+
+ # convert gt from obb to poly
+ gt_bboxes = obb2poly(gt_bboxes)
+
+ assign_result = assigner.assign(proposals, gt_bboxes,
+ gt_bboxes_ignore,
+ None if self.sampling else gt_labels)
+ sampling_result = self.sampler.sample(assign_result, proposals,
+ gt_bboxes)
+
+
+ num_valid_proposals = proposals.shape[0]
+ bbox_gt = jt.zeros([num_valid_proposals, 8], dtype=proposals.dtype)
+ pos_proposals = jt.zeros_like(proposals)
+ proposals_weights = jt.zeros(num_valid_proposals, dtype=proposals.dtype)
+ labels = jt.full((num_valid_proposals, ),
+ 0,
+ dtype=jt.int32)
+ label_weights = jt.zeros((num_valid_proposals,), dtype=jt.float32)
+
+ pos_inds = sampling_result.pos_inds
+ neg_inds = sampling_result.neg_inds
+ if len(pos_inds) > 0:
+ pos_gt_bboxes = sampling_result.pos_gt_bboxes
+ bbox_gt[pos_inds, :] = pos_gt_bboxes
+ pos_proposals[pos_inds, :] = proposals[pos_inds, :]
+ proposals_weights[pos_inds] = 1.0
+ if gt_labels is None:
+ labels[pos_inds] = 1
+ else:
+ labels[pos_inds] = gt_labels[
+ sampling_result.pos_assigned_gt_inds]
+ if pos_weight <= 0:
+ label_weights[pos_inds] = 1.0
+ else:
+ label_weights[pos_inds] = pos_weight
+ if len(neg_inds) > 0:
+ label_weights[neg_inds] = 1.0
+
+ # map up to original set of proposals
+ if unmap_outputs:
+ num_total_proposals = flat_proposals.size(0)
+ labels = unmap(labels, num_total_proposals, inside_flags)
+ label_weights = unmap(label_weights, num_total_proposals,
+ inside_flags)
+ bbox_gt = unmap(bbox_gt, num_total_proposals, inside_flags)
+ pos_proposals = unmap(pos_proposals, num_total_proposals,
+ inside_flags)
+ proposals_weights = unmap(proposals_weights, num_total_proposals,
+ inside_flags)
+
+ return (labels, label_weights, bbox_gt, pos_proposals,
+ proposals_weights, pos_inds, neg_inds, sampling_result)
+
+ def init_loss_single(self, pts_pred_init, bbox_gt_init, bbox_weights_init,
+ stride):
+ """Single initial stage loss function."""
+ normalize_term = self.point_base_scale * stride
+
+ bbox_gt_init = bbox_gt_init.reshape(-1, 8)
+ bbox_weights_init = bbox_weights_init.reshape(-1)
+ pts_pred_init = pts_pred_init.reshape(-1, 2 * self.num_points)
+ pos_ind_init = (bbox_weights_init > 0).nonzero().reshape(-1)
+
+ pts_pred_init_norm = pts_pred_init[pos_ind_init]
+ bbox_gt_init_norm = bbox_gt_init[pos_ind_init]
+ bbox_weights_pos_init = bbox_weights_init[pos_ind_init]
+
+ loss_pts_init = self.loss_bbox_init(
+ pts_pred_init_norm / normalize_term,
+ bbox_gt_init_norm / normalize_term, bbox_weights_pos_init)
+
+ loss_border_init = self.loss_spatial_init(
+ pts_pred_init_norm.reshape(-1, 2 * self.num_points) /
+ normalize_term,
+ bbox_gt_init_norm / normalize_term,
+ bbox_weights_pos_init,
+ avg_factor=None)
+
+ return loss_pts_init, loss_border_init
+
+ def get_targets(self,
+ proposals_list,
+ valid_flag_list,
+ gt_bboxes_list,
+ img_metas,
+ gt_bboxes_ignore_list=None,
+ gt_labels_list=None,
+ stage='init',
+ label_channels=1,
+ unmap_outputs=True):
+ """Compute corresponding GT box and classification targets for
+ proposals.
+
+ Args:
+ proposals_list (list[list]): Multi level points/bboxes of each
+ image.
+ valid_flag_list (list[list]): Multi level valid flags of each
+ image.
+ gt_bboxes_list (list[Tensor]): Ground truth bboxes of each image.
+ img_metas (list[dict]): Meta info of each image.
+ gt_bboxes_ignore_list (list[Tensor]): Ground truth bboxes to be
+ ignored.
+ gt_bboxes_list (list[Tensor]): Ground truth labels of each box.
+ stage (str): `init` or `refine`. Generate target for init stage or
+ refine stage
+ label_channels (int): Channel of label.
+ unmap_outputs (bool): Whether to map outputs back to the original
+ set of anchors.
+
+ Returns:
+ tuple (list[Tensor]):
+
+ - labels_list (list[Tensor]): Labels of each level.
+ - label_weights_list (list[Tensor]): Label weights of each \
+ level.
+ - bbox_gt_list (list[Tensor]): Ground truth bbox of each level.
+ - proposal_list (list[Tensor]): Proposals(points/bboxes) of \
+ each level.
+ - proposal_weights_list (list[Tensor]): Proposal weights of \
+ each level.
+ - num_total_pos (int): Number of positive samples in all \
+ images.
+ - num_total_neg (int): Number of negative samples in all \
+ images.
+ """
+ assert stage in ['init', 'refine']
+ num_imgs = len(img_metas)
+ assert len(proposals_list) == len(valid_flag_list) == num_imgs
+
+ # points number of multi levels
+ num_level_proposals = [points.size(0) for points in proposals_list[0]]
+
+ # concat all level points and flags to a single tensor
+ for i in range(num_imgs):
+ assert len(proposals_list[i]) == len(valid_flag_list[i])
+ proposals_list[i] = jt.concat(proposals_list[i])
+ valid_flag_list[i] = jt.concat(valid_flag_list[i])
+
+ # compute targets for each image
+ if gt_bboxes_ignore_list is None:
+ gt_bboxes_ignore_list = [None for _ in range(num_imgs)]
+ if gt_labels_list is None:
+ gt_labels_list = [None for _ in range(num_imgs)]
+ all_overlaps_rotate_list = [None] * len(proposals_list)
+ (all_labels, all_label_weights, all_bbox_gt, all_proposals,
+ all_proposal_weights, pos_inds_list, neg_inds_list,
+ sampling_result) = multi_apply(
+ self._point_target_single,
+ proposals_list,
+ valid_flag_list,
+ gt_bboxes_list,
+ gt_bboxes_ignore_list,
+ gt_labels_list,
+ all_overlaps_rotate_list,
+ stage=stage,
+ unmap_outputs=unmap_outputs)
+
+ if stage == 'init':
+ # no valid points
+ if any([labels is None for labels in all_labels]):
+ return None
+ # sampled points of all images
+ num_total_pos = sum([max(inds.numel(), 1) for inds in pos_inds_list])
+ num_total_neg = sum([max(inds.numel(), 1) for inds in neg_inds_list])
+ labels_list = images_to_levels(all_labels, num_level_proposals)
+ label_weights_list = images_to_levels(all_label_weights,
+ num_level_proposals)
+ bbox_gt_list = images_to_levels(all_bbox_gt, num_level_proposals)
+ proposals_list = images_to_levels(all_proposals, num_level_proposals)
+ proposal_weights_list = images_to_levels(all_proposal_weights,
+ num_level_proposals)
+
+ return (labels_list, label_weights_list, bbox_gt_list, proposals_list,
+ proposal_weights_list, num_total_pos, num_total_neg, None)
+ else:
+ pos_inds = []
+ # pos_gt_index = []
+ for i, single_labels in enumerate(all_labels):
+ pos_mask = single_labels > 0
+ pos_inds.append(pos_mask.nonzero().view(-1))
+
+ gt_inds = [item.pos_assigned_gt_inds for item in sampling_result]
+
+ return (all_labels, all_label_weights, all_bbox_gt, all_proposals,
+ all_proposal_weights, pos_inds, gt_inds)
+
+ def loss(self,
+ cls_scores,
+ pts_preds_init,
+ pts_preds_refine,
+ base_features,
+ gt_bboxes,
+ gt_labels,
+ img_metas,
+ gt_bboxes_ignore=None):
+ """Loss function of CFA head."""
+
+ featmap_sizes = [featmap.size()[-2:] for featmap in cls_scores]
+ assert len(featmap_sizes) == self.prior_generator.num_levels
+ label_channels = self.cls_out_channels if self.use_sigmoid_cls else 1
+
+
+ # target for initial stage
+ center_list, valid_flag_list = self.get_points(
+ featmap_sizes, img_metas)
+ pts_coordinate_preds_init = self.offset_to_pts(center_list,
+ pts_preds_init)
+
+ num_proposals_each_level = [(featmap.size(-1) * featmap.size(-2))
+ for featmap in cls_scores]
+ num_level = len(featmap_sizes)
+ assert num_level == len(pts_coordinate_preds_init)
+
+ if self.train_cfg.init.assigner['type'] == 'ConvexAssigner':
+ candidate_list = center_list
+ else:
+ raise NotImplementedError
+ cls_reg_targets_init = self.get_targets(
+ candidate_list,
+ valid_flag_list,
+ gt_bboxes,
+ img_metas,
+ gt_bboxes_ignore_list=gt_bboxes_ignore,
+ gt_labels_list=gt_labels,
+ stage='init',
+ label_channels=label_channels)
+ (*_, bbox_gt_list_init, candidate_list_init, bbox_weights_list_init,
+ num_total_pos_init, num_total_neg_init, _) = cls_reg_targets_init
+
+ # target for refinement stage
+ center_list, valid_flag_list = self.get_points(
+ featmap_sizes, img_metas)
+ pts_coordinate_preds_refine = self.offset_to_pts(
+ center_list, pts_preds_refine)
+
+ refine_points_features, = multi_apply(self.get_adaptive_points_feature,
+ base_features,
+ pts_coordinate_preds_refine,
+ self.point_strides)
+ features_pts_refine = levels_to_images(refine_points_features)
+ features_pts_refine = [
+ item.reshape(-1, self.num_points, item.shape[-1])
+ for item in features_pts_refine
+ ]
+
+ points_list = []
+ for i_img, center in enumerate(center_list):
+ points = []
+ for i_lvl in range(len(pts_preds_refine)):
+ points_preds_init_ = pts_preds_init[i_lvl].detach()
+ points_preds_init_ = points_preds_init_.view(
+ points_preds_init_.shape[0], -1,
+ *points_preds_init_.shape[2:])
+ points_shift = points_preds_init_.permute(
+ 0, 2, 3, 1) * self.point_strides[i_lvl]
+ points_center = center[i_lvl][:, :2].repeat(1, self.num_points)
+ points.append(
+ points_center +
+ points_shift[i_img].reshape(-1, 2 * self.num_points))
+ points_list.append(points)
+
+ cls_reg_targets_refine = self.get_targets(
+ points_list,
+ valid_flag_list,
+ gt_bboxes,
+ img_metas,
+ gt_bboxes_ignore_list=gt_bboxes_ignore,
+ gt_labels_list=gt_labels,
+ stage='refine',
+ label_channels=label_channels)
+
+ (labels_list, label_weights_list, bbox_gt_list_refine, _,
+ bbox_weights_list_refine, pos_inds_list_refine,
+ pos_gt_index_list_refine) = cls_reg_targets_refine
+
+ cls_scores = levels_to_images(cls_scores)
+ cls_scores = [
+ item.reshape(-1, self.cls_out_channels) for item in cls_scores
+ ]
+
+ pts_coordinate_preds_init_img = levels_to_images(
+ pts_coordinate_preds_init, flatten=True)
+ pts_coordinate_preds_init_img = [
+ item.reshape(-1, 2 * self.num_points)
+ for item in pts_coordinate_preds_init_img
+ ]
+
+ pts_coordinate_preds_refine_img = levels_to_images(
+ pts_coordinate_preds_refine, flatten=True)
+ pts_coordinate_preds_refine_img = [
+ item.reshape(-1, 2 * self.num_points)
+ for item in pts_coordinate_preds_refine_img
+ ]
+
+ with jt.no_grad():
+
+ quality_assess_list, = multi_apply(
+ self.pointsets_quality_assessment, features_pts_refine,
+ cls_scores, pts_coordinate_preds_init_img,
+ pts_coordinate_preds_refine_img, labels_list,
+ bbox_gt_list_refine, label_weights_list,
+ bbox_weights_list_refine, pos_inds_list_refine)
+
+ labels_list, label_weights_list, bbox_weights_list_refine, \
+ num_pos, pos_normalize_term = multi_apply(
+ self.dynamic_pointset_samples_selection,
+ quality_assess_list,
+ labels_list,
+ label_weights_list,
+ bbox_weights_list_refine,
+ pos_inds_list_refine,
+ pos_gt_index_list_refine,
+ num_proposals_each_level=num_proposals_each_level,
+ num_level=num_level
+ )
+ num_pos = sum(num_pos)
+
+ # convert all tensor list to a flatten tensor
+ cls_scores = jt.concat(cls_scores, 0).view(-1, cls_scores[0].size(-1))
+ pts_preds_refine = jt.concat(pts_coordinate_preds_refine_img, 0).view(
+ -1, pts_coordinate_preds_refine_img[0].size(-1))
+
+ labels = jt.concat(labels_list, 0).view(-1)
+ labels_weight = jt.concat(label_weights_list, 0).view(-1)
+ bbox_gt_refine = jt.concat(bbox_gt_list_refine,
+ 0).view(-1, bbox_gt_list_refine[0].size(-1))
+ bbox_weights_refine = jt.concat(bbox_weights_list_refine, 0).view(-1)
+ pos_normalize_term = jt.concat(pos_normalize_term, 0).reshape(-1)
+ pos_inds_flatten = (labels > 0).nonzero().reshape(-1)
+
+ # print('pos_normalize_term: ', pos_normalize_term.shape, pos_inds_flatten.shape)
+ # assert len(pos_normalize_term) == len(pos_inds_flatten)
+
+ if bool(num_pos) and len(pos_normalize_term) == len(pos_inds_flatten):
+ losses_cls = self.loss_cls(
+ cls_scores, labels, labels_weight, avg_factor=num_pos)
+ pos_pts_pred_refine = pts_preds_refine[pos_inds_flatten]
+ pos_bbox_gt_refine = bbox_gt_refine[pos_inds_flatten]
+
+ pos_bbox_weights_refine = bbox_weights_refine[pos_inds_flatten]
+ losses_pts_refine = self.loss_bbox_refine(
+ pos_pts_pred_refine / pos_normalize_term.reshape(-1, 1),
+ pos_bbox_gt_refine / pos_normalize_term.reshape(-1, 1),
+ pos_bbox_weights_refine)
+
+ loss_border_refine = self.loss_spatial_refine(
+ pos_pts_pred_refine.reshape(-1, 2 * self.num_points) /
+ pos_normalize_term.reshape(-1, 1),
+ pos_bbox_gt_refine / pos_normalize_term.reshape(-1, 1),
+ pos_bbox_weights_refine,
+ avg_factor=None)
+
+ else:
+ losses_cls = cls_scores.sum() * 0
+ losses_pts_refine = pts_preds_refine.sum() * 0
+ loss_border_refine = pts_preds_refine.sum() * 0
+
+ losses_pts_init, loss_border_init = multi_apply(
+ self.init_loss_single, pts_coordinate_preds_init,
+ bbox_gt_list_init, bbox_weights_list_init, self.point_strides)
+
+ loss_dict_all = {
+ 'loss_cls': losses_cls,
+ 'loss_pts_init': losses_pts_init,
+ 'loss_pts_refine': losses_pts_refine,
+ 'loss_spatial_init': loss_border_init,
+ 'loss_spatial_refine': loss_border_refine
+ }
+ return loss_dict_all
+
+ def get_bboxes(self,
+ cls_scores,
+ pts_preds_init,
+ pts_preds_refine,
+ img_metas,
+ cfg=None,
+ rescale=False,
+ with_nms=True,
+ **kwargs):
+ """Transform network outputs of a batch into bbox results.
+
+ Args:
+ cls_scores (list[Tensor]): Classification scores for all
+ scale levels, each is a 4D-tensor, has shape
+ (batch_size, num_priors * num_classes, H, W).
+ pts_preds_init (list[Tensor]): Box energies / deltas for all
+ scale levels, each is a 18D-tensor, has shape
+ (batch_size, num_points * 2, H, W).
+ pts_preds_refine (list[Tensor]): Box energies / deltas for all
+ scale levels, each is a 18D-tensor, has shape
+ (batch_size, num_points * 2, H, W).
+ img_metas (list[dict], Optional): Image meta info. Default None.
+ cfg (mmcv.Config, Optional): Test / postprocessing configuration,
+ if None, test_cfg would be used. Default None.
+ rescale (bool): If True, return boxes in original image space.
+ Default False.
+ with_nms (bool): If True, do nms before return boxes.
+ Default True.
+
+ Returns:
+ list[list[Tensor, Tensor]]: Each item in result_list is 2-tuple.
+ The first item is an (n, 6) tensor, where the first 4 columns
+ are bounding box positions (cx, cy, w, h, a) and the
+ 6-th column is a score between 0 and 1. The second item is a
+ (n,) tensor where each item is the predicted class label of
+ the corresponding box.
+ """
+ assert len(cls_scores) == len(pts_preds_refine)
+
+ num_levels = len(cls_scores)
+
+ featmap_sizes = [cls_scores[i].shape[-2:] for i in range(num_levels)]
+ mlvl_priors = self.prior_generator.grid_priors(
+ featmap_sizes,
+ dtype=cls_scores[0].dtype)
+
+ result_list = []
+
+ for img_id, _ in enumerate(img_metas):
+ img_meta = img_metas[img_id]
+ cls_score_list = select_single_mlvl(cls_scores, img_id)
+ point_pred_list = select_single_mlvl(pts_preds_refine, img_id)
+
+ results = self._get_bboxes_single(cls_score_list, point_pred_list,
+ mlvl_priors, img_meta, cfg,
+ rescale, with_nms, **kwargs)
+ result_list.append(results)
+
+ return result_list
+
+ def _get_bboxes_single(self,
+ cls_score_list,
+ point_pred_list,
+ mlvl_priors,
+ img_meta,
+ cfg,
+ rescale=False,
+ with_nms=True,
+ **kwargs):
+ """Transform outputs of a single image into bbox predictions.
+ Args:
+ cls_score_list (list[Tensor]): Box scores from all scale
+ levels of a single image, each item has shape
+ (num_priors * num_classes, H, W).
+ bbox_pred_list (list[Tensor]): Box energies / deltas from
+ all scale levels of a single image, each item has shape
+ (num_priors * 4, H, W).
+ score_factor_list (list[Tensor]): Score factor from all scale
+ levels of a single image. RepPoints head does not need
+ this value.
+ mlvl_priors (list[Tensor]): Each element in the list is
+ the priors of a single level in feature pyramid, has shape
+ (num_priors, 2).
+ img_meta (dict): Image meta info.
+ cfg (mmcv.Config): Test / postprocessing configuration,
+ if None, test_cfg would be used.
+ rescale (bool): If True, return boxes in original image space.
+ Default: False.
+ with_nms (bool): If True, do nms before return boxes.
+ Default: True.
+ Returns:
+ tuple[Tensor]: Results of detected bboxes and labels. If with_nms
+ is False and mlvl_score_factor is None, return mlvl_bboxes and
+ mlvl_scores, else return mlvl_bboxes, mlvl_scores and
+ mlvl_score_factor. Usually with_nms is False is used for aug
+ test. If with_nms is True, then return the following format
+
+ - det_bboxes (Tensor): Predicted bboxes with shape \
+ [num_bboxes, 5], where the first 4 columns are bounding \
+ box positions (cx, cy, w, h, a) and the 5-th \
+ column are scores between 0 and 1.
+ - det_labels (Tensor): Predicted labels of the corresponding \
+ box with shape [num_bboxes].
+ """
+
+ cfg = self.test_cfg if cfg is None else cfg
+ assert len(cls_score_list) == len(point_pred_list)
+ scale_factor = img_meta['scale_factor']
+
+ mlvl_bboxes = []
+ mlvl_scores = []
+ for level_idx, (cls_score, points_pred, points) in enumerate(
+ zip(cls_score_list, point_pred_list, mlvl_priors)):
+ assert cls_score.size()[-2:] == points_pred.size()[-2:]
+
+ cls_score = cls_score.permute(1, 2,
+ 0).reshape(-1, self.cls_out_channels)
+ if self.use_sigmoid_cls:
+ scores = cls_score.sigmoid()
+ else:
+ scores = cls_score.softmax(-1)[:, :-1]
+
+ points_pred = points_pred.permute(1, 2, 0).reshape(
+ -1, 2 * self.num_points)
+ nms_pre = cfg.get('nms_pre', -1)
+ if 0 < nms_pre < scores.shape[0]:
+ if self.use_sigmoid_cls:
+ max_scores = scores.max(dim=1)
+ else:
+ max_scores = scores[:, 1:].max(dim=1)
+ _, topk_inds = max_scores.topk(nms_pre)
+ points = points[topk_inds, :]
+ points_pred = points_pred[topk_inds, :]
+ scores = scores[topk_inds, :]
+
+ poly_pred = self.points2rotrect(points_pred, y_first=True)
+ bbox_pos_center = points[:, :2].repeat(1, 4)
+ polys = poly_pred * self.point_strides[level_idx] + bbox_pos_center
+ bboxes = poly2obb(polys)
+
+ mlvl_bboxes.append(bboxes)
+ mlvl_scores.append(scores)
+
+ mlvl_bboxes = jt.concat(mlvl_bboxes)
+
+ if rescale:
+ mlvl_bboxes[..., :4] /= mlvl_bboxes[..., :4].new_tensor(
+ scale_factor)
+ mlvl_scores = jt.concat(mlvl_scores)
+ if self.use_sigmoid_cls:
+ padding = jt.zeros((mlvl_scores.shape[0], 1), dtype=mlvl_scores.dtype)
+ mlvl_scores = jt.concat([padding, mlvl_scores], dim=1)
+
+ if with_nms:
+ det_bboxes, det_labels = multiclass_nms_rotated(
+ mlvl_bboxes, mlvl_scores, cfg.score_thr, cfg.nms,
+ cfg.max_per_img)
+ boxes = det_bboxes[:, :5]
+ scores = det_bboxes[:, 5]
+ polys = rotated_box_to_poly(boxes)
+ return polys, scores, det_labels
+ else:
+ raise NotImplementedError
+
+ def parse_targets(self, targets):
+ img_metas = []
+ gt_bboxes = []
+ gt_bboxes_ignore = []
+ gt_labels = []
+
+ for target in targets:
+ if self.is_training():
+ gt_bboxes.append(target["rboxes"])
+ gt_labels.append(target["labels"])
+ gt_bboxes_ignore.append(target["rboxes_ignore"])
+ img_metas.append(dict(
+ img_shape=target["img_size"][::-1],
+ scale_factor=target["scale_factor"],
+ pad_shape = target["pad_shape"]
+ ))
+ if not self.is_training():
+ return dict(img_metas = img_metas)
+ return dict(
+ gt_bboxes = gt_bboxes,
+ gt_labels = gt_labels,
+ img_metas = img_metas,
+ gt_bboxes_ignore = gt_bboxes_ignore,
+ )
+
+ def execute(self, feats, targets):
+ outs = multi_apply(self.forward_single, feats)
+ if self.is_training():
+ return self.loss(*outs, **self.parse_targets(targets))
+ return self.get_bboxes(*outs, **self.parse_targets(targets))
+
+
+def select_single_mlvl(mlvl_tensors, batch_id, detach=True):
+ """Extract a multi-scale single image tensor from a multi-scale batch
+ tensor based on batch index.
+
+ Note: The default value of detach is True, because the proposal gradient
+ needs to be detached during the training of the two-stage model. E.g
+ Cascade Mask R-CNN.
+
+ Args:
+ mlvl_tensors (list[Tensor]): Batch tensor for all scale levels,
+ each is a 4D-tensor.
+ batch_id (int): Batch index.
+ detach (bool): Whether detach gradient. Default True.
+
+ Returns:
+ list[Tensor]: Multi-scale single image tensor.
+ """
+ assert isinstance(mlvl_tensors, (list, tuple))
+ num_levels = len(mlvl_tensors)
+
+ if detach:
+ mlvl_tensor_list = [
+ mlvl_tensors[i][batch_id].detach() for i in range(num_levels)
+ ]
+ else:
+ mlvl_tensor_list = [
+ mlvl_tensors[i][batch_id] for i in range(num_levels)
+ ]
+ return mlvl_tensor_list
+
+def levels_to_images(mlvl_tensor, flatten=False):
+ """Concat multi-level feature maps by image.
+
+ [feature_level0, feature_level1...] -> [feature_image0, feature_image1...]
+ Convert the shape of each element in mlvl_tensor from (N, C, H, W) to
+ (N, H*W , C), then split the element to N elements with shape (H*W, C), and
+ concat elements in same image of all level along first dimension.
+
+ Args:
+ mlvl_tensor (list[jt.Tensor]): list of Tensor which collect from
+ corresponding level. Each element is of shape (N, C, H, W)
+ flatten (bool, optional): if shape of mlvl_tensor is (N, C, H, W)
+ set False, if shape of mlvl_tensor is (N, H, W, C) set True.
+
+ Returns:
+ list[jt.Tensor]: A list that contains N tensors and each tensor is
+ of shape (num_elements, C)
+ """
+ batch_size = mlvl_tensor[0].size(0)
+ batch_list = [[] for _ in range(batch_size)]
+ if flatten:
+ channels = mlvl_tensor[0].size(-1)
+ else:
+ channels = mlvl_tensor[0].size(1)
+ for t in mlvl_tensor:
+ if not flatten:
+ t = t.permute(0, 2, 3, 1)
+ t = t.view(batch_size, -1, channels)
+ for img in range(batch_size):
+ batch_list[img].append(t[img])
+ return [jt.concat(item, 0) for item in batch_list]
+
+class MlvlPointGenerator:
+ """Standard points generator for multi-level (Mlvl) feature maps in 2D
+ points-based detectors.
+
+ Args:
+ strides (list[int] | list[tuple[int, int]]): Strides of anchors
+ in multiple feature levels in order (w, h).
+ offset (float): The offset of points, the value is normalized with
+ corresponding stride. Defaults to 0.5.
+ """
+
+ def __init__(self, strides, offset=0.5):
+ self.strides = [_pair(stride) for stride in strides]
+ self.offset = offset
+
+ @property
+ def num_levels(self):
+ """int: number of feature levels that the generator will be applied"""
+ return len(self.strides)
+
+ @property
+ def num_base_priors(self):
+ """list[int]: The number of priors (points) at a point
+ on the feature grid"""
+ return [1 for _ in range(len(self.strides))]
+
+ def _meshgrid(self, x, y, row_major=True):
+ yy, xx = jt.meshgrid(y, x)
+ if row_major:
+ # warning .flatten() would cause error in ONNX exporting
+ # have to use reshape here
+ return xx.reshape(-1), yy.reshape(-1)
+
+ else:
+ return yy.reshape(-1), xx.reshape(-1)
+
+ def grid_priors(self,
+ featmap_sizes,
+ dtype=jt.float32,
+ with_stride=False):
+ """Generate grid points of multiple feature levels.
+
+ Args:
+ featmap_sizes (list[tuple]): List of feature map sizes in
+ multiple feature levels, each size arrange as
+ as (h, w).
+ dtype (:obj:`dtype`): Dtype of priors. Default: jt.float32.
+ device (str): The device where the anchors will be put on.
+ with_stride (bool): Whether to concatenate the stride to
+ the last dimension of points.
+
+ Return:
+ list[jt.Tensor]: Points of multiple feature levels.
+ The sizes of each tensor should be (N, 2) when with stride is
+ ``False``, where N = width * height, width and height
+ are the sizes of the corresponding feature level,
+ and the last dimension 2 represent (coord_x, coord_y),
+ otherwise the shape should be (N, 4),
+ and the last dimension 4 represent
+ (coord_x, coord_y, stride_w, stride_h).
+ """
+
+ assert self.num_levels == len(featmap_sizes)
+ multi_level_priors = []
+ for i in range(self.num_levels):
+ priors = self.single_level_grid_priors(
+ featmap_sizes[i],
+ level_idx=i,
+ dtype=dtype,
+ with_stride=with_stride)
+ multi_level_priors.append(priors)
+ return multi_level_priors
+
+ def single_level_grid_priors(self,
+ featmap_size,
+ level_idx,
+ dtype=jt.float32,
+ with_stride=False):
+ """Generate grid Points of a single level.
+
+ Note:
+ This function is usually called by method ``self.grid_priors``.
+
+ Args:
+ featmap_size (tuple[int]): Size of the feature maps, arrange as
+ (h, w).
+ level_idx (int): The index of corresponding feature map level.
+ dtype (:obj:`dtype`): Dtype of priors. Default: jt.float32.
+ device (str, optional): The device the tensor will be put on.
+ Defaults to 'cuda'.
+ with_stride (bool): Concatenate the stride to the last dimension
+ of points.
+
+ Return:
+ Tensor: Points of single feature levels.
+ The shape of tensor should be (N, 2) when with stride is
+ ``False``, where N = width * height, width and height
+ are the sizes of the corresponding feature level,
+ and the last dimension 2 represent (coord_x, coord_y),
+ otherwise the shape should be (N, 4),
+ and the last dimension 4 represent
+ (coord_x, coord_y, stride_w, stride_h).
+ """
+ feat_h, feat_w = featmap_size
+ stride_w, stride_h = self.strides[level_idx]
+ shift_x = (jt.arange(0, feat_w) +
+ self.offset) * stride_w
+ # keep featmap_size as Tensor instead of int, so that we
+ # can convert to ONNX correctly
+ shift_x = shift_x.to(dtype)
+
+ shift_y = (jt.arange(0, feat_h) +
+ self.offset) * stride_h
+ # keep featmap_size as Tensor instead of int, so that we
+ # can convert to ONNX correctly
+ shift_y = shift_y.to(dtype)
+ shift_xx, shift_yy = self._meshgrid(shift_x, shift_y)
+ if not with_stride:
+ shifts = jt.stack([shift_xx, shift_yy], dim=-1)
+ else:
+ # use `shape[0]` instead of `len(shift_xx)` for ONNX export
+ stride_w = jt.full((shift_xx.shape[0], ), stride_w, dtype=dtype)
+ stride_h = jt.full((shift_yy.shape[0], ), stride_h, dtype=dtype)
+ shifts = jt.stack([shift_xx, shift_yy, stride_w, stride_h],
+ dim=-1)
+ all_points = shifts
+ return all_points
+
+ def valid_flags(self, featmap_sizes, pad_shape):
+ """Generate valid flags of points of multiple feature levels.
+
+ Args:
+ featmap_sizes (list(tuple)): List of feature map sizes in
+ multiple feature levels, each size arrange as
+ as (h, w).
+ pad_shape (tuple(int)): The padded shape of the image,
+ arrange as (h, w).
+ device (str): The device where the anchors will be put on.
+
+ Return:
+ list(jt.Tensor): Valid flags of points of multiple levels.
+ """
+ assert self.num_levels == len(featmap_sizes)
+ multi_level_flags = []
+ for i in range(self.num_levels):
+ point_stride = self.strides[i]
+ feat_h, feat_w = featmap_sizes[i]
+ h, w = pad_shape[:2]
+ valid_feat_h = min(int(np.ceil(h / point_stride[1])), feat_h)
+ valid_feat_w = min(int(np.ceil(w / point_stride[0])), feat_w)
+ flags = self.single_level_valid_flags((feat_h, feat_w),
+ (valid_feat_h, valid_feat_w))
+ multi_level_flags.append(flags)
+ return multi_level_flags
+
+ def single_level_valid_flags(self,
+ featmap_size,
+ valid_size):
+ """Generate the valid flags of points of a single feature map.
+
+ Args:
+ featmap_size (tuple[int]): The size of feature maps, arrange as
+ as (h, w).
+ valid_size (tuple[int]): The valid size of the feature maps.
+ The size arrange as as (h, w).
+ device (str, optional): The device where the flags will be put on.
+ Defaults to 'cuda'.
+
+ Returns:
+ jt.Tensor: The valid flags of each points in a single level \
+ feature map.
+ """
+ feat_h, feat_w = featmap_size
+ valid_h, valid_w = valid_size
+ assert valid_h <= feat_h and valid_w <= feat_w
+ valid_x = jt.zeros(feat_w, dtype=jt.bool)
+ valid_y = jt.zeros(feat_h, dtype=jt.bool)
+ valid_x[:valid_w] = 1
+ valid_y[:valid_h] = 1
+ valid_xx, valid_yy = self._meshgrid(valid_x, valid_y)
+ valid = valid_xx & valid_yy
+ return valid
+
+ def sparse_priors(self,
+ prior_idxs,
+ featmap_size,
+ level_idx,
+ dtype=jt.float32):
+ """Generate sparse points according to the ``prior_idxs``.
+
+ Args:
+ prior_idxs (Tensor): The index of corresponding anchors
+ in the feature map.
+ featmap_size (tuple[int]): feature map size arrange as (w, h).
+ level_idx (int): The level index of corresponding feature
+ map.
+ dtype (obj:`jt.dtype`): Date type of points. Defaults to
+ ``jt.float32``.
+ device (obj:`jt.device`): The device where the points is
+ located.
+ Returns:
+ Tensor: Anchor with shape (N, 2), N should be equal to
+ the length of ``prior_idxs``. And last dimension
+ 2 represent (coord_x, coord_y).
+ """
+ height, width = featmap_size
+ x = (prior_idxs % width + self.offset) * self.strides[level_idx][0]
+ y = ((prior_idxs // width) % height +
+ self.offset) * self.strides[level_idx][1]
+ prioris = jt.stack([x, y], 1).to(dtype)
+ return prioris
diff --git a/python/jdet/models/roi_heads/rotated_reppoints_head.py b/python/jdet/models/roi_heads/rotated_reppoints_head.py
index 962887f..d5a3c63 100644
--- a/python/jdet/models/roi_heads/rotated_reppoints_head.py
+++ b/python/jdet/models/roi_heads/rotated_reppoints_head.py
@@ -15,6 +15,7 @@
from jittor.nn import _pair
+
import numpy as np
def deleteme(a, b, size = 10):
if a is None and b is None:
@@ -67,7 +68,17 @@ def transpose_to(a, b):
print(type(b))
raise NotImplementedError
-
+def fake_argsort2(x, dim=0, descending=False):
+ x_ = x.data
+ if (descending):
+ x__ = -x_
+ else:
+ x__ = x_
+ index_ = np.argsort(x__, axis=dim, kind="stable")
+ y_ = x_[index_]
+ index = jt.array(index_)
+ y = jt.array(y_)
+ return y, index
@HEADS.register_module()
class RotatedRepPointsHead(nn.Module):
@@ -328,11 +339,13 @@ def _point_target_single(self,
stage='init',
unmap_outputs=True):
"""Single point target function."""
+ print('A: ', flat_proposals.shape, valid_flags.shape, gt_bboxes.shape, gt_labels.shape)
inside_flags = valid_flags
if not inside_flags.any():
return (None, ) * 8
# assign gt and sample proposals
proposals = flat_proposals[inside_flags, :]
+ print(' proposals: ', len(proposals))
if stage == 'init':
assigner = self.init_assigner
@@ -352,16 +365,23 @@ def _point_target_single(self,
# convert gt from obb to poly
gt_bboxes = obb2poly(gt_bboxes)
+
+ # if stage != 'init':
assign_result = assigner.assign(proposals, gt_bboxes,
gt_bboxes_ignore,
None if self.sampling else gt_labels)
+ # assign_result = assigner.assign(proposals, gt_bboxes,
+ # gt_bboxes_ignore, gt_labels)
sampling_result = self.sampler.sample(assign_result, proposals,
gt_bboxes)
+
+
# if stage != 'init':
- # out_list = [sampling_result.pos_inds, sampling_result.neg_inds,
- # sampling_result.pos_gt_bboxes, sampling_result.pos_assigned_gt_inds,
- # assign_result.gt_inds]
+ # out_list = [sampling_result.pos_inds, sampling_result.neg_inds,
+ # sampling_result.pos_gt_bboxes, sampling_result.pos_assigned_gt_inds,
+ # assign_result.gt_inds]
+ # print('sampling_result: ', len(sampling_result.pos_inds), len(sampling_result.neg_inds), len(sampling_result.pos_gt_bboxes), len(sampling_result.pos_assigned_gt_inds), len(assign_result.gt_inds))
# result_list = pickle.load(open("/mnt/disk/flowey/remote/JDet-debug/weights/result_dict.pkl", "rb"))
# deleteme(out_list, result_list)
# exit(0)
@@ -373,10 +393,12 @@ def _point_target_single(self,
labels = jt.full((num_valid_proposals, ),
self.background_label,
dtype=jt.int32)
+ # print('init labels: ', labels)
label_weights = jt.zeros((num_valid_proposals,), dtype=jt.float32)
pos_inds = sampling_result.pos_inds
neg_inds = sampling_result.neg_inds
+ print('pos_inds', len(pos_inds))
if len(pos_inds) > 0:
pos_gt_bboxes = sampling_result.pos_gt_bboxes
bbox_gt[pos_inds, :] = pos_gt_bboxes
@@ -409,6 +431,8 @@ def _point_target_single(self,
inside_flags)
proposals_weights = unmap(proposals_weights, num_total_proposals,
inside_flags)
+
+ print('In _point_target_single: ', len(labels), len((labels > 0).nonzero().view(-1)))
return (labels, label_weights, bbox_gt, pos_proposals,
proposals_weights, pos_inds, neg_inds, sampling_result)
@@ -460,6 +484,7 @@ def get_targets(self,
"""
assert stage in ['init', 'refine']
num_imgs = len(img_metas)
+ print('num_imgs: ', num_imgs)
assert len(proposals_list) == len(valid_flag_list) == num_imgs
# points number of multi levels
@@ -581,9 +606,14 @@ def get_cfa_targets(self,
pos_inds = []
# pos_gt_index = []
for i, single_labels in enumerate(all_labels):
- pos_mask = (0 < single_labels) & (
- single_labels <= self.num_classes) #TODO(514flowey): num_class not include background
+ # pos_mask = (0 < single_labels) & (
+ # single_labels <= self.num_classes) #TODO(514flowey): num_class not include background
+ # print('single_labels: ', single_labels)
+ pos_mask = single_labels > 0
+ # pos_mask = (0 < single_labels) & (
+ # single_labels <= self.num_classes)
pos_inds.append(pos_mask.nonzero().view(-1))
+ print('In get_cfa_targets pos_mask.nonzero().view(-1): ', len(pos_mask.nonzero().view(-1)), len(single_labels))
gt_inds = [item.pos_assigned_gt_inds for item in sampling_result]
@@ -658,8 +688,6 @@ def loss(self,
assert len(featmap_sizes) == self.prior_generator.num_levels
label_channels = self.cls_out_channels if self.use_sigmoid_cls else 1
-
-
# import pickle
# input_dict = pickle.load(open("/mnt/disk/flowey/remote/JDet-debug/weights/input_dict.pkl", "rb"))
# featmap_sizes = input_dict['featmap_sizes']
@@ -748,12 +776,17 @@ def loss(self,
item.reshape(-1, 2 * self.num_points)
for item in pts_coordinate_preds_refine
]
+
+ # labels_init = jt.concat(labels_list, 0).view(-1)
+ # pos_inds_flatten_init = (labels_init > 0).nonzero().reshape(-1)
+
with jt.no_grad():
pos_losses_list, = multi_apply(
self.get_pos_loss, cls_scores,
pts_coordinate_preds_init_cfa, labels_list,
rbbox_gt_list_refine, label_weights_list,
convex_weights_list_refine, pos_inds_list_refine)
+ # print('pos_inds_list_refine: ', len(pos_inds_list_refine[0]))
labels_list, label_weights_list, convex_weights_list_refine, \
num_pos, pos_normalize_term = multi_apply(
self.reassign,
@@ -781,11 +814,22 @@ def loss(self,
convex_weights_refine = jt.concat(convex_weights_list_refine,
0).view(-1)
pos_normalize_term = jt.concat(pos_normalize_term, 0).reshape(-1)
- pos_inds_flatten = ((0 <= labels) &
- (labels < self.num_classes)).nonzero(
- as_tuple=False).reshape(-1)
+ # pos_inds_flatten = ((0 <= labels) &
+ # (labels < self.num_classes)).nonzero(
+ # ).reshape(-1)
+ pos_inds_flatten = (labels > 0).nonzero().reshape(-1)
+
+ # labels = jt.concat(labels_list, 0).view(-1)
+ # pos_normalize_term = jt.concat(pos_normalize_term, 0).reshape(-1)
+ # if len(pos_normalize_term) != len(pos_inds_flatten):
+ # print('len(pos_normalize_term): ', len(pos_normalize_term))
+ # print('len(pos_inds_flatten): ', len(pos_inds_flatten))
+ # print('len(pos_inds_flatten_init): ', len(pos_inds_flatten_init))
+ # exit()
assert len(pos_normalize_term) == len(pos_inds_flatten)
- if num_pos:
+
+
+ if bool(num_pos):
losses_cls = self.loss_cls(
cls_scores, labels, labels_weight, avg_factor=num_pos)
pos_pts_pred_refine = pts_preds_refine[pos_inds_flatten]
@@ -951,10 +995,14 @@ def reassign(self,
- pos_normalize_term (list): pos normalize term for refine \
points losses.
"""
+ # print('label: ', label.shape)
+ print('pos_inds: ', len(pos_inds), len(pos_gt_inds))
+
if len(pos_inds) == 0:
return label, label_weight, convex_weight, 0, jt.array([]).type_as(convex_weight)
num_gt = pos_gt_inds.max() + 1
+
num_proposals_each_level_ = num_proposals_each_level.copy()
num_proposals_each_level_.insert(0, 0)
inds_level_interval = np.cumsum(num_proposals_each_level_)
@@ -971,7 +1019,7 @@ def reassign(self,
pos_inds_after_cfa = []
ignore_inds_after_cfa = []
re_assign_weights_after_cfa = []
- for gt_ind in range(num_gt):
+ for gt_ind in range(int(num_gt)):
pos_inds_cfa = []
pos_loss_cfa = []
pos_overlaps_init_cfa = []
@@ -980,7 +1028,7 @@ def reassign(self,
level_mask = pos_level_mask[level]
level_gt_mask = level_mask & gt_mask
value, topk_inds = pos_losses[level_gt_mask].topk(
- min(level_gt_mask.sum(), self.topk), largest=False)
+ int(min(level_gt_mask.sum(), self.topk)), largest=False)
pos_inds_cfa.append(pos_inds[level_gt_mask][topk_inds])
pos_loss_cfa.append(value)
pos_overlaps_init_cfa.append(
@@ -990,10 +1038,11 @@ def reassign(self,
pos_overlaps_init_cfa = jt.concat(pos_overlaps_init_cfa, 1)
if len(pos_inds_cfa) < 2:
pos_inds_after_cfa.append(pos_inds_cfa)
- ignore_inds_after_cfa.append(jt.empty((0)))
+ ignore_inds_after_cfa.append(jt.empty([]))
re_assign_weights_after_cfa.append(jt.ones([len(pos_inds_cfa)]))
else:
- pos_loss_cfa, sort_inds = pos_loss_cfa.sort()
+ # pos_loss_cfa, sort_inds = pos_loss_cfa.sort()
+ pos_loss_cfa, sort_inds = fake_argsort2(pos_loss_cfa)
pos_inds_cfa = pos_inds_cfa[sort_inds]
pos_overlaps_init_cfa = pos_overlaps_init_cfa[:, sort_inds] \
.reshape(-1, len(pos_inds_cfa))
@@ -1004,8 +1053,9 @@ def reassign(self,
gauss_prob_density = \
(-(pos_loss_cfa - loss_mean) ** 2 / loss_var) \
.exp() / loss_var.sqrt()
- index_inverted, _ = jt.arange(
- len(gauss_prob_density)).sort(descending=True)
+ # index_inverted, _ = jt.arange(
+ # len(gauss_prob_density)).sort(descending=True)
+ index_inverted, _ = fake_argsort2(jt.arange(len(gauss_prob_density)), descending=True)
gauss_prob_inverted = jt.cumsum(
gauss_prob_density[index_inverted], 0)
gauss_prob = gauss_prob_inverted[index_inverted]
@@ -1015,6 +1065,7 @@ def reassign(self,
# splitting by gradient consistency
loss_curve = gauss_prob_norm * pos_loss_cfa
_, max_thr = loss_curve.topk(1)
+ max_thr = int(max_thr)
reweights = gauss_prob_norm[:max_thr + 1]
# feature anti-aliasing coefficient
pos_overlaps_init_cfa = pos_overlaps_init_cfa[:, :max_thr + 1]
@@ -1028,19 +1079,24 @@ def reassign(self,
jt.ones(len(reweights)).type_as(
gauss_prob_norm).sum()
pos_inds_temp = pos_inds_cfa[:max_thr + 1]
- ignore_inds_temp = pos_inds_cfa.new_tensor([])
-
+ # print('pos_inds_cfa: ', pos_inds_cfa)
+ # print('pos_inds_temp: ', pos_inds_temp)
+ ignore_inds_temp = jt.empty([])
+
pos_inds_after_cfa.append(pos_inds_temp)
ignore_inds_after_cfa.append(ignore_inds_temp)
re_assign_weights_after_cfa.append(re_assign_weights)
-
+
pos_inds_after_cfa = jt.concat(pos_inds_after_cfa)
ignore_inds_after_cfa = jt.concat(ignore_inds_after_cfa)
re_assign_weights_after_cfa = jt.concat(re_assign_weights_after_cfa)
-
+ print('pos_inds_after_cfa: ', len(pos_inds), len(pos_inds_after_cfa))
reassign_mask = (pos_inds.unsqueeze(1) != pos_inds_after_cfa).all(1)
+ print('reassign_mask: ', len(reassign_mask))
reassign_ids = pos_inds[reassign_mask]
- label[reassign_ids] = self.num_classes
+ # label[reassign_ids] = self.num_classes
+ # print('reassign_ids: ', reassign_ids)
+ label[reassign_ids] = 0
label_weight[ignore_inds_after_cfa] = 0
convex_weight[reassign_ids] = 0
num_pos = len(pos_inds_after_cfa)
@@ -1060,11 +1116,21 @@ def reassign(self,
0).type_as(label)
pos_normalize_term = pos_level_mask_after_cfa * (
self.point_base_scale *
- jt.as_tensor(self.point_strides).type_as(label)).reshape(-1, 1)
+ jt.array(self.point_strides).type_as(label)).reshape(-1, 1)
pos_normalize_term = pos_normalize_term[
pos_normalize_term > 0].type_as(convex_weight)
assert len(pos_normalize_term) == len(pos_inds_after_cfa)
+ # label = jt.concat(label, 0).view(-1)
+ pos_inds_flatten = (label > 0).nonzero().reshape(-1)
+ # print('len(pos_normalize_term): ', len(pos_normalize_term))
+ # print('len(pos_inds_flatten): ', len(pos_inds_flatten))
+ # pos_normalize_term = jt.concat(pos_normalize_term, 0).reshape(-1)
+ # if len(pos_normalize_term) != len(pos_inds_flatten):
+ # print('len(pos_normalize_term): ', len(pos_normalize_term))
+ # print('len(pos_inds_flatten): ', len(pos_inds_flatten))
+ # exit()
+
return label, label_weight, convex_weight, num_pos, pos_normalize_term
def get_bboxes(self,
diff --git a/python/jdet/models/roi_heads/sam_reppoints_head.py b/python/jdet/models/roi_heads/sam_reppoints_head.py
new file mode 100644
index 0000000..5398e01
--- /dev/null
+++ b/python/jdet/models/roi_heads/sam_reppoints_head.py
@@ -0,0 +1,1308 @@
+from jdet.utils.registry import HEADS, LOSSES, BOXES, build_from_cfg
+from jdet.models.utils.modules import ConvModule
+from jdet.ops.dcn_v1 import DeformConv
+from jdet.ops.bbox_transforms import obb2poly, poly2obb
+from jdet.utils.general import multi_apply, unmap
+from jdet.ops.nms_rotated import multiclass_nms_rotated
+from jdet.models.boxes.anchor_target import images_to_levels
+from jdet.ops.reppoints_min_area_bbox import reppoints_min_area_bbox
+from jdet.models.boxes.box_ops import rotated_box_to_poly
+
+
+import jittor as jt
+from jittor import nn
+import numpy as np
+from jittor.nn import _pair
+
+
+def get_num_level_anchors_inside(num_level_anchors, inside_flags):
+ """Get number of every level anchors inside.
+
+ Args:
+ num_level_anchors (List[int]): List of number of every level's anchors.
+ inside_flags (jt.Tensor): Flags of all anchors.
+
+ Returns:
+ List[int]: List of number of inside anchors.
+ """
+ split_inside_flags = jt.split(inside_flags, num_level_anchors)
+ num_level_anchors_inside = [
+ int(flags.sum()) for flags in split_inside_flags
+ ]
+ return num_level_anchors_inside
+
+def points_center_pts(RPoints, y_first=True):
+ """Compute center point of Pointsets.
+
+ Args:
+ RPoints (jt.Tensor): the lists of Pointsets, shape (k, 18).
+ y_first (bool, optional): if True, the sequence of Pointsets is (y,x).
+
+ Returns:
+ center_pts (jt.Tensor): the mean_center coordination of Pointsets,
+ shape (k, 18).
+ """
+ RPoints = RPoints.reshape(-1, 9, 2)
+
+ if y_first:
+ pts_dy = RPoints[:, :, 0::2]
+ pts_dx = RPoints[:, :, 1::2]
+ else:
+ pts_dx = RPoints[:, :, 0::2]
+ pts_dy = RPoints[:, :, 1::2]
+ pts_dy_mean = pts_dy.mean(dim=1).reshape(-1, 1)
+ pts_dx_mean = pts_dx.mean(dim=1).reshape(-1, 1)
+ center_pts = jt.concat([pts_dx_mean, pts_dy_mean], dim=1).reshape(-1, 2)
+ return center_pts
+
+def to_oc(boxes):
+ x, y, w, h, t = boxes.unbind(dim=-1)
+ start_angle = -0.5 * np.pi
+ t = ((t - start_angle) % np.pi)
+ w_ = jt.where(t < np.pi / 2, w, h)
+ h_ = jt.where(t < np.pi / 2, h, w)
+ t = jt.where(t < np.pi / 2, t, t - np.pi / 2) + start_angle
+ return jt.stack([x, y, w_, h_, t], dim=-1)
+
+def deleteme(a, b, size = 10):
+ if a is None and b is None:
+ return
+ if isinstance(a, dict) and isinstance(b, dict):
+ print('-' * size)
+ for a1, b1 in zip(a.values(), b.values()):
+ deleteme(a1, b1, size + 10)
+ print('-' * size)
+ elif isinstance(a, (list, tuple)) and isinstance(b, (list, tuple)):
+ print('-' * size)
+ for a1, b1 in zip(a, b):
+ deleteme(a1, b1, size + 10)
+ print('-' * size)
+ elif isinstance(a, jt.Var) and isinstance(b, np.ndarray):
+ print((a - b).abs().max().item())
+ elif isinstance(a, np.ndarray) and isinstance(b, np.ndarray):
+ print(np.max(np.abs(a - b)))
+ elif isinstance(a, (int, float)) and isinstance(b, (int, float)):
+ print("number diff:", a - b)
+ else:
+ print(type(a))
+ print(type(b))
+ raise NotImplementedError
+def transpose_to(a, b):
+ if a is None:
+ return None
+ if isinstance(a, list) and isinstance(b, list):
+ rlist = []
+ for a1, b1 in zip(a, b):
+ rlist.append(transpose_to(a1, b1))
+ return rlist
+ elif isinstance(a, dict) and isinstance(b, dict):
+ rdict = []
+ for k in b.keys():
+ rdict[k] = transpose_to(a[k], b[k])
+ return rdict
+ elif isinstance(a, np.ndarray) and isinstance(b, jt.Var):
+ return jt.array(a)
+ elif isinstance(a, np.ndarray) and isinstance(b, np.ndarray):
+ return a
+ elif isinstance(a, tuple) and isinstance(b, tuple):
+ rlist = [transpose_to(a1, b1) for a1, b1 in zip(a, b)]
+ return tuple(rlist)
+ elif isinstance(a, (int, float, str)) and isinstance(b, (int, float, str)):
+ assert(type(a) == type(b))
+ return a
+ else:
+ print(type(a))
+ print(type(b))
+ raise NotImplementedError
+
+def fake_argsort2(x, dim=0, descending=False):
+ x_ = x.data
+ if (descending):
+ x__ = -x_
+ else:
+ x__ = x_
+ index_ = np.argsort(x__, axis=dim, kind="stable")
+ y_ = x_[index_]
+ index = jt.array(index_)
+ y = jt.array(y_)
+ return y, index
+
+@HEADS.register_module()
+class SAMRepPointsHead(nn.Module):
+ """Rotated RepPoints head for SASM.
+
+ Args:
+ num_classes (int): Number of classes.
+ in_channels (int): Number of input channels.
+ feat_channels (int): Number of feature channels.
+ point_feat_channels (int, optional): Number of channels of points
+ features.
+ stacked_convs (int, optional): Number of stacked convolutions.
+ num_points (int, optional): Number of points in points set.
+ gradient_mul (float, optional): The multiplier to gradients from
+ points refinement and recognition.
+ point_strides (Iterable, optional): points strides.
+ point_base_scale (int, optional): Bbox scale for assigning labels.
+ conv_bias (str, optional): The bias of convolution.
+ loss_cls (dict, optional): Config of classification loss.
+ loss_bbox_init (dict, optional): Config of initial points loss.
+ loss_bbox_refine (dict, optional): Config of points loss in refinement.
+ conv_cfg (dict, optional): The config of convolution.
+ norm_cfg (dict, optional): The config of normlization.
+ train_cfg (dict, optional): The config of train.
+ test_cfg (dict, optional): The config of test.
+ center_init (bool, optional): Whether to use center point assignment.
+ transform_method (str, optional): The methods to transform RepPoints
+ to bbox.
+ topk (int, optional): Number of the highest topk points. Defaults to 9.
+ anti_factor (float, optional): Feature anti-aliasing coefficient.
+ """
+
+ def __init__(self,
+ num_classes,
+ in_channels,
+ feat_channels,
+ point_feat_channels=256,
+ stacked_convs=3,
+ num_points=9,
+ gradient_mul=0.1,
+ point_strides=[8, 16, 32, 64, 128],
+ point_base_scale=4,
+ background_label=0,
+ conv_bias='auto',
+ loss_cls=dict(
+ type='FocalLoss',
+ use_sigmoid=True,
+ gamma=2.0,
+ alpha=0.25,
+ loss_weight=1.0),
+ loss_bbox_init=dict(
+ type='SmoothL1Loss', beta=1.0 / 9.0, loss_weight=0.5),
+ loss_bbox_refine=dict(
+ type='SmoothL1Loss', beta=1.0 / 9.0, loss_weight=1.0),
+ conv_cfg=None,
+ norm_cfg=None,
+ train_cfg=None,
+ test_cfg=None,
+ center_init=True,
+ transform_method='rotrect',
+ topk=6,
+ anti_factor=0.75,
+ **kwargs):
+ self.num_points = num_points
+ self.point_feat_channels = point_feat_channels
+ self.center_init = center_init
+
+ # we use deform conv to extract points features
+ self.dcn_kernel = int(np.sqrt(num_points))
+ self.dcn_pad = int((self.dcn_kernel - 1) / 2)
+ assert self.dcn_kernel * self.dcn_kernel == num_points, \
+ 'The points number should be a square number.'
+ assert self.dcn_kernel % 2 == 1, \
+ 'The points number should be an odd square number.'
+ dcn_base = np.arange(-self.dcn_pad, self.dcn_pad + 1).astype(np.float64)
+ dcn_base_y = np.repeat(dcn_base, self.dcn_kernel)
+ dcn_base_x = np.tile(dcn_base, self.dcn_kernel)
+ dcn_base_offset = np.stack([dcn_base_y, dcn_base_x], axis=1).reshape((-1))
+ self.dcn_base_offset = jt.array(dcn_base_offset).view(1, -1, 1, 1)
+ self.num_classes = num_classes
+ self.in_channels = in_channels
+ self.feat_channels = feat_channels
+ self.stacked_convs = stacked_convs
+ assert conv_bias == 'auto' or isinstance(conv_bias, bool)
+ self.conv_bias = conv_bias
+ self.loss_cls = build_from_cfg(loss_cls, LOSSES)
+ self.train_cfg = train_cfg
+ self.test_cfg = test_cfg
+ self.conv_cfg = conv_cfg
+ self.norm_cfg = norm_cfg
+ self.gradient_mul = gradient_mul
+ self.point_base_scale = point_base_scale
+ self.point_strides = point_strides
+ self.prior_generator = MlvlPointGenerator(self.point_strides, offset=0.)
+ self.num_base_priors = self.prior_generator.num_base_priors[0]
+ self.sampling = loss_cls['type'] not in ['FocalLoss']
+ if self.train_cfg:
+ self.init_assigner = build_from_cfg(self.train_cfg.init.assigner, BOXES)
+ self.refine_assigner = build_from_cfg(self.train_cfg.refine.assigner, BOXES)
+ # use PseudoSampler when sampling is False
+ if self.sampling and hasattr(self.train_cfg, 'sampler'):
+ sampler_cfg = self.train_cfg.sampler
+ else:
+ sampler_cfg = dict(type='PseudoSampler')
+ self.sampler = build_from_cfg(sampler_cfg, BOXES)
+ self.transform_method = transform_method
+ self.use_sigmoid_cls = loss_cls.get('use_sigmoid', False)
+ if self.use_sigmoid_cls:
+ self.cls_out_channels = self.num_classes
+ else:
+ self.cls_out_channels = self.num_classes + 1
+ self.loss_bbox_init = build_from_cfg(loss_bbox_init, LOSSES)
+ self.loss_bbox_refine = build_from_cfg(loss_bbox_refine, LOSSES)
+ self.topk = topk
+ self.anti_factor = anti_factor
+ self.background_label = background_label
+ self._init_layers()
+
+ def _init_layers(self):
+ """Initialize layers of the head."""
+ self.relu = nn.ReLU()
+ self.cls_convs = nn.ModuleList()
+ self.reg_convs = nn.ModuleList()
+ for i in range(self.stacked_convs):
+ chn = self.in_channels if i == 0 else self.feat_channels
+ self.cls_convs.append(
+ ConvModule(
+ chn,
+ self.feat_channels,
+ 3,
+ stride=1,
+ padding=1,
+ conv_cfg=self.conv_cfg,
+ norm_cfg=self.norm_cfg,
+ bias=self.conv_bias))
+ self.reg_convs.append(
+ ConvModule(
+ chn,
+ self.feat_channels,
+ 3,
+ stride=1,
+ padding=1,
+ conv_cfg=self.conv_cfg,
+ norm_cfg=self.norm_cfg,
+ bias=self.conv_bias))
+ pts_out_dim = 2 * self.num_points
+ self.reppoints_cls_conv = DeformConv(self.feat_channels,
+ self.point_feat_channels,
+ self.dcn_kernel, 1,
+ self.dcn_pad)
+ self.reppoints_cls_out = nn.Conv2d(self.point_feat_channels,
+ self.cls_out_channels, 1, 1, 0)
+ self.reppoints_pts_init_conv = nn.Conv2d(self.feat_channels,
+ self.point_feat_channels, 3,
+ 1, 1)
+ self.reppoints_pts_init_out = nn.Conv2d(self.point_feat_channels,
+ pts_out_dim, 1, 1, 0)
+ self.reppoints_pts_refine_conv = DeformConv(self.feat_channels,
+ self.point_feat_channels,
+ self.dcn_kernel, 1,
+ self.dcn_pad)
+ self.reppoints_pts_refine_out = nn.Conv2d(self.point_feat_channels,
+ pts_out_dim, 1, 1, 0)
+
+ def points2rotrect(self, pts, y_first=True):
+ """Convert points to oriented bboxes."""
+ if y_first:
+ pts = pts.reshape(-1, self.num_points, 2)
+ pts_dy = pts[:, :, 0::2]
+ pts_dx = pts[:, :, 1::2]
+ pts = jt.concat([pts_dx, pts_dy],
+ dim=2).reshape(-1, 2 * self.num_points)
+ if self.transform_method == 'rotrect':
+ rotrect_pred = reppoints_min_area_bbox(pts)
+ return rotrect_pred
+ else:
+ raise NotImplementedError
+
+ def forward_single(self, x):
+ """Forward feature map of a single FPN level."""
+ dcn_base_offset = self.dcn_base_offset.type_as(x)
+ points_init = 0
+ cls_feat = x
+ pts_feat = x
+ for cls_conv in self.cls_convs:
+ cls_feat = cls_conv(cls_feat)
+ for reg_conv in self.reg_convs:
+ pts_feat = reg_conv(pts_feat)
+ # initialize reppoints
+ pts_out_init = self.reppoints_pts_init_out(
+ self.relu(self.reppoints_pts_init_conv(pts_feat)))
+ pts_out_init = pts_out_init + points_init
+ # refine and classify reppoints
+ pts_out_init_grad_mul = (1 - self.gradient_mul) * pts_out_init.detach() + self.gradient_mul * pts_out_init
+ dcn_offset = pts_out_init_grad_mul - dcn_base_offset
+ cls_out = self.reppoints_cls_out(
+ self.relu(self.reppoints_cls_conv(cls_feat, dcn_offset)))
+ pts_out_refine = self.reppoints_pts_refine_out(
+ self.relu(self.reppoints_pts_refine_conv(pts_feat, dcn_offset)))
+ pts_out_refine = pts_out_refine + pts_out_init.detach()
+
+ return cls_out, pts_out_init, pts_out_refine
+
+ def get_points(self, featmap_sizes, img_metas):
+ """Get points according to feature map sizes.
+
+ Args:
+ featmap_sizes (list[tuple]): Multi-level feature map sizes.
+ img_metas (list[dict]): Image meta info.
+
+ Returns:
+ tuple: points of each image, valid flags of each image
+ """
+ num_imgs = len(img_metas)
+
+ multi_level_points = self.prior_generator.grid_priors(
+ featmap_sizes, with_stride=True)
+ points_list = [[point.clone() for point in multi_level_points]
+ for _ in range(num_imgs)]
+
+ valid_flag_list = []
+ for img_id, img_meta in enumerate(img_metas):
+ multi_level_flags = self.prior_generator.valid_flags(
+ featmap_sizes, img_meta['pad_shape'])
+ valid_flag_list.append(multi_level_flags)
+
+ return points_list, valid_flag_list
+
+ def offset_to_pts(self, center_list, pred_list):
+ """Change from point offset to point coordinate."""
+ pts_list = []
+ for i_lvl, _ in enumerate(self.point_strides):
+ pts_lvl = []
+ for i_img, _ in enumerate(center_list):
+ pts_center = center_list[i_img][i_lvl][:, :2].repeat(1, self.num_points)
+ pts_shift = pred_list[i_lvl][i_img]
+ yx_pts_shift = pts_shift.permute(1, 2, 0).view(-1, 2 * self.num_points)
+ y_pts_shift = yx_pts_shift[..., 0::2]
+ x_pts_shift = yx_pts_shift[..., 1::2]
+ xy_pts_shift = jt.stack([x_pts_shift, y_pts_shift], -1)
+ xy_pts_shift = xy_pts_shift.view(*yx_pts_shift.shape[:-1], -1)
+ pts = xy_pts_shift * self.point_strides[i_lvl] + pts_center
+ pts_lvl.append(pts)
+ pts_lvl = jt.stack(pts_lvl, 0)
+ pts_list.append(pts_lvl)
+ return pts_list
+
+ def _point_target_single(self,
+ flat_proposals,
+ num_level_proposals,
+ valid_flags,
+ gt_bboxes,
+ gt_bboxes_ignore,
+ gt_labels,
+ overlaps,
+ stage='init',
+ unmap_outputs=True):
+ """Single point target function."""
+ # print('A: ', flat_proposals.shape, valid_flags.shape, gt_bboxes.shape, gt_labels.shape)
+ inside_flags = valid_flags
+ if not inside_flags.any():
+ return (None, ) * 9
+ # assign gt and sample proposals
+ proposals = flat_proposals[inside_flags, :]
+
+ num_level_anchors_inside = get_num_level_anchors_inside(
+ num_level_proposals, inside_flags)
+
+ # convert gt from obb to poly
+ gt_bboxes = obb2poly(gt_bboxes)
+
+ if stage == 'init':
+ assigner = self.init_assigner
+ pos_weight = self.train_cfg.init.pos_weight
+ assign_result = assigner.assign(proposals, gt_bboxes,
+ gt_bboxes_ignore,
+ None if self.sampling else gt_labels)
+ else:
+ assigner = self.refine_assigner
+ pos_weight = self.train_cfg.refine.pos_weight
+ if self.train_cfg.refine.assigner.type not in (
+ 'ATSSAssigner', 'ATSSConvexAssigner', 'SASAssigner'):
+ assign_result = assigner.assign(
+ proposals, gt_bboxes, overlaps, gt_bboxes_ignore,
+ None if self.sampling else gt_labels)
+ else:
+ assign_result = assigner.assign(
+ proposals, num_level_anchors_inside, gt_bboxes,
+ gt_bboxes_ignore, None if self.sampling else gt_labels)
+
+ sampling_result = self.sampler.sample(assign_result, proposals,
+ gt_bboxes)
+ # if stage != 'init':
+ # import pickle
+ # input_dict = pickle.load(open("/mnt/disk/flowey/remote/JDet-debug/weights/input_dict.pkl", "rb"))
+ # proposals = transpose_to(input_dict['proposals'], proposals)
+ # gt_bboxes = transpose_to(input_dict['gt_bboxes'], gt_bboxes)
+ # gt_labels = transpose_to(input_dict['gt_labels'], gt_labels)
+ # gt_bboxes_ignore = transpose_to(input_dict['gt_bboxes_ignore'], gt_bboxes_ignore)
+
+
+
+
+ # if stage != 'init':
+ # assign_result = assigner.assign(proposals, gt_bboxes,
+ # gt_bboxes_ignore,
+ # None if self.sampling else gt_labels)
+ # # assign_result = assigner.assign(proposals, gt_bboxes,
+ # # gt_bboxes_ignore, gt_labels)
+ # sampling_result = self.sampler.sample(assign_result, proposals,
+ # gt_bboxes)
+
+
+
+ # if stage != 'init' and len(sampling_result.pos_inds) == 21824:
+ # import pickle
+ # input_list = [flat_proposals, valid_flags, proposals, gt_bboxes, gt_labels]
+ # print('input_list: ', proposals[0],proposals[10000], proposals[20000])
+ # print('input_list: ', gt_bboxes[0])
+ # result_list = pickle.dump(input_list, open("/home/zytx121/jittor/JDet-refactor_r3det/input_dict.pkl", 'wb'))
+
+ # out_list = [sampling_result.pos_inds, sampling_result.neg_inds,
+ # sampling_result.pos_gt_bboxes, sampling_result.pos_assigned_gt_inds,
+ # assign_result.gt_inds]
+ # result_list = pickle.dump(out_list, open("/home/zytx121/jittor/JDet-refactor_r3det/result_dict.pkl", 'wb'))
+ # # print('out_list: ', out_list)
+
+ # exit(0)
+
+ gt_inds = assign_result.gt_inds
+ num_valid_proposals = proposals.shape[0]
+ bbox_gt = jt.zeros([num_valid_proposals, 8], dtype=proposals.dtype)
+ pos_proposals = jt.zeros_like(proposals)
+ proposals_weights = jt.zeros(num_valid_proposals, dtype=proposals.dtype)
+ labels = jt.full((num_valid_proposals, ),
+ self.background_label,
+ dtype=jt.int32)
+ # print('init labels: ', labels)
+ label_weights = jt.zeros((num_valid_proposals,), dtype=jt.float32)
+
+ pos_inds = sampling_result.pos_inds
+ neg_inds = sampling_result.neg_inds
+ # print('pos_inds', len(pos_inds))
+ if len(pos_inds) > 0:
+ pos_gt_bboxes = sampling_result.pos_gt_bboxes
+ bbox_gt[pos_inds, :] = pos_gt_bboxes
+ pos_proposals[pos_inds, :] = proposals[pos_inds, :]
+ proposals_weights[pos_inds] = 1.0
+ if gt_labels is None:
+ # Only rpn gives gt_labels as None
+ # Foreground is the first class
+ # TODO(514flowey): first class is 1
+ # labels[pos_inds] = 0
+ labels[pos_inds] = 1
+ else:
+ labels[pos_inds] = gt_labels[
+ sampling_result.pos_assigned_gt_inds]
+ if pos_weight <= 0:
+ label_weights[pos_inds] = 1.0
+ else:
+ label_weights[pos_inds] = pos_weight
+ if len(neg_inds) > 0:
+ label_weights[neg_inds] = 1.0
+
+ # use la
+ rbboxes_center, width, height, angles = jt.split(
+ to_oc(poly2obb(bbox_gt)), [2, 1, 1, 1], dim=-1)
+
+ if stage == 'init':
+ points_xy = pos_proposals[:, :2]
+ else:
+ points_xy = points_center_pts(pos_proposals, y_first=True)
+
+ distances = jt.zeros_like(angles).reshape(-1)
+ # print('angles: ', angles.max(), angles.min())
+ # angles_index_wh = ((width != 0) & (angles >= 0) &
+ # (angles <= 1.57)).squeeze()
+ # angles_index_hw = ((width != 0) & ((angles < 0) |
+ # (angles > 1.57))).squeeze()
+
+ # print('angles_index_hw: ', bool(angles_index_hw.sum()), bool(angles_index_wh.sum()))
+
+ # 01_la:compution of distance
+ distances = jt.sqrt(
+ (jt.pow(
+ rbboxes_center[:, 0] -
+ points_xy[:, 0], 2) /
+ width.squeeze()) +
+ (jt.pow(
+ rbboxes_center[:, 1] -
+ points_xy[:, 1], 2) /
+ height.squeeze()))
+ # if bool(angles_index_wh.sum()):
+ # distances[angles_index_wh] = jt.sqrt(
+ # (jt.pow(
+ # rbboxes_center[angles_index_wh, 0] -
+ # points_xy[angles_index_wh, 0], 2) /
+ # width[angles_index_wh].squeeze()) +
+ # (jt.pow(
+ # rbboxes_center[angles_index_wh, 1] -
+ # points_xy[angles_index_wh, 1], 2) /
+ # height[angles_index_wh].squeeze()))
+ # if bool(angles_index_hw.sum()):
+ # distances[angles_index_hw] = jt.sqrt(
+ # (jt.pow(
+ # rbboxes_center[angles_index_hw, 0] -
+ # points_xy[angles_index_hw, 0], 2) /
+ # height[angles_index_hw].squeeze()) +
+ # (jt.pow(
+ # rbboxes_center[angles_index_hw, 1] -
+ # points_xy[angles_index_hw, 1], 2) /
+ # width[angles_index_hw].squeeze()))
+ distances[distances == float('nan')] = 0.
+
+ sam_weights = label_weights * (jt.exp(1 / (distances + 1)))
+ sam_weights[sam_weights == float('inf')] = 0.
+
+ # map up to original set of proposals
+ if unmap_outputs:
+ num_total_proposals = flat_proposals.size(0)
+ labels = unmap(labels, num_total_proposals, inside_flags)
+ label_weights = unmap(label_weights, num_total_proposals,
+ inside_flags)
+ bbox_gt = unmap(bbox_gt, num_total_proposals, inside_flags)
+ pos_proposals = unmap(pos_proposals, num_total_proposals,
+ inside_flags)
+ proposals_weights = unmap(proposals_weights, num_total_proposals,
+ inside_flags)
+ gt_inds = unmap(gt_inds, num_total_proposals, inside_flags)
+ sam_weights = unmap(sam_weights, num_total_proposals, inside_flags)
+
+ # print('In _point_target_single: ', len(labels), len((labels > 0).nonzero().view(-1)))
+
+ return (labels, label_weights, bbox_gt, pos_proposals,
+ proposals_weights, pos_inds, neg_inds, gt_inds, sam_weights)
+
+ def get_targets(self,
+ proposals_list,
+ valid_flag_list,
+ gt_bboxes_list,
+ img_metas,
+ gt_bboxes_ignore_list=None,
+ gt_labels_list=None,
+ stage='init',
+ label_channels=1,
+ unmap_outputs=True):
+ """Compute corresponding GT box and classification targets for
+ proposals.
+
+ Args:
+ proposals_list (list[list]): Multi level points/bboxes of each
+ image.
+ valid_flag_list (list[list]): Multi level valid flags of each
+ image.
+ gt_bboxes_list (list[Tensor]): Ground truth bboxes of each image.
+ img_metas (list[dict]): Meta info of each image.
+ gt_bboxes_ignore_list (list[Tensor]): Ground truth bboxes to be
+ ignored.
+ gt_bboxes_list (list[Tensor]): Ground truth labels of each box.
+ stage (str): `init` or `refine`. Generate target for init stage or
+ refine stage
+ label_channels (int): Channel of label.
+ unmap_outputs (bool): Whether to map outputs back to the original
+ set of anchors.
+
+ Returns:
+ tuple (list[Tensor]):
+
+ - labels_list (list[Tensor]): Labels of each level.
+ - label_weights_list (list[Tensor]): Label weights of each \
+ level.
+ - bbox_gt_list (list[Tensor]): Ground truth bbox of each level.
+ - proposal_list (list[Tensor]): Proposals(points/bboxes) of \
+ each level.
+ - proposal_weights_list (list[Tensor]): Proposal weights of \
+ each level.
+ - num_total_pos (int): Number of positive samples in all \
+ images.
+ - num_total_neg (int): Number of negative samples in all \
+ images.
+ """
+ assert stage in ['init', 'refine']
+ num_imgs = len(img_metas)
+ # print('num_imgs: ', num_imgs)
+ assert len(proposals_list) == len(valid_flag_list) == num_imgs
+
+ # points number of multi levels
+ num_level_proposals = [points.size(0) for points in proposals_list[0]]
+ num_level_proposals_list = [num_level_proposals] * num_imgs
+
+ # concat all level points and flags to a single tensor
+ for i in range(num_imgs):
+ assert len(proposals_list[i]) == len(valid_flag_list[i])
+ proposals_list[i] = jt.concat(proposals_list[i])
+ valid_flag_list[i] = jt.concat(valid_flag_list[i])
+
+ # compute targets for each image
+ if gt_bboxes_ignore_list is None:
+ gt_bboxes_ignore_list = [None for _ in range(num_imgs)]
+ if gt_labels_list is None:
+ gt_labels_list = [None for _ in range(num_imgs)]
+ len_gt_labels = len(gt_bboxes_list)
+ all_overlaps_rotate_list = [None] * len_gt_labels
+ (all_labels, all_label_weights, all_bbox_gt, all_proposals,
+ all_proposal_weights, pos_inds_list, neg_inds_list, all_gt_inds_list,
+ all_sam_init_weights) = multi_apply(
+ self._point_target_single,
+ proposals_list,
+ num_level_proposals_list,
+ valid_flag_list,
+ gt_bboxes_list,
+ gt_bboxes_ignore_list,
+ gt_labels_list,
+ all_overlaps_rotate_list,
+ stage=stage,
+ unmap_outputs=unmap_outputs)
+ # no valid points
+ if any([labels is None for labels in all_labels]):
+ return None
+ # sampled points of all images
+ num_total_pos = sum([max(inds.numel(), 1) for inds in pos_inds_list])
+ num_total_neg = sum([max(inds.numel(), 1) for inds in neg_inds_list])
+ labels_list = images_to_levels(all_labels, num_level_proposals)
+ label_weights_list = images_to_levels(all_label_weights,
+ num_level_proposals)
+ bbox_gt_list = images_to_levels(all_bbox_gt, num_level_proposals)
+ proposals_list = images_to_levels(all_proposals, num_level_proposals)
+ proposal_weights_list = images_to_levels(all_proposal_weights,
+ num_level_proposals)
+ gt_inds_list = images_to_levels(all_gt_inds_list, num_level_proposals)
+ sam_init_weights_list = images_to_levels(all_sam_init_weights,
+ num_level_proposals)
+
+ return (labels_list, label_weights_list, bbox_gt_list, proposals_list,
+ proposal_weights_list, num_total_pos, num_total_neg,
+ gt_inds_list, sam_init_weights_list)
+
+ def loss_single(self, cls_score, pts_pred_init, pts_pred_refine, labels,
+ label_weights, rbbox_gt_init, convex_weights_init,
+ sam_weights_init, rbbox_gt_refine, convex_weights_refine,
+ sam_weights_refine, stride, num_total_samples_refine):
+ """Single loss function."""
+ normalize_term = self.point_base_scale * stride
+
+ rbbox_gt_init = rbbox_gt_init.reshape(-1, 8)
+ convex_weights_init = convex_weights_init.reshape(-1)
+ sam_weights_init = sam_weights_init.reshape(-1)
+ # init points loss
+ pts_pred_init = pts_pred_init.reshape(-1, 2 * self.num_points)
+ pos_ind_init = (convex_weights_init > 0).nonzero().reshape(-1)
+ pts_pred_init_norm = pts_pred_init[pos_ind_init]
+ rbbox_gt_init_norm = rbbox_gt_init[pos_ind_init]
+ convex_weights_pos_init = convex_weights_init[pos_ind_init]
+ sam_weights_pos_init = sam_weights_init[pos_ind_init]
+ loss_pts_init = self.loss_bbox_init(
+ pts_pred_init_norm / normalize_term,
+ rbbox_gt_init_norm / normalize_term,
+ convex_weights_pos_init * sam_weights_pos_init)
+ # refine points loss
+ rbbox_gt_refine = rbbox_gt_refine.reshape(-1, 8)
+ pts_pred_refine = pts_pred_refine.reshape(-1, 2 * self.num_points)
+ convex_weights_refine = convex_weights_refine.reshape(-1)
+ sam_weights_refine = sam_weights_refine.reshape(-1)
+ pos_ind_refine = (convex_weights_refine > 0).nonzero().reshape(-1)
+ pts_pred_refine_norm = pts_pred_refine[pos_ind_refine]
+ rbbox_gt_refine_norm = rbbox_gt_refine[pos_ind_refine]
+ # print('pts_pred_refine_norm: ', pts_pred_refine_norm.shape)
+ # print('rbbox_gt_refine_norm: ', rbbox_gt_refine_norm.shape)
+ convex_weights_pos_refine = convex_weights_refine[pos_ind_refine]
+ sam_weights_pos_refine = sam_weights_refine[pos_ind_refine]
+ loss_pts_refine = self.loss_bbox_refine(
+ pts_pred_refine_norm / normalize_term,
+ rbbox_gt_refine_norm / normalize_term,
+ convex_weights_pos_refine * sam_weights_pos_refine)
+ # classification loss
+ labels = labels.reshape(-1)
+ label_weights = label_weights.reshape(-1)
+ cls_score = cls_score.permute(0, 2, 3,
+ 1).reshape(-1, self.cls_out_channels)
+ loss_cls = self.loss_cls(
+ cls_score,
+ labels,
+ label_weights * sam_weights_refine,
+ avg_factor=num_total_samples_refine)
+ return loss_cls, loss_pts_init, loss_pts_refine
+
+ def loss(self,
+ cls_scores,
+ pts_preds_init,
+ pts_preds_refine,
+ gt_bboxes,
+ gt_labels,
+ img_metas,
+ gt_bboxes_ignore=None):
+ """Loss function of SAM RepPoints head."""
+
+ featmap_sizes = [featmap.size()[-2:] for featmap in cls_scores]
+ assert len(featmap_sizes) == self.prior_generator.num_levels
+ label_channels = self.cls_out_channels if self.use_sigmoid_cls else 1
+
+ # import pickle
+ # input_dict = pickle.load(open("/mnt/disk/flowey/remote/JDet-debug/weights/input_dict.pkl", "rb"))
+ # featmap_sizes = input_dict['featmap_sizes']
+ # img_metas = input_dict['img_metas']
+ # gt_bboxes = transpose_to(input_dict['gt_bboxes'], gt_bboxes)
+ # gt_labels = transpose_to(input_dict['gt_labels'], gt_labels)
+ # gt_bboxes_ignore = transpose_to(input_dict['gt_bboxes_ignore'], gt_bboxes_ignore)
+ # cls_scores = transpose_to(input_dict['cls_scores'], cls_scores)
+ # pts_preds_init = transpose_to(input_dict['pts_preds_init'], pts_preds_init)
+ # pts_preds_refine = transpose_to(input_dict['pts_preds_refine'], pts_preds_refine)
+
+
+ # target for initial stage
+ center_list, valid_flag_list = self.get_points(
+ featmap_sizes, img_metas)
+ pts_coordinate_preds_init = self.offset_to_pts(center_list,
+ pts_preds_init)
+
+ if self.train_cfg.init.assigner['type'] == 'ConvexAssigner':
+ candidate_list = center_list
+ else:
+ raise NotImplementedError
+ cls_reg_targets_init = self.get_targets(
+ candidate_list,
+ valid_flag_list,
+ gt_bboxes,
+ img_metas,
+ gt_bboxes_ignore_list=gt_bboxes_ignore,
+ gt_labels_list=gt_labels,
+ stage='init',
+ label_channels=label_channels)
+ (*_, rbbox_gt_list_init, candidate_list_init, convex_weights_list_init,
+ num_total_pos_init, num_total_neg_init, gt_inds_init,
+ sam_weights_list_init) = cls_reg_targets_init
+
+ # target for refinement stage
+ center_list, valid_flag_list = self.get_points(
+ featmap_sizes, img_metas)
+ pts_coordinate_preds_refine = self.offset_to_pts(
+ center_list, pts_preds_refine)
+ points_list = []
+ for i_img, center in enumerate(center_list):
+ points = []
+ for i_lvl in range(len(pts_preds_refine)):
+ points_preds_init_ = pts_preds_init[i_lvl].detach()
+ points_preds_init_ = points_preds_init_.view(
+ points_preds_init_.shape[0], -1,
+ *points_preds_init_.shape[2:])
+ points_shift = points_preds_init_.permute(
+ 0, 2, 3, 1) * self.point_strides[i_lvl]
+ points_center = center[i_lvl][:, :2].repeat(1, self.num_points)
+ points.append(
+ points_center +
+ points_shift[i_img].reshape(-1, 2 * self.num_points))
+ points_list.append(points)
+
+ cls_reg_targets_refine = self.get_targets(
+ points_list,
+ valid_flag_list,
+ gt_bboxes,
+ img_metas,
+ gt_bboxes_ignore_list=gt_bboxes_ignore,
+ gt_labels_list=gt_labels,
+ stage='refine',
+ label_channels=label_channels)
+ (labels_list, label_weights_list, rbbox_gt_list_refine,
+ candidate_list_refine, convex_weights_list_refine,
+ num_total_pos_refine, num_total_neg_refine, gt_inds_refine,
+ sam_weights_list_refine) = cls_reg_targets_refine
+ num_total_samples_refine = (
+ num_total_pos_refine +
+ num_total_neg_refine if self.sampling else num_total_pos_refine)
+
+ losses_cls, losses_pts_init, losses_pts_refine = multi_apply(
+ self.loss_single,
+ cls_scores,
+ pts_coordinate_preds_init,
+ pts_coordinate_preds_refine,
+ labels_list,
+ label_weights_list,
+ rbbox_gt_list_init,
+ convex_weights_list_init,
+ sam_weights_list_init,
+ rbbox_gt_list_refine,
+ convex_weights_list_refine,
+ sam_weights_list_refine,
+ self.point_strides,
+ num_total_samples_refine=num_total_samples_refine)
+ loss_dict_all = {
+ 'loss_cls': losses_cls,
+ 'loss_pts_init': losses_pts_init,
+ 'loss_pts_refine': losses_pts_refine
+ }
+ return loss_dict_all
+
+ def get_pos_loss(self, cls_score, pts_pred, label, bbox_gt, label_weight,
+ convex_weight, pos_inds):
+ """Calculate loss of all potential positive samples obtained from first
+ match process.
+
+ Args:
+ cls_score (Tensor): Box scores of single image with shape
+ (num_anchors, num_classes)
+ pts_pred (Tensor): Box energies / deltas of single image
+ with shape (num_anchors, 4)
+ label (Tensor): classification target of each anchor with
+ shape (num_anchors,)
+ bbox_gt (Tensor): Ground truth box.
+ label_weight (Tensor): Classification loss weight of each
+ anchor with shape (num_anchors).
+ convex_weight (Tensor): Bbox weight of each anchor with shape
+ (num_anchors, 4).
+ pos_inds (Tensor): Index of all positive samples got from
+ first assign process.
+
+ Returns:
+ Tensor: Losses of all positive samples in single image.
+ """
+ if pos_inds.size(0) == 0:
+ pos_loss = jt.zeros((0))
+ return pos_loss,
+ pos_scores = cls_score[pos_inds]
+ pos_pts_pred = pts_pred[pos_inds]
+ pos_bbox_gt = bbox_gt[pos_inds]
+ pos_label = label[pos_inds]
+ pos_label_weight = label_weight[pos_inds]
+ pos_convex_weight = convex_weight[pos_inds]
+ loss_cls = self.loss_cls(
+ pos_scores,
+ pos_label,
+ pos_label_weight,
+ avg_factor=self.loss_cls.loss_weight,
+ reduction_override='none')
+ loss_bbox = self.loss_bbox_refine(
+ pos_pts_pred,
+ pos_bbox_gt,
+ pos_convex_weight,
+ avg_factor=self.loss_cls.loss_weight,
+ reduction_override='none')
+ loss_cls = loss_cls.sum(-1)
+ pos_loss = loss_bbox + loss_cls
+ return pos_loss,
+
+ def get_bboxes(self,
+ cls_scores,
+ pts_preds_init,
+ pts_preds_refine,
+ img_metas,
+ cfg=None,
+ rescale=False,
+ with_nms=True,
+ **kwargs):
+ """Transform network outputs of a batch into bbox results.
+
+ Args:
+ cls_scores (list[Tensor]): Classification scores for all
+ scale levels, each is a 4D-tensor, has shape
+ (batch_size, num_priors * num_classes, H, W).
+ pts_preds_init (list[Tensor]): Box energies / deltas for all
+ scale levels, each is a 18D-tensor, has shape
+ (batch_size, num_points * 2, H, W).
+ pts_preds_refine (list[Tensor]): Box energies / deltas for all
+ scale levels, each is a 18D-tensor, has shape
+ (batch_size, num_points * 2, H, W).
+ img_metas (list[dict], Optional): Image meta info. Default None.
+ cfg (mmcv.Config, Optional): Test / postprocessing configuration,
+ if None, test_cfg would be used. Default None.
+ rescale (bool): If True, return boxes in original image space.
+ Default False.
+ with_nms (bool): If True, do nms before return boxes.
+ Default True.
+
+ Returns:
+ list[list[Tensor, Tensor]]: Each item in result_list is 2-tuple.
+ The first item is an (n, 6) tensor, where the first 4 columns
+ are bounding box positions (cx, cy, w, h, a) and the
+ 6-th column is a score between 0 and 1. The second item is a
+ (n,) tensor where each item is the predicted class label of
+ the corresponding box.
+ """
+ assert len(cls_scores) == len(pts_preds_refine)
+
+ num_levels = len(cls_scores)
+
+ featmap_sizes = [cls_scores[i].shape[-2:] for i in range(num_levels)]
+ mlvl_priors = self.prior_generator.grid_priors(
+ featmap_sizes,
+ dtype=cls_scores[0].dtype)
+
+ result_list = []
+
+ for img_id, _ in enumerate(img_metas):
+ img_meta = img_metas[img_id]
+ cls_score_list = select_single_mlvl(cls_scores, img_id)
+ point_pred_list = select_single_mlvl(pts_preds_refine, img_id)
+
+ results = self._get_bboxes_single(cls_score_list, point_pred_list,
+ mlvl_priors, img_meta, cfg,
+ rescale, with_nms, **kwargs)
+ result_list.append(results)
+
+ return result_list
+
+ def _get_bboxes_single(self,
+ cls_score_list,
+ point_pred_list,
+ mlvl_priors,
+ img_meta,
+ cfg,
+ rescale=False,
+ with_nms=True,
+ **kwargs):
+ """Transform outputs of a single image into bbox predictions.
+ Args:
+ cls_score_list (list[Tensor]): Box scores from all scale
+ levels of a single image, each item has shape
+ (num_priors * num_classes, H, W).
+ bbox_pred_list (list[Tensor]): Box energies / deltas from
+ all scale levels of a single image, each item has shape
+ (num_priors * 4, H, W).
+ score_factor_list (list[Tensor]): Score factor from all scale
+ levels of a single image. RepPoints head does not need
+ this value.
+ mlvl_priors (list[Tensor]): Each element in the list is
+ the priors of a single level in feature pyramid, has shape
+ (num_priors, 2).
+ img_meta (dict): Image meta info.
+ cfg (mmcv.Config): Test / postprocessing configuration,
+ if None, test_cfg would be used.
+ rescale (bool): If True, return boxes in original image space.
+ Default: False.
+ with_nms (bool): If True, do nms before return boxes.
+ Default: True.
+ Returns:
+ tuple[Tensor]: Results of detected bboxes and labels. If with_nms
+ is False and mlvl_score_factor is None, return mlvl_bboxes and
+ mlvl_scores, else return mlvl_bboxes, mlvl_scores and
+ mlvl_score_factor. Usually with_nms is False is used for aug
+ test. If with_nms is True, then return the following format
+
+ - det_bboxes (Tensor): Predicted bboxes with shape \
+ [num_bboxes, 5], where the first 4 columns are bounding \
+ box positions (cx, cy, w, h, a) and the 5-th \
+ column are scores between 0 and 1.
+ - det_labels (Tensor): Predicted labels of the corresponding \
+ box with shape [num_bboxes].
+ """
+
+ cfg = self.test_cfg if cfg is None else cfg
+ assert len(cls_score_list) == len(point_pred_list)
+ scale_factor = img_meta['scale_factor']
+
+ mlvl_bboxes = []
+ mlvl_scores = []
+ for level_idx, (cls_score, points_pred, points) in enumerate(
+ zip(cls_score_list, point_pred_list, mlvl_priors)):
+ assert cls_score.size()[-2:] == points_pred.size()[-2:]
+
+ cls_score = cls_score.permute(1, 2,
+ 0).reshape(-1, self.cls_out_channels)
+ if self.use_sigmoid_cls:
+ scores = cls_score.sigmoid()
+ else:
+ scores = cls_score.softmax(-1)[:, :-1]
+
+ points_pred = points_pred.permute(1, 2, 0).reshape(
+ -1, 2 * self.num_points)
+ nms_pre = cfg.get('nms_pre', -1)
+ if 0 < nms_pre < scores.shape[0]:
+ if self.use_sigmoid_cls:
+ max_scores = scores.max(dim=1)
+ else:
+ max_scores = scores[:, 1:].max(dim=1)
+ _, topk_inds = max_scores.topk(nms_pre)
+ points = points[topk_inds, :]
+ points_pred = points_pred[topk_inds, :]
+ scores = scores[topk_inds, :]
+
+ poly_pred = self.points2rotrect(points_pred, y_first=True)
+ bbox_pos_center = points[:, :2].repeat(1, 4)
+ polys = poly_pred * self.point_strides[level_idx] + bbox_pos_center
+ bboxes = poly2obb(polys)
+
+ mlvl_bboxes.append(bboxes)
+ mlvl_scores.append(scores)
+
+ mlvl_bboxes = jt.concat(mlvl_bboxes)
+
+ if rescale:
+ mlvl_bboxes[..., :4] /= mlvl_bboxes[..., :4].new_tensor(
+ scale_factor)
+ mlvl_scores = jt.concat(mlvl_scores)
+ if self.use_sigmoid_cls:
+ padding = jt.zeros((mlvl_scores.shape[0], 1), dtype=mlvl_scores.dtype)
+ mlvl_scores = jt.concat([padding, mlvl_scores], dim=1)
+
+ if with_nms:
+ det_bboxes, det_labels = multiclass_nms_rotated(
+ mlvl_bboxes, mlvl_scores, cfg.score_thr, cfg.nms,
+ cfg.max_per_img)
+ boxes = det_bboxes[:, :5]
+ scores = det_bboxes[:, 5]
+ polys = rotated_box_to_poly(boxes)
+ return polys, scores, det_labels
+ else:
+ raise NotImplementedError
+
+ def parse_targets(self, targets):
+ img_metas = []
+ gt_bboxes = []
+ gt_bboxes_ignore = []
+ gt_labels = []
+
+ for target in targets:
+ if self.is_training():
+ gt_bboxes.append(target["rboxes"])
+ gt_labels.append(target["labels"])
+ gt_bboxes_ignore.append(target["rboxes_ignore"])
+ img_metas.append(dict(
+ img_shape=target["img_size"][::-1],
+ scale_factor=target["scale_factor"],
+ pad_shape = target["pad_shape"]
+ ))
+ if not self.is_training():
+ return dict(img_metas = img_metas)
+ return dict(
+ gt_bboxes = gt_bboxes,
+ gt_labels = gt_labels,
+ img_metas = img_metas,
+ gt_bboxes_ignore = gt_bboxes_ignore,
+ )
+
+ def execute(self, feats, targets):
+ outs = multi_apply(self.forward_single, feats)
+ if self.is_training():
+ return self.loss(*outs, **self.parse_targets(targets))
+ return self.get_bboxes(*outs, **self.parse_targets(targets))
+
+
+def select_single_mlvl(mlvl_tensors, batch_id, detach=True):
+ """Extract a multi-scale single image tensor from a multi-scale batch
+ tensor based on batch index.
+
+ Note: The default value of detach is True, because the proposal gradient
+ needs to be detached during the training of the two-stage model. E.g
+ Cascade Mask R-CNN.
+
+ Args:
+ mlvl_tensors (list[Tensor]): Batch tensor for all scale levels,
+ each is a 4D-tensor.
+ batch_id (int): Batch index.
+ detach (bool): Whether detach gradient. Default True.
+
+ Returns:
+ list[Tensor]: Multi-scale single image tensor.
+ """
+ assert isinstance(mlvl_tensors, (list, tuple))
+ num_levels = len(mlvl_tensors)
+
+ if detach:
+ mlvl_tensor_list = [
+ mlvl_tensors[i][batch_id].detach() for i in range(num_levels)
+ ]
+ else:
+ mlvl_tensor_list = [
+ mlvl_tensors[i][batch_id] for i in range(num_levels)
+ ]
+ return mlvl_tensor_list
+
+class MlvlPointGenerator:
+ """Standard points generator for multi-level (Mlvl) feature maps in 2D
+ points-based detectors.
+
+ Args:
+ strides (list[int] | list[tuple[int, int]]): Strides of anchors
+ in multiple feature levels in order (w, h).
+ offset (float): The offset of points, the value is normalized with
+ corresponding stride. Defaults to 0.5.
+ """
+
+ def __init__(self, strides, offset=0.5):
+ self.strides = [_pair(stride) for stride in strides]
+ self.offset = offset
+
+ @property
+ def num_levels(self):
+ """int: number of feature levels that the generator will be applied"""
+ return len(self.strides)
+
+ @property
+ def num_base_priors(self):
+ """list[int]: The number of priors (points) at a point
+ on the feature grid"""
+ return [1 for _ in range(len(self.strides))]
+
+ def _meshgrid(self, x, y, row_major=True):
+ yy, xx = jt.meshgrid(y, x)
+ if row_major:
+ # warning .flatten() would cause error in ONNX exporting
+ # have to use reshape here
+ return xx.reshape(-1), yy.reshape(-1)
+
+ else:
+ return yy.reshape(-1), xx.reshape(-1)
+
+ def grid_priors(self,
+ featmap_sizes,
+ dtype=jt.float32,
+ with_stride=False):
+ """Generate grid points of multiple feature levels.
+
+ Args:
+ featmap_sizes (list[tuple]): List of feature map sizes in
+ multiple feature levels, each size arrange as
+ as (h, w).
+ dtype (:obj:`dtype`): Dtype of priors. Default: jt.float32.
+ device (str): The device where the anchors will be put on.
+ with_stride (bool): Whether to concatenate the stride to
+ the last dimension of points.
+
+ Return:
+ list[jt.Tensor]: Points of multiple feature levels.
+ The sizes of each tensor should be (N, 2) when with stride is
+ ``False``, where N = width * height, width and height
+ are the sizes of the corresponding feature level,
+ and the last dimension 2 represent (coord_x, coord_y),
+ otherwise the shape should be (N, 4),
+ and the last dimension 4 represent
+ (coord_x, coord_y, stride_w, stride_h).
+ """
+
+ assert self.num_levels == len(featmap_sizes)
+ multi_level_priors = []
+ for i in range(self.num_levels):
+ priors = self.single_level_grid_priors(
+ featmap_sizes[i],
+ level_idx=i,
+ dtype=dtype,
+ with_stride=with_stride)
+ multi_level_priors.append(priors)
+ return multi_level_priors
+
+ def single_level_grid_priors(self,
+ featmap_size,
+ level_idx,
+ dtype=jt.float32,
+ with_stride=False):
+ """Generate grid Points of a single level.
+
+ Note:
+ This function is usually called by method ``self.grid_priors``.
+
+ Args:
+ featmap_size (tuple[int]): Size of the feature maps, arrange as
+ (h, w).
+ level_idx (int): The index of corresponding feature map level.
+ dtype (:obj:`dtype`): Dtype of priors. Default: jt.float32.
+ device (str, optional): The device the tensor will be put on.
+ Defaults to 'cuda'.
+ with_stride (bool): Concatenate the stride to the last dimension
+ of points.
+
+ Return:
+ Tensor: Points of single feature levels.
+ The shape of tensor should be (N, 2) when with stride is
+ ``False``, where N = width * height, width and height
+ are the sizes of the corresponding feature level,
+ and the last dimension 2 represent (coord_x, coord_y),
+ otherwise the shape should be (N, 4),
+ and the last dimension 4 represent
+ (coord_x, coord_y, stride_w, stride_h).
+ """
+ feat_h, feat_w = featmap_size
+ stride_w, stride_h = self.strides[level_idx]
+ shift_x = (jt.arange(0, feat_w) +
+ self.offset) * stride_w
+ # keep featmap_size as Tensor instead of int, so that we
+ # can convert to ONNX correctly
+ shift_x = shift_x.to(dtype)
+
+ shift_y = (jt.arange(0, feat_h) +
+ self.offset) * stride_h
+ # keep featmap_size as Tensor instead of int, so that we
+ # can convert to ONNX correctly
+ shift_y = shift_y.to(dtype)
+ shift_xx, shift_yy = self._meshgrid(shift_x, shift_y)
+ if not with_stride:
+ shifts = jt.stack([shift_xx, shift_yy], dim=-1)
+ else:
+ # use `shape[0]` instead of `len(shift_xx)` for ONNX export
+ stride_w = jt.full((shift_xx.shape[0], ), stride_w, dtype=dtype)
+ stride_h = jt.full((shift_yy.shape[0], ), stride_h, dtype=dtype)
+ shifts = jt.stack([shift_xx, shift_yy, stride_w, stride_h],
+ dim=-1)
+ all_points = shifts
+ return all_points
+
+ def valid_flags(self, featmap_sizes, pad_shape):
+ """Generate valid flags of points of multiple feature levels.
+
+ Args:
+ featmap_sizes (list(tuple)): List of feature map sizes in
+ multiple feature levels, each size arrange as
+ as (h, w).
+ pad_shape (tuple(int)): The padded shape of the image,
+ arrange as (h, w).
+ device (str): The device where the anchors will be put on.
+
+ Return:
+ list(jt.Tensor): Valid flags of points of multiple levels.
+ """
+ assert self.num_levels == len(featmap_sizes)
+ multi_level_flags = []
+ for i in range(self.num_levels):
+ point_stride = self.strides[i]
+ feat_h, feat_w = featmap_sizes[i]
+ h, w = pad_shape[:2]
+ valid_feat_h = min(int(np.ceil(h / point_stride[1])), feat_h)
+ valid_feat_w = min(int(np.ceil(w / point_stride[0])), feat_w)
+ flags = self.single_level_valid_flags((feat_h, feat_w),
+ (valid_feat_h, valid_feat_w))
+ multi_level_flags.append(flags)
+ return multi_level_flags
+
+ def single_level_valid_flags(self,
+ featmap_size,
+ valid_size):
+ """Generate the valid flags of points of a single feature map.
+
+ Args:
+ featmap_size (tuple[int]): The size of feature maps, arrange as
+ as (h, w).
+ valid_size (tuple[int]): The valid size of the feature maps.
+ The size arrange as as (h, w).
+ device (str, optional): The device where the flags will be put on.
+ Defaults to 'cuda'.
+
+ Returns:
+ jt.Tensor: The valid flags of each points in a single level \
+ feature map.
+ """
+ feat_h, feat_w = featmap_size
+ valid_h, valid_w = valid_size
+ assert valid_h <= feat_h and valid_w <= feat_w
+ valid_x = jt.zeros(feat_w, dtype=jt.bool)
+ valid_y = jt.zeros(feat_h, dtype=jt.bool)
+ valid_x[:valid_w] = 1
+ valid_y[:valid_h] = 1
+ valid_xx, valid_yy = self._meshgrid(valid_x, valid_y)
+ valid = valid_xx & valid_yy
+ return valid
+
+ def sparse_priors(self,
+ prior_idxs,
+ featmap_size,
+ level_idx,
+ dtype=jt.float32):
+ """Generate sparse points according to the ``prior_idxs``.
+
+ Args:
+ prior_idxs (Tensor): The index of corresponding anchors
+ in the feature map.
+ featmap_size (tuple[int]): feature map size arrange as (w, h).
+ level_idx (int): The level index of corresponding feature
+ map.
+ dtype (obj:`jt.dtype`): Date type of points. Defaults to
+ ``jt.float32``.
+ device (obj:`jt.device`): The device where the points is
+ located.
+ Returns:
+ Tensor: Anchor with shape (N, 2), N should be equal to
+ the length of ``prior_idxs``. And last dimension
+ 2 represent (coord_x, coord_y).
+ """
+ height, width = featmap_size
+ x = (prior_idxs % width + self.offset) * self.strides[level_idx][0]
+ y = ((prior_idxs // width) % height +
+ self.offset) * self.strides[level_idx][1]
+ prioris = jt.stack([x, y], 1).to(dtype)
+ return prioris
diff --git a/python/jdet/ops/chamfer_distance.py b/python/jdet/ops/chamfer_distance.py
new file mode 100644
index 0000000..55475de
--- /dev/null
+++ b/python/jdet/ops/chamfer_distance.py
@@ -0,0 +1,240 @@
+# Modified from
+# https://github.com/chrdiller/pyTorchChamferDistance/blob/master/chamfer_distance/chamfer_distance.cu
+
+import jittor as jt
+from jittor import Function
+import numpy as np
+
+HEADER=r"""
+#define MAX_SHARED_SCALAR_T 6144
+#define THREADS_PER_BLOCK 1024
+
+inline int GET_BLOCKS(const int N) {
+ int optimal_block_num = (N + THREADS_PER_BLOCK - 1) / THREADS_PER_BLOCK;
+ int max_block_num = 65000;
+ return std::min(optimal_block_num, max_block_num);
+}
+
+template
+__global__ void chamfer_distance_forward_cuda_kernel(int b, int n,
+ const scalar_t* xyz, int m,
+ const scalar_t* xyz2,
+ scalar_t* result,
+ int* result_i) {
+ __shared__ scalar_t buf[MAX_SHARED_SCALAR_T];
+ for (int i = blockIdx.x; i < b; i += gridDim.x) {
+ for (int k2 = 0; k2 < m; k2 += THREADS_PER_BLOCK) {
+ int end_k = min(m, k2 + THREADS_PER_BLOCK) - k2;
+ for (int j = threadIdx.x; j < end_k * 2; j += blockDim.x) {
+ buf[j] = xyz2[(i * m + k2) * 2 + j];
+ }
+ __syncthreads();
+ for (int j = threadIdx.x; j < n; j += blockDim.x * gridDim.y) {
+ scalar_t x1 = xyz[(i * n + j) * 2 + 0];
+ scalar_t y1 = xyz[(i * n + j) * 2 + 1];
+ int best_i = 0;
+ scalar_t best = 1e10;
+ int end_ka = end_k & (~2);
+ if (end_ka == THREADS_PER_BLOCK) {
+ for (int k = 0; k < THREADS_PER_BLOCK; k += 4) {
+#pragma unroll
+ for (int j = 0; j < 4; ++j) {
+ scalar_t x2 = buf[(k + j) * 2] - x1;
+ scalar_t y2 = buf[(k + j) * 2 + 1] - y1;
+ scalar_t d = x2 * x2 + y2 * y2;
+ if (d < best) {
+ best = d;
+ best_i = k + k2 + j;
+ }
+ }
+ }
+ } else {
+ for (int k = 0; k < end_ka; k += 4) {
+#pragma unroll
+ for (int j = 0; j < 4; ++j) {
+ scalar_t x2 = buf[(k + j) * 2] - x1;
+ scalar_t y2 = buf[(k + j) * 2 + 1] - y1;
+ scalar_t d = x2 * x2 + y2 * y2;
+ if (d < best) {
+ best = d;
+ best_i = k + k2 + j;
+ }
+ }
+ }
+ }
+ for (int k = end_ka; k < end_k; k++) {
+ scalar_t x2 = buf[k * 2 + 0] - x1;
+ scalar_t y2 = buf[k * 2 + 1] - y1;
+ scalar_t d = x2 * x2 + y2 * y2;
+ if (k == 0 || d < best) {
+ best = d;
+ best_i = k + k2;
+ }
+ }
+ if (k2 == 0 || result[(i * n + j)] > best) {
+ result[(i * n + j)] = best;
+ result_i[(i * n + j)] = best_i;
+ }
+ }
+ __syncthreads();
+ }
+ }
+}
+
+template
+__global__ void chamfer_distance_backward_cuda_kernel(
+ int b, int n, const scalar_t* xyz1, int m, const scalar_t* xyz2,
+ const scalar_t* grad_dist1, const int* idx1, scalar_t* grad_xyz1,
+ scalar_t* grad_xyz2) {
+ for (int i = blockIdx.x; i < b; i += gridDim.x) {
+ for (int j = threadIdx.x; j < n; j += blockDim.x * gridDim.y) {
+ scalar_t x1 = xyz1[(i * n + j) * 2 + 0];
+ scalar_t y1 = xyz1[(i * n + j) * 2 + 1];
+ int j2 = idx1[i * n + j];
+ scalar_t x2 = xyz2[(i * m + j2) * 2 + 0];
+ scalar_t y2 = xyz2[(i * m + j2) * 2 + 1];
+ scalar_t g = grad_dist1[i * n + j] * 2;
+ atomicAdd(&(grad_xyz1[(i * n + j) * 2 + 0]), g * (x1 - x2));
+ atomicAdd(&(grad_xyz1[(i * n + j) * 2 + 1]), g * (y1 - y2));
+ atomicAdd(&(grad_xyz2[(i * m + j2) * 2 + 0]), -(g * (x1 - x2)));
+ atomicAdd(&(grad_xyz2[(i * m + j2) * 2 + 1]), -(g * (y1 - y2)));
+ }
+ }
+}
+"""
+
+def chamfer_distance_forward(xyz1, xyz2, dist1, dist2, idx1, idx2):
+ src = f"""
+ const int batch_size = {xyz1.size(0)};
+ const int n = {xyz1.size(1)};
+ const int m = {xyz2.size(1)};
+ chamfer_distance_forward_cuda_kernel<<>>(
+ batch_size, n, in0_p, m,
+ in1_p, out0_p, out2_p);
+ chamfer_distance_forward_cuda_kernel<<>>(
+ batch_size, m, in1_p, n,
+ in0_p, out1_p, out3_p);
+ """
+ return jt.code(
+ outputs=[dist1, dist2, idx1, idx2],
+ inputs=[xyz1, xyz2],
+ cuda_header=HEADER,
+ cuda_src=src)
+
+def chamfer_distance_backward(xyz1, xyz2, idx1, idx2, grad_dist1, grad_dist2, grad_xyz1, grad_xyz2):
+ src=f"""
+ const int batch_size = {xyz1.size(0)};
+ const int n = {xyz1.size(1)};
+ const int m = {xyz2.size(1)};
+ chamfer_distance_backward_cuda_kernel<<>>(
+ batch_size, m, in0_p, n,
+ in1_p, in4_p,
+ in2_p, out0_p,
+ out1_p);
+
+ chamfer_distance_backward_cuda_kernel<<>>(
+ batch_size, n, in1_p, m,
+ in0_p, in5_p,
+ in3_p, out1_p,
+ out0_p);
+ """
+ return jt.code(
+ outputs=[grad_xyz1, grad_xyz2],
+ inputs=[xyz1, xyz2, idx1, idx2, grad_dist1, grad_dist2],
+ cuda_header=HEADER,
+ cuda_src=src)
+
+
+class ChamferDistanceFunction(Function):
+ """This is an implementation of the 2D Chamfer Distance.
+ It has been used in the paper `Oriented RepPoints for Aerial Object
+ Detection (CVPR 2022) _`.
+ """
+
+ def execute(self, xyz1, xyz2):
+ """
+ Args:
+ xyz1 (Tensor): Point set with shape (B, N, 2).
+ xyz2 (Tensor): Point set with shape (B, N, 2).
+ Returns:
+ Sequence[Tensor]:
+ - dist1 (Tensor): Chamfer distance (xyz1 to xyz2) with
+ shape (B, N).
+ - dist2 (Tensor): Chamfer distance (xyz2 to xyz1) with
+ shape (B, N).
+ - idx1 (Tensor): Index of chamfer distance (xyz1 to xyz2)
+ with shape (B, N), which be used in compute gradient.
+ - idx2 (Tensor): Index of chamfer distance (xyz2 to xyz2)
+ with shape (B, N), which be used in compute gradient.
+ """
+ batch_size, n, _ = xyz1.size()
+ _, m, _ = xyz2.size()
+ xyz1 = xyz1.contiguous()
+ xyz2 = xyz2.contiguous()
+
+ dist1 = jt.zeros(batch_size, n)
+ dist2 = jt.zeros(batch_size, m)
+ idx1 = jt.zeros((batch_size, n), dtype=jt.int32)
+ idx2 = jt.zeros((batch_size, m), dtype=jt.int32)
+
+ chamfer_distance_forward(xyz1, xyz2, dist1, dist2, idx1, idx2)
+ self.save_for_backward = xyz1, xyz2, idx1, idx2
+ return dist1, dist2, idx1, idx2
+
+ def grad(self,
+ grad_dist1,
+ grad_dist2,
+ grad_idx1=None,
+ grad_idx2=None):
+ """
+ Args:
+ grad_dist1 (Tensor): Gradient of chamfer distance
+ (xyz1 to xyz2) with shape (B, N).
+ grad_dist2 (Tensor): Gradient of chamfer distance
+ (xyz2 to xyz1) with shape (B, N).
+ Returns:
+ Tuple[Tensor, Tensor]:
+ - grad_xyz1 (Tensor): Gradient of the point set with shape \
+ (B, N, 2).
+ - grad_xyz2 (Tensor):Gradient of the point set with shape \
+ (B, N, 2).
+ """
+ xyz1, xyz2, idx1, idx2 = self.save_for_backward
+ grad_dist1 = grad_dist1.contiguous()
+ grad_dist2 = grad_dist2.contiguous()
+ grad_xyz1 = jt.zeros(xyz1.size())
+ grad_xyz2 = jt.zeros(xyz2.size())
+
+ chamfer_distance_backward(xyz1, xyz2, idx1, idx2,
+ grad_dist1, grad_dist2, grad_xyz1,
+ grad_xyz2)
+ return grad_xyz1, grad_xyz2
+
+
+chamfer_distance = ChamferDistanceFunction.apply
+
+def test():
+ pointset1 = jt.array(
+ [[[1.3, 9.39], [2.3, 9.39], [2.3, 10.39], [1.3, 10.39]],
+ [[1.0, 9.39], [3.0, 9.39], [3.0, 10.39], [1.0, 10.39]],
+ [[1.6, 9.99], [2.3, 9.99], [2.3, 10.39], [1.6, 10.39]]])
+
+ pointset2 = jt.array(
+ [[[1.0, 9.39], [3.0, 9.39], [3.0, 10.39], [1.0, 10.39]],
+ [[1.3, 9.39], [2.3, 9.39], [2.3, 10.39], [1.3, 10.39]],
+ [[1.0, 9.39], [3.0, 9.39], [3.0, 10.39], [1.0, 10.39]]])
+
+ expected_dist1 = jt.array(
+ [[0.0900, 0.4900, 0.4900, 0.0900], [0.0900, 0.4900, 0.4900, 0.0900],
+ [0.5200, 0.6500, 0.4900, 0.3600]])
+ expected_dist2 = jt.array(
+ [[0.0900, 0.4900, 0.4900, 0.0900], [0.0900, 0.4900, 0.4900, 0.0900],
+ [0.7200, 0.8500, 0.4900, 0.3600]])
+
+ dist1, dist2, idx1, idx2 = chamfer_distance(pointset1, pointset2)
+ np.testing.assert_allclose(dist1.numpy(),expected_dist1.numpy(),rtol=1e-2)
+ np.testing.assert_allclose(dist2.numpy(),expected_dist2.numpy(),rtol=1e-2)
+
+if __name__ == "__main__":
+ jt.flags.use_cuda=1
+ test()