Skip to content

Commit c7d917d

Browse files
Add Deci-AI YOLO-NAS models support to SAHI (#874)
Co-authored-by: fatih <[email protected]>
1 parent 1cd7a1b commit c7d917d

11 files changed

+1212
-1
lines changed

.github/workflows/ci.yml

+4
Original file line numberDiff line numberDiff line change
@@ -109,6 +109,10 @@ jobs:
109109
run: >
110110
pip install ultralytics==8.0.99
111111
112+
- name: Install super-gradients
113+
run: >
114+
pip install super-gradients==3.1.1
115+
112116
- name: Unittest for SAHI+YOLOV5/MMDET/Detectron2 on all platforms
113117
run: |
114118
python -m unittest

.github/workflows/ci_torch1.10.yml

+4
Original file line numberDiff line numberDiff line change
@@ -115,6 +115,10 @@ jobs:
115115
- name: Install ultralytics
116116
run: >
117117
pip install ultralytics==8.0.99
118+
119+
- name: Install super-gradients
120+
run: >
121+
pip install super-gradients==3.1.1
118122
119123
- name: Unittest for SAHI+YOLOV5/MMDET/Detectron2 on all platforms
120124
run: |

.github/workflows/package_testing.yml

+4
Original file line numberDiff line numberDiff line change
@@ -91,6 +91,10 @@ jobs:
9191
run: >
9292
pip install ultralytics==8.0.99
9393
94+
- name: Install super-gradients
95+
run: >
96+
pip install super-gradients==3.1.1
97+
9498
- name: Install latest SAHI package
9599
run: >
96100
pip install --upgrade --force-reinstall sahi

demo/inference_for_yolonas.ipynb

+713
Large diffs are not rendered by default.

sahi/auto_model.py

+1
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
"huggingface": "HuggingfaceDetectionModel",
1111
"torchvision": "TorchVisionDetectionModel",
1212
"yolov5sparse": "Yolov5SparseDetectionModel",
13+
"yolonas": "YoloNasDetectionModel",
1314
}
1415

1516

sahi/models/__init__.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -1 +1 @@
1-
from . import base, detectron2, huggingface, mmdet, torchvision, yolov5
1+
from . import base, detectron2, huggingface, mmdet, torchvision, yolonas, yolov5

sahi/models/yolonas.py

