Skip to content

Commit cfc8781

Browse files
committed
video and readme with imgs
1 parent 8a7cd85 commit cfc8781

File tree

5 files changed

+49
-11
lines changed

5 files changed

+49
-11
lines changed

README.md

+37-9
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
# Yet-Another-YOLOv4-Pytorch
2+
![](github_imgs/from_net.png)
23

34
This is implementation of YOLOv4 object detection neural network on pytorch. I'll try to implement all features of original paper.
45

@@ -14,12 +15,26 @@ This is implementation of YOLOv4 object detection neural network on pytorch. I'l
1415
- [ ] Self attention attack
1516
- [ ] Notebook with guide
1617

18+
## What you can already do
19+
You can use video_demo.py to take a look at the original weights realtime OD detection. (Have 9 fps on my GTX1060 laptop!!!)
20+
![](/github_imgs/realtime.jpg)
21+
22+
You can train your own model with mosaic augmentation for training. Guides how to do this are written below. Borders of images on some datasets are even hard to find.
23+
![](/github_imgs/mosaic.png)
24+
25+
26+
You can make inference, guide bellow.
27+
28+
1729
## Initialize NN
1830

1931
import model
2032
#If you change n_classes from the pretrained, there will be caught one error, don't panic it is ok
2133
m = model.YOLOv4(n_classes=1, weights_path="weights/yolov4.pth")
2234

35+
## Download weights
36+
You can download weights using from this link: https://drive.google.com/open?id=12AaR4fvIQPZ468vhm0ZYZSLgWac2HBnq
37+
2338
## Initialize dataset
2439

2540
import dataset
@@ -43,22 +58,35 @@ dataset has collate_function
4358
paths_b, xb, yb = d.collate_fn((y1, y2))
4459
# yb has 6 columns
4560

46-
### Bboxes format
61+
## Y's format
4762
1. Num of img to which this anchor belongs
4863
2. BBox class
4964
3. x center
5065
4. y center
5166
5. width
5267
6. height
53-
54-
### Forward with loss
55-
(y_hat1, y_hat2, y_hat3), (loss_1, loss_2, loss_3) = m(xb, yb)
5668

57-
### Forward without loss
58-
(y_hat1, y_hat2, y_hat3), _ = m(img_batch) #_ is (0, 0, 0)
69+
## Forward with loss
70+
y_hat, loss = m(xb, yb)
71+
72+
!!! y_hat is already resized anchors to image size bboxes
5973

60-
### Check if bboxes are correct
74+
## Forward without loss
75+
y_hat, _ = m(img_batch) #_ is (0, 0, 0)
76+
77+
## Check if bboxes are correct
6178
import utils
79+
from PIL import Image
6280
path, img, bboxes = d[0]
63-
img_with_bboxes = utils.get_img_with_bboxes(img, bboxes) #PIL image
64-
81+
img_with_bboxes = utils.get_img_with_bboxes(img, bboxes[:, 2:]) #Returns numpy array
82+
Image.fromarray(img_with_bboxes)
83+
84+
## Get predicted bboxes
85+
anchors, loss = m(xb.cuda(), yb.cuda())
86+
confidence_threshold = 0.05
87+
iou_threshold = 0.5
88+
bboxes, labels = utils.get_bboxes_from_anchors(anchors, confidence_threshold, iou_threshold, coco_dict) #COCO dict is id->class dictionary (f.e. 0->person)
89+
#For first img
90+
arr = utils.get_img_with_bboxes(xb[0].cpu(), bboxes[0].cpu(), resize=False, labels=labels[0])
91+
Image.fromarray(arr)
92+

github_imgs/from_net.png

498 KB
Loading

github_imgs/mosaic.png

669 KB
Loading

github_imgs/realtime.jpg

58.7 KB
Loading

video_demo.py

+12-2
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
from torch.backends import cudnn
44
import torch
55
import utils
6+
import time
67

78
coco_dict = {0: 'person',
89
1: 'bicycle',
@@ -97,11 +98,15 @@
9798

9899
m = m.cuda()
99100

101+
#To warm up JIT
102+
m(torch.zeros((1, 3, 608, 608)).cuda())
100103

101104
cap = cv2.VideoCapture(0)
102105

106+
frames_n = 0
107+
start_time = time.time()
108+
103109
while True:
104-
print("Frame got")
105110
ret, frame = cap.read()
106111
if not ret:
107112
break
@@ -123,14 +128,19 @@
123128

124129
bboxes, labels = utils.get_bboxes_from_anchors(anchors, 0.4, 0.5, coco_dict)
125130
arr = utils.get_img_with_bboxes(x.cpu(), bboxes[0].cpu(), resize=False, labels=labels[0])
126-
127131
arr = cv2.cvtColor(arr, cv2.COLOR_RGB2BGR)
132+
133+
frames_n += 1
134+
135+
arr = cv2.putText(arr, "FPS: " + str(frames_n / (time.time() - start_time)), (100, 100), cv2.FONT_HERSHEY_DUPLEX, 0.75, (255, 255, 255))
128136

129137
cv2.imshow("test", arr)
130138
if cv2.waitKey(1) & 0xFF == ord('q'):
131139
break
132140

133141

142+
143+
134144

135145

136146

0 commit comments

Comments
 (0)