@@ -71,6 +71,10 @@ def setup(self):
7171 def _make_interpreter (self , ** kwargs ) -> tf .lite .Interpreter :
7272 return tf .lite .Interpreter (** kwargs )
7373
74+ def _post_process_result (self , input_tensor : np .ndarray ) -> np .ndarray :
75+ """Custom post processor for TFLite predictions, default is no-op."""
76+ return input_tensor
77+
7478 def _get_input_name_from_input_detail (self , input_detail ):
7579 """Get input name from input detail.
7680
@@ -85,7 +89,7 @@ def _get_input_name_from_input_detail(self, input_detail):
8589 # of the input names. TFLite rewriter assumes that the default signature key
8690 # ('serving_default') will be used as an exported name when saving.
8791 if input_name .startswith ('serving_default_' ):
88- input_name = input_name [len ('serving_default_' ):]
92+ input_name = input_name [len ('serving_default_' ) :]
8993 # Remove argument that starts with ':'.
9094 input_name = input_name .split (':' )[0 ]
9195 return input_name
@@ -187,10 +191,12 @@ def _batch_reducible_process(
187191 for o in output_details :
188192 tensor = interpreter .get_tensor (o [_INDEX ])
189193 params = o [_QUANTIZATION_PARAMETERS ]
190- outputs [ o [ 'name' ]] = self ._dequantize (
194+ dequantized_tensor = self ._dequantize (
191195 tensor , params [_SCALES ], params [_ZERO_POINTS ]
192196 )
193197
198+ outputs [o ['name' ]] = self ._post_process_result (dequantized_tensor )
199+
194200 for v in outputs .values ():
195201 if len (v ) != batch_size :
196202 raise ValueError ('Did not get the expected number of results.' )
0 commit comments