+193
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,193 @@
1+
# OBSS SAHI Tool
2+
# Code written by Fatih C Akyon, 2020.
3+
4+
import logging
5+
from typing import Any, Dict, List, Optional
6+
7+
import numpy as np
8+
from yaml import safe_load
9+
10+
from sahi.models.base import DetectionModel
11+
from sahi.prediction import ObjectPrediction
12+
from sahi.utils.compatibility import fix_full_shape_list, fix_shift_amount_list
13+
from sahi.utils.import_utils import check_requirements
14+
15+
logger = logging.getLogger(__name__)
16+
17+
18+
class YoloNasDetectionModel(DetectionModel):
19+
def __init__(
20+
self,
21+
model_name: str,
22+
model_path: Optional[str] = None,
23+
class_names_yaml_path: Optional[List[str]] = None,
24+
**kwargs,
25+
):
26+
if model_name is not None and not isinstance(model_name, str):
27+
raise TypeError(
28+
f"model_name should be a string, got {model_name} with type of '{model_name.__class__.__name__}'"
29+
)
30+
if model_name not in ["yolo_nas_s", "yolo_nas_m", "yolo_nas_l"]:
31+
raise ValueError(f"Unsupported model type {model_name}")
32+
if not model_path: # use pretrained models downloaded from Deci-AI remote client
33+
self.pretrained_weights = "coco"
34+
self.class_names = None
35+
self.num_classes = None
36+
else: # use local / custom trained models
37+
self.pretrained_weights = None
38+
if not class_names_yaml_path:
39+
raise ValueError(
40+
"'class_names_yaml_path' should be provided for the models that have custom class mapping"
41+
)
42+
with open(class_names_yaml_path, "r") as fs:
43+
yaml_content = safe_load(fs)
44+
if not isinstance(yaml_content, list):
45+
raise ValueError(
46+
"Invalid yaml file format, make sure your class names are given in list format in yaml"
47+
)
48+
self.class_names = yaml_content
49+
self.num_classes = len(self.class_names)
50+
self.model_name = model_name
51+
super().__init__(model_path=model_path, **kwargs)
52+
53+
def check_dependencies(self) -> None:
54+
check_requirements(["torch", "super_gradients"])
55+
56+
def load_model(self):
57+
"""
58+
Detection model is initialized and set to self.model.
59+
"""
60+
from super_gradients.training import models
61+
62+
try:
63+
model = models.get(
64+
model_name=self.model_name,
65+
checkpoint_path=self.model_path,
66+
pretrained_weights=self.pretrained_weights,
67+
num_classes=self.num_classes,
68+
).to(device=self.device)
69+
self.set_model(model)
70+
except Exception as e:
71+
raise TypeError("Load model failed. Provided model weights and model_name might be mismatching. ", e)
72+
73+
def set_model(self, model: Any):
74+
"""
75+
Sets the underlying YoloNas model.
76+
Args:
77+
model: Any
78+
A YoloNas model
79+
"""
80+
from super_gradients.training.processing.processing import get_pretrained_processing_params
81+
82+
if model.__class__.__module__.split(".")[-1] != "yolo_nas_variants":
83+
raise Exception(f"Not a YoloNas model: {type(model)}")
84+
85+
# set default processing params for yolo_nas model
86+
processing_params = get_pretrained_processing_params(model_name=self.model_name, pretrained_weights="coco")
87+
processing_params["conf"] = self.confidence_threshold
88+
if self.class_names: # override class names for custom trained models
89+
processing_params["class_names"] = self.class_names
90+
model.set_dataset_processing_params(**processing_params)
91+
self.model = model
92+
93+
# set category_mapping
94+
if not self.category_mapping:
95+
category_mapping = {str(ind): category_name for ind, category_name in enumerate(self.category_names)}
96+
self.category_mapping = category_mapping
97+
98+
def perform_inference(self, image: np.ndarray):
99+
"""
100+
Prediction is performed using self.model and the prediction result is set to self._original_predictions.
101+
Args:
102+
image: np.ndarray
103+
A numpy array that contains the image to be predicted. 3 channel image should be in RGB order.
104+
"""
105+
106+
# Confirm model is loaded
107+
if self.model is None:
108+
raise ValueError("Model is not loaded, load it by calling .load_model()")
109+
prediction_result = list(self.model.predict(image))
110+
self._original_predictions = prediction_result
111+
112+
@property
113+
def num_categories(self):
114+
"""
115+
Returns number of categories
116+
"""
117+
return len(self.model._class_names)
118+
119+
@property
120+
def has_mask(self):
121+
"""
122+
Returns if model output contains segmentation mask
123+
"""
124+
return False
125+
126+
@property
127+
def category_names(self):
128+
return self.model._class_names
129+
130+
def _create_object_prediction_list_from_original_predictions(
131+
self,
132+
shift_amount_list: Optional[List[List[int]]] = [[0, 0]],
133+
full_shape_list: Optional[List[List[int]]] = None,
134+
):
135+
"""
136+
self._original_predictions is converted to a list of prediction.ObjectPrediction and set to
137+
self._object_prediction_list_per_image.
138+
Args:
139+
shift_amount_list: list of list
140+
To shift the box and mask predictions from sliced image to full sized image, should
141+
be in the form of List[[shift_x, shift_y],[shift_x, shift_y],...]
142+
full_shape_list: list of list
143+
Size of the full image after shifting, should be in the form of
144+
List[[height, width],[height, width],...]
145+
"""
146+
original_predictions = self._original_predictions
147+
148+
# compatilibty for sahi v0.8.15
149+
shift_amount_list = fix_shift_amount_list(shift_amount_list)
150+
full_shape_list = fix_full_shape_list(full_shape_list)
151+
152+
# handle all predictions
153+
object_prediction_list_per_image = []
154+
for image_ind, image_predictions in enumerate(original_predictions):
155+
shift_amount = shift_amount_list[image_ind]
156+
full_shape = None if full_shape_list is None else full_shape_list[image_ind]
157+
object_prediction_list = []
158+
# process predictions
159+
preds = image_predictions.prediction
160+
for bbox_xyxy, score, category_id in zip(preds.bboxes_xyxy, preds.confidence, preds.labels):
161+
bbox = bbox_xyxy
162+
category_name = self.category_mapping[str(int(category_id))]
163+
# fix negative box coords
164+
bbox[0] = max(0, bbox[0])
165+
bbox[1] = max(0, bbox[1])
166+
bbox[2] = max(0, bbox[2])
167+
bbox[3] = max(0, bbox[3])
168+
169+
# fix out of image box coords
170+
if full_shape is not None:
171+
bbox[0] = min(full_shape[1], bbox[0])
172+
bbox[1] = min(full_shape[0], bbox[1])
173+
bbox[2] = min(full_shape[1], bbox[2])
174+
bbox[3] = min(full_shape[0], bbox[3])
175+
176+
# ignore invalid predictions
177+
if not (bbox[0] < bbox[2]) or not (bbox[1] < bbox[3]):
178+
logger.warning(f"ignoring invalid prediction with bbox: {bbox}")
179+
continue
180+
181+
object_prediction = ObjectPrediction(
182+
bbox=bbox,
183+
category_id=int(category_id),
184+
score=score,
185+
bool_mask=None,
186+
category_name=category_name,
187+
shift_amount=shift_amount,
188+
full_shape=full_shape,
189+
)
190+
object_prediction_list.append(object_prediction)
191+
object_prediction_list_per_image.append(object_prediction_list)
192+
193+
self._object_prediction_list_per_image = object_prediction_list_per_image

