Skip to content

Commit

Permalink
review changes
Browse files Browse the repository at this point in the history
  • Loading branch information
ddavis-2015 committed Aug 14, 2023
1 parent 095510e commit de92db3
Show file tree
Hide file tree
Showing 2 changed files with 240 additions and 270 deletions.
106 changes: 46 additions & 60 deletions tensorflow/lite/micro/kernels/batch_matmul.cc
Original file line number Diff line number Diff line change
Expand Up @@ -30,8 +30,8 @@ limitations under the License.
namespace tflite {
namespace {

constexpr int kInputLHSTensor = 0;
constexpr int kInputRHSTensor = 1;
constexpr int kInputLhsTensor = 0;
constexpr int kInputRhsTensor = 1;
constexpr int kOutputTensor = 0;

struct QuantizationOpData {
Expand Down Expand Up @@ -62,23 +62,21 @@ struct OpData {
};

struct OpContext {
OpContext(TfLiteContext* context, TfLiteNode* node) {
params = static_cast<TfLiteBatchMatMulParams*>(node->builtin_data);
opdata = static_cast<OpData*>(node->user_data);
}
OpContext(TfLiteContext* context, TfLiteNode* node)
: params(static_cast<TfLiteBatchMatMulParams*>(node->builtin_data)),
op_data(static_cast<OpData*>(node->user_data)) {}

TfLiteBatchMatMulParams* params;
OpData* opdata;
OpData* op_data;
};

struct PrepareOpContext : OpContext {
PrepareOpContext(TfLiteContext* context, TfLiteNode* node)
: OpContext(context, node) {
micro_context_ = GetMicroContext(context);
lhs = micro_context_->AllocateTempInputTensor(node, kInputLHSTensor);
rhs = micro_context_->AllocateTempInputTensor(node, kInputRHSTensor);
output = micro_context_->AllocateTempOutputTensor(node, kOutputTensor);
}
: OpContext(context, node),
micro_context_(GetMicroContext(context)),
lhs(micro_context_->AllocateTempInputTensor(node, kInputLhsTensor)),
rhs(micro_context_->AllocateTempInputTensor(node, kInputRhsTensor)),
output(micro_context_->AllocateTempOutputTensor(node, kOutputTensor)) {}

~PrepareOpContext() {
if (lhs != nullptr) {
Expand All @@ -92,21 +90,21 @@ struct PrepareOpContext : OpContext {
}
}

private:
MicroContext* micro_context_;

public:
TfLiteTensor* lhs;
TfLiteTensor* rhs;
TfLiteTensor* output;

private:
MicroContext* micro_context_;
};

struct EvalOpContext : OpContext {
EvalOpContext(TfLiteContext* context, TfLiteNode* node)
: OpContext(context, node) {
lhs = tflite::micro::GetEvalInput(context, node, kInputLHSTensor);
rhs = tflite::micro::GetEvalInput(context, node, kInputRHSTensor);
output = tflite::micro::GetEvalOutput(context, node, kOutputTensor);
}
: OpContext(context, node),
lhs(tflite::micro::GetEvalInput(context, node, kInputLhsTensor)),
rhs(tflite::micro::GetEvalInput(context, node, kInputRhsTensor)),
output(tflite::micro::GetEvalOutput(context, node, kOutputTensor)) {}

const TfLiteEvalTensor* lhs;
const TfLiteEvalTensor* rhs;
Expand Down Expand Up @@ -196,27 +194,27 @@ TfLiteEvalTensor* AllocInitTransposeTensorFromTfLiteTensor(
// Allocate normal quantization data if needed.
TfLiteStatus InitializeTemporaries(TfLiteContext* context, TfLiteNode* node,
const PrepareOpContext& op_context) {
OpData* op_data = op_context.opdata;
OpData* op_data = op_context.op_data;
const TfLiteTensor* lhs = op_context.lhs;
const TfLiteTensor* rhs = op_context.rhs;
MicroContext* micro_context = GetMicroContext(context);

op_data->quantization = nullptr;
op_data->lhs_transposed_tensor = nullptr;
op_data->rhs_transposed_tensor = nullptr;

if (lhs->type == kTfLiteInt8 || lhs->type == kTfLiteInt16) {
op_data->quantization = static_cast<decltype(op_data->quantization)>(
micro_context->AllocatePersistentBuffer(
sizeof(*op_data->quantization)));
TF_LITE_ENSURE(context, op_data->quantization != nullptr);
} else {
op_data->quantization = nullptr;
}

// tensor for Transposed LHS;
if (op_context.params->adj_x) {
op_data->lhs_transposed_tensor =
AllocInitTransposeTensorFromTfLiteTensor(context, *lhs);
TF_LITE_ENSURE(context, op_data->lhs_transposed_tensor != nullptr);
} else {
op_data->lhs_transposed_tensor = nullptr;
}

// We need a buffer for the RHS if we need to transpose the RHS. We
Expand All @@ -227,17 +225,16 @@ TfLiteStatus InitializeTemporaries(TfLiteContext* context, TfLiteNode* node,
op_data->rhs_transposed_tensor =
AllocInitTransposeTensorFromTfLiteTensor(context, *rhs);
TF_LITE_ENSURE(context, op_data->rhs_transposed_tensor != nullptr);
} else {
op_data->rhs_transposed_tensor = nullptr;
}

return kTfLiteOk;
}

template <typename scalar>
template <typename Scalar>
void TransposeRowsColumnsImpl(const TfLiteEvalTensor& tensor_in,
const scalar* input, TfLiteEvalTensor* tensor_out,
scalar* output) {
TfLiteEvalTensor* tensor_out) {
const Scalar* input = tflite::micro::GetTensorData<Scalar>(&tensor_in);
Scalar* output = tflite::micro::GetTensorData<Scalar>(tensor_out);
RuntimeShape transposed_shape(tflite::micro::GetTensorShape(&tensor_in));
RuntimeShape shape(transposed_shape);
TransposeParams params;
Expand All @@ -254,23 +251,16 @@ void TransposeRowsColumnsImpl(const TfLiteEvalTensor& tensor_in,
reference_ops::Transpose(params, shape, input, transposed_shape, output);
}

TfLiteStatus TransposeRowsColumns(TfLiteContext* context,
const TfLiteEvalTensor& tensor_in,
TfLiteStatus TransposeRowsColumns(const TfLiteEvalTensor& tensor_in,
TfLiteEvalTensor* tensor_out) {
if (tensor_in.type == kTfLiteFloat32) {
TransposeRowsColumnsImpl<float>(
tensor_in, tflite::micro::GetTensorData<float>(&tensor_in), tensor_out,
tflite::micro::GetTensorData<float>(tensor_out));
TransposeRowsColumnsImpl<float>(tensor_in, tensor_out);
return kTfLiteOk;
} else if (tensor_in.type == kTfLiteInt8) {
TransposeRowsColumnsImpl<int8_t>(
tensor_in, tflite::micro::GetTensorData<int8_t>(&tensor_in), tensor_out,
tflite::micro::GetTensorData<int8_t>(tensor_out));
TransposeRowsColumnsImpl<int8_t>(tensor_in, tensor_out);
return kTfLiteOk;
} else if (tensor_in.type == kTfLiteInt16) {
TransposeRowsColumnsImpl<int16_t>(
tensor_in, tflite::micro::GetTensorData<int16_t>(&tensor_in),
tensor_out, tflite::micro::GetTensorData<int16_t>(tensor_out));
TransposeRowsColumnsImpl<int16_t>(tensor_in, tensor_out);
return kTfLiteOk;
} else {
MicroPrintf(
Expand All @@ -288,7 +278,16 @@ RuntimeShape SwapRowColumnDims(const RuntimeShape& shape) {
return swapped_shape;
}

TfLiteStatus CalculateOpData(TfLiteContext* context, TfLiteNode* node) {
void* Init(TfLiteContext* context, const char* buffer, size_t length) {
// This is a builtin op, so we don't use the contents in 'buffer', if any.
// Instead, we allocate a new object to carry information from Prepare() to
// Eval().
TFLITE_DCHECK(context->AllocatePersistentBuffer != nullptr);
MicroContext* micro_context = GetMicroContext(context);
return micro_context->AllocatePersistentBuffer(sizeof(OpData));
}

TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
TF_LITE_ENSURE_EQ(context, NumInputs(node), 2);
TF_LITE_ENSURE_EQ(context, NumOutputs(node), 1);

Expand Down Expand Up @@ -322,7 +321,7 @@ TfLiteStatus CalculateOpData(TfLiteContext* context, TfLiteNode* node) {

TF_LITE_ENSURE_OK(context, InitializeTemporaries(context, node, op_context));

OpData* op_data = op_context.opdata;
OpData* op_data = op_context.op_data;
// If the RHS is constant, we only transpose once.
op_data->rhs_is_transposed = false;
op_data->lhs_is_constant_tensor = IsConstantTensor(lhs_data);
Expand Down Expand Up @@ -392,19 +391,6 @@ TfLiteStatus CalculateOpData(TfLiteContext* context, TfLiteNode* node) {
return status;
}

void* Init(TfLiteContext* context, const char* buffer, size_t length) {
// This is a builtin op, so we don't use the contents in 'buffer', if any.
// Instead, we allocate a new object to carry information from Prepare() to
// Eval().
TFLITE_DCHECK(context->AllocatePersistentBuffer != nullptr);
MicroContext* micro_context = GetMicroContext(context);
return micro_context->AllocatePersistentBuffer(sizeof(OpData));
}

TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
return CalculateOpData(context, node);
}

TfLiteStatus EvalInt8(TfLiteContext* context, const OpData& data,
const RuntimeShape& lhs_shape,
const TfLiteEvalTensor& lhs,
Expand Down Expand Up @@ -478,7 +464,7 @@ TfLiteStatus EvalInt16(TfLiteContext* context, const OpData& data,
// A X C row-oriented.
TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
EvalOpContext op_context(context, node);
OpData* op_data = op_context.opdata;
OpData* op_data = op_context.op_data;
const TfLiteEvalTensor* lhs = op_context.lhs;
const TfLiteEvalTensor* rhs = op_context.rhs;
TfLiteEvalTensor* output = op_context.output;
Expand Down Expand Up @@ -523,12 +509,12 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
// TODO(b/154760341): Constant tensors should already be transposed, but
// we transpose once if necessary for now.
if (!(op_data->rhs_is_constant_tensor && op_data->rhs_is_transposed)) {
TransposeRowsColumns(context, *rhs, rhs_tensor);
TransposeRowsColumns(*rhs, rhs_tensor);
op_data->rhs_is_transposed = true;
}
}
if (adj_x) {
TransposeRowsColumns(context, *lhs, lhs_tensor);
TransposeRowsColumns(*lhs, lhs_tensor);
}
RuntimeShape rhs_shape =
adj_y ? orig_rhs_shape : SwapRowColumnDims(orig_rhs_shape);
Expand Down
Loading

0 comments on commit de92db3

Please sign in to comment.