Skip to content

Commit 2b0eeb6

Browse files
committed
update ssd infer bug
1 parent f992a1c commit 2b0eeb6

File tree

10 files changed

+80
-50
lines changed

10 files changed

+80
-50
lines changed

configs/ssd/_base_/ssd_mobilenet_reader.yml

+1-1
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@ TrainReader:
66
- Decode: {}
77
- RandomDistort: {brightness: [0.5, 1.125, 0.875], random_apply: False}
88
- RandomExpand: {fill_value: [127.5, 127.5, 127.5]}
9-
- RandomCrop: {allow_no_crop: Fasle}
9+
- RandomCrop: {allow_no_crop: False}
1010
- RandomFlip: {}
1111
- Resize: {target_size: [300, 300], keep_ratio: False, interp: 1}
1212
- NormalizeBox: {}

configs/ssd/_base_/ssd_vgg16_300.yml

+1-1
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
architecture: SSD
22
pretrain_weights: https://paddledet.bj.bcebos.com/models/pretrained/VGG16_caffe_pretrained.pdparams
33

4-
# Model Achitecture
4+
# Model Architecture
55
SSD:
66
# model feat info flow
77
backbone: VGG

ppdettorch/modeling/architectures/ssd.py

+6-3
Original file line numberDiff line numberDiff line change
@@ -75,9 +75,12 @@ def _forward(self):
7575
self.inputs['gt_class'])
7676
else:
7777
preds, anchors = self.ssd_head(body_feats, self.inputs['image'])
78-
bbox, bbox_num = self.post_process(preds, anchors,
79-
self.inputs['im_shape'],
80-
self.inputs['scale_factor'])
78+
79+
bbox, bbox_num, before_nms_indexes = self.post_process(preds, anchors,
80+
self.inputs['im_shape'],
81+
self.inputs['scale_factor'])
82+
83+
return bbox, bbox_num
8184
return bbox, bbox_num
8285

8386
def get_loss(self, ):

ppdettorch/modeling/backbones/__init__.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
# @Author :sl
66
# @Date :2022/11/1 14:45
77

8-
# from . import vgg
8+
from . import vgg
99
from . import resnet
1010
from . import darknet
1111
from . import mobilenet_v1
@@ -32,7 +32,7 @@
3232
# from . import vision_transformer
3333
# from . import mobileone
3434

35-
# from .vgg import *
35+
from .vgg import *
3636
from .resnet import *
3737
from .darknet import *
3838
from .mobilenet_v1 import *

ppdettorch/modeling/backbones/mobilenet_v3.py

