Skip to content

Commit

Permalink
Update cmsis min/max with int8 registration
Browse files Browse the repository at this point in the history
Signed-off-by: Ryan O'Shea <[email protected]>
  • Loading branch information
ArmRyan committed Nov 25, 2024
1 parent 4377f8c commit 81d16ea
Show file tree
Hide file tree
Showing 4 changed files with 122 additions and 12 deletions.
88 changes: 88 additions & 0 deletions tensorflow/lite/micro/kernels/cmsis_nn/maximum_minimum.cc
Original file line number Diff line number Diff line change
Expand Up @@ -94,6 +94,46 @@ TfLiteStatus EvalMaximum(TfLiteContext* context, TfLiteNode* node) {
return kTfLiteOk;
}

TfLiteStatus EvalMaximumInt8(TfLiteContext* context, TfLiteNode* node) {
OpContext op_context(context, node);
const TfLiteEvalTensor* input1 =
tflite::micro::GetEvalInput(context, node, kInputTensor1);
const TfLiteEvalTensor* input2 =
tflite::micro::GetEvalInput(context, node, kInputTensor2);
TfLiteEvalTensor* output =
tflite::micro::GetEvalOutput(context, node, kOutputTensor);

RuntimeShape input_1_shape = tflite::micro::GetTensorShape(input1);
RuntimeShape input_2_shape = tflite::micro::GetTensorShape(input2);
RuntimeShape output_shape = tflite::micro::GetTensorShape(output);

cmsis_nn_dims input_1_dims = FillVariableShape(
input_1_shape.DimensionsCount(), input_1_shape.DimsData());
cmsis_nn_dims input_2_dims = FillVariableShape(
input_2_shape.DimensionsCount(), input_2_shape.DimsData());
cmsis_nn_dims output_dims = FillVariableShape(output_shape.DimensionsCount(),
output_shape.DimsData());

switch (op_context.output->type) {
case kTfLiteInt8:
cmsis_nn_context ctx;
ctx.buf = nullptr;
ctx.size = 0;

arm_maximum_s8(
&ctx, tflite::micro::GetTensorData<int8_t>(input1), &input_1_dims,
tflite::micro::GetTensorData<int8_t>(input2), &input_2_dims,
tflite::micro::GetTensorData<int8_t>(output), &output_dims);
break;
default:
MicroPrintf("Type %s (%d) is not supported by Maximum Int8 Registration.",
TfLiteTypeGetName(op_context.output->type),
op_context.output->type);
return kTfLiteError;
}
return kTfLiteOk;
}

TfLiteStatus EvalMinimum(TfLiteContext* context, TfLiteNode* node) {
OpContext op_context(context, node);
const TfLiteEvalTensor* input1 =
Expand Down Expand Up @@ -146,6 +186,46 @@ TfLiteStatus EvalMinimum(TfLiteContext* context, TfLiteNode* node) {
return kTfLiteOk;
}

TfLiteStatus EvalMinimumInt8(TfLiteContext* context, TfLiteNode* node) {
OpContext op_context(context, node);
const TfLiteEvalTensor* input1 =
tflite::micro::GetEvalInput(context, node, kInputTensor1);
const TfLiteEvalTensor* input2 =
tflite::micro::GetEvalInput(context, node, kInputTensor2);
TfLiteEvalTensor* output =
tflite::micro::GetEvalOutput(context, node, kOutputTensor);

RuntimeShape input_1_shape = tflite::micro::GetTensorShape(input1);
RuntimeShape input_2_shape = tflite::micro::GetTensorShape(input2);
RuntimeShape output_shape = tflite::micro::GetTensorShape(output);

cmsis_nn_dims input_1_dims = FillVariableShape(
input_1_shape.DimensionsCount(), input_1_shape.DimsData());
cmsis_nn_dims input_2_dims = FillVariableShape(
input_2_shape.DimensionsCount(), input_2_shape.DimsData());
cmsis_nn_dims output_dims = FillVariableShape(output_shape.DimensionsCount(),
output_shape.DimsData());

switch (op_context.output->type) {
case kTfLiteInt8:
cmsis_nn_context ctx;
ctx.buf = nullptr;
ctx.size = 0;

arm_minimum_s8(
&ctx, tflite::micro::GetTensorData<int8_t>(input1), &input_1_dims,
tflite::micro::GetTensorData<int8_t>(input2), &input_2_dims,
tflite::micro::GetTensorData<int8_t>(output), &output_dims);
break;
default:
MicroPrintf("Type %s (%d) is not supported by Minimum Int8 registration.",
TfLiteTypeGetName(op_context.output->type),
op_context.output->type);
return kTfLiteError;
}
return kTfLiteOk;
}

} // namespace

TFLMRegistration Register_MAXIMUM() {
Expand All @@ -156,4 +236,12 @@ TFLMRegistration Register_MINIMUM() {
return tflite::micro::RegisterOp(nullptr, nullptr, EvalMinimum);
}

TFLMRegistration Register_MAXIMUM_INT8() {
return tflite::micro::RegisterOp(nullptr, nullptr, EvalMaximumInt8);
}

TFLMRegistration Register_MINIMUM_INT8() {
return tflite::micro::RegisterOp(nullptr, nullptr, EvalMinimumInt8);
}

} // namespace tflite
20 changes: 20 additions & 0 deletions tensorflow/lite/micro/kernels/maximum_minimum.h
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,26 @@ TFLMRegistration Register_MAXIMUM();

