Skip to content

Commit f303345

Browse files
committed
Merge branch 'develop'
2 parents 4bc9edc + 733e773 commit f303345

12 files changed

+240
-362
lines changed

CHANGELOG.md

+11
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,14 @@
1+
## [1.2.5] - 2024-05-04
2+
### Added
3+
- Added exception in `mltu.dataProvider.DataProvider` to raise ValueError when dataset is not iterable
4+
- Added custom training code for YoloV8 object detector: `Tutorials\11_Yolov8\train_yolov8.py`
5+
- Added custom trained inference code for YoloV8 object detector:`Tutorials\11_Yolov8\test_yolov8.py`
6+
7+
### Changed
8+
- Fixed `RandomElasticTransform` in `mltu.augmentors` to handle elastic transformation not to exceed image boundaries
9+
- Modified `YoloPreprocessor` in `mltu.torch.yolo.preprocessors` to output dictionary with np.arrays istead of lists
10+
11+
112
## [1.2.4] - 2024-03-21
213
### Added
314
- Added `RandomElasticTransform` to `mltu.augmentors` to work with `Image` objects

README.md

+2-1
Original file line numberDiff line numberDiff line change
@@ -26,4 +26,5 @@ Each tutorial has its own requirements.txt file for a specific mltu version. As
2626
8. [Handwriting words recognition with PyTorch](https://pylessons.com/handwriting-recognition-pytorch), code in ```Tutorials\08_handwriting_recognition_torch``` folder;
2727
9. [Transformer training with TensorFlow for Translation task](https://pylessons.com/transformers-training), code in ```Tutorials\09_translation_transformer``` folder;
2828
10. [Speech Recognition in Python | finetune wav2vec2 model for a custom ASR model](https://youtu.be/h6ooEGzjkj0), code in ```Tutorials\10_wav2vec2_torch``` folder;
29-
11. [YOLOv8: Real-Time Object Detection Simplified](https://youtu.be/vegL__weCxY), code in ```Tutorials\11_Yolov8``` folder;
29+
11. [YOLOv8: Real-Time Object Detection Simplified](https://youtu.be/vegL__weCxY), code in ```Tutorials\11_Yolov8``` folder;
30+
12. [YOLOv8: Customizing Object Detector training](https://youtu.be/ysYiV1CbCyY), code in ```Tutorials\11_Yolov8\train_yolov8.py``` folder;

Tutorials/11_Yolov8/README.md

+174-2
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,12 @@
11
# Run Ultralytics YOLOv8 pretrained model
22

3-
YouTube tutorial link: [YOLOv8: Real-Time Object Detection Simplified](https://youtu.be/vegL__weCxY)
3+
YouTube tutorial link:
4+
- [YOLOv8: Real-Time Object Detection Simplified](https://youtu.be/vegL__weCxY);
5+
- [YOLOv8: Customizing Object Detector training](https://youtu.be/ysYiV1CbCyY);
46

57
First, I recommend you to install the required packages in a virtual environment:
68
```bash
7-
mltu==1.2.3
9+
mltu==1.2.5
810
ultralytics==8.1.28
911
torch==2.0.0
1012
torchvision==0.15.1
@@ -134,5 +136,175 @@ while True:
134136
break
135137

136138
cap.release()
139+
cv2.destroyAllWindows()
140+
```
141+
142+
## Customize YoloV8 Object Detector training:
143+
```python
144+
import os
145+
import time
146+
import torch
147+
from mltu.preprocessors import ImageReader
148+
from mltu.annotations.images import CVImage
149+
from mltu.transformers import ImageResizer, ImageShowCV2, ImageNormalizer
150+
from mltu.augmentors import RandomBrightness, RandomRotate, RandomErodeDilate, RandomSharpen, \
151+
RandomMirror, RandomFlip, RandomGaussianBlur, RandomSaltAndPepper, RandomDropBlock, RandomMosaic, RandomElasticTransform
152+
from mltu.torch.model import Model
153+
from mltu.torch.dataProvider import DataProvider
154+
from mltu.torch.yolo.annotation import VOCAnnotationReader
155+
from mltu.torch.yolo.preprocessors import YoloPreprocessor
156+
from mltu.torch.yolo.loss import v8DetectionLoss
157+
from mltu.torch.yolo.metrics import YoloMetrics
158+
from mltu.torch.yolo.optimizer import build_optimizer, AccumulativeOptimizer
159+
from mltu.torch.callbacks import EarlyStopping, ModelCheckpoint, TensorBoard, Model2onnx, WarmupCosineDecay
160+
161+
from ultralytics.nn.tasks import DetectionModel
162+
from ultralytics.engine.model import Model as BaseModel
163+
164+
# https://www.kaggle.com/datasets/andrewmvd/car-plate-detection
165+
annotations_path = "Datasets/car-plate-detection/annotations"
166+
167+
# Create a dataset from the annotations, the dataset is a list of lists where each list contains the [image path, annotation path]
168+
dataset = [[None, os.path.join(annotations_path, f)] for f in os.listdir(annotations_path)]
169+
170+
# Make sure torch can see GPU device, it is not recommended to train with CPU
171+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
172+
173+
img_size = 416
174+
labels = {0: "licence"}
175+
176+
# Create a data provider for the dataset
177+
data_provider = DataProvider(
178+
dataset=dataset,
179+
skip_validation=True,
180+
batch_size=16,
181+
data_preprocessors=[
182+
VOCAnnotationReader(labels=labels),
183+
ImageReader(CVImage),
184+
],
185+
transformers=[
186+
# ImageShowCV2(),
187+
ImageResizer(img_size, img_size),
188+
ImageNormalizer(transpose_axis=True),
189+
],
190+
batch_postprocessors=[
191+
YoloPreprocessor(device, img_size)
192+
],
193+
numpy=False,
194+
)
195+
196+
# split the dataset into train and test
197+
train_data_provider, val_data_provider = data_provider.split(0.9, shuffle=False)
198+
199+
# Attaach augmentation to the train data provider
200+
train_data_provider.augmentors = [
201+
RandomBrightness(),
202+
RandomErodeDilate(),
203+
RandomSharpen(),
204+
RandomMirror(),
205+
RandomFlip(),
206+
RandomElasticTransform(),
207+
RandomGaussianBlur(),
208+
RandomSaltAndPepper(),
209+
RandomRotate(angle=10),
210+
RandomDropBlock(),
211+
RandomMosaic(),
212+
]
213+
214+
base_model = BaseModel("yolov8n.pt")
215+
# Create a YOLO model
216+
model = DetectionModel('yolov8n.yaml', nc=len(labels))
217+
218+
# Load the weight from base model
219+
try: model.load_state_dict(base_model.model.state_dict(), strict=False)
220+
except: pass
221+
222+
model.to(device)
223+
224+
for k, v in model.named_parameters():
225+
if any(x in k for x in [".dfl"]):
226+
print("freezing", k)
227+
v.requires_grad = False
228+
elif not v.requires_grad:
229+
v.requires_grad = True
230+
231+
lr = 1e-3
232+
optimizer = build_optimizer(model.model, name="AdamW", lr=lr, weight_decay=0.0, momentum=0.937, decay=0.0005)
233+
optimizer = AccumulativeOptimizer(optimizer, 16, 64)
234+
235+
# create model object that will handle training and testing of the network
236+
model = Model(
237+
model,
238+
optimizer,
239+
v8DetectionLoss(model),
240+
metrics=[YoloMetrics(nc=len(labels))],
241+
log_errors=False,
242+
output_path=f"Models/11_Yolov8/{int(time.time())}",
243+
clip_grad_norm=10.0,
244+
ema=True,
245+
)
246+
247+
modelCheckpoint = ModelCheckpoint(monitor="val_fitness", mode="max", save_best_only=True, verbose=True)
248+
tensorBoard = TensorBoard()
249+
earlyStopping = EarlyStopping(monitor="val_fitness", mode="max", patience=31, verbose=True)
250+
model2onnx = Model2onnx(input_shape=(1, 3, img_size, img_size), verbose=True, opset_version=14,
251+
dynamic_axes = {"input": {0: "batch_size", 2: "height", 3: "width"},
252+
"output": {0: "batch_size", 2: "anchors"}},
253+
metadata={"classes": labels})
254+
warmupCosineDecayBias = WarmupCosineDecay(lr_after_warmup=lr, final_lr=lr, initial_lr=0.1,
255+
warmup_steps=len(train_data_provider), warmup_epochs=10, ignore_param_groups=[1, 2]) # lr0
256+
warmupCosineDecay = WarmupCosineDecay(lr_after_warmup=lr, final_lr=lr/10, initial_lr=1e-7,
257+
warmup_steps=len(train_data_provider), warmup_epochs=10, decay_epochs=190, ignore_param_groups=[0]) # lr1 and lr2
258+
259+
# Train the model
260+
history = model.fit(
261+
train_data_provider,
262+
test_dataProvider=val_data_provider,
263+
epochs=200,
264+
callbacks=[
265+
modelCheckpoint,
266+
tensorBoard,
267+
earlyStopping,
268+
model2onnx,
269+
warmupCosineDecayBias,
270+
warmupCosineDecay
271+
]
272+
)
273+
```
274+
275+
## Test Custom trained YoloV8 Object Detector:
276+
```python
277+
import os
278+
import cv2
279+
from mltu.annotations.detections import Detections
280+
from mltu.torch.yolo.detectors.onnx_detector import Detector as OnnxDetector
281+
282+
# https://www.kaggle.com/datasets/andrewmvd/car-plate-detection
283+
images_path = "Datasets/car-plate-detection/images"
284+
285+
input_width, input_height = 416, 416
286+
confidence_threshold = 0.5
287+
iou_threshold = 0.5
288+
289+
detector = OnnxDetector("Models/11_Yolov8/1714135287/model.onnx", input_width, input_height, confidence_threshold, iou_threshold, force_cpu=False)
290+
291+
for image_path in os.listdir(images_path):
292+
293+
frame = cv2.imread(os.path.join(images_path, image_path))
294+
295+
# Perform Yolo object detection
296+
detections: Detections = detector(frame)
297+
298+
# Apply the detections to the frame
299+
frame = detections.applyToFrame(frame)
300+
301+
# Print the FPS
302+
print(detector.fps)
303+
304+
# Display the output image
305+
cv2.imshow("Object Detection", frame)
306+
if cv2.waitKey(0) & 0xFF == ord('q'):
307+
break
308+
137309
cv2.destroyAllWindows()
138310
```

Tutorials/11_Yolov8/requirements.txt

+1-1
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
mltu==1.2.3
1+
mltu==1.2.5
22
ultralytics==8.1.28
33
torch==2.0.0
44
torchvision==0.15.1

Tutorials/11_Yolov8/run_pretrained.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
import cv2
22
from ultralytics.engine.model import Model as BaseModel
3+
from mltu.annotations.detections import Detections
34
from mltu.torch.yolo.detectors.torch_detector import Detector as TorchDetector
45
from mltu.torch.yolo.detectors.onnx_detector import Detector as OnnxDetector
56

@@ -18,7 +19,7 @@
1819
break
1920

2021
# Perform Yolo object detection
21-
detections = detector(frame)
22+
detections: Detections = detector(frame)
2223

2324
# Apply the detections to the frame
2425
frame = detections.applyToFrame(frame)

Tutorials/11_Yolov8/test_yolov8.py

+33
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,33 @@
1+
import os
2+
import cv2
3+
from mltu.annotations.detections import Detections
4+
from mltu.torch.yolo.detectors.onnx_detector import Detector as OnnxDetector
5+
6+
# https://www.kaggle.com/datasets/andrewmvd/car-plate-detection
7+
images_path = "Datasets/car-plate-detection/images"
8+
9+
input_width, input_height = 416, 416
10+
confidence_threshold = 0.5
11+
iou_threshold = 0.5
12+
13+
detector = OnnxDetector("Models/11_Yolov8/1714135287/model.onnx", input_width, input_height, confidence_threshold, iou_threshold, force_cpu=False)
14+
15+
for image_path in os.listdir(images_path):
16+
17+
frame = cv2.imread(os.path.join(images_path, image_path))
18+
19+
# Perform Yolo object detection
20+
detections: Detections = detector(frame)
21+
22+
# Apply the detections to the frame
23+
frame = detections.applyToFrame(frame)
24+
25+
# Print the FPS
26+
print(detector.fps)
27+
28+
# Display the output image
29+
cv2.imshow("Object Detection", frame)
30+
if cv2.waitKey(0) & 0xFF == ord('q'):
31+
break
32+
33+
cv2.destroyAllWindows()

mltu/torch/yolo/train_yolo.py renamed to Tutorials/11_Yolov8/train_yolov8.py

+4-3
Original file line numberDiff line numberDiff line change
@@ -13,12 +13,12 @@
1313
from mltu.torch.yolo.loss import v8DetectionLoss
1414
from mltu.torch.yolo.metrics import YoloMetrics
1515
from mltu.torch.yolo.optimizer import build_optimizer, AccumulativeOptimizer
16-
from mltu.torch.callbacks import EarlyStopping, ModelCheckpoint, ReduceLROnPlateau, TensorBoard, Model2onnx, WarmupCosineDecay
16+
from mltu.torch.callbacks import EarlyStopping, ModelCheckpoint, TensorBoard, Model2onnx, WarmupCosineDecay
1717

1818
from ultralytics.nn.tasks import DetectionModel
1919
from ultralytics.engine.model import Model as BaseModel
2020

21-
21+
# https://www.kaggle.com/datasets/andrewmvd/car-plate-detection
2222
annotations_path = "Datasets/car-plate-detection/annotations"
2323

2424
# Create a dataset from the annotations, the dataset is a list of lists where each list contains the [image path, annotation path]
@@ -72,6 +72,7 @@
7272
# Create a YOLO model
7373
model = DetectionModel('yolov8n.yaml', nc=len(labels))
7474

75+
# Load the weight from base model
7576
try: model.load_state_dict(base_model.model.state_dict(), strict=False)
7677
except: pass
7778

@@ -95,7 +96,7 @@
9596
v8DetectionLoss(model),
9697
metrics=[YoloMetrics(nc=len(labels))],
9798
log_errors=False,
98-
output_path=f"Models/detector/{int(time.time())}",
99+
output_path=f"Models/11_Yolov8/{int(time.time())}",
99100
clip_grad_norm=10.0,
100101
ema=True,
101102
)

mltu/__init__.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
__version__ = "1.2.4"
1+
__version__ = "1.2.5"
22

33
from .annotations.images import Image
44
from .annotations.images import CVImage

mltu/augmentors.py

+2
Original file line numberDiff line numberDiff line change
@@ -916,6 +916,8 @@ def __call__(self, image: Image, annotation: typing.Any) -> typing.Tuple[Image,
916916
detections = []
917917
for detection in annotation:
918918
x_min, y_min, x_max, y_max = detection.xyxy_abs
919+
x_max = min(x_max, dx.shape[1] - 1)
920+
y_max = min(y_max, dy.shape[0] - 1)
919921
new_x_min = min(max(0, x_min + dx[y_min, x_min]), image.width - 1)
920922
new_y_min = min(max(0, y_min + dy[y_min, x_min]), image.height - 1)
921923
new_x_max = min(max(0, x_max + dx[y_max, x_max]), image.width - 1)

mltu/dataProvider.py

+4
Original file line numberDiff line numberDiff line change
@@ -70,6 +70,10 @@ def __init__(
7070
else:
7171
self.logger.info("Skipping Dataset validation...")
7272

73+
# Check if dataset has length
74+
if not len(dataset):
75+
raise ValueError("Dataset must be iterable")
76+
7377
if limit:
7478
self.logger.info(f"Limiting dataset to {limit} samples.")
7579
self._dataset = self._dataset[:limit]

mltu/torch/yolo/preprocessors.py

+6-5
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,13 @@
11
import torch
2+
import typing
23
import numpy as np
34

45
class YoloPreprocessor:
5-
def __init__(self, device, imgsz=640):
6+
def __init__(self, device: torch.device, imgsz: int=640):
67
self.device = device
78
self.imgsz = imgsz
89

9-
def __call__(self, images, annotations):
10+
def __call__(self, images, annotations) -> typing.Tuple[np.ndarray, dict]:
1011
batch = {
1112
"ori_shape": [],
1213
"resized_shape": [],
@@ -23,8 +24,8 @@ def __call__(self, images, annotations):
2324
batch["bboxes"].append(detection.xywh)
2425
batch["batch_idx"].append(i)
2526

26-
batch["cls"] = torch.tensor(batch["cls"]).to(self.device)
27-
batch["bboxes"] = torch.tensor(batch["bboxes"]).to(self.device)
28-
batch["batch_idx"] = torch.tensor(batch["batch_idx"]).to(self.device)
27+
batch["cls"] = torch.tensor(np.array(batch["cls"])).to(self.device)
28+
batch["bboxes"] = torch.tensor(np.array(batch["bboxes"])).to(self.device)
29+
batch["batch_idx"] = torch.tensor(np.array(batch["batch_idx"])).to(self.device)
2930

3031
return np.array(images), batch

0 commit comments

Comments
 (0)