-
Notifications
You must be signed in to change notification settings - Fork 4
/
Copy pathaiy_cat_detection.py
240 lines (195 loc) · 7.74 KB
/
aiy_cat_detection.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
# Copyright 2017 Google Inc.
#ls /a
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""API for Object Detection tasks.
Modified by chadwallacehart
"""
import math
import sys
import os
from aiy.vision.inference import ModelDescriptor
from aiy.vision.models import utils
from aiy.vision.models.object_detection_anchors import ANCHORS
# _COMPUTE_GRAPH_NAME = 'mobilenet_ssd_256res_0.125_person_cat_dog.binaryproto'
_COMPUTE_GRAPH_NAME = os.getcwd() + '/custom_models/cat_detector_cwh_180602.binaryproto'
_NUM_ANCHORS = len(ANCHORS)
_MACHINE_EPS = sys.float_info.epsilon
_NUM_LABELS = 3 # Changed
class Object(object):
"""Object detection result."""
# updated for my custom model below
BACKGROUND = 0
CAT = 1
PERSON = 2
_LABELS = {
BACKGROUND: 'BACKGROUND',
CAT: 'CAT',
PERSON: 'PERSON'
}
def __init__(self, bounding_box, kind, score):
"""Initialization.
Args:
bounding_box: a tuple of 4 ints, (x, y, width, height) order.
kind: int, tells what object is in the bounding box.
score: float, confidence score.
"""
self.bounding_box = bounding_box
self.kind = kind
self.score = score
self.label = self._LABELS[self.kind] #chadwallacehart: added
def __str__(self):
return 'kind=%s(%d), score=%f, bbox=%s' % (self._LABELS[self.kind],
self.kind, self.score,
str(self.bounding_box))
def _decode_detection_result(logit_scores, box_encodings, anchors,
score_threshold, image_size, offset):
"""Decodes result as bounding boxes.
Args:
logit_scores: list of scores
box_encodings: list of bounding boxes
anchors: list of anchors
score_threshold: float, bounding box candidates below this threshold will
be rejected.
image_size: (width, height)
offset: (x, y)
Returns:
A list of ObjectDetection.Result.
"""
assert len(box_encodings) == 4 * _NUM_ANCHORS
# chadwallacehart: modified below to handle a variable number of labels
assert len(logit_scores) == _NUM_LABELS * _NUM_ANCHORS
x0, y0 = offset
width, height = image_size
objs = []
score_threshold = max(score_threshold, _MACHINE_EPS)
logit_score_threshold = math.log(score_threshold / (1 - score_threshold))
for i in range(_NUM_ANCHORS):
# chadwallacehart: modified below to handle a variable number of labels
logits = logit_scores[_NUM_LABELS * i: _NUM_LABELS * (i + 1)]
max_logit_score = max(logits)
max_score_index = logits.index(max_logit_score)
# Skip if max score is below threshold or max score is 'background'.
if max_score_index == 0 or max_logit_score <= logit_score_threshold:
continue
box_encoding = box_encodings[4 * i: 4 * (i + 1)]
xmin, ymin, xmax, ymax = _decode_box_encoding(box_encoding, anchors[i])
x = int(x0 + xmin * width)
y = int(y0 + ymin * height)
w = int((xmax - xmin) * width)
h = int((ymax - ymin) * height)
max_score = 1.0 / (1.0 + math.exp(-max_logit_score))
objs.append(Object((x, y, w, h), max_score_index, max_score))
return objs
def _clamp(value):
"""Clamps value to range [0.0, 1.0]."""
return min(max(0.0, value), 1.0)
def _decode_box_encoding(box_encoding, anchor):
"""Decodes bounding box encoding.
Args:
box_encoding: a tuple of 4 floats.
anchor: a tuple of 4 floats.
Returns:
A tuple of 4 floats (xmin, ymin, xmax, ymax), each has range [0.0, 1.0].
"""
assert len(box_encoding) == 4
assert len(anchor) == 4
y_scale = 10.0
x_scale = 10.0
height_scale = 5.0
width_scale = 5.0
rel_y_translation = box_encoding[0] / y_scale
rel_x_translation = box_encoding[1] / x_scale
rel_height_dilation = box_encoding[2] / height_scale
rel_width_dilation = box_encoding[3] / width_scale
anchor_ymin, anchor_xmin, anchor_ymax, anchor_xmax = anchor
anchor_ycenter = (anchor_ymax + anchor_ymin) / 2
anchor_xcenter = (anchor_xmax + anchor_xmin) / 2
anchor_height = anchor_ymax - anchor_ymin
anchor_width = anchor_xmax - anchor_xmin
ycenter = anchor_ycenter + anchor_height * rel_y_translation
xcenter = anchor_xcenter + anchor_width * rel_x_translation
height = math.exp(rel_height_dilation) * anchor_height
width = math.exp(rel_width_dilation) * anchor_width
# Clamp value to [0.0, 1.0] range, otherwise, part of the bounding box may
# fall outside of the image.
xmin = _clamp(xcenter - width / 2)
ymin = _clamp(ycenter - height / 2)
xmax = _clamp(xcenter + width / 2)
ymax = _clamp(ycenter + height / 2)
return (xmin, ymin, xmax, ymax)
def _area(box):
_, _, width, height = box
area = width * height
assert area >= 0
return area
def _intersection_area(box1, box2):
x1, y1, width1, height1 = box1
x2, y2, width2, height2 = box2
x = max(x1, x2)
y = max(y1, y2)
width = max(min(x1 + width1, x2 + width2) - x, 0)
height = max(min(y1 + height1, y2 + height2) - y, 0)
area = width * height
assert area >= 0
return area
def _overlap_ratio(box1, box2):
"""Computes overlap ratio of two bounding boxes.
Args:
box1: (x, y, width, height).
box2: (x, y, width, height).
Returns:
float, represents overlap ratio between given boxes.
"""
intersection_area = _intersection_area(box1, box2)
union_area = _area(box1) + _area(box2) - intersection_area
assert union_area >= 0
if union_area > 0:
return float(intersection_area) / float(union_area)
return 1.0
def _non_maximum_suppression(objs, overlap_threshold=0.5):
"""Runs Non Maximum Suppression.
Removes candidate that overlaps with existing candidate who has higher
score.
Args:
objs: list of ObjectDetection.Object
overlap_threshold: float
Returns:
A list of ObjectDetection.Object
"""
objs = sorted(objs, key=lambda x: x.score, reverse=True)
for i in range(len(objs)):
if objs[i].score < 0.0:
continue
# Suppress any nearby bounding boxes having lower score than boxes[i]
for j in range(i + 1, len(objs)):
if objs[j].score < 0.0:
continue
if _overlap_ratio(objs[i].bounding_box,
objs[j].bounding_box) > overlap_threshold:
objs[j].score = -1.0 # Suppress box
return [obj for obj in objs if obj.score >= 0.0] # Exclude suppressed boxes
def model():
return ModelDescriptor(
name='object_detection',
input_shape=(1, 256, 256, 3),
input_normalizer=(128.0, 128.0),
compute_graph=utils.load_compute_graph(_COMPUTE_GRAPH_NAME))
# TODO: check all tensor shapes
def get_objects(result, score_threshold=0.3, offset=(0, 0)):
assert len(result.tensors) == 2
logit_scores = tuple(result.tensors['concat_1'].data)
box_encodings = tuple(result.tensors['concat'].data)
size = (result.window.width, result.window.height)
objs = _decode_detection_result(logit_scores, box_encodings, ANCHORS,
score_threshold, size, offset)
return _non_maximum_suppression(objs)