TFLMRegistration Register_MINIMUM();

#if defined(CMSIS_NN)
// Returns a TFLMRegistration struct for kernel variant that only supports
// int8.
TFLMRegistration Register_MAXIMUM_INT8();

// Returns a TFLMRegistration struct for kernel variant that only supports
// int8.
TFLMRegistration Register_MINIMUM_INT8();

#else
// Note that while this block gets used for both reference and optimized kernels
// that do not have any specialized implementations, the only goal here is to
// define fallback implementation that allow reference kernels to still be used
// from applications that call a more specific kernel variant.
inline TFLMRegistration Register_MAXIMUM_INT8() { return Register_MAXIMUM(); }

inline TFLMRegistration Register_MINIMUM_INT8() { return Register_MINIMUM(); }

#endif

} // namespace tflite

#endif // TENSORFLOW_LITE_MICRO_KERNELS_MAXIMUM_MINIMUM_H_
22 changes: 12 additions & 10 deletions tensorflow/lite/micro/micro_mutable_op_resolver.h
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
/* Copyright 2023 The TensorFlow Authors. All Rights Reserved.
/* Copyright 2024 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -28,6 +28,7 @@ limitations under the License.
#include "tensorflow/lite/micro/kernels/depthwise_conv.h"
#include "tensorflow/lite/micro/kernels/ethosu.h"
#include "tensorflow/lite/micro/kernels/fully_connected.h"
#include "tensorflow/lite/micro/kernels/maximum_minimum.h"
#include "tensorflow/lite/micro/kernels/micro_ops.h"
#include "tensorflow/lite/micro/kernels/mul.h"
#include "tensorflow/lite/micro/kernels/pooling.h"
Expand Down Expand Up @@ -414,9 +415,9 @@ class MicroMutableOpResolver : public MicroOpResolver {
tflite::Register_LOG_SOFTMAX(), ParseLogSoftmax);
}

TfLiteStatus AddMaximum() {
return AddBuiltin(BuiltinOperator_MAXIMUM, Register_MAXIMUM(),
ParseMaximum);
TfLiteStatus AddMaximum(
const TFLMRegistration& registration = Register_MAXIMUM()) {
return AddBuiltin(BuiltinOperator_MAXIMUM, registration, ParseMaximum);
}

TfLiteStatus AddMaxPool2D(
Expand All @@ -433,9 +434,9 @@ class MicroMutableOpResolver : public MicroOpResolver {
return AddBuiltin(BuiltinOperator_MEAN, Register_MEAN(), ParseReducer);
}

TfLiteStatus AddMinimum() {
return AddBuiltin(BuiltinOperator_MINIMUM, Register_MINIMUM(),
ParseMinimum);
TfLiteStatus AddMinimum(
const TFLMRegistration& registration = Register_MINIMUM()) {
return AddBuiltin(BuiltinOperator_MINIMUM, registration, ParseMinimum);
}

TfLiteStatus AddMul(const TFLMRegistration& registration = Register_MUL()) {
Expand All @@ -452,7 +453,8 @@ class MicroMutableOpResolver : public MicroOpResolver {
}

TfLiteStatus AddOverlapAdd() {
// TODO(b/286250473): change back name to "OverlapAdd" and remove namespace
// TODO(b/286250473): change back name to "OverlapAdd" and remove
// namespace
return AddCustom("SignalOverlapAdd",
tflite::tflm_signal::Register_OVERLAP_ADD());
}
Expand Down Expand Up @@ -684,8 +686,8 @@ class MicroMutableOpResolver : public MicroOpResolver {
}

registrations_[registrations_len_] = registration;
// Strictly speaking, the builtin_code is not necessary for TFLM but filling
// it in regardless.
// Strictly speaking, the builtin_code is not necessary for TFLM but
// filling it in regardless.
registrations_[registrations_len_].builtin_code = op;
registrations_len_++;

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -38,9 +38,9 @@ source ${TENSORFLOW_ROOT}tensorflow/lite/micro/tools/make/bash_helpers.sh
DOWNLOADS_DIR=${1}
DOWNLOADED_CMSIS_NN_PATH=${DOWNLOADS_DIR}/cmsis_nn

ZIP_PREFIX_NN="5f8f1a96797cfce64032492151b01cf0e1c97f06"
ZIP_PREFIX_NN="22080c68d040c98139e6cb1549473e3149735f4d"
CMSIS_NN_URL="http://github.com/ARM-software/CMSIS-NN/archive/${ZIP_PREFIX_NN}.zip"
CMSIS_NN_MD5="903bbdaf3b73ed3c5e42e46b9d8f1f7e"
CMSIS_NN_MD5="32aa69692541060a76b18bd5d2d98956"

should_download=$(check_should_download ${DOWNLOADS_DIR})

Expand Down

0 comments on commit 81d16ea

Please sign in to comment.