Skip to content

Commit 55b2641

Browse files
authored
improve test durations (#252)
* improve test durations * update tests * fix test
1 parent 842c39f commit 55b2641

File tree

2 files changed

+36
-37
lines changed

2 files changed

+36
-37
lines changed

tests/test_predict.py

+22-23
Original file line numberDiff line numberDiff line change
@@ -45,15 +45,12 @@ def test_get_prediction_mmdet(self):
4545

4646
# get full sized prediction
4747
prediction_result = get_prediction(
48-
image=image,
49-
detection_model=mmdet_detection_model,
50-
shift_amount=[0, 0],
51-
full_shape=None,
48+
image=image, detection_model=mmdet_detection_model, shift_amount=[0, 0], full_shape=None, image_size=320
5249
)
5350
object_prediction_list = prediction_result.object_prediction_list
5451

5552
# compare
56-
self.assertEqual(len(object_prediction_list), 23)
53+
self.assertEqual(len(object_prediction_list), 4)
5754
num_person = 0
5855
for object_prediction in object_prediction_list:
5956
if object_prediction.category.name == "person":
@@ -63,23 +60,23 @@ def test_get_prediction_mmdet(self):
6360
for object_prediction in object_prediction_list:
6461
if object_prediction.category.name == "truck":
6562
num_truck += 1
66-
self.assertEqual(num_truck, 3)
63+
self.assertEqual(num_truck, 0)
6764
num_car = 0
6865
for object_prediction in object_prediction_list:
6966
if object_prediction.category.name == "car":
7067
num_car += 1
71-
self.assertEqual(num_car, 20)
68+
self.assertEqual(num_car, 3)
7269

7370
def test_get_prediction_yolov5(self):
7471
from sahi.model import Yolov5DetectionModel
7572
from sahi.predict import get_prediction
76-
from sahi.utils.yolov5 import Yolov5TestConstants, download_yolov5s6_model
73+
from sahi.utils.yolov5 import Yolov5TestConstants, download_yolov5n_model
7774

7875
# init model
79-
download_yolov5s6_model()
76+
download_yolov5n_model()
8077

8178
yolov5_detection_model = Yolov5DetectionModel(
82-
model_path=Yolov5TestConstants.YOLOV5S6_MODEL_PATH,
79+
model_path=Yolov5TestConstants.YOLOV5N_MODEL_PATH,
8380
confidence_threshold=0.3,
8481
device=None,
8582
category_remapping=None,
@@ -98,7 +95,7 @@ def test_get_prediction_yolov5(self):
9895
object_prediction_list = prediction_result.object_prediction_list
9996

10097
# compare
101-
self.assertEqual(len(object_prediction_list), 12)
98+
self.assertEqual(len(object_prediction_list), 15)
10299
num_person = 0
103100
for object_prediction in object_prediction_list:
104101
if object_prediction.category.name == "person":
@@ -113,7 +110,7 @@ def test_get_prediction_yolov5(self):
113110
for object_prediction in object_prediction_list:
114111
if object_prediction.category.name == "car":
115112
num_car += 1
116-
self.assertEqual(num_car, 12)
113+
self.assertEqual(num_car, 15)
117114

118115
def test_get_sliced_prediction_mmdet(self):
119116
from sahi.model import MmdetDetectionModel
@@ -144,10 +141,12 @@ def test_get_sliced_prediction_mmdet(self):
144141
match_metric = "IOS"
145142
match_threshold = 0.5
146143
class_agnostic = True
144+
image_size = 320
147145

148146
# get sliced prediction
149147
prediction_result = get_sliced_prediction(
150148
image=image_path,
149+
image_size=image_size,
151150
detection_model=mmdet_detection_model,
152151
slice_height=slice_height,
153152
slice_width=slice_width,
@@ -162,7 +161,7 @@ def test_get_sliced_prediction_mmdet(self):
162161
object_prediction_list = prediction_result.object_prediction_list
163162

164163
# compare
165-
self.assertEqual(len(object_prediction_list), 24)
164+
self.assertEqual(len(object_prediction_list), 13)
166165
num_person = 0
167166
for object_prediction in object_prediction_list:
168167
if object_prediction.category.name == "person":
@@ -172,23 +171,23 @@ def test_get_sliced_prediction_mmdet(self):
172171
for object_prediction in object_prediction_list:
173172
if object_prediction.category.name == "truck":
174173
num_truck += 2
175-
self.assertEqual(num_truck, 4)
174+
self.assertEqual(num_truck, 0)
176175
num_car = 0
177176
for object_prediction in object_prediction_list:
178177
if object_prediction.category.name == "car":
179178
num_car += 1
180-
self.assertEqual(num_car, 22)
179+
self.assertEqual(num_car, 13)
181180

182181
def test_get_sliced_prediction_yolov5(self):
183182
from sahi.model import Yolov5DetectionModel
184183
from sahi.predict import get_sliced_prediction
185-
from sahi.utils.yolov5 import Yolov5TestConstants, download_yolov5s6_model
184+
from sahi.utils.yolov5 import Yolov5TestConstants, download_yolov5n_model
186185

187186
# init model
188-
download_yolov5s6_model()
187+
download_yolov5n_model()
189188

190189
yolov5_detection_model = Yolov5DetectionModel(
191-
model_path=Yolov5TestConstants.YOLOV5S6_MODEL_PATH,
190+
model_path=Yolov5TestConstants.YOLOV5N_MODEL_PATH,
192191
confidence_threshold=0.3,
193192
device=None,
194193
category_remapping=None,
@@ -225,7 +224,7 @@ def test_get_sliced_prediction_yolov5(self):
225224
object_prediction_list = prediction_result.object_prediction_list
226225

227226
# compare
228-
self.assertEqual(len(object_prediction_list), 21)
227+
self.assertEqual(len(object_prediction_list), 19)
229228
num_person = 0
230229
for object_prediction in object_prediction_list:
231230
if object_prediction.category.name == "person":
@@ -240,12 +239,12 @@ def test_get_sliced_prediction_yolov5(self):
240239
for object_prediction in object_prediction_list:
241240
if object_prediction.category.name == "car":
242241
num_car += 1
243-
self.assertEqual(num_car, 21)
242+
self.assertEqual(num_car, 19)
244243

245244
def test_coco_json_prediction(self):
246245
from sahi.predict import predict
247246
from sahi.utils.mmdet import MmdetTestConstants, download_mmdet_cascade_mask_rcnn_model
248-
from sahi.utils.yolov5 import Yolov5TestConstants, download_yolov5s6_model
247+
from sahi.utils.yolov5 import Yolov5TestConstants, download_yolov5n_model
249248

250249
# init model
251250
download_mmdet_cascade_mask_rcnn_model()
@@ -292,7 +291,7 @@ def test_coco_json_prediction(self):
292291
)
293292

294293
# init model
295-
download_yolov5s6_model()
294+
download_yolov5n_model()
296295

297296
# prepare paths
298297
dataset_json_path = "tests/data/coco_utils/terrain_all_coco.json"
@@ -304,7 +303,7 @@ def test_coco_json_prediction(self):
304303
shutil.rmtree(project_dir)
305304
predict(
306305
model_type="yolov5",
307-
model_path=Yolov5TestConstants.YOLOV5S6_MODEL_PATH,
306+
model_path=Yolov5TestConstants.YOLOV5N_MODEL_PATH,
308307
model_config_path=None,
309308
model_confidence_threshold=0.4,
310309
model_device=None,

tests/test_yolov5model.py

+14-14
Original file line numberDiff line numberDiff line change
@@ -6,17 +6,17 @@
66
import numpy as np
77

88
from sahi.utils.cv import read_image
9-
from sahi.utils.yolov5 import Yolov5TestConstants, download_yolov5s6_model
9+
from sahi.utils.yolov5 import Yolov5TestConstants, download_yolov5n_model, download_yolov5s6_model
1010

1111

1212
class TestYolov5DetectionModel(unittest.TestCase):
1313
def test_load_model(self):
1414
from sahi.model import Yolov5DetectionModel
1515

16-
download_yolov5s6_model()
16+
download_yolov5n_model()
1717

1818
yolov5_detection_model = Yolov5DetectionModel(
19-
model_path=Yolov5TestConstants.YOLOV5S6_MODEL_PATH,
19+
model_path=Yolov5TestConstants.YOLOV5N_MODEL_PATH,
2020
confidence_threshold=0.3,
2121
device=None,
2222
category_remapping=None,
@@ -29,10 +29,10 @@ def test_perform_inference(self):
2929
from sahi.model import Yolov5DetectionModel
3030

3131
# init model
32-
download_yolov5s6_model()
32+
download_yolov5n_model()
3333

3434
yolov5_detection_model = Yolov5DetectionModel(
35-
model_path=Yolov5TestConstants.YOLOV5S6_MODEL_PATH,
35+
model_path=Yolov5TestConstants.YOLOV5N_MODEL_PATH,
3636
confidence_threshold=0.5,
3737
device=None,
3838
category_remapping=None,
@@ -56,7 +56,7 @@ def test_perform_inference(self):
5656
break
5757

5858
# compare
59-
desired_bbox = [321, 322, 383, 362]
59+
desired_bbox = [321, 325, 384, 365]
6060
predicted_bbox = list(map(int, box[:4].tolist()))
6161
margin = 2
6262
for ind, point in enumerate(predicted_bbox):
@@ -67,10 +67,10 @@ def test_convert_original_predictions(self):
6767
from sahi.model import Yolov5DetectionModel
6868

6969
# init model
70-
download_yolov5s6_model()
70+
download_yolov5n_model()
7171

7272
yolov5_detection_model = Yolov5DetectionModel(
73-
model_path=Yolov5TestConstants.YOLOV5S6_MODEL_PATH,
73+
model_path=Yolov5TestConstants.YOLOV5N_MODEL_PATH,
7474
confidence_threshold=0.5,
7575
device=None,
7676
category_remapping=None,
@@ -89,20 +89,20 @@ def test_convert_original_predictions(self):
8989
object_prediction_list = yolov5_detection_model.object_prediction_list
9090

9191
# compare
92-
self.assertEqual(len(object_prediction_list), 9)
92+
self.assertEqual(len(object_prediction_list), 8)
9393
self.assertEqual(object_prediction_list[0].category.id, 2)
9494
self.assertEqual(object_prediction_list[0].category.name, "car")
95-
desired_bbox = [321, 322, 62, 40]
95+
desired_bbox = [321, 325, 63, 40]
9696
predicted_bbox = object_prediction_list[0].bbox.to_coco_bbox()
9797
margin = 2
9898
for ind, point in enumerate(predicted_bbox):
9999
assert point < desired_bbox[ind] + margin and point > desired_bbox[ind] - margin
100100
self.assertEqual(object_prediction_list[5].category.id, 2)
101101
self.assertEqual(object_prediction_list[5].category.name, "car")
102-
self.assertEqual(
103-
object_prediction_list[5].bbox.to_coco_bbox(),
104-
[617, 195, 24, 23],
105-
)
102+
desired_bbox = [701, 234, 20, 17]
103+
predicted_bbox = object_prediction_list[5].bbox.to_coco_bbox()
104+
for ind, point in enumerate(predicted_bbox):
105+
assert point < desired_bbox[ind] + margin and point > desired_bbox[ind] - margin
106106

107107
def test_create_original_predictions_from_object_prediction_list(
108108
self,

0 commit comments

Comments
 (0)