diff --git a/tensorflow/lite/micro/kernels/cmsis_nn/maximum_minimum.cc b/tensorflow/lite/micro/kernels/cmsis_nn/maximum_minimum.cc index 73da17baedf..a6affaa11bb 100644 --- a/tensorflow/lite/micro/kernels/cmsis_nn/maximum_minimum.cc +++ b/tensorflow/lite/micro/kernels/cmsis_nn/maximum_minimum.cc @@ -94,6 +94,46 @@ TfLiteStatus EvalMaximum(TfLiteContext* context, TfLiteNode* node) { return kTfLiteOk; } +TfLiteStatus EvalMaximumInt8(TfLiteContext* context, TfLiteNode* node) { + OpContext op_context(context, node); + const TfLiteEvalTensor* input1 = + tflite::micro::GetEvalInput(context, node, kInputTensor1); + const TfLiteEvalTensor* input2 = + tflite::micro::GetEvalInput(context, node, kInputTensor2); + TfLiteEvalTensor* output = + tflite::micro::GetEvalOutput(context, node, kOutputTensor); + + RuntimeShape input_1_shape = tflite::micro::GetTensorShape(input1); + RuntimeShape input_2_shape = tflite::micro::GetTensorShape(input2); + RuntimeShape output_shape = tflite::micro::GetTensorShape(output); + + cmsis_nn_dims input_1_dims = FillVariableShape( + input_1_shape.DimensionsCount(), input_1_shape.DimsData()); + cmsis_nn_dims input_2_dims = FillVariableShape( + input_2_shape.DimensionsCount(), input_2_shape.DimsData()); + cmsis_nn_dims output_dims = FillVariableShape(output_shape.DimensionsCount(), + output_shape.DimsData()); + + switch (op_context.output->type) { + case kTfLiteInt8: + cmsis_nn_context ctx; + ctx.buf = nullptr; + ctx.size = 0; + + arm_maximum_s8( + &ctx, tflite::micro::GetTensorData(input1), &input_1_dims, + tflite::micro::GetTensorData(input2), &input_2_dims, + tflite::micro::GetTensorData(output), &output_dims); + break; + default: + MicroPrintf("Type %s (%d) is not supported by Maximum Int8 Registration.", + TfLiteTypeGetName(op_context.output->type), + op_context.output->type); + return kTfLiteError; + } + return kTfLiteOk; +} + TfLiteStatus EvalMinimum(TfLiteContext* context, TfLiteNode* node) { OpContext op_context(context, node); const TfLiteEvalTensor* input1 = @@ -146,6 +186,46 @@ TfLiteStatus EvalMinimum(TfLiteContext* context, TfLiteNode* node) { return kTfLiteOk; } +TfLiteStatus EvalMinimumInt8(TfLiteContext* context, TfLiteNode* node) { + OpContext op_context(context, node); + const TfLiteEvalTensor* input1 = + tflite::micro::GetEvalInput(context, node, kInputTensor1); + const TfLiteEvalTensor* input2 = + tflite::micro::GetEvalInput(context, node, kInputTensor2); + TfLiteEvalTensor* output = + tflite::micro::GetEvalOutput(context, node, kOutputTensor); + + RuntimeShape input_1_shape = tflite::micro::GetTensorShape(input1); + RuntimeShape input_2_shape = tflite::micro::GetTensorShape(input2); + RuntimeShape output_shape = tflite::micro::GetTensorShape(output); + + cmsis_nn_dims input_1_dims = FillVariableShape( + input_1_shape.DimensionsCount(), input_1_shape.DimsData()); + cmsis_nn_dims input_2_dims = FillVariableShape( + input_2_shape.DimensionsCount(), input_2_shape.DimsData()); + cmsis_nn_dims output_dims = FillVariableShape(output_shape.DimensionsCount(), + output_shape.DimsData()); + + switch (op_context.output->type) { + case kTfLiteInt8: + cmsis_nn_context ctx; + ctx.buf = nullptr; + ctx.size = 0; + + arm_minimum_s8( + &ctx, tflite::micro::GetTensorData(input1), &input_1_dims, + tflite::micro::GetTensorData(input2), &input_2_dims, + tflite::micro::GetTensorData(output), &output_dims); + break; + default: + MicroPrintf("Type %s (%d) is not supported by Minimum Int8 registration.", + TfLiteTypeGetName(op_context.output->type), + op_context.output->type); + return kTfLiteError; + } + return kTfLiteOk; +} + } // namespace TFLMRegistration Register_MAXIMUM() { @@ -156,4 +236,12 @@ TFLMRegistration Register_MINIMUM() { return tflite::micro::RegisterOp(nullptr, nullptr, EvalMinimum); } +TFLMRegistration Register_MAXIMUM_INT8() { + return tflite::micro::RegisterOp(nullptr, nullptr, EvalMaximumInt8); +} + +TFLMRegistration Register_MINIMUM_INT8() { + return tflite::micro::RegisterOp(nullptr, nullptr, EvalMinimumInt8); +} + } // namespace tflite diff --git a/tensorflow/lite/micro/kernels/maximum_minimum.h b/tensorflow/lite/micro/kernels/maximum_minimum.h index ac497fe51ae..34d7e2399f3 100644 --- a/tensorflow/lite/micro/kernels/maximum_minimum.h +++ b/tensorflow/lite/micro/kernels/maximum_minimum.h @@ -80,6 +80,26 @@ TFLMRegistration Register_MAXIMUM(); TFLMRegistration Register_MINIMUM(); +#if defined(CMSIS_NN) +// Returns a TFLMRegistration struct for kernel variant that only supports +// int8. +TFLMRegistration Register_MAXIMUM_INT8(); + +// Returns a TFLMRegistration struct for kernel variant that only supports +// int8. +TFLMRegistration Register_MINIMUM_INT8(); + +#else +// Note that while this block gets used for both reference and optimized kernels +// that do not have any specialized implementations, the only goal here is to +// define fallback implementation that allow reference kernels to still be used +// from applications that call a more specific kernel variant. +inline TFLMRegistration Register_MAXIMUM_INT8() { return Register_MAXIMUM(); } + +inline TFLMRegistration Register_MINIMUM_INT8() { return Register_MINIMUM(); } + +#endif + } // namespace tflite #endif // TENSORFLOW_LITE_MICRO_KERNELS_MAXIMUM_MINIMUM_H_ diff --git a/tensorflow/lite/micro/micro_mutable_op_resolver.h b/tensorflow/lite/micro/micro_mutable_op_resolver.h index f5f6e38e003..ad642ddbc06 100644 --- a/tensorflow/lite/micro/micro_mutable_op_resolver.h +++ b/tensorflow/lite/micro/micro_mutable_op_resolver.h @@ -1,4 +1,4 @@ -/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2024 The TensorFlow Authors. All Rights Reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -28,6 +28,7 @@ limitations under the License. #include "tensorflow/lite/micro/kernels/depthwise_conv.h" #include "tensorflow/lite/micro/kernels/ethosu.h" #include "tensorflow/lite/micro/kernels/fully_connected.h" +#include "tensorflow/lite/micro/kernels/maximum_minimum.h" #include "tensorflow/lite/micro/kernels/micro_ops.h" #include "tensorflow/lite/micro/kernels/mul.h" #include "tensorflow/lite/micro/kernels/pooling.h" @@ -414,9 +415,9 @@ class MicroMutableOpResolver : public MicroOpResolver { tflite::Register_LOG_SOFTMAX(), ParseLogSoftmax); } - TfLiteStatus AddMaximum() { - return AddBuiltin(BuiltinOperator_MAXIMUM, Register_MAXIMUM(), - ParseMaximum); + TfLiteStatus AddMaximum( + const TFLMRegistration& registration = Register_MAXIMUM()) { + return AddBuiltin(BuiltinOperator_MAXIMUM, registration, ParseMaximum); } TfLiteStatus AddMaxPool2D( @@ -433,9 +434,9 @@ class MicroMutableOpResolver : public MicroOpResolver { return AddBuiltin(BuiltinOperator_MEAN, Register_MEAN(), ParseReducer); } - TfLiteStatus AddMinimum() { - return AddBuiltin(BuiltinOperator_MINIMUM, Register_MINIMUM(), - ParseMinimum); + TfLiteStatus AddMinimum( + const TFLMRegistration& registration = Register_MINIMUM()) { + return AddBuiltin(BuiltinOperator_MINIMUM, registration, ParseMinimum); } TfLiteStatus AddMul(const TFLMRegistration& registration = Register_MUL()) { @@ -452,7 +453,8 @@ class MicroMutableOpResolver : public MicroOpResolver { } TfLiteStatus AddOverlapAdd() { - // TODO(b/286250473): change back name to "OverlapAdd" and remove namespace + // TODO(b/286250473): change back name to "OverlapAdd" and remove + // namespace return AddCustom("SignalOverlapAdd", tflite::tflm_signal::Register_OVERLAP_ADD()); } @@ -684,8 +686,8 @@ class MicroMutableOpResolver : public MicroOpResolver { } registrations_[registrations_len_] = registration; - // Strictly speaking, the builtin_code is not necessary for TFLM but filling - // it in regardless. + // Strictly speaking, the builtin_code is not necessary for TFLM but + // filling it in regardless. registrations_[registrations_len_].builtin_code = op; registrations_len_++; diff --git a/tensorflow/lite/micro/tools/make/ext_libs/cmsis_nn_download.sh b/tensorflow/lite/micro/tools/make/ext_libs/cmsis_nn_download.sh index 26c6487f5f4..a211a2b38a3 100755 --- a/tensorflow/lite/micro/tools/make/ext_libs/cmsis_nn_download.sh +++ b/tensorflow/lite/micro/tools/make/ext_libs/cmsis_nn_download.sh @@ -38,9 +38,9 @@ source ${TENSORFLOW_ROOT}tensorflow/lite/micro/tools/make/bash_helpers.sh DOWNLOADS_DIR=${1} DOWNLOADED_CMSIS_NN_PATH=${DOWNLOADS_DIR}/cmsis_nn -ZIP_PREFIX_NN="5f8f1a96797cfce64032492151b01cf0e1c97f06" +ZIP_PREFIX_NN="22080c68d040c98139e6cb1549473e3149735f4d" CMSIS_NN_URL="http://github.com/ARM-software/CMSIS-NN/archive/${ZIP_PREFIX_NN}.zip" -CMSIS_NN_MD5="903bbdaf3b73ed3c5e42e46b9d8f1f7e" +CMSIS_NN_MD5="32aa69692541060a76b18bd5d2d98956" should_download=$(check_should_download ${DOWNLOADS_DIR})