@@ -45,15 +45,12 @@ def test_get_prediction_mmdet(self):
45
45
46
46
# get full sized prediction
47
47
prediction_result = get_prediction (
48
- image = image ,
49
- detection_model = mmdet_detection_model ,
50
- shift_amount = [0 , 0 ],
51
- full_shape = None ,
48
+ image = image , detection_model = mmdet_detection_model , shift_amount = [0 , 0 ], full_shape = None , image_size = 320
52
49
)
53
50
object_prediction_list = prediction_result .object_prediction_list
54
51
55
52
# compare
56
- self .assertEqual (len (object_prediction_list ), 23 )
53
+ self .assertEqual (len (object_prediction_list ), 4 )
57
54
num_person = 0
58
55
for object_prediction in object_prediction_list :
59
56
if object_prediction .category .name == "person" :
@@ -63,23 +60,23 @@ def test_get_prediction_mmdet(self):
63
60
for object_prediction in object_prediction_list :
64
61
if object_prediction .category .name == "truck" :
65
62
num_truck += 1
66
- self .assertEqual (num_truck , 3 )
63
+ self .assertEqual (num_truck , 0 )
67
64
num_car = 0
68
65
for object_prediction in object_prediction_list :
69
66
if object_prediction .category .name == "car" :
70
67
num_car += 1
71
- self .assertEqual (num_car , 20 )
68
+ self .assertEqual (num_car , 3 )
72
69
73
70
def test_get_prediction_yolov5 (self ):
74
71
from sahi .model import Yolov5DetectionModel
75
72
from sahi .predict import get_prediction
76
- from sahi .utils .yolov5 import Yolov5TestConstants , download_yolov5s6_model
73
+ from sahi .utils .yolov5 import Yolov5TestConstants , download_yolov5n_model
77
74
78
75
# init model
79
- download_yolov5s6_model ()
76
+ download_yolov5n_model ()
80
77
81
78
yolov5_detection_model = Yolov5DetectionModel (
82
- model_path = Yolov5TestConstants .YOLOV5S6_MODEL_PATH ,
79
+ model_path = Yolov5TestConstants .YOLOV5N_MODEL_PATH ,
83
80
confidence_threshold = 0.3 ,
84
81
device = None ,
85
82
category_remapping = None ,
@@ -98,7 +95,7 @@ def test_get_prediction_yolov5(self):
98
95
object_prediction_list = prediction_result .object_prediction_list
99
96
100
97
# compare
101
- self .assertEqual (len (object_prediction_list ), 12 )
98
+ self .assertEqual (len (object_prediction_list ), 15 )
102
99
num_person = 0
103
100
for object_prediction in object_prediction_list :
104
101
if object_prediction .category .name == "person" :
@@ -113,7 +110,7 @@ def test_get_prediction_yolov5(self):
113
110
for object_prediction in object_prediction_list :
114
111
if object_prediction .category .name == "car" :
115
112
num_car += 1
116
- self .assertEqual (num_car , 12 )
113
+ self .assertEqual (num_car , 15 )
117
114
118
115
def test_get_sliced_prediction_mmdet (self ):
119
116
from sahi .model import MmdetDetectionModel
@@ -144,10 +141,12 @@ def test_get_sliced_prediction_mmdet(self):
144
141
match_metric = "IOS"
145
142
match_threshold = 0.5
146
143
class_agnostic = True
144
+ image_size = 320
147
145
148
146
# get sliced prediction
149
147
prediction_result = get_sliced_prediction (
150
148
image = image_path ,
149
+ image_size = image_size ,
151
150
detection_model = mmdet_detection_model ,
152
151
slice_height = slice_height ,
153
152
slice_width = slice_width ,
@@ -162,7 +161,7 @@ def test_get_sliced_prediction_mmdet(self):
162
161
object_prediction_list = prediction_result .object_prediction_list
163
162
164
163
# compare
165
- self .assertEqual (len (object_prediction_list ), 24 )
164
+ self .assertEqual (len (object_prediction_list ), 13 )
166
165
num_person = 0
167
166
for object_prediction in object_prediction_list :
168
167
if object_prediction .category .name == "person" :
@@ -172,23 +171,23 @@ def test_get_sliced_prediction_mmdet(self):
172
171
for object_prediction in object_prediction_list :
173
172
if object_prediction .category .name == "truck" :
174
173
num_truck += 2
175
- self .assertEqual (num_truck , 4 )
174
+ self .assertEqual (num_truck , 0 )
176
175
num_car = 0
177
176
for object_prediction in object_prediction_list :
178
177
if object_prediction .category .name == "car" :
179
178
num_car += 1
180
- self .assertEqual (num_car , 22 )
179
+ self .assertEqual (num_car , 13 )
181
180
182
181
def test_get_sliced_prediction_yolov5 (self ):
183
182
from sahi .model import Yolov5DetectionModel
184
183
from sahi .predict import get_sliced_prediction
185
- from sahi .utils .yolov5 import Yolov5TestConstants , download_yolov5s6_model
184
+ from sahi .utils .yolov5 import Yolov5TestConstants , download_yolov5n_model
186
185
187
186
# init model
188
- download_yolov5s6_model ()
187
+ download_yolov5n_model ()
189
188
190
189
yolov5_detection_model = Yolov5DetectionModel (
191
- model_path = Yolov5TestConstants .YOLOV5S6_MODEL_PATH ,
190
+ model_path = Yolov5TestConstants .YOLOV5N_MODEL_PATH ,
192
191
confidence_threshold = 0.3 ,
193
192
device = None ,
194
193
category_remapping = None ,
@@ -225,7 +224,7 @@ def test_get_sliced_prediction_yolov5(self):
225
224
object_prediction_list = prediction_result .object_prediction_list
226
225
227
226
# compare
228
- self .assertEqual (len (object_prediction_list ), 21 )
227
+ self .assertEqual (len (object_prediction_list ), 19 )
229
228
num_person = 0
230
229
for object_prediction in object_prediction_list :
231
230
if object_prediction .category .name == "person" :
@@ -240,12 +239,12 @@ def test_get_sliced_prediction_yolov5(self):
240
239
for object_prediction in object_prediction_list :
241
240
if object_prediction .category .name == "car" :
242
241
num_car += 1
243
- self .assertEqual (num_car , 21 )
242
+ self .assertEqual (num_car , 19 )
244
243
245
244
def test_coco_json_prediction (self ):
246
245
from sahi .predict import predict
247
246
from sahi .utils .mmdet import MmdetTestConstants , download_mmdet_cascade_mask_rcnn_model
248
- from sahi .utils .yolov5 import Yolov5TestConstants , download_yolov5s6_model
247
+ from sahi .utils .yolov5 import Yolov5TestConstants , download_yolov5n_model
249
248
250
249
# init model
251
250
download_mmdet_cascade_mask_rcnn_model ()
@@ -292,7 +291,7 @@ def test_coco_json_prediction(self):
292
291
)
293
292
294
293
# init model
295
- download_yolov5s6_model ()
294
+ download_yolov5n_model ()
296
295
297
296
# prepare paths
298
297
dataset_json_path = "tests/data/coco_utils/terrain_all_coco.json"
@@ -304,7 +303,7 @@ def test_coco_json_prediction(self):
304
303
shutil .rmtree (project_dir )
305
304
predict (
306
305
model_type = "yolov5" ,
307
- model_path = Yolov5TestConstants .YOLOV5S6_MODEL_PATH ,
306
+ model_path = Yolov5TestConstants .YOLOV5N_MODEL_PATH ,
308
307
model_config_path = None ,
309
308
model_confidence_threshold = 0.4 ,
310
309
model_device = None ,
0 commit comments