sahi/predict.py

+2
Original file line numberDiff line numberDiff line change
@@ -366,6 +366,7 @@ def predict(
366366
verbose: int = 1,
367367
return_dict: bool = False,
368368
force_postprocess_type: bool = False,
369+
**kwargs,
369370
):
370371
"""
371372
Performs prediction for all present images in given folder.
@@ -512,6 +513,7 @@ def predict(
512513
category_remapping=model_category_remapping,
513514
load_at_init=False,
514515
image_size=image_size,
516+
**kwargs,
515517
)
516518
detection_model.load_model()
517519
time_end = time.time() - time_start

sahi/utils/yolonas.py

+57
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,57 @@
1+
import urllib.request
2+
from os import path
3+
from pathlib import Path
4+
from typing import Optional
5+
6+
7+
class YoloNasTestConstants:
8+
YOLONAS_S_MODEL_URL = "https://sghub.deci.ai/models/yolo_nas_s_coco.pth"
9+
YOLONAS_S_MODEL_PATH = "tests/data/models/yolonas/yolo_nas_s_coco.pt"
10+
11+
YOLONAS_M_MODEL_URL = "https://sghub.deci.ai/models/yolo_nas_m_coco.pth"
12+
YOLONAS_M_MODEL_PATH = "tests/data/models/yolonas/yolo_nas_m_coco.pt"
13+
14+
YOLONAS_L_MODEL_URL = "https://sghub.deci.ai/models/yolo_nas_l_coco.pth"
15+
YOLONAS_L_MODEL_PATH = "tests/data/models/yolonas/yolo_nas_l_coco.pt"
16+
17+
18+
def download_yolonas_s_model(destination_path: Optional[str] = None):
19+
20+
if destination_path is None:
21+
destination_path = YoloNasTestConstants.YOLONAS_S_MODEL_PATH
22+
23+
Path(destination_path).parent.mkdir(parents=True, exist_ok=True)
24+
25+
if not path.exists(destination_path):
26+
urllib.request.urlretrieve(
27+
YoloNasTestConstants.YOLONAS_S_MODEL_URL,
28+
destination_path,
29+
)
30+
31+
32+
def download_yolonas_m_model(destination_path: Optional[str] = None):
33+
34+
if destination_path is None:
35+
destination_path = YoloNasTestConstants.YOLONAS_M_MODEL_PATH
36+
37+
Path(destination_path).parent.mkdir(parents=True, exist_ok=True)
38+
39+
if not path.exists(destination_path):
40+
urllib.request.urlretrieve(
41+
YoloNasTestConstants.YOLONAS_M_MODEL_URL,
42+
destination_path,
43+
)
44+
45+
46+
def download_yolonas_l_model(destination_path: Optional[str] = None):
47+
48+
if destination_path is None:
49+
destination_path = YoloNasTestConstants.YOLONAS_L_MODEL_PATH
50+
51+
Path(destination_path).parent.mkdir(parents=True, exist_ok=True)
52+
53+
if not path.exists(destination_path):
54+
urllib.request.urlretrieve(
55+
YoloNasTestConstants.YOLONAS_L_MODEL_URL,
56+
destination_path,
57+
)
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,80 @@
1+
- person
2+
- bicycle
3+
- car
4+
- motorcycle
5+
- airplane
6+
- bus
7+
- train
8+
- truck
9+
- boat
10+
- traffic light
11+
- fire hydrant
12+
- stop sign
13+
- parking meter
14+
- bench
15+
- bird
16+
- cat
17+
- dog
18+
- horse
19+
- sheep
20+
- cow
21+
- elephant
22+
- bear
23+
- zebra
24+
- giraffe
25+
- backpack
26+
- umbrella
27+
- handbag
28+
- tie
29+
- suitcase
30+
- frisbee
31+
- skis
32+
- snowboard
33+
- sports ball
34+
- kite
35+
- baseball bat
36+
- baseball glove
37+
- skateboard
38+
- surfboard
39+
- tennis racket
40+
- bottle
41+
- wine glass
42+
- cup
43+
- fork
44+
- knife
45+
- spoon
46+
- bowl
47+
- banana
48+
- apple
49+
- sandwich
50+
- orange
51+
- broccoli
52+
- carrot
53+
- hot dog
54+
- pizza
55+
- donut
56+
- cake
57+
- chair
58+
- couch
59+
- potted plant
60+
- bed
61+
- dining table
62+
- toilet
63+
- tv
64+
- laptop
65+
- mouse
66+
- remote
67+
- keyboard
68+
- cell phone
69+
- microwave
70+
- oven
71+
- toaster
72+
- sink
73+
- refrigerator
74+
- book
75+
- clock
76+
- vase
77+
- scissors
78+
- teddy bear
79+
- hair drier
80+
- toothbrush

0 commit comments

Comments
 (0)