Skip to content

Commit 21157b6

Browse files
Update to work with TF1 or TF2 models
1 parent 00b008c commit 21157b6

File tree

1 file changed

+12
-4
lines changed

1 file changed

+12
-4
lines changed

TFLite_detection_image.py

+12-4
Original file line numberDiff line numberDiff line change
@@ -129,6 +129,15 @@
129129
input_mean = 127.5
130130
input_std = 127.5
131131

132+
# Check output layer name to determine if this model was created with TF2 or TF1,
133+
# because outputs are ordered differently for TF2 and TF1 models
134+
outname = output_details[0]['name']
135+
136+
if ('StatefulPartitionedCall' in outname): # This is a TF2 model
137+
boxes_idx, classes_idx, scores_idx = 1, 3, 0
138+
else: # This is a TF1 model
139+
boxes_idx, classes_idx, scores_idx = 0, 1, 2
140+
132141
# Loop over every image and perform detection
133142
for image_path in images:
134143

@@ -148,10 +157,9 @@
148157
interpreter.invoke()
149158

150159
# Retrieve detection results
151-
boxes = interpreter.get_tensor(output_details[0]['index'])[0] # Bounding box coordinates of detected objects
152-
classes = interpreter.get_tensor(output_details[1]['index'])[0] # Class index of detected objects
153-
scores = interpreter.get_tensor(output_details[2]['index'])[0] # Confidence of detected objects
154-
#num = interpreter.get_tensor(output_details[3]['index'])[0] # Total number of detected objects (inaccurate and not needed)
160+
boxes = interpreter.get_tensor(output_details[boxes_idx]['index'])[0] # Bounding box coordinates of detected objects
161+
classes = interpreter.get_tensor(output_details[classes_idx]['index'])[0] # Class index of detected objects
162+
scores = interpreter.get_tensor(output_details[scores_idx]['index'])[0] # Confidence of detected objects
155163

156164
# Loop over all detections and draw detection box if confidence is above minimum threshold
157165
for i in range(len(scores)):

0 commit comments

Comments
 (0)