Skip to content

Commit 5a22e28

Browse files
Update to work with TF1 or TF2 models
1 parent 60f3deb commit 5a22e28

File tree

1 file changed

+12
-4
lines changed

1 file changed

+12
-4
lines changed

TFLite_detection_webcam.py

+12-4
Original file line numberDiff line numberDiff line change
@@ -150,6 +150,15 @@ def stop(self):
150150
input_mean = 127.5
151151
input_std = 127.5
152152

153+
# Check output layer name to determine if this model was created with TF2 or TF1,
154+
# because outputs are ordered differently for TF2 and TF1 models
155+
outname = output_details[0]['name']
156+
157+
if ('StatefulPartitionedCall' in outname): # This is a TF2 model
158+
boxes_idx, classes_idx, scores_idx = 1, 3, 0
159+
else: # This is a TF1 model
160+
boxes_idx, classes_idx, scores_idx = 0, 1, 2
161+
153162
# Initialize frame rate calculation
154163
frame_rate_calc = 1
155164
freq = cv2.getTickFrequency()
@@ -182,10 +191,9 @@ def stop(self):
182191
interpreter.invoke()
183192

184193
# Retrieve detection results
185-
boxes = interpreter.get_tensor(output_details[0]['index'])[0] # Bounding box coordinates of detected objects
186-
classes = interpreter.get_tensor(output_details[1]['index'])[0] # Class index of detected objects
187-
scores = interpreter.get_tensor(output_details[2]['index'])[0] # Confidence of detected objects
188-
#num = interpreter.get_tensor(output_details[3]['index'])[0] # Total number of detected objects (inaccurate and not needed)
194+
boxes = interpreter.get_tensor(output_details[boxes_idx]['index'])[0] # Bounding box coordinates of detected objects
195+
classes = interpreter.get_tensor(output_details[classes_idx]['index'])[0] # Class index of detected objects
196+
scores = interpreter.get_tensor(output_details[scores_idx]['index'])[0] # Confidence of detected objects
189197

190198
# Loop over all detections and draw detection box if confidence is above minimum threshold
191199
for i in range(len(scores)):

0 commit comments

Comments
 (0)