diff --git a/README.md b/README.md index 2f3738c..643d01f 100644 --- a/README.md +++ b/README.md @@ -118,6 +118,9 @@ python run_net.py --config-file=configs/base.py --task=test | RSDet-R50-FPN | DOTA1.0|1024/200|Flip|-| SGD | 1x | 68.41 | [arxiv](https://arxiv.org/abs/1911.08299) | [config](configs/rotated_retinanet/rsdet_obb_r50_fpn_1x_dota_lmr5p.py) | [model](https://cloud.tsinghua.edu.cn/f/642e200f5a8a420eb726/?dl=1) | | ATSS-R50-FPN|DOTA1.0|1024/200| flip|-| SGD | 1x | 72.44 | [arxiv](https://arxiv.org/abs/1912.02424) | [config](configs/rotated_retinanet/rotated_retinanet_obb_r50_fpn_1x_dota_atss.py) | [model](https://cloud.tsinghua.edu.cn/f/5168189dcd364eaebce5/?dl=1) | | Reppoints-R50-FPN|DOTA1.0|1024/200| flip|-| SGD | 1x | 56.34 | [arxiv](https://arxiv.org/abs/1904.11490) | [config](configs/rotated_retinanet/rotated_retinanet_obb_r50_fpn_1x_dota_atss.py) | [model](https://cloud.tsinghua.edu.cn/f/be359ac932c84f9c839e/?dl=1) | +| CFA-R50-FPN|DOTA1.0|1024/200| flip|-| SGD | 1x | | [arxiv]() | [config]() | [model]() | +| Oriented-Reppoints-R50-FPN|DOTA1.0|1024/200| flip|-| SGD | 1x | | [arxiv]() | [config]() | [model]() | +| SASM-R50-FPN|DOTA1.0|1024/200| flip|-| SGD | 1x | | [arxiv]() | [config]() | [model]() | **Notice**: @@ -153,9 +156,11 @@ python run_net.py --config-file=configs/base.py --task=test - :heavy_check_mark: Reppoints - :heavy_check_mark: RSDet - :heavy_check_mark: ATSS +- :heavy_check_mark: CFA +- :heavy_check_mark: Oriented Reppoints +- :heavy_check_mark: SASM - :clock3: R3Det - :clock3: Cascade R-CNN -- :clock3: Oriented Reppoints - :heavy_plus_sign: DCL - :heavy_plus_sign: Double Head OBB - :heavy_plus_sign: Guided Anchoring diff --git a/configs/cfa_r50_fpn_1x_dota.py b/configs/cfa_r50_fpn_1x_dota.py new file mode 100644 index 0000000..4f341b4 --- /dev/null +++ b/configs/cfa_r50_fpn_1x_dota.py @@ -0,0 +1,161 @@ +# model settings +model = dict( + type='RotatedRetinaNet', + backbone=dict( + type='Resnet50', + frozen_stages=1, + return_stages=["layer1","layer2","layer3","layer4"], + pretrained= True), + neck=dict( + type='FPN', + in_channels=[256, 512, 1024, 2048], + out_channels=256, + start_level=1, + add_extra_convs="on_input", + norm_cfg = dict(type='GN', num_groups=32), + num_outs=5), + bbox_head=dict( + type='RotatedRepPointsHead', + num_classes=15, + in_channels=256, + feat_channels=256, + point_feat_channels=256, + stacked_convs=3, + num_points=9, + gradient_mul=0.3, + point_strides=[8, 16, 32, 64, 128], + point_base_scale=2, + norm_cfg = dict(type='GN', num_groups=32), + loss_cls=dict( + type='FocalLoss', + use_sigmoid=True, + gamma=2.0, + alpha=0.25, + loss_weight=1.0), + loss_bbox_init=dict(type='ConvexGIoULoss', loss_weight=0.375), + loss_bbox_refine=dict(type='ConvexGIoULoss', loss_weight=1.0), + transform_method='rotrect', + use_reassign=True, + topk=6, + anti_factor=0.75, + train_cfg=dict( + init=dict( + assigner=dict(type='ConvexAssigner', scale=4, pos_num=1, assigned_labels_filled=-1), + allowed_border=-1, + pos_weight=-1, + debug=False), + refine=dict( + assigner=dict( + type='MaxConvexIoUAssigner', + pos_iou_thr=0.1, + neg_iou_thr=0.1, + min_pos_iou=0, + ignore_iof_thr=-1, + assigned_labels_filled=-1, + iou_calculator=dict(type='ConvexOverlaps')), + allowed_border=-1, + pos_weight=-1, + debug=False)), + test_cfg=dict( + nms_pre=2000, + min_bbox_size=0, + score_thr=0.05, + nms=dict(iou_thr=0.4), + max_per_img=2000)) + ) +dataset = dict( + train=dict( + type="DOTADataset", + dataset_dir='/home/zytx121/mmrotate/data/processed_DOTA/trainval_1024_200_1.0', + transforms=[ + dict( + type="RotatedResize", + min_size=1024, + max_size=1024 + ), + dict(type='RotatedRandomFlip', prob=0.5, direction="horizontal"), + dict(type='RotatedRandomFlip', prob=0.5, direction="vertical"), + dict( + type = "Pad", + size_divisor=32), + dict( + type = "Normalize", + mean = [123.675, 116.28, 103.53], + std = [58.395, 57.12, 57.375], + to_bgr=False,) + + ], + batch_size=1, + num_workers=4, + shuffle=True, + filter_empty_gt=False + ), + val=dict( + type="DOTADataset", + dataset_dir='/home/zytx121/mmrotate/data/processed_DOTA/trainval_1024_200_1.0', + transforms=[ + dict( + type="RotatedResize", + min_size=1024, + max_size=1024 + ), + dict( + type = "Pad", + size_divisor=32), + dict( + type = "Normalize", + mean = [123.675, 116.28, 103.53], + std = [58.395, 57.12, 57.375], + to_bgr=False), + ], + batch_size=2, + num_workers=4, + shuffle=False + ), + test=dict( + type="ImageDataset", + images_dir='/home/zytx121/mmrotate/data/processed_DOTA/test_1024_200_1.0/images', + transforms=[ + dict( + type="RotatedResize", + min_size=1024, + max_size=1024 + ), + dict( + type = "Pad", + size_divisor=32), + dict( + type = "Normalize", + mean = [123.675, 116.28, 103.53], + std = [58.395, 57.12, 57.375], + to_bgr=False,), + ], + num_workers=4, + batch_size=1, + ) +) + +optimizer = dict( + type='SGD', + lr=0.008, + momentum=0.9, + weight_decay=0.0001, + grad_clip=dict( + max_norm=35, + norm_type=2)) + +scheduler = dict( + type='StepLR', + warmup='linear', + warmup_iters=500, + warmup_ratio=1.0 / 3, + milestones=[7, 10]) + + +logger = dict( + type="RunLogger") + +max_epoch = 12 +eval_interval = 1 +checkpoint_interval = 1 +log_interval = 50 diff --git a/configs/sasm_obb_r50_fpn_1x_dota.py b/configs/sasm_obb_r50_fpn_1x_dota.py new file mode 100644 index 0000000..5397876 --- /dev/null +++ b/configs/sasm_obb_r50_fpn_1x_dota.py @@ -0,0 +1,155 @@ +# model settings +model = dict( + type='RotatedRetinaNet', + backbone=dict( + type='Resnet50', + frozen_stages=1, + return_stages=["layer1","layer2","layer3","layer4"], + pretrained= True), + neck=dict( + type='FPN', + in_channels=[256, 512, 1024, 2048], + out_channels=256, + start_level=1, + add_extra_convs="on_input", + norm_cfg = dict(type='GN', num_groups=32), + num_outs=5), + bbox_head=dict( + type='SAMRepPointsHead', + num_classes=15, + in_channels=256, + feat_channels=256, + point_feat_channels=256, + stacked_convs=3, + num_points=9, + gradient_mul=0.3, + point_strides=[8, 16, 32, 64, 128], + point_base_scale=2, + norm_cfg = dict(type='GN', num_groups=32), + loss_cls=dict( + type='FocalLoss', + use_sigmoid=True, + gamma=2.0, + alpha=0.25, + loss_weight=1.0), + loss_bbox_init=dict(type='ConvexGIoULoss', loss_weight=0.375), + loss_bbox_refine=dict(type='BCConvexGIoULoss', loss_weight=1.0), + transform_method='rotrect', + use_reassign=False, + topk=6, + anti_factor=0.75, + train_cfg=dict( + init=dict( + assigner=dict(type='ConvexAssigner', scale=4, pos_num=1, assigned_labels_filled=-1), + allowed_border=-1, + pos_weight=-1, + debug=False), + refine=dict( + assigner=dict(type='SASAssigner', topk=9), + allowed_border=-1, + pos_weight=-1, + debug=False)), + test_cfg=dict( + nms_pre=2000, + min_bbox_size=0, + score_thr=0.05, + nms=dict(iou_thr=0.4), + max_per_img=2000)) + ) +dataset = dict( + train=dict( + type="DOTADataset", + dataset_dir='/home/zytx121/mmrotate/data/processed_DOTA/trainval_1024_200_1.0', + transforms=[ + dict( + type="RotatedResize", + min_size=1024, + max_size=1024 + ), + dict(type='RotatedRandomFlip', prob=0.5, direction="horizontal"), + dict(type='RotatedRandomFlip', prob=0.5, direction="vertical"), + dict( + type = "Pad", + size_divisor=32), + dict( + type = "Normalize", + mean = [123.675, 116.28, 103.53], + std = [58.395, 57.12, 57.375], + to_bgr=False,) + + ], + batch_size=2, + num_workers=4, + shuffle=True, + filter_empty_gt=False + ), + val=dict( + type="DOTADataset", + dataset_dir='/home/zytx121/mmrotate/data/processed_DOTA/trainval_1024_200_1.0', + transforms=[ + dict( + type="RotatedResize", + min_size=1024, + max_size=1024 + ), + dict( + type = "Pad", + size_divisor=32), + dict( + type = "Normalize", + mean = [123.675, 116.28, 103.53], + std = [58.395, 57.12, 57.375], + to_bgr=False), + ], + batch_size=2, + num_workers=4, + shuffle=False + ), + test=dict( + type="ImageDataset", + images_dir='/home/zytx121/mmrotate/data/processed_DOTA/test_1024_200_1.0/images', + transforms=[ + dict( + type="RotatedResize", + min_size=1024, + max_size=1024 + ), + dict( + type = "Pad", + size_divisor=32), + dict( + type = "Normalize", + mean = [123.675, 116.28, 103.53], + std = [58.395, 57.12, 57.375], + to_bgr=False,), + ], + num_workers=4, + batch_size=1, + ) +) + +optimizer = dict( + type='SGD', + # lr=0.01/4., #0.0,#0.01*(1/8.), + lr=0.008, + momentum=0.9, + weight_decay=0.0001, + grad_clip=dict( + max_norm=35, + norm_type=2)) + +scheduler = dict( + type='StepLR', + warmup='linear', + warmup_iters=500, + warmup_ratio=1.0 / 3, + milestones=[7, 10]) + + +logger = dict( + type="RunLogger") + +max_epoch = 12 +eval_interval = 1 +checkpoint_interval = 1 +log_interval = 50 diff --git a/projects/oriented_reppoints/README.md b/projects/oriented_reppoints/README.md new file mode 100644 index 0000000..22e4b01 --- /dev/null +++ b/projects/oriented_reppoints/README.md @@ -0,0 +1,45 @@ +## Oriented RepPoints +> [Oriented RepPoints for Aerial Object Detection](https://openaccess.thecvf.com/content/CVPR2022/papers/Li_Oriented_RepPoints_for_Aerial_Object_Detection_CVPR_2022_paper.pdf) + + +### Abstract + +
+ +
+ +In contrast to the generic object, aerial targets are often non-axis aligned with arbitrary orientations having +the cluttered surroundings. Unlike the mainstreamed approaches regressing the bounding box orientations, this paper +proposes an effective adaptive points learning approach to aerial object detection by taking advantage of the adaptive +points representation, which is able to capture the geometric information of the arbitrary-oriented instances. +To this end, three oriented conversion functions are presented to facilitate the classification and localization +with accurate orientation. Moreover, we propose an effective quality assessment and sample assignment scheme for +adaptive points learning toward choosing the representative oriented reppoints samples during training, which is +able to capture the non-axis aligned features from adjacent objects or background noises. A spatial constraint is +introduced to penalize the outlier points for roust adaptive learning. Experimental results on four challenging +aerial datasets including DOTA, HRSC2016, UCAS-AOD and DIOR-R, demonstrate the efficacy of our proposed approach. + +### Training +```sh +python run_net.py --config-file=configs/oriented_reppoints_r50_fpn_1x_dota.py --task=train +``` + +### Testing +```sh +python run_net.py --config-file=configs/oriented_reppoints_r50_fpn_1x_dota.py --task=test +``` + +### Performance +| Models | Dataset| Sub_Image_Size/Overlap |Train Aug | Test Aug | Optim | Lr schd | mAP | Paper | Config | Download | +|:-----------:| :-----: |:-----:|:-----:| :-----: | :-----:| :-----:| :----: |:--------:|:--------------------------------------------------------------:| :--------: | +| Oriented-Reppoints-R50-FPN | DOTA1.0|1024/200| flip|-| SGD | 1x | 66.99 | [paper](https://openaccess.thecvf.com/content/CVPR2022/papers/Li_Oriented_RepPoints_for_Aerial_Object_Detection_CVPR_2022_paper.pdf)| [config](configs/oriented_reppoints_r50_fpn_1x_dota.py) | [model]() | + +### Citation +``` +@inproceedings{li2022ori, + title={Oriented RepPoints for Aerial Object Detection}, + author={Wentong Li, Yijie Chen, Kaixuan Hu, Jianke Zhu}, + booktitle={Proceedings of IEEE/CVF Conference on Computer Vision and Pattern Recognition (CVPR)}, + year={2022} +} +``` \ No newline at end of file diff --git a/projects/oriented_reppoints/configs/oriented_reppoints_r50_fpn_1x_dota.py b/projects/oriented_reppoints/configs/oriented_reppoints_r50_fpn_1x_dota.py new file mode 100644 index 0000000..8a620c1 --- /dev/null +++ b/projects/oriented_reppoints/configs/oriented_reppoints_r50_fpn_1x_dota.py @@ -0,0 +1,165 @@ +# model settings +model = dict( + type='RotatedRetinaNet', + backbone=dict( + type='Resnet50', + frozen_stages=1, + return_stages=["layer1","layer2","layer3","layer4"], + pretrained= True), + neck=dict( + type='FPN', + in_channels=[256, 512, 1024, 2048], + out_channels=256, + start_level=1, + add_extra_convs="on_input", + norm_cfg = dict(type='GN', num_groups=32), + num_outs=5), + bbox_head=dict( + type='OrientedRepPointsHead', + num_classes=15, + in_channels=256, + feat_channels=256, + point_feat_channels=256, + stacked_convs=3, + num_points=9, + gradient_mul=0.3, + point_strides=[8, 16, 32, 64, 128], + point_base_scale=2, + norm_cfg = dict(type='GN', num_groups=32), + loss_cls=dict( + type='FocalLoss', + use_sigmoid=True, + gamma=2.0, + alpha=0.25, + loss_weight=1.0), + loss_bbox_init=dict(type='ConvexGIoULoss', loss_weight=0.375), + loss_bbox_refine=dict(type='ConvexGIoULoss', loss_weight=1.0), + loss_spatial_init=dict(type='SpatialBorderLoss', loss_weight=0.05), + loss_spatial_refine=dict(type='SpatialBorderLoss', loss_weight=0.1), + transform_method='rotrect', + use_reassign=False, + init_qua_weight=0.2, + top_ratio=0.4, + # topk=6, + # anti_factor=0.75, + train_cfg=dict( + init=dict( + assigner=dict(type='ConvexAssigner', scale=4, pos_num=1, assigned_labels_filled=-1), + allowed_border=-1, + pos_weight=-1, + debug=False), + refine=dict( + assigner=dict( + type='MaxConvexIoUAssigner', + pos_iou_thr=0.1, + neg_iou_thr=0.1, + min_pos_iou=0, + ignore_iof_thr=-1, + assigned_labels_filled=-1, + iou_calculator=dict(type='ConvexOverlaps')), + allowed_border=-1, + pos_weight=-1, + debug=False)), + test_cfg=dict( + nms_pre=2000, + min_bbox_size=0, + score_thr=0.05, + nms=dict(iou_thr=0.4), + max_per_img=2000)) + ) +dataset = dict( + train=dict( + type="DOTADataset", + dataset_dir='/home/cxjyxx_me/workspace/JAD/datasets/processed_DOTA/trainval_1024_200_1.0', + transforms=[ + dict( + type="RotatedResize", + min_size=1024, + max_size=1024 + ), + dict(type='RotatedRandomFlip', prob=0.5, direction="horizontal"), + dict(type='RotatedRandomFlip', prob=0.5, direction="vertical"), + dict( + type = "Pad", + size_divisor=32), + dict( + type = "Normalize", + mean = [123.675, 116.28, 103.53], + std = [58.395, 57.12, 57.375], + to_bgr=False,) + + ], + batch_size=2, + num_workers=4, + shuffle=True, + filter_empty_gt=False + ), + val=dict( + type="DOTADataset", + dataset_dir='/home/cxjyxx_me/workspace/JAD/datasets/processed_DOTA/trainval_1024_200_1.0', + transforms=[ + dict( + type="RotatedResize", + min_size=1024, + max_size=1024 + ), + dict( + type = "Pad", + size_divisor=32), + dict( + type = "Normalize", + mean = [123.675, 116.28, 103.53], + std = [58.395, 57.12, 57.375], + to_bgr=False), + ], + batch_size=2, + num_workers=4, + shuffle=False + ), + test=dict( + type="ImageDataset", + images_dir='/home/cxjyxx_me/workspace/JAD/datasets/processed_DOTA/test_1024_200_1.0/images', + transforms=[ + dict( + type="RotatedResize", + min_size=1024, + max_size=1024 + ), + dict( + type = "Pad", + size_divisor=32), + dict( + type = "Normalize", + mean = [123.675, 116.28, 103.53], + std = [58.395, 57.12, 57.375], + to_bgr=False,), + ], + num_workers=4, + batch_size=1, + ) +) + +optimizer = dict( + type='SGD', + lr=0.008, + momentum=0.9, + weight_decay=0.0001, + grad_clip=dict( + max_norm=35, + norm_type=2)) + +scheduler = dict( + type='StepLR', + warmup='linear', + warmup_iters=500, + warmup_ratio=1.0 / 3, + milestones=[7, 10]) + + +logger = dict( + type="RunLogger") + +max_epoch = 12 +eval_interval = 1 +checkpoint_interval = 1 +log_interval = 50 diff --git a/projects/oriented_reppoints/configs/oriented_reppoints_test.py b/projects/oriented_reppoints/configs/oriented_reppoints_test.py new file mode 100644 index 0000000..b8094e8 --- /dev/null +++ b/projects/oriented_reppoints/configs/oriented_reppoints_test.py @@ -0,0 +1,162 @@ +# model settings +model = dict( + type='RotatedRetinaNet', + backbone=dict( + type='Resnet50', + frozen_stages=1, + return_stages=["layer1","layer2","layer3","layer4"], + pretrained= True), + neck=dict( + type='FPN', + in_channels=[256, 512, 1024, 2048], + out_channels=256, + start_level=1, + add_extra_convs="on_input", + norm_cfg = dict(type='GN', num_groups=32), + num_outs=5), + bbox_head=dict( + type='RotatedRepPointsHead', + num_classes=15, + in_channels=256, + feat_channels=256, + point_feat_channels=256, + stacked_convs=3, + num_points=9, + gradient_mul=0.3, + point_strides=[8, 16, 32, 64, 128], + point_base_scale=2, + norm_cfg = dict(type='GN', num_groups=32), + loss_cls=dict( + type='FocalLoss', + use_sigmoid=True, + gamma=2.0, + alpha=0.25, + loss_weight=1.0), + loss_bbox_init=dict(type='ConvexGIoULoss', loss_weight=0.375), + loss_bbox_refine=dict(type='ConvexGIoULoss', loss_weight=1.0), + transform_method='rotrect', + use_reassign=False, + topk=6, + anti_factor=0.75, + train_cfg=dict( + init=dict( + assigner=dict(type='ConvexAssigner', scale=4, pos_num=1, assigned_labels_filled=-1), + allowed_border=-1, + pos_weight=-1, + debug=False), + refine=dict( + assigner=dict( + type='MaxConvexIoUAssigner', + pos_iou_thr=0.4, + neg_iou_thr=0.3, + min_pos_iou=0, + ignore_iof_thr=-1, + assigned_labels_filled=-1, + iou_calculator=dict(type='ConvexOverlaps')), + allowed_border=-1, + pos_weight=-1, + debug=False)), + test_cfg=dict( + nms_pre=2000, + min_bbox_size=0, + score_thr=0.05, + nms=dict(iou_thr=0.4), + max_per_img=2000)) + ) +dataset = dict( + train=dict( + type="DOTADataset", + dataset_dir='/home/cxjyxx_me/workspace/JAD/datasets/processed_DOTA/trainval_1024_200_1.0', + transforms=[ + dict( + type="RotatedResize", + min_size=1024, + max_size=1024 + ), + # dict(type='RotatedRandomFlip', prob=0.5, direction="horizontal"), + # dict(type='RotatedRandomFlip', prob=0.5, direction="vertical"), + dict( + type = "Pad", + size_divisor=32), + dict( + type = "Normalize", + mean = [123.675, 116.28, 103.53], + std = [58.395, 57.12, 57.375], + to_bgr=False,) + + ], + batch_size=2, + num_workers=4, + shuffle=True, + filter_empty_gt=False + ), + val=dict( + type="DOTADataset", + dataset_dir='/home/cxjyxx_me/workspace/JAD/datasets/processed_DOTA/trainval_1024_200_1.0', + transforms=[ + dict( + type="RotatedResize", + min_size=1024, + max_size=1024 + ), + dict( + type = "Pad", + size_divisor=32), + dict( + type = "Normalize", + mean = [123.675, 116.28, 103.53], + std = [58.395, 57.12, 57.375], + to_bgr=False), + ], + batch_size=2, + num_workers=4, + shuffle=False + ), + test=dict( + type="ImageDataset", + images_dir='/home/cxjyxx_me/workspace/JAD/datasets/processed_DOTA/test_1024_200_1.0/images', + transforms=[ + dict( + type="RotatedResize", + min_size=1024, + max_size=1024 + ), + dict( + type = "Pad", + size_divisor=32), + dict( + type = "Normalize", + mean = [123.675, 116.28, 103.53], + std = [58.395, 57.12, 57.375], + to_bgr=False,), + ], + num_workers=4, + batch_size=1, + ) +) + +optimizer = dict( + type='SGD', + # lr=0.01/4., #0.0,#0.01*(1/8.), + lr=0.008, + momentum=0.9, + weight_decay=0.0001, + grad_clip=dict( + max_norm=35, + norm_type=2)) + +scheduler = dict( + type='StepLR', + warmup='linear', + warmup_iters=500, + warmup_ratio=1.0 / 3, + milestones=[7, 10]) + + +logger = dict( + type="RunLogger") + +max_epoch = 12 +eval_interval = 1 +checkpoint_interval = 1 +log_interval = 50 diff --git a/projects/oriented_reppoints/run_net.py b/projects/oriented_reppoints/run_net.py new file mode 100644 index 0000000..25f1d65 --- /dev/null +++ b/projects/oriented_reppoints/run_net.py @@ -0,0 +1,56 @@ +import argparse +import jittor as jt +from jdet.runner import Runner +from jdet.config import init_cfg + + +def main(): + parser = argparse.ArgumentParser(description="Jittor Object Detection Training") + parser.add_argument( + "--config-file", + default="", + metavar="FILE", + help="path to config file", + type=str, + ) + parser.add_argument( + "--task", + default="train", + help="train,val,test", + type=str, + ) + + parser.add_argument( + "--no_cuda", + action='store_true' + ) + + parser.add_argument( + "--save_dir", + default=".", + type=str, + ) + + args = parser.parse_args() + + if not args.no_cuda: + jt.flags.use_cuda=1 + + assert args.task in ["train","val","test","vis_test"],f"{args.task} not support, please choose [train,val,test,vis_test]" + + if args.config_file: + init_cfg(args.config_file) + + runner = Runner() + + if args.task == "train": + runner.run() + elif args.task == "val": + runner.val() + elif args.task == "test": + runner.test() + elif args.task == "vis_test": + runner.run_on_images(args.save_dir) + +if __name__ == "__main__": + main() \ No newline at end of file diff --git a/projects/oriented_reppoints/test_oriented_reppoints.py b/projects/oriented_reppoints/test_oriented_reppoints.py new file mode 100644 index 0000000..b2b34b8 --- /dev/null +++ b/projects/oriented_reppoints/test_oriented_reppoints.py @@ -0,0 +1,72 @@ +import jittor as jt +from jdet.config import init_cfg, get_cfg +from jdet.utils.general import parse_losses +from jdet.utils.registry import build_from_cfg,MODELS,DATASETS,OPTIMS +import argparse +import os +import pickle as pk +import jdet + +def main(): + parser = argparse.ArgumentParser(description="Jittor Object Detection Training") + parser.add_argument( + "--set_data", + action='store_true' + ) + args = parser.parse_args() + + jt.flags.use_cuda=1 + jt.set_global_seed(666) + init_cfg("projects/oriented_reppoints/configs/oriented_reppoints_test.py") + cfg = get_cfg() + + model = build_from_cfg(cfg.model,MODELS) + optimizer = build_from_cfg(cfg.optimizer,OPTIMS,params= model.parameters()) + + model.train() + if (args.set_data): + imagess = [] + targetss = [] + correct_loss = [] + train_dataset = build_from_cfg(cfg.dataset.train,DATASETS) + for batch_idx,(images,targets) in enumerate(train_dataset): + if (batch_idx > 10): + break + print(batch_idx) + imagess.append(jdet.utils.general.sync(images)) + targetss.append(jdet.utils.general.sync(targets)) + losses = model(images,targets) + all_loss,losses = parse_losses(losses) + optimizer.step(all_loss) + correct_loss.append(all_loss.item()) + data = { + "imagess": imagess, + "targetss": targetss, + "correct_loss": correct_loss, + } + if (not os.path.exists("test_datas_oriented_reppoints")): + os.makedirs("test_datas_oriented_reppoints") + pk.dump(data, open("test_datas_oriented_reppoints/test_data.pk", "wb")) + print(correct_loss) + else: + data = pk.load(open("test_datas_oriented_reppoints/test_data.pk", "rb")) + imagess = jdet.utils.general.to_jt_var(data["imagess"]) + targetss = jdet.utils.general.to_jt_var(data["targetss"]) + correct_loss = data["correct_loss"] + + for batch_idx in range(len(imagess)): + images = imagess[batch_idx] + targets = targetss[batch_idx] + losses = model(images,targets) + all_loss,losses = parse_losses(losses) + optimizer.step(all_loss) + l = all_loss.item() + c_l = correct_loss[batch_idx] + err_rate = abs(c_l-l)/min(c_l,l) + print(f"correct loss is {c_l:.4f}, runtime loss is {l:.4f}, err rate is {err_rate*100:.2f}%") + assert err_rate<1e-3,"LOSS is not correct, please check it" + print(f"Loss is correct with err_rate<{1e-3}") + print("success!") + +if __name__ == "__main__": + main() \ No newline at end of file diff --git a/python/jdet/models/boxes/assigner.py b/python/jdet/models/boxes/assigner.py index 253aeac..c777534 100644 --- a/python/jdet/models/boxes/assigner.py +++ b/python/jdet/models/boxes/assigner.py @@ -1,6 +1,7 @@ import jittor as jt from jdet.utils.registry import BOXES,build_from_cfg from jdet.models.boxes.box_ops import points_in_rotated_boxes +from jdet.ops.reppoints_convex_iou import reppoints_convex_iou import numpy as np def deleteme(a, b, size = 10): @@ -414,10 +415,10 @@ def get_horizontal_bboxes(self, gt_rbboxes): """get_horizontal_bboxes from polygons. Args: - gt_rbboxes (torch.Tensor): Groundtruth polygons, shape (k, 8). + gt_rbboxes (jt.Tensor): Groundtruth polygons, shape (k, 8). Returns: - gt_rect_bboxes (torch.Tensor): The horizontal bboxes, shape (k, 4). + gt_rect_bboxes (jt.Tensor): The horizontal bboxes, shape (k, 4). """ gt_xs, gt_ys = gt_rbboxes[:, 0::2], gt_rbboxes[:, 1::2] gt_xmin = gt_xs.min(1) @@ -451,8 +452,8 @@ def assign(self, 6. limit the positive sample's center in gt Args: - points (torch.Tensor): Points to be assigned, shape(n, 18). - gt_rbboxes (torch.Tensor): Groundtruth polygons, shape (k, 8). + points (jt.Tensor): Points to be assigned, shape(n, 18). + gt_rbboxes (jt.Tensor): Groundtruth polygons, shape (k, 8). gt_rbboxes_ignore (Tensor, optional): Ground truth polygons that are labelled as `ignored`, e.g., crowd boxes in COCO. gt_labels (Tensor, optional): Label of gt_bboxes, shape (k, ). @@ -465,15 +466,17 @@ def assign(self, if num_gts == 0 or num_points == 0: # If no truth assign everything to the background - assigned_gt_inds = points.new_full((num_points, ), - 0, - dtype=jt.int32) + assigned_gt_inds = jt.full((num_points, ), 0, dtype=jt.int32) + # assigned_gt_inds = points.new_full((num_points, ), + # 0, + # dtype=jt.int32) if gt_labels is None: assigned_labels = None else: - assigned_labels = points.new_full((num_points, ), - -1, - dtype=jt.int32) + assigned_labels = jt.full((num_bboxes, ), -1, dtype=jt.int32) + # assigned_labels = points.new_full((num_points, ), + # -1, + # dtype=jt.int32) return AssignResult( num_gts, assigned_gt_inds, None, labels=assigned_labels) @@ -588,9 +591,9 @@ def assign(self, bboxes, gt_bboxes, gt_bboxes_ignore=None, gt_labels=None): 6. limit the positive sample's center in gt Args: - points (torch.Tensor): Points to be assigned, shape(n, 18). - gt_rbboxes (torch.Tensor): Groundtruth polygons, shape (k, 8). - overlaps (torch.Tensor): Overlaps between k gt_bboxes and n bboxes, + points (jt.Tensor): Points to be assigned, shape(n, 18). + gt_rbboxes (jt.Tensor): Groundtruth polygons, shape (k, 8). + overlaps (jt.Tensor): Overlaps between k gt_bboxes and n bboxes, shape(k, n). gt_rbboxes_ignore (Tensor, optional): Ground truth polygons that are labelled as `ignored`, e.g., crowd boxes in COCO. @@ -599,6 +602,8 @@ def assign(self, bboxes, gt_bboxes, gt_bboxes_ignore=None, gt_labels=None): Returns: :obj:`AssignResult`: The assign result. """ + + if bboxes.shape[0] == 0 or gt_bboxes.shape[0] == 0: raise ValueError('No gt or bboxes') @@ -609,3 +614,217 @@ def assign(self, bboxes, gt_bboxes, gt_bboxes_ignore=None, gt_labels=None): assert NotImplementedError assign_result = self.assign_wrt_overlaps(overlaps, gt_labels) return assign_result + +@BOXES.register_module() +class SASAssigner: + """Assign a corresponding gt bbox or background to each bbox. Each + proposals will be assigned with `0` or a positive integer indicating the + ground truth index. + + - 0: negative sample, no assigned gt + - positive integer: positive sample, index (1-based) of assigned gt + + Args: + scale (float): IoU threshold for positive bboxes. + pos_num (float): find the nearest pos_num points to gt center in this + level. + """ + + def __init__(self, topk): + self.topk = topk + + def assign(self, + bboxes, + num_level_bboxes, + gt_bboxes, + gt_bboxes_ignore=None, + gt_labels=None): + """Assign gt to bboxes. + + The assignment is done in following steps + + 1. compute iou between all bbox (bbox of all pyramid levels) and gt + 2. compute center distance between all bbox and gt + 3. on each pyramid level, for each gt, select k bbox whose center + are closest to the gt center, so we total select k*l bbox as + candidates for each gt + 4. get corresponding iou for the these candidates, and compute the + mean and std, set mean + std as the iou threshold + 5. select these candidates whose iou are greater than or equal to + the threshold as positive + 6. limit the positive sample's center in gt + + Args: + bboxes (jt.Tensor): Bounding boxes to be assigned, shape(n, 4). + num_level_bboxes (List): num of bboxes in each level + gt_bboxes (jt.Tensor): Groundtruth boxes, shape (k, 4). + gt_bboxes_ignore (Tensor, optional): Ground truth bboxes that are + labelled as `ignored`, e.g., crowd boxes in COCO. + gt_labels (Tensor, optional): Label of gt_bboxes, shape (k, ). + + Returns: + :obj:`AssignResult`: The assign result. + """ + INF = 100000000 + + num_gt, num_bboxes = gt_bboxes.size(0), bboxes.size(0) + overlaps = convex_overlaps(gt_bboxes, bboxes) + # assigned_gt_inds = overlaps.new_full((num_bboxes, ), + # 0, + # dtype=jt.long) + assigned_gt_inds = jt.full((num_bboxes, ), 0, dtype=jt.int32) + + if num_gt == 0 or num_bboxes == 0: + max_overlaps = overlaps.new_zeros((num_bboxes, )) + if num_gt == 0: + assigned_gt_inds[:] = 0 + if gt_labels is None: + assigned_labels = None + else: + assigned_labels = jt.full((num_bboxes, ), -1, dtype=jt.int32) + # assigned_labels = overlaps.new_full((num_bboxes, ), + # -1, + # dtype=jt.long) + return AssignResult( + num_gt, assigned_gt_inds, max_overlaps, labels=assigned_labels) + + # compute center distance between all bbox and gt + # the center of poly + gt_bboxes_hbb = get_horizontal_bboxes(gt_bboxes) + + gt_cx = (gt_bboxes_hbb[:, 0] + gt_bboxes_hbb[:, 2]) / 2.0 + gt_cy = (gt_bboxes_hbb[:, 1] + gt_bboxes_hbb[:, 3]) / 2.0 + gt_points = jt.stack((gt_cx, gt_cy), dim=1) + + bboxes = bboxes.reshape(-1, 9, 2) + pts_x = bboxes[:, :, 0::2] + pts_y = bboxes[:, :, 1::2] + + pts_x_mean = pts_x.mean(dim=1).squeeze() + pts_y_mean = pts_y.mean(dim=1).squeeze() + + bboxes_points = jt.stack((pts_x_mean, pts_y_mean), dim=1) + + distances = (bboxes_points[:, None, :] - + gt_points[None, :, :]).pow(2).sum(-1).sqrt() + + # Selecting candidates based on the center distance + candidate_idxs = [] + start_idx = 0 + for level, bboxes_per_level in enumerate(num_level_bboxes): + end_idx = start_idx + bboxes_per_level + distances_per_level = distances[start_idx:end_idx, :] + _, topk_idxs_per_level = distances_per_level.topk( + self.topk, dim=0, largest=False) + candidate_idxs.append(topk_idxs_per_level + start_idx) + start_idx = end_idx + candidate_idxs = jt.contrib.concat(candidate_idxs, dim=0) + + gt_bboxes_ratios = AspectRatio(gt_bboxes) + gt_bboxes_ratios_per_gt = gt_bboxes_ratios.mean(0) + candidate_overlaps = overlaps[candidate_idxs, jt.arange(num_gt)] + overlaps_mean_per_gt = candidate_overlaps.mean(0) + overlaps_std_per_gt = candidate_overlaps.std() + overlaps_thr_per_gt = overlaps_mean_per_gt + overlaps_std_per_gt + + # new assign + iou_thr_weight = jt.exp((-1 / 4) * gt_bboxes_ratios_per_gt) + overlaps_thr_per_gt = overlaps_thr_per_gt * iou_thr_weight + is_pos = candidate_overlaps >= overlaps_thr_per_gt[None, :] + + # limit the positive sample's center in gt + # inside_flag = jt.full([num_bboxes, num_gt], + # 0.).to(gt_bboxes.device).float() + # inside_flag = points_in_polygons(bboxes_points, gt_bboxes) + inside_flag = points_in_rotated_boxes(bboxes_points, gt_bboxes) + # pointsJf(bboxes_points, gt_bboxes, inside_flag) + is_in_gts = inside_flag[candidate_idxs, + jt.arange(num_gt)].to(is_pos.dtype) + + is_pos = is_pos & is_in_gts + for gt_idx in range(num_gt): + candidate_idxs[:, gt_idx] += gt_idx * num_bboxes + candidate_idxs = candidate_idxs.view(-1) + + # if an anchor box is assigned to multiple gts, + # the one with the highest IoU will be selected. + overlaps_inf = jt.full_like(overlaps, + -INF).t().contiguous().view(-1) + index = candidate_idxs.view(-1)[is_pos.view(-1)] + + overlaps_inf[index] = overlaps.t().contiguous().view(-1)[index] + overlaps_inf = overlaps_inf.view(num_gt, -1).t() + + argmax_overlaps, max_overlaps = overlaps_inf.argmax(dim=1) + assigned_gt_inds[ + max_overlaps != -INF] = argmax_overlaps[max_overlaps != -INF] + 1 + if gt_labels is not None and bool(assigned_gt_inds.sum()): + # assigned_labels = assigned_gt_inds.new_full((num_bboxes, ), -1) + assigned_labels = jt.full((num_bboxes, ), -1, dtype=jt.int32) + pos_inds = jt.nonzero( + assigned_gt_inds > 0).squeeze() + if pos_inds.numel() > 0: + assigned_labels[pos_inds] = gt_labels[ + assigned_gt_inds[pos_inds] - 1] + else: + assigned_labels = None + return AssignResult( + num_gt, assigned_gt_inds, max_overlaps, labels=assigned_labels) + +def AspectRatio(gt_rbboxes): + """Compute the aspect ratio of all gts. + + Args: + gt_rbboxes (jt.Tensor): Groundtruth polygons, shape (k, 8). + + Returns: + ratios (jt.Tensor): The aspect ratio of gt_rbboxes, shape (k, 1). + """ + pt1, pt2, pt3, pt4 = gt_rbboxes[..., :8].chunk(4, 1) + edge1 = jt.sqrt( + jt.pow(pt1[..., 0] - pt2[..., 0], 2) + + jt.pow(pt1[..., 1] - pt2[..., 1], 2)) + edge2 = jt.sqrt( + jt.pow(pt2[..., 0] - pt3[..., 0], 2) + + jt.pow(pt2[..., 1] - pt3[..., 1], 2)) + edges = jt.stack([edge1, edge2], dim=1) + width = jt.max(edges, 1) + height = jt.min(edges, 1) + ratios = (width / height) + return ratios + +def get_horizontal_bboxes(gt_rbboxes): + """Get horizontal bboxes from polygons. + + Args: + gt_rbboxes (jt.Tensor): Groundtruth polygons, shape (k, 8). + + Returns: + gt_rect_bboxes (jt.Tensor): The horizontal bboxes, shape (k, 4). + """ + gt_xs, gt_ys = gt_rbboxes[:, 0::2], gt_rbboxes[:, 1::2] + gt_xmin = gt_xs.min(1) + gt_ymin = gt_ys.min(1) + gt_xmax = gt_xs.max(1) + gt_ymax = gt_ys.max(1) + gt_rect_bboxes = jt.contrib.concat([ + gt_xmin[:, None], gt_ymin[:, None], gt_xmax[:, None], gt_ymax[:, None] + ], + dim=1) + return gt_rect_bboxes + +def convex_overlaps(gt_rbboxes, points): + """Compute overlaps between polygons and points. + + Args: + gt_rbboxes (jt.Tensor): Groundtruth polygons, shape (k, 8). + points (jt.Tensor): Points to be assigned, shape(n, 18). + + Returns: + overlaps (jt.Tensor): Overlaps between k gt_bboxes and n bboxes, + shape(k, n). + """ + if gt_rbboxes.shape[0] == 0: + return gt_rbboxes.new_zeros((0, points.shape[0])) + overlaps = reppoints_convex_iou(points, gt_rbboxes) + return overlaps diff --git a/python/jdet/models/losses/__init__.py b/python/jdet/models/losses/__init__.py index fb153d1..b0d3fe6 100644 --- a/python/jdet/models/losses/__init__.py +++ b/python/jdet/models/losses/__init__.py @@ -13,4 +13,5 @@ from .kd_loss import IMLoss from .rsdet_loss import RSDetLoss from .ridet_loss import RIDetLoss -from .convex_giou_loss import ConvexGIoULoss +from .convex_giou_loss import ConvexGIoULoss, BCConvexGIoULoss +from .spatial_border_loss import SpatialBorderLoss \ No newline at end of file diff --git a/python/jdet/models/losses/convex_giou_loss.py b/python/jdet/models/losses/convex_giou_loss.py index 35218b5..0d9dcc6 100644 --- a/python/jdet/models/losses/convex_giou_loss.py +++ b/python/jdet/models/losses/convex_giou_loss.py @@ -18,9 +18,9 @@ def execute(self, Args: ctx: {save_for_backward, convex_points_grad} - pred (torch.Tensor): Predicted convexes. - target (torch.Tensor): Corresponding gt convexes. - weight (torch.Tensor, optional): The weight of loss for each + pred (jt.Tensor): Predicted convexes. + target (jt.Tensor): Corresponding gt convexes. + weight (jt.Tensor, optional): The weight of loss for each prediction. Defaults to None. reduction (str, optional): The reduction method of the loss. Defaults to None. @@ -72,7 +72,7 @@ class ConvexGIoULoss(nn.Module): loss_weight (float, optional): The weight of loss. Defaults to 1.0. Return: - torch.Tensor: Loss tensor. + jt.Tensor: Loss tensor. """ def __init__(self, reduction='mean', loss_weight=1.0): @@ -90,9 +90,9 @@ def execute(self, """Forward function. Args: - pred (torch.Tensor): Predicted convexes. - target (torch.Tensor): Corresponding gt convexes. - weight (torch.Tensor, optional): The weight of loss for each + pred (jt.Tensor): Predicted convexes. + target (jt.Tensor): Corresponding gt convexes. + weight (jt.Tensor, optional): The weight of loss for each prediction. Defaults to None. avg_factor (int, optional): Average factor that is used to average the loss. Defaults to None. @@ -109,3 +109,220 @@ def execute(self, pred, target, weight, reduction, avg_factor, self.loss_weight) return loss + +@LOSSES.register_module() +class BCConvexGIoULossFuction(jt.Function): + """The function of BCConvex GIoU loss.""" + + + def execute(self, + pred, + target, + weight=None, + reduction=None, + avg_factor=None, + loss_weight=1.0): + """Forward function. + + Args: + ctx: {save_for_backward, convex_points_grad} + pred (jt.Tensor): Predicted convexes. + target (jt.Tensor): Corresponding gt convexes. + weight (jt.Tensor, optional): The weight of loss for each + prediction. Defaults to None. + reduction (str, optional): The reduction method of the + loss. Defaults to None. + avg_factor (int, optional): Average factor that is used to average + the loss. Defaults to None. + loss_weight (float, optional): The weight of loss. Defaults to 1.0. + """ + convex_gious, grad = reppoints_convex_giou(pred, target) + + pts_pred_all_dx = pred[:, 0::2] + pts_pred_all_dy = pred[:, 1::2] + pred_left_x_inds = pts_pred_all_dx.argmin(dim=1)[0].unsqueeze(0) + pred_right_x_inds = pts_pred_all_dx.argmax(dim=1)[0].unsqueeze(0) + pred_up_y_inds = pts_pred_all_dy.argmin(dim=1)[0].unsqueeze(0) + pred_bottom_y_inds = pts_pred_all_dy.argmax(dim=1)[0].unsqueeze(0) + + pred_right_x = pts_pred_all_dx.gather(dim=1, index=pred_right_x_inds) + pred_right_y = pts_pred_all_dy.gather(dim=1, index=pred_right_x_inds) + + pred_left_x = pts_pred_all_dx.gather(dim=1, index=pred_left_x_inds) + pred_left_y = pts_pred_all_dy.gather(dim=1, index=pred_left_x_inds) + + pred_up_x = pts_pred_all_dx.gather(dim=1, index=pred_up_y_inds) + pred_up_y = pts_pred_all_dy.gather(dim=1, index=pred_up_y_inds) + + pred_bottom_x = pts_pred_all_dx.gather(dim=1, index=pred_bottom_y_inds) + pred_bottom_y = pts_pred_all_dy.gather(dim=1, index=pred_bottom_y_inds) + pred_corners = jt.concat([ + pred_left_x, pred_left_y, pred_up_x, pred_up_y, pred_right_x, + pred_right_y, pred_bottom_x, pred_bottom_y + ], + dim=-1) + + pts_target_all_dx = target[:, 0::2] + pts_target_all_dy = target[:, 1::2] + + target_left_x_inds = pts_target_all_dx.argmin(dim=1)[0].unsqueeze(0) + target_right_x_inds = pts_target_all_dx.argmax(dim=1)[0].unsqueeze(0) + target_up_y_inds = pts_target_all_dy.argmin(dim=1)[0].unsqueeze(0) + target_bottom_y_inds = pts_target_all_dy.argmax(dim=1)[0].unsqueeze(0) + + target_right_x = pts_target_all_dx.gather( + dim=1, index=target_right_x_inds) + target_right_y = pts_target_all_dy.gather( + dim=1, index=target_right_x_inds) + + target_left_x = pts_target_all_dx.gather( + dim=1, index=target_left_x_inds) + target_left_y = pts_target_all_dy.gather( + dim=1, index=target_left_x_inds) + + target_up_x = pts_target_all_dx.gather(dim=1, index=target_up_y_inds) + target_up_y = pts_target_all_dy.gather(dim=1, index=target_up_y_inds) + + target_bottom_x = pts_target_all_dx.gather( + dim=1, index=target_bottom_y_inds) + target_bottom_y = pts_target_all_dy.gather( + dim=1, index=target_bottom_y_inds) + + target_corners = jt.concat([ + target_left_x, target_left_y, target_up_x, target_up_y, + target_right_x, target_right_y, target_bottom_x, target_bottom_y + ], + dim=-1) + + pts_pred_dx_mean = pts_pred_all_dx.mean( + dim=1).reshape(-1, 1) + pts_pred_dy_mean = pts_pred_all_dy.mean( + dim=1).reshape(-1, 1) + pts_pred_mean = jt.concat([pts_pred_dx_mean, pts_pred_dy_mean], dim=-1) + + pts_target_dx_mean = pts_target_all_dx.mean( + dim=1).reshape(-1, 1) + pts_target_dy_mean = pts_target_all_dy.mean( + dim=1).reshape(-1, 1) + pts_target_mean = jt.concat([pts_target_dx_mean, pts_target_dy_mean], + dim=-1) + + beta = 1.0 + + diff_mean = jt.abs(pts_pred_mean - pts_target_mean) + diff_mean_loss = jt.where(diff_mean < beta, + 0.5 * diff_mean * diff_mean / beta, + diff_mean - 0.5 * beta) + diff_mean_loss = diff_mean_loss.sum() / len(diff_mean_loss) + + diff_corners = jt.abs(pred_corners - target_corners) + diff_corners_loss = jt.where( + diff_corners < beta, 0.5 * diff_corners * diff_corners / beta, + diff_corners - 0.5 * beta) + diff_corners_loss = diff_corners_loss.sum() / len(diff_corners_loss) + + target_aspect = AspectRatio(target) + smooth_loss_weight = jt.exp((-1 / 4) * target_aspect) + loss = \ + smooth_loss_weight * (diff_mean_loss.reshape(-1, 1) + + diff_corners_loss.reshape(-1, 1)) + \ + 1 - (1 - 2 * smooth_loss_weight) * convex_gious + + if weight is not None: + loss = loss * weight + grad = grad * weight.reshape(-1, 1) + if reduction == 'sum': + loss = loss.sum() + elif reduction == 'mean': + loss = loss.mean() + + unvaild_inds = jt.nonzero((grad > 1).sum(1))[:, 0] + grad[unvaild_inds] = 1e-6 + + reduce_grad = -grad / grad.size(0) * loss_weight + self.convex_points_grad = reduce_grad + return loss + + def grad(self, input=None): + """Backward function.""" + convex_points_grad = self.convex_points_grad + return convex_points_grad, None, None, None, None, None + + +bc_convex_giou_loss = BCConvexGIoULossFuction.apply + + +@LOSSES.register_module() +class BCConvexGIoULoss(nn.Module): + """BCConvex GIoU loss. + + Computing the BCConvex GIoU loss between a set of predicted convexes and + target convexes. + + Args: + reduction (str, optional): The reduction method of the loss. Defaults + to 'mean'. + loss_weight (float, optional): The weight of loss. Defaults to 1.0. + + Return: + jt.Tensor: Loss tensor. + """ + + def __init__(self, reduction='mean', loss_weight=1.0): + super(BCConvexGIoULoss, self).__init__() + self.reduction = reduction + self.loss_weight = loss_weight + + def execute(self, + pred, + target, + weight=None, + avg_factor=None, + reduction_override=None, + **kwargs): + """Forward function. + + Args: + pred (jt.Tensor): Predicted convexes. + target (jt.Tensor): Corresponding gt convexes. + weight (jt.Tensor, optional): The weight of loss for each + prediction. Defaults to None. + avg_factor (int, optional): Average factor that is used to average + the loss. Defaults to None. + reduction_override (str, optional): The reduction method used to + override the original reduction method of the loss. + Defaults to None. + """ + if weight is not None and not jt.any(weight > 0): + return (pred * weight.unsqueeze(-1)).sum() + assert reduction_override in (None, 'none', 'mean', 'sum') + reduction = ( + reduction_override if reduction_override else self.reduction) + loss = self.loss_weight * bc_convex_giou_loss( + pred, target, weight, reduction, avg_factor, self.loss_weight) + return loss + + +def AspectRatio(gt_rbboxes): + """Compute the aspect ratio of all gts. + + Args: + gt_rbboxes (jt.Tensor): Groundtruth polygons, shape (k, 8). + + Returns: + ratios (jt.Tensor): The aspect ratio of gt_rbboxes, shape (k, 1). + """ + pt1, pt2, pt3, pt4 = gt_rbboxes[..., :8].chunk(4, 1) + edge1 = jt.sqrt( + jt.pow(pt1[..., 0] - pt2[..., 0], 2) + + jt.pow(pt1[..., 1] - pt2[..., 1], 2)) + edge2 = jt.sqrt( + jt.pow(pt2[..., 0] - pt3[..., 0], 2) + + jt.pow(pt2[..., 1] - pt3[..., 1], 2)) + + edges = jt.stack([edge1, edge2], dim=1) + + width= jt.max(edges, 1) + height = jt.min(edges, 1) + ratios = (width / height) + return ratios diff --git a/python/jdet/models/losses/spatial_border_loss.py b/python/jdet/models/losses/spatial_border_loss.py new file mode 100644 index 0000000..863120b --- /dev/null +++ b/python/jdet/models/losses/spatial_border_loss.py @@ -0,0 +1,99 @@ +import jittor as jt +from jittor import nn +from jdet.utils.registry import LOSSES + +from jdet.models.boxes.box_ops import points_in_rotated_boxes +from jdet.ops.bbox_transforms import poly2obb +import numpy as np + +@LOSSES.register_module() +class SpatialBorderLoss(nn.Module): + """Spatial Border loss for learning points in Oriented RepPoints. + Args: + pts (jt.Tensor): point sets with shape (N, 9*2). + Default points number in each point set is 9. + gt_bboxes (jt.Tensor): gt_bboxes with polygon form with shape(N, 8) + Returns: + loss (jt.Tensor) + """ + + def __init__(self, loss_weight=1.0): + super(SpatialBorderLoss, self).__init__() + self.loss_weight = loss_weight + + def execute(self, pts, gt_bboxes, weight, *args, **kwargs): + loss = self.loss_weight * weighted_spatial_border_loss( + pts, gt_bboxes, weight, *args, **kwargs) + return loss + +def to_le135(boxes): + # swap edge and angle if h >= w + x, y, w, h, t = boxes.unbind(dim=-1) + start_angle = -45 / 180 * np.pi + w_ = jt.where(w > h, w, h) + h_ = jt.where(w > h, h, w) + t = jt.where(w > h, t, t + np.pi / 2) + t = ((t - start_angle) % np.pi) + start_angle + return jt.stack([x, y, w_, h_, t], dim=-1) + +def spatial_border_loss(pts, gt_bboxes): + """The loss is used to penalize the learning points out of the assigned + ground truth boxes (polygon by default). + Args: + pts (jt.Tensor): point sets with shape (N, 9*2). + gt_bboxes (jt.Tensor): gt_bboxes with polygon form with shape(N, 8) + Returns: + loss (jt.Tensor) + """ + + num_gts, num_pointsets = gt_bboxes.size(0), pts.size(0) + num_point = int(pts.size(1) / 2.0) + loss = jt.zeros([0]) + gt_bboxes_ = to_le135(poly2obb(gt_bboxes)) + + if num_gts > 0: + inside_flag_list = [] + for i in range(num_point): + pt = pts[:, (2 * i):(2 * i + 2)].reshape(num_pointsets, + 2).contiguous() + # inside_pt_flag = points_in_polygons(pt, gt_bboxes) + inside_pt_flag = points_in_rotated_boxes(pt, gt_bboxes_) + inside_pt_flag = jt.diag(inside_pt_flag) + inside_flag_list.append(inside_pt_flag) + + inside_flag = jt.stack(inside_flag_list, dim=1) + pts = pts.reshape(-1, num_point, 2) + out_border_pts = pts[jt.where(inside_flag == 0)] + + if out_border_pts.size(0) > 0: + corr_gt_boxes = gt_bboxes[jt.where(inside_flag == 0)[0]] + corr_gt_boxes_center_x = (corr_gt_boxes[:, 0] + + corr_gt_boxes[:, 4]) / 2.0 + corr_gt_boxes_center_y = (corr_gt_boxes[:, 1] + + corr_gt_boxes[:, 5]) / 2.0 + corr_gt_boxes_center = jt.stack( + [corr_gt_boxes_center_x, corr_gt_boxes_center_y], dim=1) + distance_out_pts = 0.2 * (( + (out_border_pts - corr_gt_boxes_center)**2).sum(dim=1).sqrt()) + loss = distance_out_pts.sum() / out_border_pts.size(0) + + return loss + + +def weighted_spatial_border_loss(pts, gt_bboxes, weight, avg_factor=None): + """Weghted spatial border loss. + Args: + pts (jt.Tensor): point sets with shape (N, 9*2). + gt_bboxes (jt.Tensor): gt_bboxes with polygon form with shape(N, 8) + weight (jt.Tensor): weights for point sets with shape (N) + Returns: + loss (jt.Tensor) + """ + + weight = weight.unsqueeze(dim=1).repeat(1, 4) + assert len(weight.shape) == 2 + if avg_factor is None: + avg_factor = jt.sum(weight > 0).float().item() / 4 + 1e-6 + loss = spatial_border_loss(pts, gt_bboxes) + + return jt.sum(loss)[None] / avg_factor diff --git a/python/jdet/models/roi_heads/__init__.py b/python/jdet/models/roi_heads/__init__.py index c82b747..19e79df 100644 --- a/python/jdet/models/roi_heads/__init__.py +++ b/python/jdet/models/roi_heads/__init__.py @@ -19,4 +19,6 @@ from . import rsdet_head from . import rotated_atss_head from . import rotated_reppoints_head +from . import oriented_reppoints_head +from . import sam_reppoints_head __all__ = [] diff --git a/python/jdet/models/roi_heads/oriented_reppoints_head.py b/python/jdet/models/roi_heads/oriented_reppoints_head.py new file mode 100644 index 0000000..628ba8f --- /dev/null +++ b/python/jdet/models/roi_heads/oriented_reppoints_head.py @@ -0,0 +1,1527 @@ +from jdet.utils.registry import HEADS, LOSSES, BOXES, build_from_cfg +from jdet.models.utils.modules import ConvModule +from jdet.ops.dcn_v1 import DeformConv +from jdet.ops.bbox_transforms import obb2poly, poly2obb +from jdet.utils.general import multi_apply, unmap +from jdet.ops.nms_rotated import multiclass_nms_rotated +from jdet.models.boxes.anchor_target import images_to_levels +from jdet.ops.reppoints_convex_iou import reppoints_convex_iou +from jdet.ops.reppoints_min_area_bbox import reppoints_min_area_bbox +from jdet.ops.chamfer_distance import chamfer_distance +from jdet.models.boxes.box_ops import rotated_box_to_poly + +import math +import jittor as jt +from jittor import nn +import numpy as np +from jittor.nn import _pair, grid_sample + +import numpy as np +def deleteme(a, b, size = 10): + if a is None and b is None: + return + if isinstance(a, dict) and isinstance(b, dict): + print('-' * size) + for a1, b1 in zip(a.values(), b.values()): + deleteme(a1, b1, size + 10) + print('-' * size) + elif isinstance(a, (list, tuple)) and isinstance(b, (list, tuple)): + print('-' * size) + for a1, b1 in zip(a, b): + deleteme(a1, b1, size + 10) + print('-' * size) + elif isinstance(a, jt.Var) and isinstance(b, np.ndarray): + print((a - b).abs().max().item()) + elif isinstance(a, np.ndarray) and isinstance(b, np.ndarray): + print(np.max(np.abs(a - b))) + elif isinstance(a, (int, float)) and isinstance(b, (int, float)): + print("number diff:", a - b) + else: + print(type(a)) + print(type(b)) + raise NotImplementedError +def transpose_to(a, b): + if a is None: + return None + if isinstance(a, list) and isinstance(b, list): + rlist = [] + for a1, b1 in zip(a, b): + rlist.append(transpose_to(a1, b1)) + return rlist + elif isinstance(a, dict) and isinstance(b, dict): + rdict = [] + for k in b.keys(): + rdict[k] = transpose_to(a[k], b[k]) + return rdict + elif isinstance(a, np.ndarray) and isinstance(b, jt.Var): + return jt.array(a) + elif isinstance(a, np.ndarray) and isinstance(b, np.ndarray): + return a + elif isinstance(a, tuple) and isinstance(b, tuple): + rlist = [transpose_to(a1, b1) for a1, b1 in zip(a, b)] + return tuple(rlist) + elif isinstance(a, (int, float, str)) and isinstance(b, (int, float, str)): + assert(type(a) == type(b)) + return a + else: + print(type(a)) + print(type(b)) + raise NotImplementedError + +def fake_argsort2(x, dim=0, descending=False): + x_ = x.data + if (descending): + x__ = -x_ + else: + x__ = x_ + index_ = np.argsort(x__, axis=dim, kind="stable") + y_ = x_[index_] + index = jt.array(index_) + y = jt.array(y_) + return index, y + +def ChamferDistance2D(point_set_1, + point_set_2, + distance_weight=0.05, + eps=1e-12): + """Compute the Chamfer distance between two point sets. + + Args: + point_set_1 (jt.tensor): point set 1 with shape (N_pointsets, + N_points, 2) + point_set_2 (jt.tensor): point set 2 with shape (N_pointsets, + N_points, 2) + + Returns: + dist (jt.tensor): chamfer distance between two point sets + with shape (N_pointsets,) + """ + assert point_set_1.shape == point_set_2.shape + assert point_set_1.shape[-1] == point_set_2.shape[-1] + # assert point_set_1.dim() <= 3 + dist1, dist2, _, _ = chamfer_distance(point_set_1, point_set_2) + dist1 = jt.sqrt(jt.clamp(dist1, eps)) + dist2 = jt.sqrt(jt.clamp(dist2, eps)) + dist = distance_weight * (dist1.mean(-1) + dist2.mean(-1)) / 2.0 + + return dist + + +@HEADS.register_module() +class OrientedRepPointsHead(nn.Module): + """Rotated RepPoints head. + + Args: + num_classes (int): Number of classes. + in_channels (int): Number of input channels. + feat_channels (int): Number of feature channels. + point_feat_channels (int, optional): Number of channels of points + features. + stacked_convs (int, optional): Number of stacked convolutions. + num_points (int, optional): Number of points in points set. + gradient_mul (float, optional): The multiplier to gradients from + points refinement and recognition. + point_strides (Iterable, optional): points strides. + point_base_scale (int, optional): Bbox scale for assigning labels. + conv_bias (str, optional): The bias of convolution. + loss_cls (dict, optional): Config of classification loss. + loss_bbox_init (dict, optional): Config of initial points loss. + loss_bbox_refine (dict, optional): Config of points loss in refinement. + conv_cfg (dict, optional): The config of convolution. + norm_cfg (dict, optional): The config of normlization. + train_cfg (dict, optional): The config of train. + test_cfg (dict, optional): The config of test. + center_init (bool, optional): Whether to use center point assignment. + top_ratio (float, optional): Ratio of top high-quality point sets. + Defaults to 0.4. + init_qua_weight (float, optional): Quality weight of initial + stage. + ori_qua_weight (float, optional): Orientation quality weight. + poc_qua_weight (float, optional): Point-wise correlation + quality weight. + """ + + def __init__(self, + num_classes, + in_channels, + feat_channels, + point_feat_channels=256, + stacked_convs=3, + num_points=9, + gradient_mul=0.1, + point_strides=[8, 16, 32, 64, 128], + point_base_scale=4, + conv_bias='auto', + loss_cls=dict( + type='FocalLoss', + use_sigmoid=True, + gamma=2.0, + alpha=0.25, + loss_weight=1.0), + loss_bbox_init=dict( + type='SmoothL1Loss', beta=1.0 / 9.0, loss_weight=0.5), + loss_bbox_refine=dict( + type='SmoothL1Loss', beta=1.0 / 9.0, loss_weight=1.0), + loss_spatial_init=dict( + type='SpatialBorderLoss', loss_weight=0.05), + loss_spatial_refine=dict( + type='SpatialBorderLoss', loss_weight=0.1), + conv_cfg=None, + norm_cfg=None, + train_cfg=None, + test_cfg=None, + center_init=True, + top_ratio=0.4, + init_qua_weight=0.2, + ori_qua_weight=0.3, + poc_qua_weight=0.1, + **kwargs): + self.num_points = num_points + self.point_feat_channels = point_feat_channels + self.center_init = center_init + + # we use deform conv to extract points features + self.dcn_kernel = int(np.sqrt(num_points)) + self.dcn_pad = int((self.dcn_kernel - 1) / 2) + assert self.dcn_kernel * self.dcn_kernel == num_points, \ + 'The points number should be a square number.' + assert self.dcn_kernel % 2 == 1, \ + 'The points number should be an odd square number.' + dcn_base = np.arange(-self.dcn_pad, self.dcn_pad + 1).astype(np.float64) + dcn_base_y = np.repeat(dcn_base, self.dcn_kernel) + dcn_base_x = np.tile(dcn_base, self.dcn_kernel) + dcn_base_offset = np.stack([dcn_base_y, dcn_base_x], axis=1).reshape((-1)) + self.dcn_base_offset = jt.array(dcn_base_offset).view(1, -1, 1, 1) + self.num_classes = num_classes + self.in_channels = in_channels + self.feat_channels = feat_channels + self.stacked_convs = stacked_convs + assert conv_bias == 'auto' or isinstance(conv_bias, bool) + self.conv_bias = conv_bias + self.loss_cls = build_from_cfg(loss_cls, LOSSES) + self.train_cfg = train_cfg + self.test_cfg = test_cfg + self.conv_cfg = conv_cfg + self.norm_cfg = norm_cfg + self.gradient_mul = gradient_mul + self.point_base_scale = point_base_scale + self.point_strides = point_strides + self.prior_generator = MlvlPointGenerator(self.point_strides, offset=0.) + self.num_base_priors = self.prior_generator.num_base_priors[0] + self.sampling = loss_cls['type'] not in ['FocalLoss'] + if self.train_cfg: + self.init_assigner = build_from_cfg(self.train_cfg.init.assigner, BOXES) + self.refine_assigner = build_from_cfg(self.train_cfg.refine.assigner, BOXES) + # use PseudoSampler when sampling is False + if self.sampling and hasattr(self.train_cfg, 'sampler'): + sampler_cfg = self.train_cfg.sampler + else: + sampler_cfg = dict(type='PseudoSampler') + self.sampler = build_from_cfg(sampler_cfg, BOXES) + self.use_sigmoid_cls = loss_cls.get('use_sigmoid', False) + if self.use_sigmoid_cls: + self.cls_out_channels = self.num_classes + else: + self.cls_out_channels = self.num_classes + 1 + self.loss_bbox_init = build_from_cfg(loss_bbox_init, LOSSES) + self.loss_bbox_refine = build_from_cfg(loss_bbox_refine, LOSSES) + self.loss_spatial_init = build_from_cfg(loss_spatial_init, LOSSES) + self.loss_spatial_refine = build_from_cfg(loss_spatial_refine, LOSSES) + self.init_qua_weight = init_qua_weight + self.ori_qua_weight = ori_qua_weight + self.poc_qua_weight = poc_qua_weight + self.top_ratio = top_ratio + self._init_layers() + + def _init_layers(self): + """Initialize layers of the head.""" + self.relu = nn.ReLU() + self.cls_convs = nn.ModuleList() + self.reg_convs = nn.ModuleList() + for i in range(self.stacked_convs): + chn = self.in_channels if i == 0 else self.feat_channels + self.cls_convs.append( + ConvModule( + chn, + self.feat_channels, + 3, + stride=1, + padding=1, + conv_cfg=self.conv_cfg, + norm_cfg=self.norm_cfg, + bias=self.conv_bias)) + self.reg_convs.append( + ConvModule( + chn, + self.feat_channels, + 3, + stride=1, + padding=1, + conv_cfg=self.conv_cfg, + norm_cfg=self.norm_cfg, + bias=self.conv_bias)) + pts_out_dim = 2 * self.num_points + self.reppoints_cls_conv = DeformConv(self.feat_channels, + self.point_feat_channels, + self.dcn_kernel, 1, + self.dcn_pad) + self.reppoints_cls_out = nn.Conv2d(self.point_feat_channels, + self.cls_out_channels, 1, 1, 0) + self.reppoints_pts_init_conv = nn.Conv2d(self.feat_channels, + self.point_feat_channels, 3, + 1, 1) + self.reppoints_pts_init_out = nn.Conv2d(self.point_feat_channels, + pts_out_dim, 1, 1, 0) + self.reppoints_pts_refine_conv = DeformConv(self.feat_channels, + self.point_feat_channels, + self.dcn_kernel, 1, + self.dcn_pad) + self.reppoints_pts_refine_out = nn.Conv2d(self.point_feat_channels, + pts_out_dim, 1, 1, 0) + + def points2rotrect(self, pts, y_first=True): + """Convert points to oriented bboxes.""" + if y_first: + pts = pts.reshape(-1, self.num_points, 2) + pts_dy = pts[:, :, 0::2] + pts_dx = pts[:, :, 1::2] + pts = jt.concat([pts_dx, pts_dy], + dim=2).reshape(-1, 2 * self.num_points) + return reppoints_min_area_bbox(pts) + + + def forward_single(self, x): + """Forward feature map of a single FPN level.""" + dcn_base_offset = self.dcn_base_offset.type_as(x) + points_init = 0 + cls_feat = x + pts_feat = x + base_feat = x + for cls_conv in self.cls_convs: + cls_feat = cls_conv(cls_feat) + for reg_conv in self.reg_convs: + pts_feat = reg_conv(pts_feat) + # initialize reppoints + pts_out_init = self.reppoints_pts_init_out( + self.relu(self.reppoints_pts_init_conv(pts_feat))) + pts_out_init = pts_out_init + points_init + # refine and classify reppoints + pts_out_init_grad_mul = (1 - self.gradient_mul) * pts_out_init.detach() + self.gradient_mul * pts_out_init + dcn_offset = pts_out_init_grad_mul - dcn_base_offset + cls_out = self.reppoints_cls_out( + self.relu(self.reppoints_cls_conv(cls_feat, dcn_offset))) + pts_out_refine = self.reppoints_pts_refine_out( + self.relu(self.reppoints_pts_refine_conv(pts_feat, dcn_offset))) + pts_out_refine = pts_out_refine + pts_out_init.detach() + + return cls_out, pts_out_init, pts_out_refine, base_feat + + def get_points(self, featmap_sizes, img_metas): + """Get points according to feature map sizes. + + Args: + featmap_sizes (list[tuple]): Multi-level feature map sizes. + img_metas (list[dict]): Image meta info. + + Returns: + tuple: points of each image, valid flags of each image + """ + num_imgs = len(img_metas) + + multi_level_points = self.prior_generator.grid_priors( + featmap_sizes, with_stride=True) + points_list = [[point.clone() for point in multi_level_points] + for _ in range(num_imgs)] + + valid_flag_list = [] + for img_id, img_meta in enumerate(img_metas): + multi_level_flags = self.prior_generator.valid_flags( + featmap_sizes, img_meta['pad_shape']) + valid_flag_list.append(multi_level_flags) + + return points_list, valid_flag_list + + def offset_to_pts(self, center_list, pred_list): + """Change from point offset to point coordinate.""" + pts_list = [] + for i_lvl, _ in enumerate(self.point_strides): + pts_lvl = [] + for i_img, _ in enumerate(center_list): + pts_center = center_list[i_img][i_lvl][:, :2].repeat(1, self.num_points) + pts_shift = pred_list[i_lvl][i_img] + yx_pts_shift = pts_shift.permute(1, 2, 0).view(-1, 2 * self.num_points) + y_pts_shift = yx_pts_shift[..., 0::2] + x_pts_shift = yx_pts_shift[..., 1::2] + xy_pts_shift = jt.stack([x_pts_shift, y_pts_shift], -1) + xy_pts_shift = xy_pts_shift.view(*yx_pts_shift.shape[:-1], -1) + pts = xy_pts_shift * self.point_strides[i_lvl] + pts_center + pts_lvl.append(pts) + pts_lvl = jt.stack(pts_lvl, 0) + pts_list.append(pts_lvl) + return pts_list + + def get_adaptive_points_feature(self, features, pt_locations, stride): + """Get the points features from the locations of predicted points. + + Args: + features (jt.tensor): base feature with shape (B,C,W,H) + pt_locations (jt.tensor): locations of points in each point set + with shape (B, N_points_set(number of point set), + N_points(number of points in each point set) *2) + Returns: + tensor: sampling features with (B, C, N_points_set, N_points) + """ + + h = features.shape[2] * stride + w = features.shape[3] * stride + + pt_locations = pt_locations.view(pt_locations.shape[0], + pt_locations.shape[1], -1, 2).clone() + pt_locations[..., 0] = pt_locations[..., 0] / (w / 2.) - 1 + pt_locations[..., 1] = pt_locations[..., 1] / (h / 2.) - 1 + + batch_size = features.size(0) + sampled_features = jt.zeros([ + pt_locations.shape[0], + features.size(1), + pt_locations.size(1), + pt_locations.size(2) + ]) + + for i in range(batch_size): + feature = grid_sample(features[i:i + 1], + pt_locations[i:i + 1])[0] + sampled_features[i] = feature + + return sampled_features, + + def feature_cosine_similarity(self, points_features): + """Compute the points features similarity for points-wise correlation. + + Args: + points_features (jt.tensor): sampling point feature with + shape (N_pointsets, N_points, C) + Returns: + max_correlation: max feature similarity in each point set with + shape (N_points_set, N_points, C) + """ + + mean_points_feats = jt.mean(points_features, dim=1, keepdims=True) + norm_pts_feats = jt.norm( + points_features, p=2, dim=2).unsqueeze(dim=2).clamp(min_v=1e-2) + norm_mean_pts_feats = jt.norm( + mean_points_feats, p=2, dim=2).unsqueeze(dim=2).clamp(min_v=1e-2) + + unity_points_features = points_features / norm_pts_feats + unity_mean_points_feats = mean_points_feats / norm_mean_pts_feats + + # cos_similarity = nn.CosineSimilarity(dim=2, eps=1e-6) + unity_points_features = 1. * unity_points_features / (jt.norm(unity_points_features, 2, 2, keepdims=True).expand_as(unity_mean_points_feats) + 1e-6) + unity_mean_points_feats = 1. * unity_mean_points_feats / (jt.norm(unity_mean_points_feats, 2, 2, keepdims=True).expand_as(unity_mean_points_feats) + 1e-6) + feats_similarity = 1.0 - (unity_points_features * unity_mean_points_feats).sum(dim=-1) + # feats_similarity = 1.0 - cos_similarity(unity_points_features, + # unity_mean_points_feats) + + max_correlation = jt.max(feats_similarity, dim=1) + + return max_correlation + + def sampling_points(self, polygons, points_num): + """Sample edge points for polygon. + + Args: + polygons (jt.tensor): polygons with shape (N, 8) + points_num (int): number of sampling points for each polygon edge. + 10 by default. + + Returns: + sampling_points (jt.tensor): sampling points with shape (N, + points_num*4, 2) + """ + polygons_xs, polygons_ys = polygons[:, 0::2], polygons[:, 1::2] + ratio = jt.linspace(0, 1, points_num).repeat( + polygons.shape[0], 1) + + edge_pts_x = [] + edge_pts_y = [] + for i in range(4): + if i < 3: + points_x = ratio * polygons_xs[:, i + 1:i + 2] + ( + 1 - ratio) * polygons_xs[:, i:i + 1] + points_y = ratio * polygons_ys[:, i + 1:i + 2] + ( + 1 - ratio) * polygons_ys[:, i:i + 1] + else: + points_x = ratio * polygons_xs[:, 0].unsqueeze(1) + ( + 1 - ratio) * polygons_xs[:, i].unsqueeze(1) + points_y = ratio * polygons_ys[:, 0].unsqueeze(1) + ( + 1 - ratio) * polygons_ys[:, i].unsqueeze(1) + + edge_pts_x.append(points_x) + edge_pts_y.append(points_y) + + sampling_points_x = jt.concat(edge_pts_x, dim=1).unsqueeze(dim=2) + sampling_points_y = jt.concat(edge_pts_y, dim=1).unsqueeze(dim=2) + sampling_points = jt.concat([sampling_points_x, sampling_points_y], + dim=2) + + return sampling_points + + def pointsets_quality_assessment(self, pts_features, cls_score, + pts_pred_init, pts_pred_refine, label, + bbox_gt, label_weight, bbox_weight, + pos_inds): + """Assess the quality of each point set from the classification, + localization, orientation, and point-wise correlation based on + the assigned point sets samples. + Args: + pts_features (jt.tensor): points features with shape (N, 9, C) + cls_score (jt.tensor): classification scores with + shape (N, class_num) + pts_pred_init (jt.tensor): initial point sets prediction with + shape (N, 9*2) + pts_pred_refine (jt.tensor): refined point sets prediction with + shape (N, 9*2) + label (jt.tensor): gt label with shape (N) + bbox_gt(jt.tensor): gt bbox of polygon with shape (N, 8) + label_weight (jt.tensor): label weight with shape (N) + bbox_weight (jt.tensor): box weight with shape (N) + pos_inds (jt.tensor): the inds of positive point set samples + + Returns: + qua (jt.tensor) : weighted quality values for positive + point set samples. + """ + pos_scores = cls_score[pos_inds] + pos_pts_pred_init = pts_pred_init[pos_inds] + pos_pts_pred_refine = pts_pred_refine[pos_inds] + pos_pts_refine_features = pts_features[pos_inds] + pos_bbox_gt = bbox_gt[pos_inds] + pos_label = label[pos_inds] + pos_label_weight = label_weight[pos_inds] + pos_bbox_weight = bbox_weight[pos_inds] + + # quality of point-wise correlation + qua_poc = self.poc_qua_weight * self.feature_cosine_similarity( + pos_pts_refine_features) + + qua_cls = self.loss_cls( + pos_scores, + pos_label, + pos_label_weight, + avg_factor=self.loss_cls.loss_weight, + reduction_override='none') + + polygons_pred_init = reppoints_min_area_bbox(pos_pts_pred_init) + polygons_pred_refine = reppoints_min_area_bbox(pos_pts_pred_refine) + sampling_pts_pred_init = self.sampling_points( + polygons_pred_init, 10) + sampling_pts_pred_refine = self.sampling_points( + polygons_pred_refine, 10) + sampling_pts_gt = self.sampling_points(pos_bbox_gt, 10) + + # quality of orientation + qua_ori_init = self.ori_qua_weight * ChamferDistance2D( + sampling_pts_gt, sampling_pts_pred_init) + qua_ori_refine = self.ori_qua_weight * ChamferDistance2D( + sampling_pts_gt, sampling_pts_pred_refine) + + # quality of localization + qua_loc_init = self.loss_bbox_refine( + pos_pts_pred_init, + pos_bbox_gt, + pos_bbox_weight, + avg_factor=self.loss_cls.loss_weight, + reduction_override='none') + qua_loc_refine = self.loss_bbox_refine( + pos_pts_pred_refine, + pos_bbox_gt, + pos_bbox_weight, + avg_factor=self.loss_cls.loss_weight, + reduction_override='none') + + # quality of classification + qua_cls = qua_cls.sum(-1) + + # weighted inti-stage and refine-stage + qua = qua_cls + self.init_qua_weight * ( + qua_loc_init + qua_ori_init) + (1.0 - self.init_qua_weight) * ( + qua_loc_refine + qua_ori_refine) + qua_poc + + return qua, + + def dynamic_pointset_samples_selection(self, + quality, + label, + label_weight, + bbox_weight, + pos_inds, + pos_gt_inds, + num_proposals_each_level=None, + num_level=None): + """The dynamic top k selection of point set samples based on the + quality assessment values. + + Args: + quality (jt.tensor): the quality values of positive + point set samples + label (jt.tensor): gt label with shape (N) + bbox_gt(jt.tensor): gt bbox of polygon with shape (N, 8) + label_weight (jt.tensor): label weight with shape (N) + bbox_weight (jt.tensor): box weight with shape (N) + pos_inds (jt.tensor): the inds of positive point set samples + num_proposals_each_level (list[int]): proposals number of + each level + num_level (int): the level number + Returns: + label: gt label with shape (N) + label_weight: label weight with shape (N) + bbox_weight: box weight with shape (N) + num_pos (int): the number of selected positive point samples + with high-qualty + pos_normalize_term (jt.tensor): the corresponding positive + normalize term + """ + + if len(pos_inds) == 0: + return label, label_weight, bbox_weight, 0, jt.array( + []).type_as(bbox_weight) + + num_gt = pos_gt_inds.max() + num_proposals_each_level_ = num_proposals_each_level.copy() + num_proposals_each_level_.insert(0, 0) + inds_level_interval = np.cumsum(num_proposals_each_level_) + pos_level_mask = [] + for i in range(num_level): + mask = (pos_inds >= inds_level_interval[i]) & ( + pos_inds < inds_level_interval[i + 1]) + pos_level_mask.append(mask) + + pos_inds_after_select = [] + ignore_inds_after_select = [] + + if int(num_gt) == 0: + return label, label_weight, bbox_weight, 0, jt.array( + []).type_as(bbox_weight) + + for gt_ind in range(int(num_gt)): + pos_inds_select = [] + pos_loss_select = [] + gt_mask = pos_gt_inds == (gt_ind + 1) + for level in range(num_level): + level_mask = pos_level_mask[level] + level_gt_mask = level_mask & gt_mask + mask_quality = quality[level_gt_mask] + if level_gt_mask.sum() <= 6: + topk_inds, value = jt.argsort(mask_quality,dim=mask_quality.ndim-1,descending=False) + else: + value, topk_inds = mask_quality.topk( + 6, largest=False) + pos_inds_select.append(pos_inds[level_gt_mask][topk_inds]) + pos_loss_select.append(value) + pos_inds_select = jt.concat(pos_inds_select) + pos_loss_select = jt.concat(pos_loss_select) + + if len(pos_inds_select) < 2: + pos_inds_after_select.append(pos_inds_select) + ignore_inds_after_select.append(jt.empty([])) + + else: + # pos_loss_select, sort_inds = pos_loss_select.sort() # small to large + # pos_loss_select, sort_inds = fake_argsort2(pos_loss_select) + sort_inds, pos_loss_select = fake_argsort2(pos_loss_select) + pos_inds_select = pos_inds_select[sort_inds] + # dynamic top k + topk = math.ceil(pos_loss_select.shape[0] * self.top_ratio) + pos_inds_select_topk = pos_inds_select[:topk] + pos_inds_after_select.append(pos_inds_select_topk) + ignore_inds_after_select.append( + jt.empty([])) + + pos_inds_after_select = jt.concat(pos_inds_after_select) + ignore_inds_after_select = jt.concat(ignore_inds_after_select) + + reassign_mask = (pos_inds.unsqueeze(1) != pos_inds_after_select).all(1) + reassign_ids = pos_inds[reassign_mask] + label[reassign_ids] = 0 + label_weight[ignore_inds_after_select] = 0 + bbox_weight[reassign_ids] = 0 + num_pos = len(pos_inds_after_select) + + pos_level_mask_after_select = [] + for i in range(num_level): + mask = (pos_inds_after_select >= inds_level_interval[i]) & ( + pos_inds_after_select < inds_level_interval[i + 1]) + pos_level_mask_after_select.append(mask) + pos_level_mask_after_select = jt.stack(pos_level_mask_after_select, + 0).type_as(label) + pos_normalize_term = pos_level_mask_after_select * ( + self.point_base_scale * + jt.array(self.point_strides).type_as(label)).reshape(-1, 1) + pos_normalize_term = pos_normalize_term[ + pos_normalize_term > 0].type_as(bbox_weight) + assert len(pos_normalize_term) == len(pos_inds_after_select) + + return label, label_weight, bbox_weight, num_pos, pos_normalize_term + + def _point_target_single(self, + flat_proposals, + valid_flags, + gt_bboxes, + gt_bboxes_ignore, + gt_labels, + overlaps, + stage='init', + unmap_outputs=True): + """Single point target function.""" + inside_flags = valid_flags + if not inside_flags.any(): + return (None, ) * 8 + # assign gt and sample proposals + proposals = flat_proposals[inside_flags, :] + + if stage == 'init': + assigner = self.init_assigner + pos_weight = self.train_cfg.init.pos_weight + else: + assigner = self.refine_assigner + pos_weight = self.train_cfg.refine.pos_weight + + # convert gt from obb to poly + gt_bboxes = obb2poly(gt_bboxes) + + assign_result = assigner.assign(proposals, gt_bboxes, + gt_bboxes_ignore, + None if self.sampling else gt_labels) + sampling_result = self.sampler.sample(assign_result, proposals, + gt_bboxes) + + + num_valid_proposals = proposals.shape[0] + bbox_gt = jt.zeros([num_valid_proposals, 8], dtype=proposals.dtype) + pos_proposals = jt.zeros_like(proposals) + proposals_weights = jt.zeros(num_valid_proposals, dtype=proposals.dtype) + labels = jt.full((num_valid_proposals, ), + 0, + dtype=jt.int32) + label_weights = jt.zeros((num_valid_proposals,), dtype=jt.float32) + + pos_inds = sampling_result.pos_inds + neg_inds = sampling_result.neg_inds + if len(pos_inds) > 0: + pos_gt_bboxes = sampling_result.pos_gt_bboxes + bbox_gt[pos_inds, :] = pos_gt_bboxes + pos_proposals[pos_inds, :] = proposals[pos_inds, :] + proposals_weights[pos_inds] = 1.0 + if gt_labels is None: + labels[pos_inds] = 1 + else: + labels[pos_inds] = gt_labels[ + sampling_result.pos_assigned_gt_inds] + if pos_weight <= 0: + label_weights[pos_inds] = 1.0 + else: + label_weights[pos_inds] = pos_weight + if len(neg_inds) > 0: + label_weights[neg_inds] = 1.0 + + # map up to original set of proposals + if unmap_outputs: + num_total_proposals = flat_proposals.size(0) + labels = unmap(labels, num_total_proposals, inside_flags) + label_weights = unmap(label_weights, num_total_proposals, + inside_flags) + bbox_gt = unmap(bbox_gt, num_total_proposals, inside_flags) + pos_proposals = unmap(pos_proposals, num_total_proposals, + inside_flags) + proposals_weights = unmap(proposals_weights, num_total_proposals, + inside_flags) + + return (labels, label_weights, bbox_gt, pos_proposals, + proposals_weights, pos_inds, neg_inds, sampling_result) + + def init_loss_single(self, pts_pred_init, bbox_gt_init, bbox_weights_init, + stride): + """Single initial stage loss function.""" + normalize_term = self.point_base_scale * stride + + bbox_gt_init = bbox_gt_init.reshape(-1, 8) + bbox_weights_init = bbox_weights_init.reshape(-1) + pts_pred_init = pts_pred_init.reshape(-1, 2 * self.num_points) + pos_ind_init = (bbox_weights_init > 0).nonzero().reshape(-1) + + pts_pred_init_norm = pts_pred_init[pos_ind_init] + bbox_gt_init_norm = bbox_gt_init[pos_ind_init] + bbox_weights_pos_init = bbox_weights_init[pos_ind_init] + + loss_pts_init = self.loss_bbox_init( + pts_pred_init_norm / normalize_term, + bbox_gt_init_norm / normalize_term, bbox_weights_pos_init) + + loss_border_init = self.loss_spatial_init( + pts_pred_init_norm.reshape(-1, 2 * self.num_points) / + normalize_term, + bbox_gt_init_norm / normalize_term, + bbox_weights_pos_init, + avg_factor=None) + + return loss_pts_init, loss_border_init + + def get_targets(self, + proposals_list, + valid_flag_list, + gt_bboxes_list, + img_metas, + gt_bboxes_ignore_list=None, + gt_labels_list=None, + stage='init', + label_channels=1, + unmap_outputs=True): + """Compute corresponding GT box and classification targets for + proposals. + + Args: + proposals_list (list[list]): Multi level points/bboxes of each + image. + valid_flag_list (list[list]): Multi level valid flags of each + image. + gt_bboxes_list (list[Tensor]): Ground truth bboxes of each image. + img_metas (list[dict]): Meta info of each image. + gt_bboxes_ignore_list (list[Tensor]): Ground truth bboxes to be + ignored. + gt_bboxes_list (list[Tensor]): Ground truth labels of each box. + stage (str): `init` or `refine`. Generate target for init stage or + refine stage + label_channels (int): Channel of label. + unmap_outputs (bool): Whether to map outputs back to the original + set of anchors. + + Returns: + tuple (list[Tensor]): + + - labels_list (list[Tensor]): Labels of each level. + - label_weights_list (list[Tensor]): Label weights of each \ + level. + - bbox_gt_list (list[Tensor]): Ground truth bbox of each level. + - proposal_list (list[Tensor]): Proposals(points/bboxes) of \ + each level. + - proposal_weights_list (list[Tensor]): Proposal weights of \ + each level. + - num_total_pos (int): Number of positive samples in all \ + images. + - num_total_neg (int): Number of negative samples in all \ + images. + """ + assert stage in ['init', 'refine'] + num_imgs = len(img_metas) + assert len(proposals_list) == len(valid_flag_list) == num_imgs + + # points number of multi levels + num_level_proposals = [points.size(0) for points in proposals_list[0]] + + # concat all level points and flags to a single tensor + for i in range(num_imgs): + assert len(proposals_list[i]) == len(valid_flag_list[i]) + proposals_list[i] = jt.concat(proposals_list[i]) + valid_flag_list[i] = jt.concat(valid_flag_list[i]) + + # compute targets for each image + if gt_bboxes_ignore_list is None: + gt_bboxes_ignore_list = [None for _ in range(num_imgs)] + if gt_labels_list is None: + gt_labels_list = [None for _ in range(num_imgs)] + all_overlaps_rotate_list = [None] * len(proposals_list) + (all_labels, all_label_weights, all_bbox_gt, all_proposals, + all_proposal_weights, pos_inds_list, neg_inds_list, + sampling_result) = multi_apply( + self._point_target_single, + proposals_list, + valid_flag_list, + gt_bboxes_list, + gt_bboxes_ignore_list, + gt_labels_list, + all_overlaps_rotate_list, + stage=stage, + unmap_outputs=unmap_outputs) + + if stage == 'init': + # no valid points + if any([labels is None for labels in all_labels]): + return None + # sampled points of all images + num_total_pos = sum([max(inds.numel(), 1) for inds in pos_inds_list]) + num_total_neg = sum([max(inds.numel(), 1) for inds in neg_inds_list]) + labels_list = images_to_levels(all_labels, num_level_proposals) + label_weights_list = images_to_levels(all_label_weights, + num_level_proposals) + bbox_gt_list = images_to_levels(all_bbox_gt, num_level_proposals) + proposals_list = images_to_levels(all_proposals, num_level_proposals) + proposal_weights_list = images_to_levels(all_proposal_weights, + num_level_proposals) + + return (labels_list, label_weights_list, bbox_gt_list, proposals_list, + proposal_weights_list, num_total_pos, num_total_neg, None) + else: + pos_inds = [] + # pos_gt_index = [] + for i, single_labels in enumerate(all_labels): + pos_mask = single_labels > 0 + pos_inds.append(pos_mask.nonzero().view(-1)) + + gt_inds = [item.pos_assigned_gt_inds for item in sampling_result] + + return (all_labels, all_label_weights, all_bbox_gt, all_proposals, + all_proposal_weights, pos_inds, gt_inds) + + def loss(self, + cls_scores, + pts_preds_init, + pts_preds_refine, + base_features, + gt_bboxes, + gt_labels, + img_metas, + gt_bboxes_ignore=None): + """Loss function of CFA head.""" + + featmap_sizes = [featmap.size()[-2:] for featmap in cls_scores] + assert len(featmap_sizes) == self.prior_generator.num_levels + label_channels = self.cls_out_channels if self.use_sigmoid_cls else 1 + + + # target for initial stage + center_list, valid_flag_list = self.get_points( + featmap_sizes, img_metas) + pts_coordinate_preds_init = self.offset_to_pts(center_list, + pts_preds_init) + + num_proposals_each_level = [(featmap.size(-1) * featmap.size(-2)) + for featmap in cls_scores] + num_level = len(featmap_sizes) + assert num_level == len(pts_coordinate_preds_init) + + if self.train_cfg.init.assigner['type'] == 'ConvexAssigner': + candidate_list = center_list + else: + raise NotImplementedError + cls_reg_targets_init = self.get_targets( + candidate_list, + valid_flag_list, + gt_bboxes, + img_metas, + gt_bboxes_ignore_list=gt_bboxes_ignore, + gt_labels_list=gt_labels, + stage='init', + label_channels=label_channels) + (*_, bbox_gt_list_init, candidate_list_init, bbox_weights_list_init, + num_total_pos_init, num_total_neg_init, _) = cls_reg_targets_init + + # target for refinement stage + center_list, valid_flag_list = self.get_points( + featmap_sizes, img_metas) + pts_coordinate_preds_refine = self.offset_to_pts( + center_list, pts_preds_refine) + + refine_points_features, = multi_apply(self.get_adaptive_points_feature, + base_features, + pts_coordinate_preds_refine, + self.point_strides) + features_pts_refine = levels_to_images(refine_points_features) + features_pts_refine = [ + item.reshape(-1, self.num_points, item.shape[-1]) + for item in features_pts_refine + ] + + points_list = [] + for i_img, center in enumerate(center_list): + points = [] + for i_lvl in range(len(pts_preds_refine)): + points_preds_init_ = pts_preds_init[i_lvl].detach() + points_preds_init_ = points_preds_init_.view( + points_preds_init_.shape[0], -1, + *points_preds_init_.shape[2:]) + points_shift = points_preds_init_.permute( + 0, 2, 3, 1) * self.point_strides[i_lvl] + points_center = center[i_lvl][:, :2].repeat(1, self.num_points) + points.append( + points_center + + points_shift[i_img].reshape(-1, 2 * self.num_points)) + points_list.append(points) + + cls_reg_targets_refine = self.get_targets( + points_list, + valid_flag_list, + gt_bboxes, + img_metas, + gt_bboxes_ignore_list=gt_bboxes_ignore, + gt_labels_list=gt_labels, + stage='refine', + label_channels=label_channels) + + (labels_list, label_weights_list, bbox_gt_list_refine, _, + bbox_weights_list_refine, pos_inds_list_refine, + pos_gt_index_list_refine) = cls_reg_targets_refine + + cls_scores = levels_to_images(cls_scores) + cls_scores = [ + item.reshape(-1, self.cls_out_channels) for item in cls_scores + ] + + pts_coordinate_preds_init_img = levels_to_images( + pts_coordinate_preds_init, flatten=True) + pts_coordinate_preds_init_img = [ + item.reshape(-1, 2 * self.num_points) + for item in pts_coordinate_preds_init_img + ] + + pts_coordinate_preds_refine_img = levels_to_images( + pts_coordinate_preds_refine, flatten=True) + pts_coordinate_preds_refine_img = [ + item.reshape(-1, 2 * self.num_points) + for item in pts_coordinate_preds_refine_img + ] + + with jt.no_grad(): + + quality_assess_list, = multi_apply( + self.pointsets_quality_assessment, features_pts_refine, + cls_scores, pts_coordinate_preds_init_img, + pts_coordinate_preds_refine_img, labels_list, + bbox_gt_list_refine, label_weights_list, + bbox_weights_list_refine, pos_inds_list_refine) + + labels_list, label_weights_list, bbox_weights_list_refine, \ + num_pos, pos_normalize_term = multi_apply( + self.dynamic_pointset_samples_selection, + quality_assess_list, + labels_list, + label_weights_list, + bbox_weights_list_refine, + pos_inds_list_refine, + pos_gt_index_list_refine, + num_proposals_each_level=num_proposals_each_level, + num_level=num_level + ) + num_pos = sum(num_pos) + + # convert all tensor list to a flatten tensor + cls_scores = jt.concat(cls_scores, 0).view(-1, cls_scores[0].size(-1)) + pts_preds_refine = jt.concat(pts_coordinate_preds_refine_img, 0).view( + -1, pts_coordinate_preds_refine_img[0].size(-1)) + + labels = jt.concat(labels_list, 0).view(-1) + labels_weight = jt.concat(label_weights_list, 0).view(-1) + bbox_gt_refine = jt.concat(bbox_gt_list_refine, + 0).view(-1, bbox_gt_list_refine[0].size(-1)) + bbox_weights_refine = jt.concat(bbox_weights_list_refine, 0).view(-1) + pos_normalize_term = jt.concat(pos_normalize_term, 0).reshape(-1) + pos_inds_flatten = (labels > 0).nonzero().reshape(-1) + + # print('pos_normalize_term: ', pos_normalize_term.shape, pos_inds_flatten.shape) + # assert len(pos_normalize_term) == len(pos_inds_flatten) + + if bool(num_pos) and len(pos_normalize_term) == len(pos_inds_flatten): + losses_cls = self.loss_cls( + cls_scores, labels, labels_weight, avg_factor=num_pos) + pos_pts_pred_refine = pts_preds_refine[pos_inds_flatten] + pos_bbox_gt_refine = bbox_gt_refine[pos_inds_flatten] + + pos_bbox_weights_refine = bbox_weights_refine[pos_inds_flatten] + losses_pts_refine = self.loss_bbox_refine( + pos_pts_pred_refine / pos_normalize_term.reshape(-1, 1), + pos_bbox_gt_refine / pos_normalize_term.reshape(-1, 1), + pos_bbox_weights_refine) + + loss_border_refine = self.loss_spatial_refine( + pos_pts_pred_refine.reshape(-1, 2 * self.num_points) / + pos_normalize_term.reshape(-1, 1), + pos_bbox_gt_refine / pos_normalize_term.reshape(-1, 1), + pos_bbox_weights_refine, + avg_factor=None) + + else: + losses_cls = cls_scores.sum() * 0 + losses_pts_refine = pts_preds_refine.sum() * 0 + loss_border_refine = pts_preds_refine.sum() * 0 + + losses_pts_init, loss_border_init = multi_apply( + self.init_loss_single, pts_coordinate_preds_init, + bbox_gt_list_init, bbox_weights_list_init, self.point_strides) + + loss_dict_all = { + 'loss_cls': losses_cls, + 'loss_pts_init': losses_pts_init, + 'loss_pts_refine': losses_pts_refine, + 'loss_spatial_init': loss_border_init, + 'loss_spatial_refine': loss_border_refine + } + return loss_dict_all + + def get_bboxes(self, + cls_scores, + pts_preds_init, + pts_preds_refine, + img_metas, + cfg=None, + rescale=False, + with_nms=True, + **kwargs): + """Transform network outputs of a batch into bbox results. + + Args: + cls_scores (list[Tensor]): Classification scores for all + scale levels, each is a 4D-tensor, has shape + (batch_size, num_priors * num_classes, H, W). + pts_preds_init (list[Tensor]): Box energies / deltas for all + scale levels, each is a 18D-tensor, has shape + (batch_size, num_points * 2, H, W). + pts_preds_refine (list[Tensor]): Box energies / deltas for all + scale levels, each is a 18D-tensor, has shape + (batch_size, num_points * 2, H, W). + img_metas (list[dict], Optional): Image meta info. Default None. + cfg (mmcv.Config, Optional): Test / postprocessing configuration, + if None, test_cfg would be used. Default None. + rescale (bool): If True, return boxes in original image space. + Default False. + with_nms (bool): If True, do nms before return boxes. + Default True. + + Returns: + list[list[Tensor, Tensor]]: Each item in result_list is 2-tuple. + The first item is an (n, 6) tensor, where the first 4 columns + are bounding box positions (cx, cy, w, h, a) and the + 6-th column is a score between 0 and 1. The second item is a + (n,) tensor where each item is the predicted class label of + the corresponding box. + """ + assert len(cls_scores) == len(pts_preds_refine) + + num_levels = len(cls_scores) + + featmap_sizes = [cls_scores[i].shape[-2:] for i in range(num_levels)] + mlvl_priors = self.prior_generator.grid_priors( + featmap_sizes, + dtype=cls_scores[0].dtype) + + result_list = [] + + for img_id, _ in enumerate(img_metas): + img_meta = img_metas[img_id] + cls_score_list = select_single_mlvl(cls_scores, img_id) + point_pred_list = select_single_mlvl(pts_preds_refine, img_id) + + results = self._get_bboxes_single(cls_score_list, point_pred_list, + mlvl_priors, img_meta, cfg, + rescale, with_nms, **kwargs) + result_list.append(results) + + return result_list + + def _get_bboxes_single(self, + cls_score_list, + point_pred_list, + mlvl_priors, + img_meta, + cfg, + rescale=False, + with_nms=True, + **kwargs): + """Transform outputs of a single image into bbox predictions. + Args: + cls_score_list (list[Tensor]): Box scores from all scale + levels of a single image, each item has shape + (num_priors * num_classes, H, W). + bbox_pred_list (list[Tensor]): Box energies / deltas from + all scale levels of a single image, each item has shape + (num_priors * 4, H, W). + score_factor_list (list[Tensor]): Score factor from all scale + levels of a single image. RepPoints head does not need + this value. + mlvl_priors (list[Tensor]): Each element in the list is + the priors of a single level in feature pyramid, has shape + (num_priors, 2). + img_meta (dict): Image meta info. + cfg (mmcv.Config): Test / postprocessing configuration, + if None, test_cfg would be used. + rescale (bool): If True, return boxes in original image space. + Default: False. + with_nms (bool): If True, do nms before return boxes. + Default: True. + Returns: + tuple[Tensor]: Results of detected bboxes and labels. If with_nms + is False and mlvl_score_factor is None, return mlvl_bboxes and + mlvl_scores, else return mlvl_bboxes, mlvl_scores and + mlvl_score_factor. Usually with_nms is False is used for aug + test. If with_nms is True, then return the following format + + - det_bboxes (Tensor): Predicted bboxes with shape \ + [num_bboxes, 5], where the first 4 columns are bounding \ + box positions (cx, cy, w, h, a) and the 5-th \ + column are scores between 0 and 1. + - det_labels (Tensor): Predicted labels of the corresponding \ + box with shape [num_bboxes]. + """ + + cfg = self.test_cfg if cfg is None else cfg + assert len(cls_score_list) == len(point_pred_list) + scale_factor = img_meta['scale_factor'] + + mlvl_bboxes = [] + mlvl_scores = [] + for level_idx, (cls_score, points_pred, points) in enumerate( + zip(cls_score_list, point_pred_list, mlvl_priors)): + assert cls_score.size()[-2:] == points_pred.size()[-2:] + + cls_score = cls_score.permute(1, 2, + 0).reshape(-1, self.cls_out_channels) + if self.use_sigmoid_cls: + scores = cls_score.sigmoid() + else: + scores = cls_score.softmax(-1)[:, :-1] + + points_pred = points_pred.permute(1, 2, 0).reshape( + -1, 2 * self.num_points) + nms_pre = cfg.get('nms_pre', -1) + if 0 < nms_pre < scores.shape[0]: + if self.use_sigmoid_cls: + max_scores = scores.max(dim=1) + else: + max_scores = scores[:, 1:].max(dim=1) + _, topk_inds = max_scores.topk(nms_pre) + points = points[topk_inds, :] + points_pred = points_pred[topk_inds, :] + scores = scores[topk_inds, :] + + poly_pred = self.points2rotrect(points_pred, y_first=True) + bbox_pos_center = points[:, :2].repeat(1, 4) + polys = poly_pred * self.point_strides[level_idx] + bbox_pos_center + bboxes = poly2obb(polys) + + mlvl_bboxes.append(bboxes) + mlvl_scores.append(scores) + + mlvl_bboxes = jt.concat(mlvl_bboxes) + + if rescale: + mlvl_bboxes[..., :4] /= mlvl_bboxes[..., :4].new_tensor( + scale_factor) + mlvl_scores = jt.concat(mlvl_scores) + if self.use_sigmoid_cls: + padding = jt.zeros((mlvl_scores.shape[0], 1), dtype=mlvl_scores.dtype) + mlvl_scores = jt.concat([padding, mlvl_scores], dim=1) + + if with_nms: + det_bboxes, det_labels = multiclass_nms_rotated( + mlvl_bboxes, mlvl_scores, cfg.score_thr, cfg.nms, + cfg.max_per_img) + boxes = det_bboxes[:, :5] + scores = det_bboxes[:, 5] + polys = rotated_box_to_poly(boxes) + return polys, scores, det_labels + else: + raise NotImplementedError + + def parse_targets(self, targets): + img_metas = [] + gt_bboxes = [] + gt_bboxes_ignore = [] + gt_labels = [] + + for target in targets: + if self.is_training(): + gt_bboxes.append(target["rboxes"]) + gt_labels.append(target["labels"]) + gt_bboxes_ignore.append(target["rboxes_ignore"]) + img_metas.append(dict( + img_shape=target["img_size"][::-1], + scale_factor=target["scale_factor"], + pad_shape = target["pad_shape"] + )) + if not self.is_training(): + return dict(img_metas = img_metas) + return dict( + gt_bboxes = gt_bboxes, + gt_labels = gt_labels, + img_metas = img_metas, + gt_bboxes_ignore = gt_bboxes_ignore, + ) + + def execute(self, feats, targets): + outs = multi_apply(self.forward_single, feats) + if self.is_training(): + return self.loss(*outs, **self.parse_targets(targets)) + return self.get_bboxes(*outs, **self.parse_targets(targets)) + + +def select_single_mlvl(mlvl_tensors, batch_id, detach=True): + """Extract a multi-scale single image tensor from a multi-scale batch + tensor based on batch index. + + Note: The default value of detach is True, because the proposal gradient + needs to be detached during the training of the two-stage model. E.g + Cascade Mask R-CNN. + + Args: + mlvl_tensors (list[Tensor]): Batch tensor for all scale levels, + each is a 4D-tensor. + batch_id (int): Batch index. + detach (bool): Whether detach gradient. Default True. + + Returns: + list[Tensor]: Multi-scale single image tensor. + """ + assert isinstance(mlvl_tensors, (list, tuple)) + num_levels = len(mlvl_tensors) + + if detach: + mlvl_tensor_list = [ + mlvl_tensors[i][batch_id].detach() for i in range(num_levels) + ] + else: + mlvl_tensor_list = [ + mlvl_tensors[i][batch_id] for i in range(num_levels) + ] + return mlvl_tensor_list + +def levels_to_images(mlvl_tensor, flatten=False): + """Concat multi-level feature maps by image. + + [feature_level0, feature_level1...] -> [feature_image0, feature_image1...] + Convert the shape of each element in mlvl_tensor from (N, C, H, W) to + (N, H*W , C), then split the element to N elements with shape (H*W, C), and + concat elements in same image of all level along first dimension. + + Args: + mlvl_tensor (list[jt.Tensor]): list of Tensor which collect from + corresponding level. Each element is of shape (N, C, H, W) + flatten (bool, optional): if shape of mlvl_tensor is (N, C, H, W) + set False, if shape of mlvl_tensor is (N, H, W, C) set True. + + Returns: + list[jt.Tensor]: A list that contains N tensors and each tensor is + of shape (num_elements, C) + """ + batch_size = mlvl_tensor[0].size(0) + batch_list = [[] for _ in range(batch_size)] + if flatten: + channels = mlvl_tensor[0].size(-1) + else: + channels = mlvl_tensor[0].size(1) + for t in mlvl_tensor: + if not flatten: + t = t.permute(0, 2, 3, 1) + t = t.view(batch_size, -1, channels) + for img in range(batch_size): + batch_list[img].append(t[img]) + return [jt.concat(item, 0) for item in batch_list] + +class MlvlPointGenerator: + """Standard points generator for multi-level (Mlvl) feature maps in 2D + points-based detectors. + + Args: + strides (list[int] | list[tuple[int, int]]): Strides of anchors + in multiple feature levels in order (w, h). + offset (float): The offset of points, the value is normalized with + corresponding stride. Defaults to 0.5. + """ + + def __init__(self, strides, offset=0.5): + self.strides = [_pair(stride) for stride in strides] + self.offset = offset + + @property + def num_levels(self): + """int: number of feature levels that the generator will be applied""" + return len(self.strides) + + @property + def num_base_priors(self): + """list[int]: The number of priors (points) at a point + on the feature grid""" + return [1 for _ in range(len(self.strides))] + + def _meshgrid(self, x, y, row_major=True): + yy, xx = jt.meshgrid(y, x) + if row_major: + # warning .flatten() would cause error in ONNX exporting + # have to use reshape here + return xx.reshape(-1), yy.reshape(-1) + + else: + return yy.reshape(-1), xx.reshape(-1) + + def grid_priors(self, + featmap_sizes, + dtype=jt.float32, + with_stride=False): + """Generate grid points of multiple feature levels. + + Args: + featmap_sizes (list[tuple]): List of feature map sizes in + multiple feature levels, each size arrange as + as (h, w). + dtype (:obj:`dtype`): Dtype of priors. Default: jt.float32. + device (str): The device where the anchors will be put on. + with_stride (bool): Whether to concatenate the stride to + the last dimension of points. + + Return: + list[jt.Tensor]: Points of multiple feature levels. + The sizes of each tensor should be (N, 2) when with stride is + ``False``, where N = width * height, width and height + are the sizes of the corresponding feature level, + and the last dimension 2 represent (coord_x, coord_y), + otherwise the shape should be (N, 4), + and the last dimension 4 represent + (coord_x, coord_y, stride_w, stride_h). + """ + + assert self.num_levels == len(featmap_sizes) + multi_level_priors = [] + for i in range(self.num_levels): + priors = self.single_level_grid_priors( + featmap_sizes[i], + level_idx=i, + dtype=dtype, + with_stride=with_stride) + multi_level_priors.append(priors) + return multi_level_priors + + def single_level_grid_priors(self, + featmap_size, + level_idx, + dtype=jt.float32, + with_stride=False): + """Generate grid Points of a single level. + + Note: + This function is usually called by method ``self.grid_priors``. + + Args: + featmap_size (tuple[int]): Size of the feature maps, arrange as + (h, w). + level_idx (int): The index of corresponding feature map level. + dtype (:obj:`dtype`): Dtype of priors. Default: jt.float32. + device (str, optional): The device the tensor will be put on. + Defaults to 'cuda'. + with_stride (bool): Concatenate the stride to the last dimension + of points. + + Return: + Tensor: Points of single feature levels. + The shape of tensor should be (N, 2) when with stride is + ``False``, where N = width * height, width and height + are the sizes of the corresponding feature level, + and the last dimension 2 represent (coord_x, coord_y), + otherwise the shape should be (N, 4), + and the last dimension 4 represent + (coord_x, coord_y, stride_w, stride_h). + """ + feat_h, feat_w = featmap_size + stride_w, stride_h = self.strides[level_idx] + shift_x = (jt.arange(0, feat_w) + + self.offset) * stride_w + # keep featmap_size as Tensor instead of int, so that we + # can convert to ONNX correctly + shift_x = shift_x.to(dtype) + + shift_y = (jt.arange(0, feat_h) + + self.offset) * stride_h + # keep featmap_size as Tensor instead of int, so that we + # can convert to ONNX correctly + shift_y = shift_y.to(dtype) + shift_xx, shift_yy = self._meshgrid(shift_x, shift_y) + if not with_stride: + shifts = jt.stack([shift_xx, shift_yy], dim=-1) + else: + # use `shape[0]` instead of `len(shift_xx)` for ONNX export + stride_w = jt.full((shift_xx.shape[0], ), stride_w, dtype=dtype) + stride_h = jt.full((shift_yy.shape[0], ), stride_h, dtype=dtype) + shifts = jt.stack([shift_xx, shift_yy, stride_w, stride_h], + dim=-1) + all_points = shifts + return all_points + + def valid_flags(self, featmap_sizes, pad_shape): + """Generate valid flags of points of multiple feature levels. + + Args: + featmap_sizes (list(tuple)): List of feature map sizes in + multiple feature levels, each size arrange as + as (h, w). + pad_shape (tuple(int)): The padded shape of the image, + arrange as (h, w). + device (str): The device where the anchors will be put on. + + Return: + list(jt.Tensor): Valid flags of points of multiple levels. + """ + assert self.num_levels == len(featmap_sizes) + multi_level_flags = [] + for i in range(self.num_levels): + point_stride = self.strides[i] + feat_h, feat_w = featmap_sizes[i] + h, w = pad_shape[:2] + valid_feat_h = min(int(np.ceil(h / point_stride[1])), feat_h) + valid_feat_w = min(int(np.ceil(w / point_stride[0])), feat_w) + flags = self.single_level_valid_flags((feat_h, feat_w), + (valid_feat_h, valid_feat_w)) + multi_level_flags.append(flags) + return multi_level_flags + + def single_level_valid_flags(self, + featmap_size, + valid_size): + """Generate the valid flags of points of a single feature map. + + Args: + featmap_size (tuple[int]): The size of feature maps, arrange as + as (h, w). + valid_size (tuple[int]): The valid size of the feature maps. + The size arrange as as (h, w). + device (str, optional): The device where the flags will be put on. + Defaults to 'cuda'. + + Returns: + jt.Tensor: The valid flags of each points in a single level \ + feature map. + """ + feat_h, feat_w = featmap_size + valid_h, valid_w = valid_size + assert valid_h <= feat_h and valid_w <= feat_w + valid_x = jt.zeros(feat_w, dtype=jt.bool) + valid_y = jt.zeros(feat_h, dtype=jt.bool) + valid_x[:valid_w] = 1 + valid_y[:valid_h] = 1 + valid_xx, valid_yy = self._meshgrid(valid_x, valid_y) + valid = valid_xx & valid_yy + return valid + + def sparse_priors(self, + prior_idxs, + featmap_size, + level_idx, + dtype=jt.float32): + """Generate sparse points according to the ``prior_idxs``. + + Args: + prior_idxs (Tensor): The index of corresponding anchors + in the feature map. + featmap_size (tuple[int]): feature map size arrange as (w, h). + level_idx (int): The level index of corresponding feature + map. + dtype (obj:`jt.dtype`): Date type of points. Defaults to + ``jt.float32``. + device (obj:`jt.device`): The device where the points is + located. + Returns: + Tensor: Anchor with shape (N, 2), N should be equal to + the length of ``prior_idxs``. And last dimension + 2 represent (coord_x, coord_y). + """ + height, width = featmap_size + x = (prior_idxs % width + self.offset) * self.strides[level_idx][0] + y = ((prior_idxs // width) % height + + self.offset) * self.strides[level_idx][1] + prioris = jt.stack([x, y], 1).to(dtype) + return prioris diff --git a/python/jdet/models/roi_heads/rotated_reppoints_head.py b/python/jdet/models/roi_heads/rotated_reppoints_head.py index 962887f..d5a3c63 100644 --- a/python/jdet/models/roi_heads/rotated_reppoints_head.py +++ b/python/jdet/models/roi_heads/rotated_reppoints_head.py @@ -15,6 +15,7 @@ from jittor.nn import _pair + import numpy as np def deleteme(a, b, size = 10): if a is None and b is None: @@ -67,7 +68,17 @@ def transpose_to(a, b): print(type(b)) raise NotImplementedError - +def fake_argsort2(x, dim=0, descending=False): + x_ = x.data + if (descending): + x__ = -x_ + else: + x__ = x_ + index_ = np.argsort(x__, axis=dim, kind="stable") + y_ = x_[index_] + index = jt.array(index_) + y = jt.array(y_) + return y, index @HEADS.register_module() class RotatedRepPointsHead(nn.Module): @@ -328,11 +339,13 @@ def _point_target_single(self, stage='init', unmap_outputs=True): """Single point target function.""" + print('A: ', flat_proposals.shape, valid_flags.shape, gt_bboxes.shape, gt_labels.shape) inside_flags = valid_flags if not inside_flags.any(): return (None, ) * 8 # assign gt and sample proposals proposals = flat_proposals[inside_flags, :] + print(' proposals: ', len(proposals)) if stage == 'init': assigner = self.init_assigner @@ -352,16 +365,23 @@ def _point_target_single(self, # convert gt from obb to poly gt_bboxes = obb2poly(gt_bboxes) + + # if stage != 'init': assign_result = assigner.assign(proposals, gt_bboxes, gt_bboxes_ignore, None if self.sampling else gt_labels) + # assign_result = assigner.assign(proposals, gt_bboxes, + # gt_bboxes_ignore, gt_labels) sampling_result = self.sampler.sample(assign_result, proposals, gt_bboxes) + + # if stage != 'init': - # out_list = [sampling_result.pos_inds, sampling_result.neg_inds, - # sampling_result.pos_gt_bboxes, sampling_result.pos_assigned_gt_inds, - # assign_result.gt_inds] + # out_list = [sampling_result.pos_inds, sampling_result.neg_inds, + # sampling_result.pos_gt_bboxes, sampling_result.pos_assigned_gt_inds, + # assign_result.gt_inds] + # print('sampling_result: ', len(sampling_result.pos_inds), len(sampling_result.neg_inds), len(sampling_result.pos_gt_bboxes), len(sampling_result.pos_assigned_gt_inds), len(assign_result.gt_inds)) # result_list = pickle.load(open("/mnt/disk/flowey/remote/JDet-debug/weights/result_dict.pkl", "rb")) # deleteme(out_list, result_list) # exit(0) @@ -373,10 +393,12 @@ def _point_target_single(self, labels = jt.full((num_valid_proposals, ), self.background_label, dtype=jt.int32) + # print('init labels: ', labels) label_weights = jt.zeros((num_valid_proposals,), dtype=jt.float32) pos_inds = sampling_result.pos_inds neg_inds = sampling_result.neg_inds + print('pos_inds', len(pos_inds)) if len(pos_inds) > 0: pos_gt_bboxes = sampling_result.pos_gt_bboxes bbox_gt[pos_inds, :] = pos_gt_bboxes @@ -409,6 +431,8 @@ def _point_target_single(self, inside_flags) proposals_weights = unmap(proposals_weights, num_total_proposals, inside_flags) + + print('In _point_target_single: ', len(labels), len((labels > 0).nonzero().view(-1))) return (labels, label_weights, bbox_gt, pos_proposals, proposals_weights, pos_inds, neg_inds, sampling_result) @@ -460,6 +484,7 @@ def get_targets(self, """ assert stage in ['init', 'refine'] num_imgs = len(img_metas) + print('num_imgs: ', num_imgs) assert len(proposals_list) == len(valid_flag_list) == num_imgs # points number of multi levels @@ -581,9 +606,14 @@ def get_cfa_targets(self, pos_inds = [] # pos_gt_index = [] for i, single_labels in enumerate(all_labels): - pos_mask = (0 < single_labels) & ( - single_labels <= self.num_classes) #TODO(514flowey): num_class not include background + # pos_mask = (0 < single_labels) & ( + # single_labels <= self.num_classes) #TODO(514flowey): num_class not include background + # print('single_labels: ', single_labels) + pos_mask = single_labels > 0 + # pos_mask = (0 < single_labels) & ( + # single_labels <= self.num_classes) pos_inds.append(pos_mask.nonzero().view(-1)) + print('In get_cfa_targets pos_mask.nonzero().view(-1): ', len(pos_mask.nonzero().view(-1)), len(single_labels)) gt_inds = [item.pos_assigned_gt_inds for item in sampling_result] @@ -658,8 +688,6 @@ def loss(self, assert len(featmap_sizes) == self.prior_generator.num_levels label_channels = self.cls_out_channels if self.use_sigmoid_cls else 1 - - # import pickle # input_dict = pickle.load(open("/mnt/disk/flowey/remote/JDet-debug/weights/input_dict.pkl", "rb")) # featmap_sizes = input_dict['featmap_sizes'] @@ -748,12 +776,17 @@ def loss(self, item.reshape(-1, 2 * self.num_points) for item in pts_coordinate_preds_refine ] + + # labels_init = jt.concat(labels_list, 0).view(-1) + # pos_inds_flatten_init = (labels_init > 0).nonzero().reshape(-1) + with jt.no_grad(): pos_losses_list, = multi_apply( self.get_pos_loss, cls_scores, pts_coordinate_preds_init_cfa, labels_list, rbbox_gt_list_refine, label_weights_list, convex_weights_list_refine, pos_inds_list_refine) + # print('pos_inds_list_refine: ', len(pos_inds_list_refine[0])) labels_list, label_weights_list, convex_weights_list_refine, \ num_pos, pos_normalize_term = multi_apply( self.reassign, @@ -781,11 +814,22 @@ def loss(self, convex_weights_refine = jt.concat(convex_weights_list_refine, 0).view(-1) pos_normalize_term = jt.concat(pos_normalize_term, 0).reshape(-1) - pos_inds_flatten = ((0 <= labels) & - (labels < self.num_classes)).nonzero( - as_tuple=False).reshape(-1) + # pos_inds_flatten = ((0 <= labels) & + # (labels < self.num_classes)).nonzero( + # ).reshape(-1) + pos_inds_flatten = (labels > 0).nonzero().reshape(-1) + + # labels = jt.concat(labels_list, 0).view(-1) + # pos_normalize_term = jt.concat(pos_normalize_term, 0).reshape(-1) + # if len(pos_normalize_term) != len(pos_inds_flatten): + # print('len(pos_normalize_term): ', len(pos_normalize_term)) + # print('len(pos_inds_flatten): ', len(pos_inds_flatten)) + # print('len(pos_inds_flatten_init): ', len(pos_inds_flatten_init)) + # exit() assert len(pos_normalize_term) == len(pos_inds_flatten) - if num_pos: + + + if bool(num_pos): losses_cls = self.loss_cls( cls_scores, labels, labels_weight, avg_factor=num_pos) pos_pts_pred_refine = pts_preds_refine[pos_inds_flatten] @@ -951,10 +995,14 @@ def reassign(self, - pos_normalize_term (list): pos normalize term for refine \ points losses. """ + # print('label: ', label.shape) + print('pos_inds: ', len(pos_inds), len(pos_gt_inds)) + if len(pos_inds) == 0: return label, label_weight, convex_weight, 0, jt.array([]).type_as(convex_weight) num_gt = pos_gt_inds.max() + 1 + num_proposals_each_level_ = num_proposals_each_level.copy() num_proposals_each_level_.insert(0, 0) inds_level_interval = np.cumsum(num_proposals_each_level_) @@ -971,7 +1019,7 @@ def reassign(self, pos_inds_after_cfa = [] ignore_inds_after_cfa = [] re_assign_weights_after_cfa = [] - for gt_ind in range(num_gt): + for gt_ind in range(int(num_gt)): pos_inds_cfa = [] pos_loss_cfa = [] pos_overlaps_init_cfa = [] @@ -980,7 +1028,7 @@ def reassign(self, level_mask = pos_level_mask[level] level_gt_mask = level_mask & gt_mask value, topk_inds = pos_losses[level_gt_mask].topk( - min(level_gt_mask.sum(), self.topk), largest=False) + int(min(level_gt_mask.sum(), self.topk)), largest=False) pos_inds_cfa.append(pos_inds[level_gt_mask][topk_inds]) pos_loss_cfa.append(value) pos_overlaps_init_cfa.append( @@ -990,10 +1038,11 @@ def reassign(self, pos_overlaps_init_cfa = jt.concat(pos_overlaps_init_cfa, 1) if len(pos_inds_cfa) < 2: pos_inds_after_cfa.append(pos_inds_cfa) - ignore_inds_after_cfa.append(jt.empty((0))) + ignore_inds_after_cfa.append(jt.empty([])) re_assign_weights_after_cfa.append(jt.ones([len(pos_inds_cfa)])) else: - pos_loss_cfa, sort_inds = pos_loss_cfa.sort() + # pos_loss_cfa, sort_inds = pos_loss_cfa.sort() + pos_loss_cfa, sort_inds = fake_argsort2(pos_loss_cfa) pos_inds_cfa = pos_inds_cfa[sort_inds] pos_overlaps_init_cfa = pos_overlaps_init_cfa[:, sort_inds] \ .reshape(-1, len(pos_inds_cfa)) @@ -1004,8 +1053,9 @@ def reassign(self, gauss_prob_density = \ (-(pos_loss_cfa - loss_mean) ** 2 / loss_var) \ .exp() / loss_var.sqrt() - index_inverted, _ = jt.arange( - len(gauss_prob_density)).sort(descending=True) + # index_inverted, _ = jt.arange( + # len(gauss_prob_density)).sort(descending=True) + index_inverted, _ = fake_argsort2(jt.arange(len(gauss_prob_density)), descending=True) gauss_prob_inverted = jt.cumsum( gauss_prob_density[index_inverted], 0) gauss_prob = gauss_prob_inverted[index_inverted] @@ -1015,6 +1065,7 @@ def reassign(self, # splitting by gradient consistency loss_curve = gauss_prob_norm * pos_loss_cfa _, max_thr = loss_curve.topk(1) + max_thr = int(max_thr) reweights = gauss_prob_norm[:max_thr + 1] # feature anti-aliasing coefficient pos_overlaps_init_cfa = pos_overlaps_init_cfa[:, :max_thr + 1] @@ -1028,19 +1079,24 @@ def reassign(self, jt.ones(len(reweights)).type_as( gauss_prob_norm).sum() pos_inds_temp = pos_inds_cfa[:max_thr + 1] - ignore_inds_temp = pos_inds_cfa.new_tensor([]) - + # print('pos_inds_cfa: ', pos_inds_cfa) + # print('pos_inds_temp: ', pos_inds_temp) + ignore_inds_temp = jt.empty([]) + pos_inds_after_cfa.append(pos_inds_temp) ignore_inds_after_cfa.append(ignore_inds_temp) re_assign_weights_after_cfa.append(re_assign_weights) - + pos_inds_after_cfa = jt.concat(pos_inds_after_cfa) ignore_inds_after_cfa = jt.concat(ignore_inds_after_cfa) re_assign_weights_after_cfa = jt.concat(re_assign_weights_after_cfa) - + print('pos_inds_after_cfa: ', len(pos_inds), len(pos_inds_after_cfa)) reassign_mask = (pos_inds.unsqueeze(1) != pos_inds_after_cfa).all(1) + print('reassign_mask: ', len(reassign_mask)) reassign_ids = pos_inds[reassign_mask] - label[reassign_ids] = self.num_classes + # label[reassign_ids] = self.num_classes + # print('reassign_ids: ', reassign_ids) + label[reassign_ids] = 0 label_weight[ignore_inds_after_cfa] = 0 convex_weight[reassign_ids] = 0 num_pos = len(pos_inds_after_cfa) @@ -1060,11 +1116,21 @@ def reassign(self, 0).type_as(label) pos_normalize_term = pos_level_mask_after_cfa * ( self.point_base_scale * - jt.as_tensor(self.point_strides).type_as(label)).reshape(-1, 1) + jt.array(self.point_strides).type_as(label)).reshape(-1, 1) pos_normalize_term = pos_normalize_term[ pos_normalize_term > 0].type_as(convex_weight) assert len(pos_normalize_term) == len(pos_inds_after_cfa) + # label = jt.concat(label, 0).view(-1) + pos_inds_flatten = (label > 0).nonzero().reshape(-1) + # print('len(pos_normalize_term): ', len(pos_normalize_term)) + # print('len(pos_inds_flatten): ', len(pos_inds_flatten)) + # pos_normalize_term = jt.concat(pos_normalize_term, 0).reshape(-1) + # if len(pos_normalize_term) != len(pos_inds_flatten): + # print('len(pos_normalize_term): ', len(pos_normalize_term)) + # print('len(pos_inds_flatten): ', len(pos_inds_flatten)) + # exit() + return label, label_weight, convex_weight, num_pos, pos_normalize_term def get_bboxes(self, diff --git a/python/jdet/models/roi_heads/sam_reppoints_head.py b/python/jdet/models/roi_heads/sam_reppoints_head.py new file mode 100644 index 0000000..5398e01 --- /dev/null +++ b/python/jdet/models/roi_heads/sam_reppoints_head.py @@ -0,0 +1,1308 @@ +from jdet.utils.registry import HEADS, LOSSES, BOXES, build_from_cfg +from jdet.models.utils.modules import ConvModule +from jdet.ops.dcn_v1 import DeformConv +from jdet.ops.bbox_transforms import obb2poly, poly2obb +from jdet.utils.general import multi_apply, unmap +from jdet.ops.nms_rotated import multiclass_nms_rotated +from jdet.models.boxes.anchor_target import images_to_levels +from jdet.ops.reppoints_min_area_bbox import reppoints_min_area_bbox +from jdet.models.boxes.box_ops import rotated_box_to_poly + + +import jittor as jt +from jittor import nn +import numpy as np +from jittor.nn import _pair + + +def get_num_level_anchors_inside(num_level_anchors, inside_flags): + """Get number of every level anchors inside. + + Args: + num_level_anchors (List[int]): List of number of every level's anchors. + inside_flags (jt.Tensor): Flags of all anchors. + + Returns: + List[int]: List of number of inside anchors. + """ + split_inside_flags = jt.split(inside_flags, num_level_anchors) + num_level_anchors_inside = [ + int(flags.sum()) for flags in split_inside_flags + ] + return num_level_anchors_inside + +def points_center_pts(RPoints, y_first=True): + """Compute center point of Pointsets. + + Args: + RPoints (jt.Tensor): the lists of Pointsets, shape (k, 18). + y_first (bool, optional): if True, the sequence of Pointsets is (y,x). + + Returns: + center_pts (jt.Tensor): the mean_center coordination of Pointsets, + shape (k, 18). + """ + RPoints = RPoints.reshape(-1, 9, 2) + + if y_first: + pts_dy = RPoints[:, :, 0::2] + pts_dx = RPoints[:, :, 1::2] + else: + pts_dx = RPoints[:, :, 0::2] + pts_dy = RPoints[:, :, 1::2] + pts_dy_mean = pts_dy.mean(dim=1).reshape(-1, 1) + pts_dx_mean = pts_dx.mean(dim=1).reshape(-1, 1) + center_pts = jt.concat([pts_dx_mean, pts_dy_mean], dim=1).reshape(-1, 2) + return center_pts + +def to_oc(boxes): + x, y, w, h, t = boxes.unbind(dim=-1) + start_angle = -0.5 * np.pi + t = ((t - start_angle) % np.pi) + w_ = jt.where(t < np.pi / 2, w, h) + h_ = jt.where(t < np.pi / 2, h, w) + t = jt.where(t < np.pi / 2, t, t - np.pi / 2) + start_angle + return jt.stack([x, y, w_, h_, t], dim=-1) + +def deleteme(a, b, size = 10): + if a is None and b is None: + return + if isinstance(a, dict) and isinstance(b, dict): + print('-' * size) + for a1, b1 in zip(a.values(), b.values()): + deleteme(a1, b1, size + 10) + print('-' * size) + elif isinstance(a, (list, tuple)) and isinstance(b, (list, tuple)): + print('-' * size) + for a1, b1 in zip(a, b): + deleteme(a1, b1, size + 10) + print('-' * size) + elif isinstance(a, jt.Var) and isinstance(b, np.ndarray): + print((a - b).abs().max().item()) + elif isinstance(a, np.ndarray) and isinstance(b, np.ndarray): + print(np.max(np.abs(a - b))) + elif isinstance(a, (int, float)) and isinstance(b, (int, float)): + print("number diff:", a - b) + else: + print(type(a)) + print(type(b)) + raise NotImplementedError +def transpose_to(a, b): + if a is None: + return None + if isinstance(a, list) and isinstance(b, list): + rlist = [] + for a1, b1 in zip(a, b): + rlist.append(transpose_to(a1, b1)) + return rlist + elif isinstance(a, dict) and isinstance(b, dict): + rdict = [] + for k in b.keys(): + rdict[k] = transpose_to(a[k], b[k]) + return rdict + elif isinstance(a, np.ndarray) and isinstance(b, jt.Var): + return jt.array(a) + elif isinstance(a, np.ndarray) and isinstance(b, np.ndarray): + return a + elif isinstance(a, tuple) and isinstance(b, tuple): + rlist = [transpose_to(a1, b1) for a1, b1 in zip(a, b)] + return tuple(rlist) + elif isinstance(a, (int, float, str)) and isinstance(b, (int, float, str)): + assert(type(a) == type(b)) + return a + else: + print(type(a)) + print(type(b)) + raise NotImplementedError + +def fake_argsort2(x, dim=0, descending=False): + x_ = x.data + if (descending): + x__ = -x_ + else: + x__ = x_ + index_ = np.argsort(x__, axis=dim, kind="stable") + y_ = x_[index_] + index = jt.array(index_) + y = jt.array(y_) + return y, index + +@HEADS.register_module() +class SAMRepPointsHead(nn.Module): + """Rotated RepPoints head for SASM. + + Args: + num_classes (int): Number of classes. + in_channels (int): Number of input channels. + feat_channels (int): Number of feature channels. + point_feat_channels (int, optional): Number of channels of points + features. + stacked_convs (int, optional): Number of stacked convolutions. + num_points (int, optional): Number of points in points set. + gradient_mul (float, optional): The multiplier to gradients from + points refinement and recognition. + point_strides (Iterable, optional): points strides. + point_base_scale (int, optional): Bbox scale for assigning labels. + conv_bias (str, optional): The bias of convolution. + loss_cls (dict, optional): Config of classification loss. + loss_bbox_init (dict, optional): Config of initial points loss. + loss_bbox_refine (dict, optional): Config of points loss in refinement. + conv_cfg (dict, optional): The config of convolution. + norm_cfg (dict, optional): The config of normlization. + train_cfg (dict, optional): The config of train. + test_cfg (dict, optional): The config of test. + center_init (bool, optional): Whether to use center point assignment. + transform_method (str, optional): The methods to transform RepPoints + to bbox. + topk (int, optional): Number of the highest topk points. Defaults to 9. + anti_factor (float, optional): Feature anti-aliasing coefficient. + """ + + def __init__(self, + num_classes, + in_channels, + feat_channels, + point_feat_channels=256, + stacked_convs=3, + num_points=9, + gradient_mul=0.1, + point_strides=[8, 16, 32, 64, 128], + point_base_scale=4, + background_label=0, + conv_bias='auto', + loss_cls=dict( + type='FocalLoss', + use_sigmoid=True, + gamma=2.0, + alpha=0.25, + loss_weight=1.0), + loss_bbox_init=dict( + type='SmoothL1Loss', beta=1.0 / 9.0, loss_weight=0.5), + loss_bbox_refine=dict( + type='SmoothL1Loss', beta=1.0 / 9.0, loss_weight=1.0), + conv_cfg=None, + norm_cfg=None, + train_cfg=None, + test_cfg=None, + center_init=True, + transform_method='rotrect', + topk=6, + anti_factor=0.75, + **kwargs): + self.num_points = num_points + self.point_feat_channels = point_feat_channels + self.center_init = center_init + + # we use deform conv to extract points features + self.dcn_kernel = int(np.sqrt(num_points)) + self.dcn_pad = int((self.dcn_kernel - 1) / 2) + assert self.dcn_kernel * self.dcn_kernel == num_points, \ + 'The points number should be a square number.' + assert self.dcn_kernel % 2 == 1, \ + 'The points number should be an odd square number.' + dcn_base = np.arange(-self.dcn_pad, self.dcn_pad + 1).astype(np.float64) + dcn_base_y = np.repeat(dcn_base, self.dcn_kernel) + dcn_base_x = np.tile(dcn_base, self.dcn_kernel) + dcn_base_offset = np.stack([dcn_base_y, dcn_base_x], axis=1).reshape((-1)) + self.dcn_base_offset = jt.array(dcn_base_offset).view(1, -1, 1, 1) + self.num_classes = num_classes + self.in_channels = in_channels + self.feat_channels = feat_channels + self.stacked_convs = stacked_convs + assert conv_bias == 'auto' or isinstance(conv_bias, bool) + self.conv_bias = conv_bias + self.loss_cls = build_from_cfg(loss_cls, LOSSES) + self.train_cfg = train_cfg + self.test_cfg = test_cfg + self.conv_cfg = conv_cfg + self.norm_cfg = norm_cfg + self.gradient_mul = gradient_mul + self.point_base_scale = point_base_scale + self.point_strides = point_strides + self.prior_generator = MlvlPointGenerator(self.point_strides, offset=0.) + self.num_base_priors = self.prior_generator.num_base_priors[0] + self.sampling = loss_cls['type'] not in ['FocalLoss'] + if self.train_cfg: + self.init_assigner = build_from_cfg(self.train_cfg.init.assigner, BOXES) + self.refine_assigner = build_from_cfg(self.train_cfg.refine.assigner, BOXES) + # use PseudoSampler when sampling is False + if self.sampling and hasattr(self.train_cfg, 'sampler'): + sampler_cfg = self.train_cfg.sampler + else: + sampler_cfg = dict(type='PseudoSampler') + self.sampler = build_from_cfg(sampler_cfg, BOXES) + self.transform_method = transform_method + self.use_sigmoid_cls = loss_cls.get('use_sigmoid', False) + if self.use_sigmoid_cls: + self.cls_out_channels = self.num_classes + else: + self.cls_out_channels = self.num_classes + 1 + self.loss_bbox_init = build_from_cfg(loss_bbox_init, LOSSES) + self.loss_bbox_refine = build_from_cfg(loss_bbox_refine, LOSSES) + self.topk = topk + self.anti_factor = anti_factor + self.background_label = background_label + self._init_layers() + + def _init_layers(self): + """Initialize layers of the head.""" + self.relu = nn.ReLU() + self.cls_convs = nn.ModuleList() + self.reg_convs = nn.ModuleList() + for i in range(self.stacked_convs): + chn = self.in_channels if i == 0 else self.feat_channels + self.cls_convs.append( + ConvModule( + chn, + self.feat_channels, + 3, + stride=1, + padding=1, + conv_cfg=self.conv_cfg, + norm_cfg=self.norm_cfg, + bias=self.conv_bias)) + self.reg_convs.append( + ConvModule( + chn, + self.feat_channels, + 3, + stride=1, + padding=1, + conv_cfg=self.conv_cfg, + norm_cfg=self.norm_cfg, + bias=self.conv_bias)) + pts_out_dim = 2 * self.num_points + self.reppoints_cls_conv = DeformConv(self.feat_channels, + self.point_feat_channels, + self.dcn_kernel, 1, + self.dcn_pad) + self.reppoints_cls_out = nn.Conv2d(self.point_feat_channels, + self.cls_out_channels, 1, 1, 0) + self.reppoints_pts_init_conv = nn.Conv2d(self.feat_channels, + self.point_feat_channels, 3, + 1, 1) + self.reppoints_pts_init_out = nn.Conv2d(self.point_feat_channels, + pts_out_dim, 1, 1, 0) + self.reppoints_pts_refine_conv = DeformConv(self.feat_channels, + self.point_feat_channels, + self.dcn_kernel, 1, + self.dcn_pad) + self.reppoints_pts_refine_out = nn.Conv2d(self.point_feat_channels, + pts_out_dim, 1, 1, 0) + + def points2rotrect(self, pts, y_first=True): + """Convert points to oriented bboxes.""" + if y_first: + pts = pts.reshape(-1, self.num_points, 2) + pts_dy = pts[:, :, 0::2] + pts_dx = pts[:, :, 1::2] + pts = jt.concat([pts_dx, pts_dy], + dim=2).reshape(-1, 2 * self.num_points) + if self.transform_method == 'rotrect': + rotrect_pred = reppoints_min_area_bbox(pts) + return rotrect_pred + else: + raise NotImplementedError + + def forward_single(self, x): + """Forward feature map of a single FPN level.""" + dcn_base_offset = self.dcn_base_offset.type_as(x) + points_init = 0 + cls_feat = x + pts_feat = x + for cls_conv in self.cls_convs: + cls_feat = cls_conv(cls_feat) + for reg_conv in self.reg_convs: + pts_feat = reg_conv(pts_feat) + # initialize reppoints + pts_out_init = self.reppoints_pts_init_out( + self.relu(self.reppoints_pts_init_conv(pts_feat))) + pts_out_init = pts_out_init + points_init + # refine and classify reppoints + pts_out_init_grad_mul = (1 - self.gradient_mul) * pts_out_init.detach() + self.gradient_mul * pts_out_init + dcn_offset = pts_out_init_grad_mul - dcn_base_offset + cls_out = self.reppoints_cls_out( + self.relu(self.reppoints_cls_conv(cls_feat, dcn_offset))) + pts_out_refine = self.reppoints_pts_refine_out( + self.relu(self.reppoints_pts_refine_conv(pts_feat, dcn_offset))) + pts_out_refine = pts_out_refine + pts_out_init.detach() + + return cls_out, pts_out_init, pts_out_refine + + def get_points(self, featmap_sizes, img_metas): + """Get points according to feature map sizes. + + Args: + featmap_sizes (list[tuple]): Multi-level feature map sizes. + img_metas (list[dict]): Image meta info. + + Returns: + tuple: points of each image, valid flags of each image + """ + num_imgs = len(img_metas) + + multi_level_points = self.prior_generator.grid_priors( + featmap_sizes, with_stride=True) + points_list = [[point.clone() for point in multi_level_points] + for _ in range(num_imgs)] + + valid_flag_list = [] + for img_id, img_meta in enumerate(img_metas): + multi_level_flags = self.prior_generator.valid_flags( + featmap_sizes, img_meta['pad_shape']) + valid_flag_list.append(multi_level_flags) + + return points_list, valid_flag_list + + def offset_to_pts(self, center_list, pred_list): + """Change from point offset to point coordinate.""" + pts_list = [] + for i_lvl, _ in enumerate(self.point_strides): + pts_lvl = [] + for i_img, _ in enumerate(center_list): + pts_center = center_list[i_img][i_lvl][:, :2].repeat(1, self.num_points) + pts_shift = pred_list[i_lvl][i_img] + yx_pts_shift = pts_shift.permute(1, 2, 0).view(-1, 2 * self.num_points) + y_pts_shift = yx_pts_shift[..., 0::2] + x_pts_shift = yx_pts_shift[..., 1::2] + xy_pts_shift = jt.stack([x_pts_shift, y_pts_shift], -1) + xy_pts_shift = xy_pts_shift.view(*yx_pts_shift.shape[:-1], -1) + pts = xy_pts_shift * self.point_strides[i_lvl] + pts_center + pts_lvl.append(pts) + pts_lvl = jt.stack(pts_lvl, 0) + pts_list.append(pts_lvl) + return pts_list + + def _point_target_single(self, + flat_proposals, + num_level_proposals, + valid_flags, + gt_bboxes, + gt_bboxes_ignore, + gt_labels, + overlaps, + stage='init', + unmap_outputs=True): + """Single point target function.""" + # print('A: ', flat_proposals.shape, valid_flags.shape, gt_bboxes.shape, gt_labels.shape) + inside_flags = valid_flags + if not inside_flags.any(): + return (None, ) * 9 + # assign gt and sample proposals + proposals = flat_proposals[inside_flags, :] + + num_level_anchors_inside = get_num_level_anchors_inside( + num_level_proposals, inside_flags) + + # convert gt from obb to poly + gt_bboxes = obb2poly(gt_bboxes) + + if stage == 'init': + assigner = self.init_assigner + pos_weight = self.train_cfg.init.pos_weight + assign_result = assigner.assign(proposals, gt_bboxes, + gt_bboxes_ignore, + None if self.sampling else gt_labels) + else: + assigner = self.refine_assigner + pos_weight = self.train_cfg.refine.pos_weight + if self.train_cfg.refine.assigner.type not in ( + 'ATSSAssigner', 'ATSSConvexAssigner', 'SASAssigner'): + assign_result = assigner.assign( + proposals, gt_bboxes, overlaps, gt_bboxes_ignore, + None if self.sampling else gt_labels) + else: + assign_result = assigner.assign( + proposals, num_level_anchors_inside, gt_bboxes, + gt_bboxes_ignore, None if self.sampling else gt_labels) + + sampling_result = self.sampler.sample(assign_result, proposals, + gt_bboxes) + # if stage != 'init': + # import pickle + # input_dict = pickle.load(open("/mnt/disk/flowey/remote/JDet-debug/weights/input_dict.pkl", "rb")) + # proposals = transpose_to(input_dict['proposals'], proposals) + # gt_bboxes = transpose_to(input_dict['gt_bboxes'], gt_bboxes) + # gt_labels = transpose_to(input_dict['gt_labels'], gt_labels) + # gt_bboxes_ignore = transpose_to(input_dict['gt_bboxes_ignore'], gt_bboxes_ignore) + + + + + # if stage != 'init': + # assign_result = assigner.assign(proposals, gt_bboxes, + # gt_bboxes_ignore, + # None if self.sampling else gt_labels) + # # assign_result = assigner.assign(proposals, gt_bboxes, + # # gt_bboxes_ignore, gt_labels) + # sampling_result = self.sampler.sample(assign_result, proposals, + # gt_bboxes) + + + + # if stage != 'init' and len(sampling_result.pos_inds) == 21824: + # import pickle + # input_list = [flat_proposals, valid_flags, proposals, gt_bboxes, gt_labels] + # print('input_list: ', proposals[0],proposals[10000], proposals[20000]) + # print('input_list: ', gt_bboxes[0]) + # result_list = pickle.dump(input_list, open("/home/zytx121/jittor/JDet-refactor_r3det/input_dict.pkl", 'wb')) + + # out_list = [sampling_result.pos_inds, sampling_result.neg_inds, + # sampling_result.pos_gt_bboxes, sampling_result.pos_assigned_gt_inds, + # assign_result.gt_inds] + # result_list = pickle.dump(out_list, open("/home/zytx121/jittor/JDet-refactor_r3det/result_dict.pkl", 'wb')) + # # print('out_list: ', out_list) + + # exit(0) + + gt_inds = assign_result.gt_inds + num_valid_proposals = proposals.shape[0] + bbox_gt = jt.zeros([num_valid_proposals, 8], dtype=proposals.dtype) + pos_proposals = jt.zeros_like(proposals) + proposals_weights = jt.zeros(num_valid_proposals, dtype=proposals.dtype) + labels = jt.full((num_valid_proposals, ), + self.background_label, + dtype=jt.int32) + # print('init labels: ', labels) + label_weights = jt.zeros((num_valid_proposals,), dtype=jt.float32) + + pos_inds = sampling_result.pos_inds + neg_inds = sampling_result.neg_inds + # print('pos_inds', len(pos_inds)) + if len(pos_inds) > 0: + pos_gt_bboxes = sampling_result.pos_gt_bboxes + bbox_gt[pos_inds, :] = pos_gt_bboxes + pos_proposals[pos_inds, :] = proposals[pos_inds, :] + proposals_weights[pos_inds] = 1.0 + if gt_labels is None: + # Only rpn gives gt_labels as None + # Foreground is the first class + # TODO(514flowey): first class is 1 + # labels[pos_inds] = 0 + labels[pos_inds] = 1 + else: + labels[pos_inds] = gt_labels[ + sampling_result.pos_assigned_gt_inds] + if pos_weight <= 0: + label_weights[pos_inds] = 1.0 + else: + label_weights[pos_inds] = pos_weight + if len(neg_inds) > 0: + label_weights[neg_inds] = 1.0 + + # use la + rbboxes_center, width, height, angles = jt.split( + to_oc(poly2obb(bbox_gt)), [2, 1, 1, 1], dim=-1) + + if stage == 'init': + points_xy = pos_proposals[:, :2] + else: + points_xy = points_center_pts(pos_proposals, y_first=True) + + distances = jt.zeros_like(angles).reshape(-1) + # print('angles: ', angles.max(), angles.min()) + # angles_index_wh = ((width != 0) & (angles >= 0) & + # (angles <= 1.57)).squeeze() + # angles_index_hw = ((width != 0) & ((angles < 0) | + # (angles > 1.57))).squeeze() + + # print('angles_index_hw: ', bool(angles_index_hw.sum()), bool(angles_index_wh.sum())) + + # 01_la:compution of distance + distances = jt.sqrt( + (jt.pow( + rbboxes_center[:, 0] - + points_xy[:, 0], 2) / + width.squeeze()) + + (jt.pow( + rbboxes_center[:, 1] - + points_xy[:, 1], 2) / + height.squeeze())) + # if bool(angles_index_wh.sum()): + # distances[angles_index_wh] = jt.sqrt( + # (jt.pow( + # rbboxes_center[angles_index_wh, 0] - + # points_xy[angles_index_wh, 0], 2) / + # width[angles_index_wh].squeeze()) + + # (jt.pow( + # rbboxes_center[angles_index_wh, 1] - + # points_xy[angles_index_wh, 1], 2) / + # height[angles_index_wh].squeeze())) + # if bool(angles_index_hw.sum()): + # distances[angles_index_hw] = jt.sqrt( + # (jt.pow( + # rbboxes_center[angles_index_hw, 0] - + # points_xy[angles_index_hw, 0], 2) / + # height[angles_index_hw].squeeze()) + + # (jt.pow( + # rbboxes_center[angles_index_hw, 1] - + # points_xy[angles_index_hw, 1], 2) / + # width[angles_index_hw].squeeze())) + distances[distances == float('nan')] = 0. + + sam_weights = label_weights * (jt.exp(1 / (distances + 1))) + sam_weights[sam_weights == float('inf')] = 0. + + # map up to original set of proposals + if unmap_outputs: + num_total_proposals = flat_proposals.size(0) + labels = unmap(labels, num_total_proposals, inside_flags) + label_weights = unmap(label_weights, num_total_proposals, + inside_flags) + bbox_gt = unmap(bbox_gt, num_total_proposals, inside_flags) + pos_proposals = unmap(pos_proposals, num_total_proposals, + inside_flags) + proposals_weights = unmap(proposals_weights, num_total_proposals, + inside_flags) + gt_inds = unmap(gt_inds, num_total_proposals, inside_flags) + sam_weights = unmap(sam_weights, num_total_proposals, inside_flags) + + # print('In _point_target_single: ', len(labels), len((labels > 0).nonzero().view(-1))) + + return (labels, label_weights, bbox_gt, pos_proposals, + proposals_weights, pos_inds, neg_inds, gt_inds, sam_weights) + + def get_targets(self, + proposals_list, + valid_flag_list, + gt_bboxes_list, + img_metas, + gt_bboxes_ignore_list=None, + gt_labels_list=None, + stage='init', + label_channels=1, + unmap_outputs=True): + """Compute corresponding GT box and classification targets for + proposals. + + Args: + proposals_list (list[list]): Multi level points/bboxes of each + image. + valid_flag_list (list[list]): Multi level valid flags of each + image. + gt_bboxes_list (list[Tensor]): Ground truth bboxes of each image. + img_metas (list[dict]): Meta info of each image. + gt_bboxes_ignore_list (list[Tensor]): Ground truth bboxes to be + ignored. + gt_bboxes_list (list[Tensor]): Ground truth labels of each box. + stage (str): `init` or `refine`. Generate target for init stage or + refine stage + label_channels (int): Channel of label. + unmap_outputs (bool): Whether to map outputs back to the original + set of anchors. + + Returns: + tuple (list[Tensor]): + + - labels_list (list[Tensor]): Labels of each level. + - label_weights_list (list[Tensor]): Label weights of each \ + level. + - bbox_gt_list (list[Tensor]): Ground truth bbox of each level. + - proposal_list (list[Tensor]): Proposals(points/bboxes) of \ + each level. + - proposal_weights_list (list[Tensor]): Proposal weights of \ + each level. + - num_total_pos (int): Number of positive samples in all \ + images. + - num_total_neg (int): Number of negative samples in all \ + images. + """ + assert stage in ['init', 'refine'] + num_imgs = len(img_metas) + # print('num_imgs: ', num_imgs) + assert len(proposals_list) == len(valid_flag_list) == num_imgs + + # points number of multi levels + num_level_proposals = [points.size(0) for points in proposals_list[0]] + num_level_proposals_list = [num_level_proposals] * num_imgs + + # concat all level points and flags to a single tensor + for i in range(num_imgs): + assert len(proposals_list[i]) == len(valid_flag_list[i]) + proposals_list[i] = jt.concat(proposals_list[i]) + valid_flag_list[i] = jt.concat(valid_flag_list[i]) + + # compute targets for each image + if gt_bboxes_ignore_list is None: + gt_bboxes_ignore_list = [None for _ in range(num_imgs)] + if gt_labels_list is None: + gt_labels_list = [None for _ in range(num_imgs)] + len_gt_labels = len(gt_bboxes_list) + all_overlaps_rotate_list = [None] * len_gt_labels + (all_labels, all_label_weights, all_bbox_gt, all_proposals, + all_proposal_weights, pos_inds_list, neg_inds_list, all_gt_inds_list, + all_sam_init_weights) = multi_apply( + self._point_target_single, + proposals_list, + num_level_proposals_list, + valid_flag_list, + gt_bboxes_list, + gt_bboxes_ignore_list, + gt_labels_list, + all_overlaps_rotate_list, + stage=stage, + unmap_outputs=unmap_outputs) + # no valid points + if any([labels is None for labels in all_labels]): + return None + # sampled points of all images + num_total_pos = sum([max(inds.numel(), 1) for inds in pos_inds_list]) + num_total_neg = sum([max(inds.numel(), 1) for inds in neg_inds_list]) + labels_list = images_to_levels(all_labels, num_level_proposals) + label_weights_list = images_to_levels(all_label_weights, + num_level_proposals) + bbox_gt_list = images_to_levels(all_bbox_gt, num_level_proposals) + proposals_list = images_to_levels(all_proposals, num_level_proposals) + proposal_weights_list = images_to_levels(all_proposal_weights, + num_level_proposals) + gt_inds_list = images_to_levels(all_gt_inds_list, num_level_proposals) + sam_init_weights_list = images_to_levels(all_sam_init_weights, + num_level_proposals) + + return (labels_list, label_weights_list, bbox_gt_list, proposals_list, + proposal_weights_list, num_total_pos, num_total_neg, + gt_inds_list, sam_init_weights_list) + + def loss_single(self, cls_score, pts_pred_init, pts_pred_refine, labels, + label_weights, rbbox_gt_init, convex_weights_init, + sam_weights_init, rbbox_gt_refine, convex_weights_refine, + sam_weights_refine, stride, num_total_samples_refine): + """Single loss function.""" + normalize_term = self.point_base_scale * stride + + rbbox_gt_init = rbbox_gt_init.reshape(-1, 8) + convex_weights_init = convex_weights_init.reshape(-1) + sam_weights_init = sam_weights_init.reshape(-1) + # init points loss + pts_pred_init = pts_pred_init.reshape(-1, 2 * self.num_points) + pos_ind_init = (convex_weights_init > 0).nonzero().reshape(-1) + pts_pred_init_norm = pts_pred_init[pos_ind_init] + rbbox_gt_init_norm = rbbox_gt_init[pos_ind_init] + convex_weights_pos_init = convex_weights_init[pos_ind_init] + sam_weights_pos_init = sam_weights_init[pos_ind_init] + loss_pts_init = self.loss_bbox_init( + pts_pred_init_norm / normalize_term, + rbbox_gt_init_norm / normalize_term, + convex_weights_pos_init * sam_weights_pos_init) + # refine points loss + rbbox_gt_refine = rbbox_gt_refine.reshape(-1, 8) + pts_pred_refine = pts_pred_refine.reshape(-1, 2 * self.num_points) + convex_weights_refine = convex_weights_refine.reshape(-1) + sam_weights_refine = sam_weights_refine.reshape(-1) + pos_ind_refine = (convex_weights_refine > 0).nonzero().reshape(-1) + pts_pred_refine_norm = pts_pred_refine[pos_ind_refine] + rbbox_gt_refine_norm = rbbox_gt_refine[pos_ind_refine] + # print('pts_pred_refine_norm: ', pts_pred_refine_norm.shape) + # print('rbbox_gt_refine_norm: ', rbbox_gt_refine_norm.shape) + convex_weights_pos_refine = convex_weights_refine[pos_ind_refine] + sam_weights_pos_refine = sam_weights_refine[pos_ind_refine] + loss_pts_refine = self.loss_bbox_refine( + pts_pred_refine_norm / normalize_term, + rbbox_gt_refine_norm / normalize_term, + convex_weights_pos_refine * sam_weights_pos_refine) + # classification loss + labels = labels.reshape(-1) + label_weights = label_weights.reshape(-1) + cls_score = cls_score.permute(0, 2, 3, + 1).reshape(-1, self.cls_out_channels) + loss_cls = self.loss_cls( + cls_score, + labels, + label_weights * sam_weights_refine, + avg_factor=num_total_samples_refine) + return loss_cls, loss_pts_init, loss_pts_refine + + def loss(self, + cls_scores, + pts_preds_init, + pts_preds_refine, + gt_bboxes, + gt_labels, + img_metas, + gt_bboxes_ignore=None): + """Loss function of SAM RepPoints head.""" + + featmap_sizes = [featmap.size()[-2:] for featmap in cls_scores] + assert len(featmap_sizes) == self.prior_generator.num_levels + label_channels = self.cls_out_channels if self.use_sigmoid_cls else 1 + + # import pickle + # input_dict = pickle.load(open("/mnt/disk/flowey/remote/JDet-debug/weights/input_dict.pkl", "rb")) + # featmap_sizes = input_dict['featmap_sizes'] + # img_metas = input_dict['img_metas'] + # gt_bboxes = transpose_to(input_dict['gt_bboxes'], gt_bboxes) + # gt_labels = transpose_to(input_dict['gt_labels'], gt_labels) + # gt_bboxes_ignore = transpose_to(input_dict['gt_bboxes_ignore'], gt_bboxes_ignore) + # cls_scores = transpose_to(input_dict['cls_scores'], cls_scores) + # pts_preds_init = transpose_to(input_dict['pts_preds_init'], pts_preds_init) + # pts_preds_refine = transpose_to(input_dict['pts_preds_refine'], pts_preds_refine) + + + # target for initial stage + center_list, valid_flag_list = self.get_points( + featmap_sizes, img_metas) + pts_coordinate_preds_init = self.offset_to_pts(center_list, + pts_preds_init) + + if self.train_cfg.init.assigner['type'] == 'ConvexAssigner': + candidate_list = center_list + else: + raise NotImplementedError + cls_reg_targets_init = self.get_targets( + candidate_list, + valid_flag_list, + gt_bboxes, + img_metas, + gt_bboxes_ignore_list=gt_bboxes_ignore, + gt_labels_list=gt_labels, + stage='init', + label_channels=label_channels) + (*_, rbbox_gt_list_init, candidate_list_init, convex_weights_list_init, + num_total_pos_init, num_total_neg_init, gt_inds_init, + sam_weights_list_init) = cls_reg_targets_init + + # target for refinement stage + center_list, valid_flag_list = self.get_points( + featmap_sizes, img_metas) + pts_coordinate_preds_refine = self.offset_to_pts( + center_list, pts_preds_refine) + points_list = [] + for i_img, center in enumerate(center_list): + points = [] + for i_lvl in range(len(pts_preds_refine)): + points_preds_init_ = pts_preds_init[i_lvl].detach() + points_preds_init_ = points_preds_init_.view( + points_preds_init_.shape[0], -1, + *points_preds_init_.shape[2:]) + points_shift = points_preds_init_.permute( + 0, 2, 3, 1) * self.point_strides[i_lvl] + points_center = center[i_lvl][:, :2].repeat(1, self.num_points) + points.append( + points_center + + points_shift[i_img].reshape(-1, 2 * self.num_points)) + points_list.append(points) + + cls_reg_targets_refine = self.get_targets( + points_list, + valid_flag_list, + gt_bboxes, + img_metas, + gt_bboxes_ignore_list=gt_bboxes_ignore, + gt_labels_list=gt_labels, + stage='refine', + label_channels=label_channels) + (labels_list, label_weights_list, rbbox_gt_list_refine, + candidate_list_refine, convex_weights_list_refine, + num_total_pos_refine, num_total_neg_refine, gt_inds_refine, + sam_weights_list_refine) = cls_reg_targets_refine + num_total_samples_refine = ( + num_total_pos_refine + + num_total_neg_refine if self.sampling else num_total_pos_refine) + + losses_cls, losses_pts_init, losses_pts_refine = multi_apply( + self.loss_single, + cls_scores, + pts_coordinate_preds_init, + pts_coordinate_preds_refine, + labels_list, + label_weights_list, + rbbox_gt_list_init, + convex_weights_list_init, + sam_weights_list_init, + rbbox_gt_list_refine, + convex_weights_list_refine, + sam_weights_list_refine, + self.point_strides, + num_total_samples_refine=num_total_samples_refine) + loss_dict_all = { + 'loss_cls': losses_cls, + 'loss_pts_init': losses_pts_init, + 'loss_pts_refine': losses_pts_refine + } + return loss_dict_all + + def get_pos_loss(self, cls_score, pts_pred, label, bbox_gt, label_weight, + convex_weight, pos_inds): + """Calculate loss of all potential positive samples obtained from first + match process. + + Args: + cls_score (Tensor): Box scores of single image with shape + (num_anchors, num_classes) + pts_pred (Tensor): Box energies / deltas of single image + with shape (num_anchors, 4) + label (Tensor): classification target of each anchor with + shape (num_anchors,) + bbox_gt (Tensor): Ground truth box. + label_weight (Tensor): Classification loss weight of each + anchor with shape (num_anchors). + convex_weight (Tensor): Bbox weight of each anchor with shape + (num_anchors, 4). + pos_inds (Tensor): Index of all positive samples got from + first assign process. + + Returns: + Tensor: Losses of all positive samples in single image. + """ + if pos_inds.size(0) == 0: + pos_loss = jt.zeros((0)) + return pos_loss, + pos_scores = cls_score[pos_inds] + pos_pts_pred = pts_pred[pos_inds] + pos_bbox_gt = bbox_gt[pos_inds] + pos_label = label[pos_inds] + pos_label_weight = label_weight[pos_inds] + pos_convex_weight = convex_weight[pos_inds] + loss_cls = self.loss_cls( + pos_scores, + pos_label, + pos_label_weight, + avg_factor=self.loss_cls.loss_weight, + reduction_override='none') + loss_bbox = self.loss_bbox_refine( + pos_pts_pred, + pos_bbox_gt, + pos_convex_weight, + avg_factor=self.loss_cls.loss_weight, + reduction_override='none') + loss_cls = loss_cls.sum(-1) + pos_loss = loss_bbox + loss_cls + return pos_loss, + + def get_bboxes(self, + cls_scores, + pts_preds_init, + pts_preds_refine, + img_metas, + cfg=None, + rescale=False, + with_nms=True, + **kwargs): + """Transform network outputs of a batch into bbox results. + + Args: + cls_scores (list[Tensor]): Classification scores for all + scale levels, each is a 4D-tensor, has shape + (batch_size, num_priors * num_classes, H, W). + pts_preds_init (list[Tensor]): Box energies / deltas for all + scale levels, each is a 18D-tensor, has shape + (batch_size, num_points * 2, H, W). + pts_preds_refine (list[Tensor]): Box energies / deltas for all + scale levels, each is a 18D-tensor, has shape + (batch_size, num_points * 2, H, W). + img_metas (list[dict], Optional): Image meta info. Default None. + cfg (mmcv.Config, Optional): Test / postprocessing configuration, + if None, test_cfg would be used. Default None. + rescale (bool): If True, return boxes in original image space. + Default False. + with_nms (bool): If True, do nms before return boxes. + Default True. + + Returns: + list[list[Tensor, Tensor]]: Each item in result_list is 2-tuple. + The first item is an (n, 6) tensor, where the first 4 columns + are bounding box positions (cx, cy, w, h, a) and the + 6-th column is a score between 0 and 1. The second item is a + (n,) tensor where each item is the predicted class label of + the corresponding box. + """ + assert len(cls_scores) == len(pts_preds_refine) + + num_levels = len(cls_scores) + + featmap_sizes = [cls_scores[i].shape[-2:] for i in range(num_levels)] + mlvl_priors = self.prior_generator.grid_priors( + featmap_sizes, + dtype=cls_scores[0].dtype) + + result_list = [] + + for img_id, _ in enumerate(img_metas): + img_meta = img_metas[img_id] + cls_score_list = select_single_mlvl(cls_scores, img_id) + point_pred_list = select_single_mlvl(pts_preds_refine, img_id) + + results = self._get_bboxes_single(cls_score_list, point_pred_list, + mlvl_priors, img_meta, cfg, + rescale, with_nms, **kwargs) + result_list.append(results) + + return result_list + + def _get_bboxes_single(self, + cls_score_list, + point_pred_list, + mlvl_priors, + img_meta, + cfg, + rescale=False, + with_nms=True, + **kwargs): + """Transform outputs of a single image into bbox predictions. + Args: + cls_score_list (list[Tensor]): Box scores from all scale + levels of a single image, each item has shape + (num_priors * num_classes, H, W). + bbox_pred_list (list[Tensor]): Box energies / deltas from + all scale levels of a single image, each item has shape + (num_priors * 4, H, W). + score_factor_list (list[Tensor]): Score factor from all scale + levels of a single image. RepPoints head does not need + this value. + mlvl_priors (list[Tensor]): Each element in the list is + the priors of a single level in feature pyramid, has shape + (num_priors, 2). + img_meta (dict): Image meta info. + cfg (mmcv.Config): Test / postprocessing configuration, + if None, test_cfg would be used. + rescale (bool): If True, return boxes in original image space. + Default: False. + with_nms (bool): If True, do nms before return boxes. + Default: True. + Returns: + tuple[Tensor]: Results of detected bboxes and labels. If with_nms + is False and mlvl_score_factor is None, return mlvl_bboxes and + mlvl_scores, else return mlvl_bboxes, mlvl_scores and + mlvl_score_factor. Usually with_nms is False is used for aug + test. If with_nms is True, then return the following format + + - det_bboxes (Tensor): Predicted bboxes with shape \ + [num_bboxes, 5], where the first 4 columns are bounding \ + box positions (cx, cy, w, h, a) and the 5-th \ + column are scores between 0 and 1. + - det_labels (Tensor): Predicted labels of the corresponding \ + box with shape [num_bboxes]. + """ + + cfg = self.test_cfg if cfg is None else cfg + assert len(cls_score_list) == len(point_pred_list) + scale_factor = img_meta['scale_factor'] + + mlvl_bboxes = [] + mlvl_scores = [] + for level_idx, (cls_score, points_pred, points) in enumerate( + zip(cls_score_list, point_pred_list, mlvl_priors)): + assert cls_score.size()[-2:] == points_pred.size()[-2:] + + cls_score = cls_score.permute(1, 2, + 0).reshape(-1, self.cls_out_channels) + if self.use_sigmoid_cls: + scores = cls_score.sigmoid() + else: + scores = cls_score.softmax(-1)[:, :-1] + + points_pred = points_pred.permute(1, 2, 0).reshape( + -1, 2 * self.num_points) + nms_pre = cfg.get('nms_pre', -1) + if 0 < nms_pre < scores.shape[0]: + if self.use_sigmoid_cls: + max_scores = scores.max(dim=1) + else: + max_scores = scores[:, 1:].max(dim=1) + _, topk_inds = max_scores.topk(nms_pre) + points = points[topk_inds, :] + points_pred = points_pred[topk_inds, :] + scores = scores[topk_inds, :] + + poly_pred = self.points2rotrect(points_pred, y_first=True) + bbox_pos_center = points[:, :2].repeat(1, 4) + polys = poly_pred * self.point_strides[level_idx] + bbox_pos_center + bboxes = poly2obb(polys) + + mlvl_bboxes.append(bboxes) + mlvl_scores.append(scores) + + mlvl_bboxes = jt.concat(mlvl_bboxes) + + if rescale: + mlvl_bboxes[..., :4] /= mlvl_bboxes[..., :4].new_tensor( + scale_factor) + mlvl_scores = jt.concat(mlvl_scores) + if self.use_sigmoid_cls: + padding = jt.zeros((mlvl_scores.shape[0], 1), dtype=mlvl_scores.dtype) + mlvl_scores = jt.concat([padding, mlvl_scores], dim=1) + + if with_nms: + det_bboxes, det_labels = multiclass_nms_rotated( + mlvl_bboxes, mlvl_scores, cfg.score_thr, cfg.nms, + cfg.max_per_img) + boxes = det_bboxes[:, :5] + scores = det_bboxes[:, 5] + polys = rotated_box_to_poly(boxes) + return polys, scores, det_labels + else: + raise NotImplementedError + + def parse_targets(self, targets): + img_metas = [] + gt_bboxes = [] + gt_bboxes_ignore = [] + gt_labels = [] + + for target in targets: + if self.is_training(): + gt_bboxes.append(target["rboxes"]) + gt_labels.append(target["labels"]) + gt_bboxes_ignore.append(target["rboxes_ignore"]) + img_metas.append(dict( + img_shape=target["img_size"][::-1], + scale_factor=target["scale_factor"], + pad_shape = target["pad_shape"] + )) + if not self.is_training(): + return dict(img_metas = img_metas) + return dict( + gt_bboxes = gt_bboxes, + gt_labels = gt_labels, + img_metas = img_metas, + gt_bboxes_ignore = gt_bboxes_ignore, + ) + + def execute(self, feats, targets): + outs = multi_apply(self.forward_single, feats) + if self.is_training(): + return self.loss(*outs, **self.parse_targets(targets)) + return self.get_bboxes(*outs, **self.parse_targets(targets)) + + +def select_single_mlvl(mlvl_tensors, batch_id, detach=True): + """Extract a multi-scale single image tensor from a multi-scale batch + tensor based on batch index. + + Note: The default value of detach is True, because the proposal gradient + needs to be detached during the training of the two-stage model. E.g + Cascade Mask R-CNN. + + Args: + mlvl_tensors (list[Tensor]): Batch tensor for all scale levels, + each is a 4D-tensor. + batch_id (int): Batch index. + detach (bool): Whether detach gradient. Default True. + + Returns: + list[Tensor]: Multi-scale single image tensor. + """ + assert isinstance(mlvl_tensors, (list, tuple)) + num_levels = len(mlvl_tensors) + + if detach: + mlvl_tensor_list = [ + mlvl_tensors[i][batch_id].detach() for i in range(num_levels) + ] + else: + mlvl_tensor_list = [ + mlvl_tensors[i][batch_id] for i in range(num_levels) + ] + return mlvl_tensor_list + +class MlvlPointGenerator: + """Standard points generator for multi-level (Mlvl) feature maps in 2D + points-based detectors. + + Args: + strides (list[int] | list[tuple[int, int]]): Strides of anchors + in multiple feature levels in order (w, h). + offset (float): The offset of points, the value is normalized with + corresponding stride. Defaults to 0.5. + """ + + def __init__(self, strides, offset=0.5): + self.strides = [_pair(stride) for stride in strides] + self.offset = offset + + @property + def num_levels(self): + """int: number of feature levels that the generator will be applied""" + return len(self.strides) + + @property + def num_base_priors(self): + """list[int]: The number of priors (points) at a point + on the feature grid""" + return [1 for _ in range(len(self.strides))] + + def _meshgrid(self, x, y, row_major=True): + yy, xx = jt.meshgrid(y, x) + if row_major: + # warning .flatten() would cause error in ONNX exporting + # have to use reshape here + return xx.reshape(-1), yy.reshape(-1) + + else: + return yy.reshape(-1), xx.reshape(-1) + + def grid_priors(self, + featmap_sizes, + dtype=jt.float32, + with_stride=False): + """Generate grid points of multiple feature levels. + + Args: + featmap_sizes (list[tuple]): List of feature map sizes in + multiple feature levels, each size arrange as + as (h, w). + dtype (:obj:`dtype`): Dtype of priors. Default: jt.float32. + device (str): The device where the anchors will be put on. + with_stride (bool): Whether to concatenate the stride to + the last dimension of points. + + Return: + list[jt.Tensor]: Points of multiple feature levels. + The sizes of each tensor should be (N, 2) when with stride is + ``False``, where N = width * height, width and height + are the sizes of the corresponding feature level, + and the last dimension 2 represent (coord_x, coord_y), + otherwise the shape should be (N, 4), + and the last dimension 4 represent + (coord_x, coord_y, stride_w, stride_h). + """ + + assert self.num_levels == len(featmap_sizes) + multi_level_priors = [] + for i in range(self.num_levels): + priors = self.single_level_grid_priors( + featmap_sizes[i], + level_idx=i, + dtype=dtype, + with_stride=with_stride) + multi_level_priors.append(priors) + return multi_level_priors + + def single_level_grid_priors(self, + featmap_size, + level_idx, + dtype=jt.float32, + with_stride=False): + """Generate grid Points of a single level. + + Note: + This function is usually called by method ``self.grid_priors``. + + Args: + featmap_size (tuple[int]): Size of the feature maps, arrange as + (h, w). + level_idx (int): The index of corresponding feature map level. + dtype (:obj:`dtype`): Dtype of priors. Default: jt.float32. + device (str, optional): The device the tensor will be put on. + Defaults to 'cuda'. + with_stride (bool): Concatenate the stride to the last dimension + of points. + + Return: + Tensor: Points of single feature levels. + The shape of tensor should be (N, 2) when with stride is + ``False``, where N = width * height, width and height + are the sizes of the corresponding feature level, + and the last dimension 2 represent (coord_x, coord_y), + otherwise the shape should be (N, 4), + and the last dimension 4 represent + (coord_x, coord_y, stride_w, stride_h). + """ + feat_h, feat_w = featmap_size + stride_w, stride_h = self.strides[level_idx] + shift_x = (jt.arange(0, feat_w) + + self.offset) * stride_w + # keep featmap_size as Tensor instead of int, so that we + # can convert to ONNX correctly + shift_x = shift_x.to(dtype) + + shift_y = (jt.arange(0, feat_h) + + self.offset) * stride_h + # keep featmap_size as Tensor instead of int, so that we + # can convert to ONNX correctly + shift_y = shift_y.to(dtype) + shift_xx, shift_yy = self._meshgrid(shift_x, shift_y) + if not with_stride: + shifts = jt.stack([shift_xx, shift_yy], dim=-1) + else: + # use `shape[0]` instead of `len(shift_xx)` for ONNX export + stride_w = jt.full((shift_xx.shape[0], ), stride_w, dtype=dtype) + stride_h = jt.full((shift_yy.shape[0], ), stride_h, dtype=dtype) + shifts = jt.stack([shift_xx, shift_yy, stride_w, stride_h], + dim=-1) + all_points = shifts + return all_points + + def valid_flags(self, featmap_sizes, pad_shape): + """Generate valid flags of points of multiple feature levels. + + Args: + featmap_sizes (list(tuple)): List of feature map sizes in + multiple feature levels, each size arrange as + as (h, w). + pad_shape (tuple(int)): The padded shape of the image, + arrange as (h, w). + device (str): The device where the anchors will be put on. + + Return: + list(jt.Tensor): Valid flags of points of multiple levels. + """ + assert self.num_levels == len(featmap_sizes) + multi_level_flags = [] + for i in range(self.num_levels): + point_stride = self.strides[i] + feat_h, feat_w = featmap_sizes[i] + h, w = pad_shape[:2] + valid_feat_h = min(int(np.ceil(h / point_stride[1])), feat_h) + valid_feat_w = min(int(np.ceil(w / point_stride[0])), feat_w) + flags = self.single_level_valid_flags((feat_h, feat_w), + (valid_feat_h, valid_feat_w)) + multi_level_flags.append(flags) + return multi_level_flags + + def single_level_valid_flags(self, + featmap_size, + valid_size): + """Generate the valid flags of points of a single feature map. + + Args: + featmap_size (tuple[int]): The size of feature maps, arrange as + as (h, w). + valid_size (tuple[int]): The valid size of the feature maps. + The size arrange as as (h, w). + device (str, optional): The device where the flags will be put on. + Defaults to 'cuda'. + + Returns: + jt.Tensor: The valid flags of each points in a single level \ + feature map. + """ + feat_h, feat_w = featmap_size + valid_h, valid_w = valid_size + assert valid_h <= feat_h and valid_w <= feat_w + valid_x = jt.zeros(feat_w, dtype=jt.bool) + valid_y = jt.zeros(feat_h, dtype=jt.bool) + valid_x[:valid_w] = 1 + valid_y[:valid_h] = 1 + valid_xx, valid_yy = self._meshgrid(valid_x, valid_y) + valid = valid_xx & valid_yy + return valid + + def sparse_priors(self, + prior_idxs, + featmap_size, + level_idx, + dtype=jt.float32): + """Generate sparse points according to the ``prior_idxs``. + + Args: + prior_idxs (Tensor): The index of corresponding anchors + in the feature map. + featmap_size (tuple[int]): feature map size arrange as (w, h). + level_idx (int): The level index of corresponding feature + map. + dtype (obj:`jt.dtype`): Date type of points. Defaults to + ``jt.float32``. + device (obj:`jt.device`): The device where the points is + located. + Returns: + Tensor: Anchor with shape (N, 2), N should be equal to + the length of ``prior_idxs``. And last dimension + 2 represent (coord_x, coord_y). + """ + height, width = featmap_size + x = (prior_idxs % width + self.offset) * self.strides[level_idx][0] + y = ((prior_idxs // width) % height + + self.offset) * self.strides[level_idx][1] + prioris = jt.stack([x, y], 1).to(dtype) + return prioris diff --git a/python/jdet/ops/chamfer_distance.py b/python/jdet/ops/chamfer_distance.py new file mode 100644 index 0000000..55475de --- /dev/null +++ b/python/jdet/ops/chamfer_distance.py @@ -0,0 +1,240 @@ +# Modified from +# https://github.com/chrdiller/pyTorchChamferDistance/blob/master/chamfer_distance/chamfer_distance.cu + +import jittor as jt +from jittor import Function +import numpy as np + +HEADER=r""" +#define MAX_SHARED_SCALAR_T 6144 +#define THREADS_PER_BLOCK 1024 + +inline int GET_BLOCKS(const int N) { + int optimal_block_num = (N + THREADS_PER_BLOCK - 1) / THREADS_PER_BLOCK; + int max_block_num = 65000; + return std::min(optimal_block_num, max_block_num); +} + +template +__global__ void chamfer_distance_forward_cuda_kernel(int b, int n, + const scalar_t* xyz, int m, + const scalar_t* xyz2, + scalar_t* result, + int* result_i) { + __shared__ scalar_t buf[MAX_SHARED_SCALAR_T]; + for (int i = blockIdx.x; i < b; i += gridDim.x) { + for (int k2 = 0; k2 < m; k2 += THREADS_PER_BLOCK) { + int end_k = min(m, k2 + THREADS_PER_BLOCK) - k2; + for (int j = threadIdx.x; j < end_k * 2; j += blockDim.x) { + buf[j] = xyz2[(i * m + k2) * 2 + j]; + } + __syncthreads(); + for (int j = threadIdx.x; j < n; j += blockDim.x * gridDim.y) { + scalar_t x1 = xyz[(i * n + j) * 2 + 0]; + scalar_t y1 = xyz[(i * n + j) * 2 + 1]; + int best_i = 0; + scalar_t best = 1e10; + int end_ka = end_k & (~2); + if (end_ka == THREADS_PER_BLOCK) { + for (int k = 0; k < THREADS_PER_BLOCK; k += 4) { +#pragma unroll + for (int j = 0; j < 4; ++j) { + scalar_t x2 = buf[(k + j) * 2] - x1; + scalar_t y2 = buf[(k + j) * 2 + 1] - y1; + scalar_t d = x2 * x2 + y2 * y2; + if (d < best) { + best = d; + best_i = k + k2 + j; + } + } + } + } else { + for (int k = 0; k < end_ka; k += 4) { +#pragma unroll + for (int j = 0; j < 4; ++j) { + scalar_t x2 = buf[(k + j) * 2] - x1; + scalar_t y2 = buf[(k + j) * 2 + 1] - y1; + scalar_t d = x2 * x2 + y2 * y2; + if (d < best) { + best = d; + best_i = k + k2 + j; + } + } + } + } + for (int k = end_ka; k < end_k; k++) { + scalar_t x2 = buf[k * 2 + 0] - x1; + scalar_t y2 = buf[k * 2 + 1] - y1; + scalar_t d = x2 * x2 + y2 * y2; + if (k == 0 || d < best) { + best = d; + best_i = k + k2; + } + } + if (k2 == 0 || result[(i * n + j)] > best) { + result[(i * n + j)] = best; + result_i[(i * n + j)] = best_i; + } + } + __syncthreads(); + } + } +} + +template +__global__ void chamfer_distance_backward_cuda_kernel( + int b, int n, const scalar_t* xyz1, int m, const scalar_t* xyz2, + const scalar_t* grad_dist1, const int* idx1, scalar_t* grad_xyz1, + scalar_t* grad_xyz2) { + for (int i = blockIdx.x; i < b; i += gridDim.x) { + for (int j = threadIdx.x; j < n; j += blockDim.x * gridDim.y) { + scalar_t x1 = xyz1[(i * n + j) * 2 + 0]; + scalar_t y1 = xyz1[(i * n + j) * 2 + 1]; + int j2 = idx1[i * n + j]; + scalar_t x2 = xyz2[(i * m + j2) * 2 + 0]; + scalar_t y2 = xyz2[(i * m + j2) * 2 + 1]; + scalar_t g = grad_dist1[i * n + j] * 2; + atomicAdd(&(grad_xyz1[(i * n + j) * 2 + 0]), g * (x1 - x2)); + atomicAdd(&(grad_xyz1[(i * n + j) * 2 + 1]), g * (y1 - y2)); + atomicAdd(&(grad_xyz2[(i * m + j2) * 2 + 0]), -(g * (x1 - x2))); + atomicAdd(&(grad_xyz2[(i * m + j2) * 2 + 1]), -(g * (y1 - y2))); + } + } +} +""" + +def chamfer_distance_forward(xyz1, xyz2, dist1, dist2, idx1, idx2): + src = f""" + const int batch_size = {xyz1.size(0)}; + const int n = {xyz1.size(1)}; + const int m = {xyz2.size(1)}; + chamfer_distance_forward_cuda_kernel<<>>( + batch_size, n, in0_p, m, + in1_p, out0_p, out2_p); + chamfer_distance_forward_cuda_kernel<<>>( + batch_size, m, in1_p, n, + in0_p, out1_p, out3_p); + """ + return jt.code( + outputs=[dist1, dist2, idx1, idx2], + inputs=[xyz1, xyz2], + cuda_header=HEADER, + cuda_src=src) + +def chamfer_distance_backward(xyz1, xyz2, idx1, idx2, grad_dist1, grad_dist2, grad_xyz1, grad_xyz2): + src=f""" + const int batch_size = {xyz1.size(0)}; + const int n = {xyz1.size(1)}; + const int m = {xyz2.size(1)}; + chamfer_distance_backward_cuda_kernel<<>>( + batch_size, m, in0_p, n, + in1_p, in4_p, + in2_p, out0_p, + out1_p); + + chamfer_distance_backward_cuda_kernel<<>>( + batch_size, n, in1_p, m, + in0_p, in5_p, + in3_p, out1_p, + out0_p); + """ + return jt.code( + outputs=[grad_xyz1, grad_xyz2], + inputs=[xyz1, xyz2, idx1, idx2, grad_dist1, grad_dist2], + cuda_header=HEADER, + cuda_src=src) + + +class ChamferDistanceFunction(Function): + """This is an implementation of the 2D Chamfer Distance. + It has been used in the paper `Oriented RepPoints for Aerial Object + Detection (CVPR 2022) _`. + """ + + def execute(self, xyz1, xyz2): + """ + Args: + xyz1 (Tensor): Point set with shape (B, N, 2). + xyz2 (Tensor): Point set with shape (B, N, 2). + Returns: + Sequence[Tensor]: + - dist1 (Tensor): Chamfer distance (xyz1 to xyz2) with + shape (B, N). + - dist2 (Tensor): Chamfer distance (xyz2 to xyz1) with + shape (B, N). + - idx1 (Tensor): Index of chamfer distance (xyz1 to xyz2) + with shape (B, N), which be used in compute gradient. + - idx2 (Tensor): Index of chamfer distance (xyz2 to xyz2) + with shape (B, N), which be used in compute gradient. + """ + batch_size, n, _ = xyz1.size() + _, m, _ = xyz2.size() + xyz1 = xyz1.contiguous() + xyz2 = xyz2.contiguous() + + dist1 = jt.zeros(batch_size, n) + dist2 = jt.zeros(batch_size, m) + idx1 = jt.zeros((batch_size, n), dtype=jt.int32) + idx2 = jt.zeros((batch_size, m), dtype=jt.int32) + + chamfer_distance_forward(xyz1, xyz2, dist1, dist2, idx1, idx2) + self.save_for_backward = xyz1, xyz2, idx1, idx2 + return dist1, dist2, idx1, idx2 + + def grad(self, + grad_dist1, + grad_dist2, + grad_idx1=None, + grad_idx2=None): + """ + Args: + grad_dist1 (Tensor): Gradient of chamfer distance + (xyz1 to xyz2) with shape (B, N). + grad_dist2 (Tensor): Gradient of chamfer distance + (xyz2 to xyz1) with shape (B, N). + Returns: + Tuple[Tensor, Tensor]: + - grad_xyz1 (Tensor): Gradient of the point set with shape \ + (B, N, 2). + - grad_xyz2 (Tensor):Gradient of the point set with shape \ + (B, N, 2). + """ + xyz1, xyz2, idx1, idx2 = self.save_for_backward + grad_dist1 = grad_dist1.contiguous() + grad_dist2 = grad_dist2.contiguous() + grad_xyz1 = jt.zeros(xyz1.size()) + grad_xyz2 = jt.zeros(xyz2.size()) + + chamfer_distance_backward(xyz1, xyz2, idx1, idx2, + grad_dist1, grad_dist2, grad_xyz1, + grad_xyz2) + return grad_xyz1, grad_xyz2 + + +chamfer_distance = ChamferDistanceFunction.apply + +def test(): + pointset1 = jt.array( + [[[1.3, 9.39], [2.3, 9.39], [2.3, 10.39], [1.3, 10.39]], + [[1.0, 9.39], [3.0, 9.39], [3.0, 10.39], [1.0, 10.39]], + [[1.6, 9.99], [2.3, 9.99], [2.3, 10.39], [1.6, 10.39]]]) + + pointset2 = jt.array( + [[[1.0, 9.39], [3.0, 9.39], [3.0, 10.39], [1.0, 10.39]], + [[1.3, 9.39], [2.3, 9.39], [2.3, 10.39], [1.3, 10.39]], + [[1.0, 9.39], [3.0, 9.39], [3.0, 10.39], [1.0, 10.39]]]) + + expected_dist1 = jt.array( + [[0.0900, 0.4900, 0.4900, 0.0900], [0.0900, 0.4900, 0.4900, 0.0900], + [0.5200, 0.6500, 0.4900, 0.3600]]) + expected_dist2 = jt.array( + [[0.0900, 0.4900, 0.4900, 0.0900], [0.0900, 0.4900, 0.4900, 0.0900], + [0.7200, 0.8500, 0.4900, 0.3600]]) + + dist1, dist2, idx1, idx2 = chamfer_distance(pointset1, pointset2) + np.testing.assert_allclose(dist1.numpy(),expected_dist1.numpy(),rtol=1e-2) + np.testing.assert_allclose(dist2.numpy(),expected_dist2.numpy(),rtol=1e-2) + +if __name__ == "__main__": + jt.flags.use_cuda=1 + test()