Skip to content

Commit 9be6b09

Browse files
tf-model-analysis-teamtfx-copybara
authored andcommitted
Add a _post_process_result to tflite_predict_extractor.
PiperOrigin-RevId: 620862012
1 parent d216096 commit 9be6b09

File tree

1 file changed

+8
-2
lines changed

1 file changed

+8
-2
lines changed

tensorflow_model_analysis/extractors/tflite_predict_extractor.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -71,6 +71,10 @@ def setup(self):
7171
def _make_interpreter(self, **kwargs) -> tf.lite.Interpreter:
7272
return tf.lite.Interpreter(**kwargs)
7373

74+
def _post_process_result(self, input_tensor: np.ndarray) -> np.ndarray:
75+
"""Custom post processor for TFLite predictions, default is no-op."""
76+
return input_tensor
77+
7478
def _get_input_name_from_input_detail(self, input_detail):
7579
"""Get input name from input detail.
7680
@@ -85,7 +89,7 @@ def _get_input_name_from_input_detail(self, input_detail):
8589
# of the input names. TFLite rewriter assumes that the default signature key
8690
# ('serving_default') will be used as an exported name when saving.
8791
if input_name.startswith('serving_default_'):
88-
input_name = input_name[len('serving_default_'):]
92+
input_name = input_name[len('serving_default_') :]
8993
# Remove argument that starts with ':'.
9094
input_name = input_name.split(':')[0]
9195
return input_name
@@ -187,10 +191,12 @@ def _batch_reducible_process(
187191
for o in output_details:
188192
tensor = interpreter.get_tensor(o[_INDEX])
189193
params = o[_QUANTIZATION_PARAMETERS]
190-
outputs[o['name']] = self._dequantize(
194+
dequantized_tensor = self._dequantize(
191195
tensor, params[_SCALES], params[_ZERO_POINTS]
192196
)
193197

198+
outputs[o['name']] = self._post_process_result(dequantized_tensor)
199+
194200
for v in outputs.values():
195201
if len(v) != batch_size:
196202
raise ValueError('Did not get the expected number of results.')

0 commit comments

Comments
 (0)