|
| 1 | +# OBSS SAHI Tool |
| 2 | +# Code written by Fatih C Akyon, 2020. |
| 3 | + |
| 4 | +import logging |
| 5 | +from typing import Any, Dict, List, Optional |
| 6 | + |
| 7 | +import numpy as np |
| 8 | +from yaml import safe_load |
| 9 | + |
| 10 | +from sahi.models.base import DetectionModel |
| 11 | +from sahi.prediction import ObjectPrediction |
| 12 | +from sahi.utils.compatibility import fix_full_shape_list, fix_shift_amount_list |
| 13 | +from sahi.utils.import_utils import check_requirements |
| 14 | + |
| 15 | +logger = logging.getLogger(__name__) |
| 16 | + |
| 17 | + |
| 18 | +class YoloNasDetectionModel(DetectionModel): |
| 19 | + def __init__( |
| 20 | + self, |
| 21 | + model_name: str, |
| 22 | + model_path: Optional[str] = None, |
| 23 | + class_names_yaml_path: Optional[List[str]] = None, |
| 24 | + **kwargs, |
| 25 | + ): |
| 26 | + if model_name is not None and not isinstance(model_name, str): |
| 27 | + raise TypeError( |
| 28 | + f"model_name should be a string, got {model_name} with type of '{model_name.__class__.__name__}'" |
| 29 | + ) |
| 30 | + if model_name not in ["yolo_nas_s", "yolo_nas_m", "yolo_nas_l"]: |
| 31 | + raise ValueError(f"Unsupported model type {model_name}") |
| 32 | + if not model_path: # use pretrained models downloaded from Deci-AI remote client |
| 33 | + self.pretrained_weights = "coco" |
| 34 | + self.class_names = None |
| 35 | + self.num_classes = None |
| 36 | + else: # use local / custom trained models |
| 37 | + self.pretrained_weights = None |
| 38 | + if not class_names_yaml_path: |
| 39 | + raise ValueError( |
| 40 | + "'class_names_yaml_path' should be provided for the models that have custom class mapping" |
| 41 | + ) |
| 42 | + with open(class_names_yaml_path, "r") as fs: |
| 43 | + yaml_content = safe_load(fs) |
| 44 | + if not isinstance(yaml_content, list): |
| 45 | + raise ValueError( |
| 46 | + "Invalid yaml file format, make sure your class names are given in list format in yaml" |
| 47 | + ) |
| 48 | + self.class_names = yaml_content |
| 49 | + self.num_classes = len(self.class_names) |
| 50 | + self.model_name = model_name |
| 51 | + super().__init__(model_path=model_path, **kwargs) |
| 52 | + |
| 53 | + def check_dependencies(self) -> None: |
| 54 | + check_requirements(["torch", "super_gradients"]) |
| 55 | + |
| 56 | + def load_model(self): |
| 57 | + """ |
| 58 | + Detection model is initialized and set to self.model. |
| 59 | + """ |
| 60 | + from super_gradients.training import models |
| 61 | + |
| 62 | + try: |
| 63 | + model = models.get( |
| 64 | + model_name=self.model_name, |
| 65 | + checkpoint_path=self.model_path, |
| 66 | + pretrained_weights=self.pretrained_weights, |
| 67 | + num_classes=self.num_classes, |
| 68 | + ).to(device=self.device) |
| 69 | + self.set_model(model) |
| 70 | + except Exception as e: |
| 71 | + raise TypeError("Load model failed. Provided model weights and model_name might be mismatching. ", e) |
| 72 | + |
| 73 | + def set_model(self, model: Any): |
| 74 | + """ |
| 75 | + Sets the underlying YoloNas model. |
| 76 | + Args: |
| 77 | + model: Any |
| 78 | + A YoloNas model |
| 79 | + """ |
| 80 | + from super_gradients.training.processing.processing import get_pretrained_processing_params |
| 81 | + |
| 82 | + if model.__class__.__module__.split(".")[-1] != "yolo_nas_variants": |
| 83 | + raise Exception(f"Not a YoloNas model: {type(model)}") |
| 84 | + |
| 85 | + # set default processing params for yolo_nas model |
| 86 | + processing_params = get_pretrained_processing_params(model_name=self.model_name, pretrained_weights="coco") |
| 87 | + processing_params["conf"] = self.confidence_threshold |
| 88 | + if self.class_names: # override class names for custom trained models |
| 89 | + processing_params["class_names"] = self.class_names |
| 90 | + model.set_dataset_processing_params(**processing_params) |
| 91 | + self.model = model |
| 92 | + |
| 93 | + # set category_mapping |
| 94 | + if not self.category_mapping: |
| 95 | + category_mapping = {str(ind): category_name for ind, category_name in enumerate(self.category_names)} |
| 96 | + self.category_mapping = category_mapping |
| 97 | + |
| 98 | + def perform_inference(self, image: np.ndarray): |
| 99 | + """ |
| 100 | + Prediction is performed using self.model and the prediction result is set to self._original_predictions. |
| 101 | + Args: |
| 102 | + image: np.ndarray |
| 103 | + A numpy array that contains the image to be predicted. 3 channel image should be in RGB order. |
| 104 | + """ |
| 105 | + |
| 106 | + # Confirm model is loaded |
| 107 | + if self.model is None: |
| 108 | + raise ValueError("Model is not loaded, load it by calling .load_model()") |
| 109 | + prediction_result = list(self.model.predict(image)) |
| 110 | + self._original_predictions = prediction_result |
| 111 | + |
| 112 | + @property |
| 113 | + def num_categories(self): |
| 114 | + """ |
| 115 | + Returns number of categories |
| 116 | + """ |
| 117 | + return len(self.model._class_names) |
| 118 | + |
| 119 | + @property |
| 120 | + def has_mask(self): |
| 121 | + """ |
| 122 | + Returns if model output contains segmentation mask |
| 123 | + """ |
| 124 | + return False |
| 125 | + |
| 126 | + @property |
| 127 | + def category_names(self): |
| 128 | + return self.model._class_names |
| 129 | + |
| 130 | + def _create_object_prediction_list_from_original_predictions( |
| 131 | + self, |
| 132 | + shift_amount_list: Optional[List[List[int]]] = [[0, 0]], |
| 133 | + full_shape_list: Optional[List[List[int]]] = None, |
| 134 | + ): |
| 135 | + """ |
| 136 | + self._original_predictions is converted to a list of prediction.ObjectPrediction and set to |
| 137 | + self._object_prediction_list_per_image. |
| 138 | + Args: |
| 139 | + shift_amount_list: list of list |
| 140 | + To shift the box and mask predictions from sliced image to full sized image, should |
| 141 | + be in the form of List[[shift_x, shift_y],[shift_x, shift_y],...] |
| 142 | + full_shape_list: list of list |
| 143 | + Size of the full image after shifting, should be in the form of |
| 144 | + List[[height, width],[height, width],...] |
| 145 | + """ |
| 146 | + original_predictions = self._original_predictions |
| 147 | + |
| 148 | + # compatilibty for sahi v0.8.15 |
| 149 | + shift_amount_list = fix_shift_amount_list(shift_amount_list) |
| 150 | + full_shape_list = fix_full_shape_list(full_shape_list) |
| 151 | + |
| 152 | + # handle all predictions |
| 153 | + object_prediction_list_per_image = [] |
| 154 | + for image_ind, image_predictions in enumerate(original_predictions): |
| 155 | + shift_amount = shift_amount_list[image_ind] |
| 156 | + full_shape = None if full_shape_list is None else full_shape_list[image_ind] |
| 157 | + object_prediction_list = [] |
| 158 | + # process predictions |
| 159 | + preds = image_predictions.prediction |
| 160 | + for bbox_xyxy, score, category_id in zip(preds.bboxes_xyxy, preds.confidence, preds.labels): |
| 161 | + bbox = bbox_xyxy |
| 162 | + category_name = self.category_mapping[str(int(category_id))] |
| 163 | + # fix negative box coords |
| 164 | + bbox[0] = max(0, bbox[0]) |
| 165 | + bbox[1] = max(0, bbox[1]) |
| 166 | + bbox[2] = max(0, bbox[2]) |
| 167 | + bbox[3] = max(0, bbox[3]) |
| 168 | + |
| 169 | + # fix out of image box coords |
| 170 | + if full_shape is not None: |
| 171 | + bbox[0] = min(full_shape[1], bbox[0]) |
| 172 | + bbox[1] = min(full_shape[0], bbox[1]) |
| 173 | + bbox[2] = min(full_shape[1], bbox[2]) |
| 174 | + bbox[3] = min(full_shape[0], bbox[3]) |
| 175 | + |
| 176 | + # ignore invalid predictions |
| 177 | + if not (bbox[0] < bbox[2]) or not (bbox[1] < bbox[3]): |
| 178 | + logger.warning(f"ignoring invalid prediction with bbox: {bbox}") |
| 179 | + continue |
| 180 | + |
| 181 | + object_prediction = ObjectPrediction( |
| 182 | + bbox=bbox, |
| 183 | + category_id=int(category_id), |
| 184 | + score=score, |
| 185 | + bool_mask=None, |
| 186 | + category_name=category_name, |
| 187 | + shift_amount=shift_amount, |
| 188 | + full_shape=full_shape, |
| 189 | + ) |
| 190 | + object_prediction_list.append(object_prediction) |
| 191 | + object_prediction_list_per_image.append(object_prediction_list) |
| 192 | + |
| 193 | + self._object_prediction_list_per_image = object_prediction_list_per_image |
0 commit comments