+7-5
Original file line numberDiff line numberDiff line change
@@ -202,14 +202,15 @@ def __init__(self,
202202
norm_type='bn',
203203
norm_decay=0.,
204204
freeze_norm=False,
205-
name=None):
205+
name=None,
206+
padding=3):
206207
super(ExtraBlockDW, self).__init__()
207208
self.pointwise_conv = ConvBNLayer(
208209
in_c=in_c,
209210
out_c=ch_1,
210211
filter_size=1,
211212
stride=1,
212-
padding='SAME',
213+
padding=0,
213214
act='relu6',
214215
lr_mult=lr_mult,
215216
conv_decay=conv_decay,
@@ -222,7 +223,7 @@ def __init__(self,
222223
out_c=ch_2,
223224
filter_size=3,
224225
stride=stride,
225-
padding='SAME',
226+
padding=padding, # TODO: same padding
226227
num_groups=int(ch_1),
227228
act='relu6',
228229
lr_mult=lr_mult,
@@ -236,7 +237,7 @@ def __init__(self,
236237
out_c=ch_2,
237238
filter_size=1,
238239
stride=1,
239-
padding='SAME',
240+
padding=0,
240241
act='relu6',
241242
lr_mult=lr_mult,
242243
conv_decay=conv_decay,
@@ -412,7 +413,8 @@ def __init__(
412413
norm_type=norm_type,
413414
norm_decay=norm_decay,
414415
freeze_norm=freeze_norm,
415-
name='conv' + str(i + 2))
416+
name='conv' + str(i + 2),
417+
padding=3) # TODO: calc same padding
416418
self.add_module("conv" + str(i + 2), module=conv_extra)
417419
self.extra_block_list.append(conv_extra)
418420
i += 1

ppdettorch/modeling/backbones/vgg.py

+18-25
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,6 @@
44
import torch.nn as nn
55
import torch.nn.functional as F
66

7-
87
from torch.nn import Conv2d, MaxPool2d
98
from ppdettorch.core.workspace import register, serializable
109
from ..shape_spec import ShapeSpec
@@ -34,21 +33,18 @@ def __init__(self,
3433
padding=1)
3534
self.conv_out_list = []
3635
for i in range(1, groups):
37-
conv_out = self.add_sublayer(
38-
'conv{}'.format(i),
39-
Conv2d(
40-
in_channels=out_channels,
41-
out_channels=out_channels,
42-
kernel_size=3,
43-
stride=1,
44-
padding=1))
36+
conv_out = Conv2d(in_channels=out_channels,
37+
out_channels=out_channels,
38+
kernel_size=3,
39+
stride=1,
40+
padding=1)
41+
self.add_module('conv{}'.format(i), conv_out)
4542
self.conv_out_list.append(conv_out)
4643

47-
self.pool = MaxPool2d(
48-
kernel_size=pool_size,
49-
stride=pool_stride,
50-
padding=pool_padding,
51-
ceil_mode=True)
44+
self.pool = MaxPool2d(kernel_size=pool_size,
45+
stride=pool_stride,
46+
padding=pool_padding,
47+
ceil_mode=True)
5248

5349
def forward(self, inputs):
5450
out = self.conv0(inputs)
@@ -95,12 +91,10 @@ def forward(self, inputs):
9591
class L2NormScale(nn.Module):
9692
def __init__(self, num_channels, scale=1.0):
9793
super(L2NormScale, self).__init__()
98-
self.scale = self.create_parameter(
99-
attr=ParamAttr(initializer=paddle.nn.initializer.Constant(scale)),
100-
shape=[num_channels])
94+
self.scale = nn.Parameter(torch.ones([num_channels]), requires_grad=False)
10195

10296
def forward(self, inputs):
103-
out = F.normalize(inputs, axis=1, epsilon=1e-10)
97+
out = F.normalize(inputs, dim=1, eps=1e-10)
10498
# out = self.scale.unsqueeze(0).unsqueeze(2).unsqueeze(3).expand_as(
10599
# out) * out
106100
out = self.scale.unsqueeze(0).unsqueeze(2).unsqueeze(3) * out
@@ -119,7 +113,7 @@ def __init__(self,
119113
super(VGG, self).__init__()
120114

121115
assert depth in [16, 19], \
122-
"depth as 16/19 supported currently, but got {}".format(depth)
116+
"depth as 16/19 supported currently, but got {}".format(depth)
123117
self.depth = depth
124118
self.groups = VGG_cfg[depth]
125119
self.normalizations = normalizations
@@ -159,19 +153,18 @@ def __init__(self,
159153
last_channels = 1024
160154
for i, v in enumerate(self.extra_block_filters):
161155
assert len(v) == 5, "extra_block_filters size not fix"
162-
extra_conv = self.add_sublayer("conv{}".format(6 + i),
163-
ExtraBlock(last_channels, v[0], v[1],
164-
v[2], v[3], v[4]))
156+
extra_conv = ExtraBlock(last_channels, v[0], v[1],
157+
v[2], v[3], v[4])
158+
self.add_module("conv{}".format(6 + i), extra_conv)
165159
last_channels = v[1]
166160
self.extra_convs.append(extra_conv)
167161
self._out_channels.append(last_channels)
168162

169163
self.norms = []
170164
for i, n in enumerate(self.normalizations):
171165
if n != -1:
172-
norm = self.add_sublayer("norm{}".format(i),
173-
L2NormScale(
174-
self.extra_block_filters[i][1], n))
166+
norm = L2NormScale(self.extra_block_filters[i][1], n)
167+
self.add_module("norm{}".format(i), norm)
175168
else:
176169
norm = None
177170
self.norms.append(norm)

ppdettorch/modeling/bbox_utils.py

+5-1
Original file line numberDiff line numberDiff line change
@@ -885,7 +885,7 @@ def ssd_prior_box_np(
885885
Default: None, means [] and will not be used.
886886
aspect_ratios (list|tuple|float, optional): the aspect ratios of generated
887887
prior boxes. Default: [1.0].
888-
variances (list|tuple, optional): the variances to be encoded in prior boxes.
888+
variance (list|tuple, optional): the variances to be encoded in prior boxes.
889889
Default:[0.1, 0.1, 0.2, 0.2].
890890
flip (bool): Whether to flip aspect ratios. Default:False.
891891
clip (bool): Whether to clip out-of-boundary boxes. Default: False.
@@ -941,6 +941,10 @@ def ssd_prior_box_np(
941941
if flip:
942942
real_aspect_ratios.append(1.0 / ar)
943943

944+
if step_w == 0 or step_h == 0:
945+
step_w = image_w / layer_w
946+
step_h = image_h / layer_h
947+
944948
num_priors = len(real_aspect_ratios) * len(min_sizes)
945949
if max_sizes is None:
946950
max_sizes = []

ppdettorch/modeling/layers.py

+6-5
Original file line numberDiff line numberDiff line change
@@ -258,29 +258,29 @@ def __init__(self,
258258
stride=stride,
259259
groups=in_channels,
260260
norm_type=norm_type,
261-
)
261+
)
262262
conv2 = ConvNormLayer(
263263
in_channels,
264264
out_channels,
265265
filter_size=1,
266266
stride=stride,
267267
norm_type=norm_type,
268-
)
268+
)
269269
conv3 = ConvNormLayer(
270270
out_channels,
271271
out_channels,
272272
filter_size=1,
273273
stride=stride,
274274
norm_type=norm_type,
275-
)
275+
)
276276
conv4 = ConvNormLayer(
277277
out_channels,
278278
out_channels,
279279
filter_size=5,
280280
stride=stride,
281281
groups=out_channels,
282282
norm_type=norm_type,
283-
)
283+
)
284284
conv_list = [conv1, conv2, conv3, conv4]
285285
self.lite_conv.add_sublayer('conv1', conv1)
286286
self.lite_conv.add_sublayer('relu6_1', nn.ReLU6())
@@ -627,6 +627,7 @@ def __call__(self,
627627

628628
return yolo_boxes, yolo_scores
629629

630+
630631
class YOLOLayer(nn.Module):
631632
"""Detection layer"""
632633

@@ -720,7 +721,7 @@ def __call__(self,
720721
output_boxes *= im_shape
721722
else:
722723
output_boxes[..., -2:] -= 1.0
723-
output_scores = F.softmax(torch.concat(scores, dim=1)).permute(0, 2, 1)
724+
output_scores = F.softmax(torch.concat(scores, dim=1), dim=-1).permute(0, 2, 1)
724725

725726
return output_boxes, output_scores
726727

ppdettorch/modeling/post_process.py

+5-2
Original file line numberDiff line numberDiff line change
@@ -62,7 +62,7 @@ def __call__(self, head_out, rois, im_shape, scale_factor):
6262
"""
6363
if self.nms is not None:
6464
bboxes, score = self.decode(head_out, rois, im_shape, scale_factor)
65-
bbox_pred, bbox_num, _ = self.nms(bboxes, score, self.num_classes)
65+
bbox_pred, bbox_num, before_nms_indexes = self.nms(bboxes, score, self.num_classes)
6666

6767
else:
6868
bbox_pred, bbox_num = self.decode(head_out, rois, im_shape,
@@ -77,7 +77,10 @@ def __call__(self, head_out, rois, im_shape, scale_factor):
7777
bbox_pred = torch.concat([bbox_pred, fake_bboxes])
7878
bbox_num = bbox_num + 1
7979

80-
return bbox_pred, bbox_num
80+
if self.nms is not None:
81+
return bbox_pred, bbox_num, before_nms_indexes
82+
else:
83+
return bbox_pred, bbox_num
8184

8285
def get_pred(self, bboxes, bbox_num, im_shape, scale_factor):
8386
"""

tests/process/infer/run_detection_infer.py

+29-5
Original file line numberDiff line numberDiff line change
@@ -163,7 +163,10 @@ def run_picodet_coco(self, config_name=None):
163163
# config_name = f"yolov8_n_500e_coco.yml"
164164

165165
# ssd
166-
config_name = f"ssd_mobilenet_v1_300_120e_voc.yml"
166+
# config_name = f"ssd_mobilenet_v1_300_120e_voc.yml"
167+
config_name = f"ssd_vgg16_300_240e_voc.yml"
168+
# config_name = f"ssdlite_mobilenet_v3_large_320_coco.yml"
169+
# config_name = f"ssdlite_mobilenet_v3_small_320_coco.yml"
167170

168171

169172
# run_arg = DetectionInferUtils.init_args()
@@ -176,6 +179,8 @@ def run_picodet_coco(self, config_name=None):
176179
model_class = "picodet"
177180
elif "ppyoloe" in config_name:
178181
model_class = "ppyoloe"
182+
elif "ssdlite_" in config_name:
183+
model_class = "ssd"
179184
else:
180185
config_name_end_index = FileUtils.get_file_name(config_name).find("_")
181186
model_class = config_name[:config_name_end_index]
@@ -229,14 +234,27 @@ def run_picodet_coco_batch():
229234
# model_class = "yolov6"
230235
# model_class = "yolov7"
231236
# model_class = "rtmdet"
232-
model_class = "yolov8"
237+
# model_class = "yolov8"
238+
model_class = "ssd"
233239

234240
with_application = False
235241
# with_application = True
236242

237243
# do_transform = False
238244
do_transform = True
239245

246+
# 需要跳过执行验证的列表
247+
skip_config_name_dict = {
248+
"yolov3": [
249+
"yolov3_darknet53_original_270e_coco.yml",
250+
"yolov3_mobilenet_v1_roadsign.yml"
251+
],
252+
"ssd": [
253+
"ssd_r34_70e_coco.yml",
254+
"ssdlite_ghostnet_320_coco.yml"
255+
]
256+
}
257+
240258
# base_dir = f"/home/mqq/shenglei/ocr/PaddleDetection/configs/{model_class}"
241259
base_dir = f"{Constants.WORK_DIR}/configs/{model_class}"
242260
if with_application:
@@ -249,14 +267,20 @@ def run_picodet_coco_batch():
249267
end_with=".yml", )
250268

251269
logger.info(f"total: {len(file_name_list)}")
252-
skip = 0
270+
271+
skip_config_name_list = skip_config_name_dict.get(model_class, [])
272+
skip = 1
253273
detection_runner = DetectionRunInfer()
254274

255275
for index, file_name in enumerate(file_name_list):
256276
if index < skip:
257277
logger.info(f"跳过已经执行的:{index} - {file_name}")
258278
continue
259279

280+
if f"{FileUtils.get_file_name(file_name)}.yml" in skip_config_name_list:
281+
logger.info(f"跳过无需测试的:{index} - {file_name}")
282+
continue
283+
260284
if "_xpu" in file_name:
261285
continue
262286
if "ppyoloe_crn_m_80e_pcb" in file_name:
@@ -269,5 +293,5 @@ def run_picodet_coco_batch():
269293

270294

271295
if __name__ == '__main__':
272-
demo_run_detection_infer()
273-
# run_picodet_coco_batch()
296+
# demo_run_detection_infer()
297+
run_picodet_coco_batch()

0 commit comments

Comments
 (0)