Skip to content
This repository was archived by the owner on Apr 17, 2023. It is now read-only.

Commit af2d508

Browse files
hhaAndroidbognabylicka
authored andcommitted
Add YOLOX config (open-mmlab#5808)
* Add YOLOX config * update * fix error * fix lr error * fix tiny config error and foreground_mask warning * fix dp train error * add comment * support browse_dataset * add comment * fix __repr__ * Switch to synchronizing norm interval. * Add README and metafile * update README * update doc * rename * revert * update (cherry picked from commit 2bdb167)
1 parent 293374e commit af2d508

10 files changed

+372
-0
lines changed

configs/yolox/README.md

+25
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,25 @@
1+
# YOLOX: Exceeding YOLO Series in 2021
2+
3+
## Introduction
4+
5+
<!-- [ALGORITHM] -->
6+
7+
```latex
8+
@article{yolox2021,
9+
title={{YOLOX}: Exceeding YOLO Series in 2021},
10+
author={Ge, Zheng and Liu, Songtao and Wang, Feng and Li, Zeming and Sun, Jian},
11+
journal={arXiv preprint arXiv:2107.08430},
12+
year={2021}
13+
}
14+
```
15+
16+
## Results and Models
17+
18+
| Backbone | size | Mem (GB) | box AP | Config | Download |
19+
|:---------:|:-------:|:-------:|:-------:|:--------:|:------:|
20+
| YOLOX-Tiny | 416 | 3.6 | 31.6 | [config](https://github.com/open-mmlab/mmdetection/tree/master/configs/yolox/yolox_tiny_8x8_300e_coco.py) |[model](https://download.openmmlab.com/mmdetection/v2.0/yolox/yolox_tiny_8x8_300e_coco/yolox_tiny_8x8_300e_coco_20210806_234250-4ff3b67e.pth) &#124; [log](https://download.openmmlab.com/mmdetection/v2.0/yolox/yolox_tiny_8x8_300e_coco/yolox_tiny_8x8_300e_coco_20210806_234250.log.json) |
21+
22+
**Note**:
23+
24+
1. The test score threshold is 0.001.
25+
2. We find that the performance is unstable and may fluctuate by about 0.7 mAP. We will continue to investigate and improve it.

configs/yolox/metafile.yml

+28
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,28 @@
1+
Collections:
2+
- Name: YOLOX
3+
Metadata:
4+
Training Data: COCO
5+
Training Techniques:
6+
- SGD with Nesterov
7+
- Weight Decay
8+
- Cosine Annealing Lr Updater
9+
Training Resources: 8x TITANXp GPUs
10+
Architecture:
11+
- CSPDarkNet
12+
- PAFPN
13+
Paper: https://arxiv.org/abs/2107.08430
14+
README: configs/yolox/README.md
15+
16+
Models:
17+
- Name: yolox_tiny_8x8_300e_coco
18+
In Collection: YOLOX
19+
Config: configs/yolox/yolox_tiny_8x8_300e_coco.py
20+
Metadata:
21+
Training Memory (GB): 3.6
22+
Epochs: 300
23+
Results:
24+
- Task: Object Detection
25+
Dataset: COCO
26+
Metrics:
27+
box AP: 31.6
28+
Weights: https://download.openmmlab.com/mmdetection/v2.0/yolox/yolox_tiny_8x8_300e_coco/yolox_tiny_8x8_300e_coco_20210806_234250-4ff3b67e.pth
+8
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,8 @@
1+
_base_ = './yolox_s_8x8_300e_coco.py'
2+
3+
# model settings
4+
model = dict(
5+
backbone=dict(deepen_factor=1.0, widen_factor=1.0),
6+
neck=dict(
7+
in_channels=[256, 512, 1024], out_channels=256, num_csp_blocks=3),
8+
bbox_head=dict(in_channels=256, feat_channels=256))
+8
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,8 @@
1+
_base_ = './yolox_s_8x8_300e_coco.py'
2+
3+
# model settings
4+
model = dict(
5+
backbone=dict(deepen_factor=0.67, widen_factor=0.75),
6+
neck=dict(in_channels=[192, 384, 768], out_channels=192, num_csp_blocks=2),
7+
bbox_head=dict(in_channels=192, feat_channels=192),
8+
)
+11
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,11 @@
1+
_base_ = './yolox_tiny_8x8_300e_coco.py'
2+
3+
# model settings
4+
model = dict(
5+
backbone=dict(deepen_factor=0.33, widen_factor=0.25, use_depthwise=True),
6+
neck=dict(
7+
in_channels=[64, 128, 256],
8+
out_channels=64,
9+
num_csp_blocks=1,
10+
use_depthwise=True),
11+
bbox_head=dict(in_channels=64, feat_channels=64, use_depthwise=True))
+143
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,143 @@
1+
_base_ = ['../_base_/schedules/schedule_1x.py', '../_base_/default_runtime.py']
2+
3+
# model settings
4+
model = dict(
5+
type='YOLOX',
6+
backbone=dict(type='CSPDarknet', deepen_factor=0.33, widen_factor=0.5),
7+
neck=dict(
8+
type='YOLOXPAFPN',
9+
in_channels=[128, 256, 512],
10+
out_channels=128,
11+
num_csp_blocks=1),
12+
bbox_head=dict(
13+
type='YOLOXHead', num_classes=80, in_channels=128, feat_channels=128),
14+
train_cfg=dict(assigner=dict(type='SimOTAAssigner', center_radius=2.5)),
15+
# In order to align the source code, the threshold of the val phase is
16+
# 0.01, and the threshold of the test phase is 0.001.
17+
test_cfg=dict(score_thr=0.01, nms=dict(type='nms', iou_threshold=0.65)))
18+
19+
# dataset settings
20+
data_root = 'data/coco/'
21+
dataset_type = 'CocoDataset'
22+
23+
img_norm_cfg = dict(
24+
mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], to_rgb=True)
25+
26+
img_scale = (640, 640)
27+
28+
train_pipeline = [
29+
dict(type='Mosaic', img_scale=img_scale, pad_val=114.0),
30+
dict(
31+
type='RandomAffine',
32+
scaling_ratio_range=(0.1, 2),
33+
border=(-img_scale[0] // 2, -img_scale[1] // 2)),
34+
dict(
35+
type='MixUp',
36+
img_scale=img_scale,
37+
ratio_range=(0.8, 1.6),
38+
pad_val=114.0),
39+
dict(
40+
type='PhotoMetricDistortion',
41+
brightness_delta=32,
42+
contrast_range=(0.5, 1.5),
43+
saturation_range=(0.5, 1.5),
44+
hue_delta=18),
45+
dict(type='RandomFlip', flip_ratio=0.5),
46+
dict(type='Resize', keep_ratio=True),
47+
dict(type='Pad', pad_to_square=True, pad_val=114.0),
48+
dict(type='Normalize', **img_norm_cfg),
49+
dict(type='DefaultFormatBundle'),
50+
dict(type='Collect', keys=['img', 'gt_bboxes', 'gt_labels'])
51+
]
52+
53+
train_dataset = dict(
54+
type='MultiImageMixDataset',
55+
dataset=dict(
56+
type=dataset_type,
57+
ann_file=data_root + 'annotations/instances_train2017.json',
58+
img_prefix=data_root + 'train2017/',
59+
pipeline=[
60+
dict(type='LoadImageFromFile', to_float32=True),
61+
dict(type='LoadAnnotations', with_bbox=True)
62+
],
63+
filter_empty_gt=False,
64+
),
65+
pipeline=train_pipeline,
66+
dynamic_scale=img_scale)
67+
68+
test_pipeline = [
69+
dict(type='LoadImageFromFile'),
70+
dict(
71+
type='MultiScaleFlipAug',
72+
img_scale=img_scale,
73+
flip=False,
74+
transforms=[
75+
dict(type='Resize', keep_ratio=True),
76+
dict(type='RandomFlip'),
77+
dict(type='Pad', size=img_scale, pad_val=114.0),
78+
dict(type='Normalize', **img_norm_cfg),
79+
dict(type='DefaultFormatBundle'),
80+
dict(type='Collect', keys=['img'])
81+
])
82+
]
83+
84+
data = dict(
85+
samples_per_gpu=8,
86+
workers_per_gpu=2,
87+
train=train_dataset,
88+
val=dict(
89+
type=dataset_type,
90+
ann_file=data_root + 'annotations/instances_val2017.json',
91+
img_prefix=data_root + 'val2017/',
92+
pipeline=test_pipeline),
93+
test=dict(
94+
type=dataset_type,
95+
ann_file=data_root + 'annotations/instances_val2017.json',
96+
img_prefix=data_root + 'val2017/',
97+
pipeline=test_pipeline))
98+
99+
# optimizer
100+
# default 8 gpu
101+
optimizer = dict(
102+
type='SGD',
103+
lr=0.01,
104+
momentum=0.9,
105+
weight_decay=5e-4,
106+
nesterov=True,
107+
paramwise_cfg=dict(norm_decay_mult=0., bias_decay_mult=0.))
108+
optimizer_config = dict(grad_clip=None)
109+
110+
# learning policy
111+
lr_config = dict(
112+
_delete_=True,
113+
policy='YOLOX',
114+
warmup='exp',
115+
by_epoch=False,
116+
warmup_by_epoch=True,
117+
warmup_ratio=1,
118+
warmup_iters=5, # 5 epoch
119+
num_last_epochs=15,
120+
min_lr_ratio=0.05)
121+
runner = dict(type='EpochBasedRunner', max_epochs=300)
122+
123+
resume_from = None
124+
interval = 10
125+
126+
custom_hooks = [
127+
dict(type='YOLOXModeSwitchHook', num_last_epochs=15, priority=48),
128+
dict(
129+
type='SyncRandomSizeHook',
130+
ratio_range=(14, 26),
131+
img_scale=img_scale,
132+
interval=interval,
133+
priority=48),
134+
dict(
135+
type='SyncNormHook',
136+
num_last_epochs=15,
137+
interval=interval,
138+
priority=48),
139+
dict(type='ExpMomentumEMAHook', resume_from=resume_from, priority=49)
140+
]
141+
checkpoint_config = dict(interval=interval)
142+
evaluation = dict(interval=interval, metric='bbox')
143+
log_config = dict(interval=50)
+79
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,79 @@
1+
_base_ = './yolox_s_8x8_300e_coco.py'
2+
3+
# model settings
4+
model = dict(
5+
backbone=dict(deepen_factor=0.33, widen_factor=0.375),
6+
neck=dict(in_channels=[96, 192, 384], out_channels=96),
7+
bbox_head=dict(in_channels=96, feat_channels=96))
8+
9+
# dataset settings
10+
img_norm_cfg = dict(
11+
mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], to_rgb=True)
12+
13+
img_scale = (640, 640)
14+
15+
train_pipeline = [
16+
dict(type='Mosaic', img_scale=img_scale, pad_val=114.0),
17+
dict(
18+
type='RandomAffine',
19+
scaling_ratio_range=(0.5, 1.5),
20+
border=(-img_scale[0] // 2, -img_scale[1] // 2)),
21+
dict(
22+
type='PhotoMetricDistortion',
23+
brightness_delta=32,
24+
contrast_range=(0.5, 1.5),
25+
saturation_range=(0.5, 1.5),
26+
hue_delta=18),
27+
dict(type='RandomFlip', flip_ratio=0.5),
28+
dict(type='Resize', keep_ratio=True),
29+
dict(type='Pad', pad_to_square=True, pad_val=114.0),
30+
dict(type='Normalize', **img_norm_cfg),
31+
dict(type='DefaultFormatBundle'),
32+
dict(type='Collect', keys=['img', 'gt_bboxes', 'gt_labels'])
33+
]
34+
35+
test_pipeline = [
36+
dict(type='LoadImageFromFile'),
37+
dict(
38+
type='MultiScaleFlipAug',
39+
img_scale=(416, 416),
40+
flip=False,
41+
transforms=[
42+
dict(type='Resize', keep_ratio=True),
43+
dict(type='RandomFlip'),
44+
dict(type='Pad', size=(416, 416), pad_val=114.0),
45+
dict(type='Normalize', **img_norm_cfg),
46+
dict(type='DefaultFormatBundle'),
47+
dict(type='Collect', keys=['img'])
48+
])
49+
]
50+
51+
train_dataset = dict(pipeline=train_pipeline)
52+
53+
data = dict(
54+
train=train_dataset,
55+
val=dict(pipeline=test_pipeline),
56+
test=dict(pipeline=test_pipeline))
57+
58+
resume_from = None
59+
interval = 10
60+
61+
# Execute in the order of insertion when the priority is the same.
62+
# The smaller the value, the higher the priority
63+
custom_hooks = [
64+
dict(type='YOLOXModeSwitchHook', num_last_epochs=15, priority=48),
65+
dict(
66+
type='SyncRandomSizeHook',
67+
ratio_range=(10, 20),
68+
img_scale=img_scale,
69+
interval=interval,
70+
priority=48),
71+
dict(
72+
type='SyncNormHook',
73+
num_last_epochs=15,
74+
interval=interval,
75+
priority=48),
76+
dict(type='ExpMomentumEMAHook', resume_from=resume_from, priority=49)
77+
]
78+
checkpoint_config = dict(interval=interval)
79+
evaluation = dict(interval=interval, metric='bbox')
+8
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,8 @@
1+
_base_ = './yolox_s_8x8_300e_coco.py'
2+
3+
# model settings
4+
model = dict(
5+
backbone=dict(deepen_factor=1.33, widen_factor=1.25),
6+
neck=dict(
7+
in_channels=[320, 640, 1280], out_channels=320, num_csp_blocks=4),
8+
bbox_head=dict(in_channels=320, feat_channels=320))

docs/model_zoo.md

+4
Original file line numberDiff line numberDiff line change
@@ -186,6 +186,10 @@ Please refer to [ResNeSt](https://github.com/open-mmlab/mmdetection/blob/master/
186186

187187
Please refer to [DETR](https://github.com/open-mmlab/mmdetection/blob/master/configs/detr) for details.
188188

189+
### YOLOX
190+
191+
Please refer to [YOLOX](https://github.com/open-mmlab/mmdetection/blob/master/configs/yolox) for details.
192+
189193
### Other datasets
190194

191195
We also benchmark some methods on [PASCAL VOC](https://github.com/open-mmlab/mmdetection/blob/master/configs/pascal_voc), [Cityscapes](https://github.com/open-mmlab/mmdetection/blob/master/configs/cityscapes) and [WIDER FACE](https://github.com/open-mmlab/mmdetection/blob/master/configs/wider_face).

model-index.yml

+58
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,58 @@
1+
Import:
2+
- configs/atss/metafile.yml
3+
- configs/autoassign/metafile.yml
4+
- configs/cascade_rcnn/metafile.yml
5+
- configs/centernet/metafile.yml
6+
- configs/centripetalnet/metafile.yml
7+
- configs/cornernet/metafile.yml
8+
- configs/dcn/metafile.yml
9+
- configs/deformable_detr/metafile.yml
10+
- configs/detectors/metafile.yml
11+
- configs/detr/metafile.yml
12+
- configs/double_heads/metafile.yml
13+
- configs/dynamic_rcnn/metafile.yml
14+
- configs/empirical_attention/metafile.yml
15+
- configs/faster_rcnn/metafile.yml
16+
- configs/fcos/metafile.yml
17+
- configs/foveabox/metafile.yml
18+
- configs/fp16/metafile.yml
19+
- configs/fpg/metafile.yml
20+
- configs/free_anchor/metafile.yml
21+
- configs/fsaf/metafile.yml
22+
- configs/gcnet/metafile.yml
23+
- configs/gfl/metafile.yml
24+
- configs/ghm/metafile.yml
25+
- configs/gn/metafile.yml
26+
- configs/gn+ws/metafile.yml
27+
- configs/grid_rcnn/metafile.yml
28+
- configs/groie/metafile.yml
29+
- configs/guided_anchoring/metafile.yml
30+
- configs/hrnet/metafile.yml
31+
- configs/htc/metafile.yml
32+
- configs/instaboost/metafile.yml
33+
- configs/ld/metafile.yml
34+
- configs/libra_rcnn/metafile.yml
35+
- configs/mask_rcnn/metafile.yml
36+
- configs/ms_rcnn/metafile.yml
37+
- configs/nas_fcos/metafile.yml
38+
- configs/nas_fpn/metafile.yml
39+
- configs/paa/metafile.yml
40+
- configs/pafpn/metafile.yml
41+
- configs/pisa/metafile.yml
42+
- configs/point_rend/metafile.yml
43+
- configs/regnet/metafile.yml
44+
- configs/reppoints/metafile.yml
45+
- configs/res2net/metafile.yml
46+
- configs/resnest/metafile.yml
47+
- configs/retinanet/metafile.yml
48+
- configs/sabl/metafile.yml
49+
- configs/scnet/metafile.yml
50+
- configs/scratch/metafile.yml
51+
- configs/sparse_rcnn/metafile.yml
52+
- configs/ssd/metafile.yml
53+
- configs/tridentnet/metafile.yml
54+
- configs/vfnet/metafile.yml
55+
- configs/yolact/metafile.yml
56+
- configs/yolo/metafile.yml
57+
- configs/yolof/metafile.yml
58+
- configs/yolox/metafile.yml

0 commit comments

Comments
 (0)