Skip to content

Commit 21ecb28

Browse files
authored
improve result visuals/plots (#271)
* improve result visuals/plots * reduce gpu to cpu copy overhead for yolov5 * dont export visuals by default
1 parent 41a67c2 commit 21ecb28

6 files changed

+246
-190
lines changed

demo/inference_for_mmdetection.ipynb

+44-38
Large diffs are not rendered by default.

demo/inference_for_yolov5.ipynb

+37-62
Large diffs are not rendered by default.

sahi/model.py

+20-11
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,22 @@
11
# OBSS SAHI Tool
22
# Code written by Fatih C Akyon, 2020.
33

4+
import logging
5+
import os
46
from typing import Dict, List, Optional, Union
57

68
import numpy as np
79

810
from sahi.prediction import ObjectPrediction
911
from sahi.utils.torch import cuda_is_available, empty_cuda_cache
1012

13+
logger = logging.getLogger(__name__)
14+
logging.basicConfig(
15+
format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
16+
datefmt="%m/%d/%Y %H:%M:%S",
17+
level=os.environ.get("LOGLEVEL", "INFO").upper(),
18+
)
19+
1120

1221
class DetectionModel:
1322
def __init__(
@@ -299,15 +308,15 @@ def _create_object_prediction_list_from_original_predictions(
299308

300309
# ignore invalid predictions
301310
if bbox[0] > bbox[2] or bbox[1] > bbox[3] or bbox[0] < 0 or bbox[1] < 0 or bbox[2] < 0 or bbox[3] < 0:
302-
print(f"ignoring invalid prediction with bbox: {bbox}")
311+
logger.warning(f"ignoring invalid prediction with bbox: {bbox}")
303312
continue
304313
if full_shape is not None and (
305314
bbox[1] > full_shape[0]
306315
or bbox[3] > full_shape[0]
307316
or bbox[0] > full_shape[1]
308317
or bbox[2] > full_shape[1]
309318
):
310-
print(f"ignoring invalid prediction with bbox: {bbox}")
319+
logger.warning(f"ignoring invalid prediction with bbox: {bbox}")
311320
continue
312321

313322
object_prediction = ObjectPrediction(
@@ -461,29 +470,29 @@ def _create_object_prediction_list_from_original_predictions(
461470
original_predictions = self._original_predictions
462471

463472
# handle only first image (batch=1)
464-
predictions_in_xyxy_format = original_predictions.xyxy[0]
473+
predictions_in_xyxy_format = original_predictions.xyxy[0].cpu().detach().numpy()
465474

466475
object_prediction_list = []
467476

468477
# process predictions
469478
for prediction in predictions_in_xyxy_format:
470-
x1 = int(prediction[0].item())
471-
y1 = int(prediction[1].item())
472-
x2 = int(prediction[2].item())
473-
y2 = int(prediction[3].item())
479+
x1 = int(prediction[0])
480+
y1 = int(prediction[1])
481+
x2 = int(prediction[2])
482+
y2 = int(prediction[3])
474483
bbox = [x1, y1, x2, y2]
475-
score = prediction[4].item()
476-
category_id = int(prediction[5].item())
484+
score = prediction[4]
485+
category_id = int(prediction[5])
477486
category_name = original_predictions.names[category_id]
478487

479488
# ignore invalid predictions
480489
if bbox[0] > bbox[2] or bbox[1] > bbox[3] or bbox[0] < 0 or bbox[1] < 0 or bbox[2] < 0 or bbox[3] < 0:
481-
print(f"ignoring invalid prediction with bbox: {bbox}")
490+
logger.warning(f"ignoring invalid prediction with bbox: {bbox}")
482491
continue
483492
if full_shape is not None and (
484493
bbox[1] > full_shape[0] or bbox[3] > full_shape[0] or bbox[0] > full_shape[1] or bbox[2] > full_shape[1]
485494
):
486-
print(f"ignoring invalid prediction with bbox: {bbox}")
495+
logger.warning(f"ignoring invalid prediction with bbox: {bbox}")
487496
continue
488497

489498
object_prediction = ObjectPrediction(

sahi/predict.py

+41-40
Original file line numberDiff line numberDiff line change
@@ -280,15 +280,15 @@ def predict(
280280
postprocess_match_metric: str = "IOS",
281281
postprocess_match_threshold: float = 0.5,
282282
postprocess_class_agnostic: bool = False,
283-
export_visual: bool = True,
283+
export_visual: bool = False,
284284
export_pickle: bool = False,
285285
export_crop: bool = False,
286286
dataset_json_path: bool = None,
287287
project: str = "runs/predict",
288288
name: str = "exp",
289-
visual_bbox_thickness: int = 1,
290-
visual_text_size: float = 0.3,
291-
visual_text_thickness: int = 1,
289+
visual_bbox_thickness: int = None,
290+
visual_text_size: float = None,
291+
visual_text_thickness: int = None,
292292
visual_export_format: str = "png",
293293
verbose: int = 1,
294294
):
@@ -469,43 +469,44 @@ def predict(
469469
coco_prediction_json = coco_prediction.json
470470
if coco_prediction_json["bbox"]:
471471
coco_json.append(coco_prediction_json)
472-
# convert ground truth annotations to object_prediction_list
473-
coco_image: CocoImage = coco.images[ind]
474-
object_prediction_gt_list: List[ObjectPrediction] = []
475-
for coco_annotation in coco_image.annotations:
476-
coco_annotation_dict = coco_annotation.json
477-
category_name = coco_annotation.category_name
478-
full_shape = [coco_image.height, coco_image.width]
479-
object_prediction_gt = ObjectPrediction.from_coco_annotation_dict(
480-
annotation_dict=coco_annotation_dict, category_name=category_name, full_shape=full_shape
472+
if export_visual:
473+
# convert ground truth annotations to object_prediction_list
474+
coco_image: CocoImage = coco.images[ind]
475+
object_prediction_gt_list: List[ObjectPrediction] = []
476+
for coco_annotation in coco_image.annotations:
477+
coco_annotation_dict = coco_annotation.json
478+
category_name = coco_annotation.category_name
479+
full_shape = [coco_image.height, coco_image.width]
480+
object_prediction_gt = ObjectPrediction.from_coco_annotation_dict(
481+
annotation_dict=coco_annotation_dict, category_name=category_name, full_shape=full_shape
482+
)
483+
object_prediction_gt_list.append(object_prediction_gt)
484+
# export visualizations with ground truths
485+
output_dir = str(visual_with_gt_dir / Path(relative_filepath).parent)
486+
color = (0, 255, 0) # original annotations in green
487+
result = visualize_object_predictions(
488+
np.ascontiguousarray(image_as_pil),
489+
object_prediction_list=object_prediction_gt_list,
490+
rect_th=visual_bbox_thickness,
491+
text_size=visual_text_size,
492+
text_th=visual_text_thickness,
493+
color=color,
494+
output_dir=None,
495+
file_name=None,
496+
export_format=None,
497+
)
498+
color = (255, 0, 0) # model predictions in red
499+
_ = visualize_object_predictions(
500+
result["image"],
501+
object_prediction_list=object_prediction_list,
502+
rect_th=visual_bbox_thickness,
503+
text_size=visual_text_size,
504+
text_th=visual_text_thickness,
505+
color=color,
506+
output_dir=output_dir,
507+
file_name=filename_without_extension,
508+
export_format=visual_export_format,
481509
)
482-
object_prediction_gt_list.append(object_prediction_gt)
483-
# export visualizations with ground truths
484-
output_dir = str(visual_with_gt_dir / Path(relative_filepath).parent)
485-
color = (0, 255, 0) # original annotations in green
486-
result = visualize_object_predictions(
487-
np.ascontiguousarray(image_as_pil),
488-
object_prediction_list=object_prediction_gt_list,
489-
rect_th=visual_bbox_thickness,
490-
text_size=visual_text_size,
491-
text_th=visual_text_thickness,
492-
color=color,
493-
output_dir=None,
494-
file_name=None,
495-
export_format=None,
496-
)
497-
color = (255, 0, 0) # model predictions in red
498-
_ = visualize_object_predictions(
499-
result["image"],
500-
object_prediction_list=object_prediction_list,
501-
rect_th=visual_bbox_thickness,
502-
text_size=visual_text_size,
503-
text_th=visual_text_thickness,
504-
color=color,
505-
output_dir=output_dir,
506-
file_name=filename_without_extension,
507-
export_format=visual_export_format,
508-
)
509510

510511
time_start = time.time()
511512
# export prediction boxes

sahi/prediction.py

+5-5
Original file line numberDiff line numberDiff line change
@@ -164,15 +164,15 @@ def __init__(
164164
self.object_prediction_list: List[ObjectPrediction] = object_prediction_list
165165
self.durations_in_seconds = durations_in_seconds
166166

167-
def export_visuals(self, export_dir: str):
167+
def export_visuals(self, export_dir: str, text_size: float = None, rect_th: int = None):
168168
Path(export_dir).mkdir(parents=True, exist_ok=True)
169169
visualize_object_predictions(
170170
image=np.ascontiguousarray(self.image),
171171
object_prediction_list=self.object_prediction_list,
172-
rect_th=1,
173-
text_size=0.3,
174-
text_th=1,
175-
color=(0, 0, 0),
172+
rect_th=rect_th,
173+
text_size=text_size,
174+
text_th=None,
175+
color=None,
176176
output_dir=export_dir,
177177
file_name="prediction_visual",
178178
export_format="png",

0 commit comments

Comments
 (0)