Skip to content

Commit 40d5751

Browse files
Update to work with TF1 or TF2 models
1 parent 21157b6 commit 40d5751

File tree

1 file changed

+12
-4
lines changed

1 file changed

+12
-4
lines changed

TFLite_detection_stream.py

+12-4
Original file line numberDiff line numberDiff line change
@@ -153,6 +153,15 @@ def stop(self):
153153
input_mean = 127.5
154154
input_std = 127.5
155155

156+
# Check output layer name to determine if this model was created with TF2 or TF1,
157+
# because outputs are ordered differently for TF2 and TF1 models
158+
outname = output_details[0]['name']
159+
160+
if ('StatefulPartitionedCall' in outname): # This is a TF2 model
161+
boxes_idx, classes_idx, scores_idx = 1, 3, 0
162+
else: # This is a TF1 model
163+
boxes_idx, classes_idx, scores_idx = 0, 1, 2
164+
156165
# Initialize frame rate calculation
157166
frame_rate_calc = 1
158167
freq = cv2.getTickFrequency()
@@ -185,10 +194,9 @@ def stop(self):
185194
interpreter.invoke()
186195

187196
# Retrieve detection results
188-
boxes = interpreter.get_tensor(output_details[0]['index'])[0] # Bounding box coordinates of detected objects
189-
classes = interpreter.get_tensor(output_details[1]['index'])[0] # Class index of detected objects
190-
scores = interpreter.get_tensor(output_details[2]['index'])[0] # Confidence of detected objects
191-
#num = interpreter.get_tensor(output_details[3]['index'])[0] # Total number of detected objects (inaccurate and not needed)
197+
boxes = interpreter.get_tensor(output_details[boxes_idx]['index'])[0] # Bounding box coordinates of detected objects
198+
classes = interpreter.get_tensor(output_details[classes_idx]['index'])[0] # Class index of detected objects
199+
scores = interpreter.get_tensor(output_details[scores_idx]['index'])[0] # Confidence of detected objects
192200

193201
# Loop over all detections and draw detection box if confidence is above minimum threshold
194202
for i in range(len(scores)):

0 commit comments

Comments
 (0)