diff --git a/tensorflow/lite/micro/kernels/xtensa/transpose_conv.cc b/tensorflow/lite/micro/kernels/xtensa/transpose_conv.cc index ba08a99f1b6..4e6306c8dda 100644 --- a/tensorflow/lite/micro/kernels/xtensa/transpose_conv.cc +++ b/tensorflow/lite/micro/kernels/xtensa/transpose_conv.cc @@ -244,6 +244,30 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) { #endif // #if defined(HIFI3) || defined(HIFI4) || defined(HIFI5) } +#if HIFI_VFPU && (defined(HIFI3) || defined(HIFI4) || defined(HIFI5)) + if (input->type == kTfLiteFloat32) { + TFLITE_DCHECK(context->RequestScratchBufferInArena != nullptr); + const int stride_width = params->stride_width; + const int stride_height = params->stride_height; + + const int input_height = SizeOfDimension(input, 1); + const int input_width = SizeOfDimension(input, 2); + const int input_depth = SizeOfDimension(input, 3); + const int output_height = height; + const int output_width = width; + int32_t scratch_buffer_size = 0; + scratch_buffer_size = xa_nn_transpose_conv_getsize(input_height, + input_width, input_depth, filter_height, + filter_width, stride_width, stride_height, + output_height, output_width, num_channels, + PREC_F32, PREC_F32); + TFLITE_DCHECK(context->RequestScratchBufferInArena( + context, + scratch_buffer_size, + &(data->scratch_buffer_index)) == kTfLiteOk); + } +#endif // HIFI_VFPU && (defined(HIFI3) || defined(HIFI4) || defined(HIFI5)) + // All per-channel quantized tensors need valid zero point and scale arrays. if (input->type == kTfLiteInt8 || input->type == kTfLiteInt16) { TF_LITE_ENSURE_EQ(context, filter->quantization.type, @@ -334,7 +358,58 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { CalculateActivationRange(params.activation, &op_params.float_activation_min, &op_params.float_activation_max); - +#if HIFI_VFPU && (defined(HIFI3) || defined(HIFI4) || defined(HIFI5)) + std::float_t *scratch_buffer = static_cast( + context->GetScratchBuffer(context, data.scratch_buffer_index)); + const RuntimeShape &input_shape = tflite::micro::GetTensorShape(input); + const RuntimeShape &filter_shape = tflite::micro::GetTensorShape(filter); + const RuntimeShape &output_shape = tflite::micro::GetTensorShape(output); + const int stride_width = data.params.stride_width; + const int stride_height = data.params.stride_height; + const int pad_width = data.params.padding_values.width; + const int pad_height = data.params.padding_values.height; + + const int batches = MatchingDim(input_shape, 0, output_shape, 0); + const int input_depth = MatchingDim(input_shape, 3, filter_shape, 3); + const int output_depth = MatchingDim(filter_shape, 0, output_shape, 3); + + const int input_height = input_shape.Dims(1); + const int input_width = input_shape.Dims(2); + const int filter_height = filter_shape.Dims(1); + const int filter_width = filter_shape.Dims(2); + const int output_height = output_shape.Dims(1); + const int output_width = output_shape.Dims(2); + const float *input_data = tflite::micro::GetTensorData(input); +#ifdef USE_TFLM_COMPRESSION + const float *filter_data = tflite::micro::GetTensorData( + micro_context, filter, filter_comp_td, data.filter_scratch_index); + const float *bias_data = tflite::micro::GetOptionalTensorData( + micro_context, bias, bias_comp_td, data.bias_scratch_index); +#else + const float *filter_data = tflite::micro::GetTensorData(filter); + const float *bias_data = tflite::micro::GetTensorData(bias); +#endif // USE_TFLM_COMPRESSION + + float *output_data = tflite::micro::GetTensorData(output); + + const int num_elements = output_shape.FlatSize(); + const int output_elements = + batches * output_height * output_width * output_depth; + + for (int b = 0; b < batches; b++) { + xa_nn_transpose_conv_f32( + &output_data[b * output_height * output_width * output_depth], + const_cast( + &input_data[b * input_height * input_width * input_depth]), + const_cast(filter_data), const_cast(bias_data), + stride_width, stride_height, pad_width, pad_height, input_depth, + output_depth, input_height, input_width, filter_height, filter_width, + output_height, output_width, num_elements / batches, scratch_buffer); + } + xa_nn_vec_activation_min_max_f32_f32( + output_data, output_data, op_params.float_activation_min, + op_params.float_activation_max, output_elements); +#else reference_ops::TransposeConv( op_params, tflite::micro::GetTensorShape(input), tflite::micro::GetTensorData(input), @@ -353,6 +428,7 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { tflite::micro::GetTensorShape(output), tflite::micro::GetTensorData(output), tflite::micro::GetTensorShape(nullptr), nullptr); +#endif // HIFI_VFPU && (defined(HIFI3) || defined(HIFI4) || defined(HIFI5)) break; } case kTfLiteInt8: {