Skip to content

Commit

Permalink
Refactor type of dimension attributes in TOSA reference implementation (
Browse files Browse the repository at this point in the history
#418)

Closes #390.
  • Loading branch information
henri-gruender authored Apr 4, 2024
1 parent ff3be56 commit f896d3f
Show file tree
Hide file tree
Showing 2 changed files with 13 additions and 13 deletions.
16 changes: 8 additions & 8 deletions reference-implementation/include/emitc/tosa.h
Original file line number Diff line number Diff line change
Expand Up @@ -646,7 +646,7 @@ namespace {
// Common reduce function used by specialized TOSA reduce ops.
template <typename Dest, typename Src, typename Computation>
inline Dest reduce(Src operand, typename get_element_type<Src>::type initValue,
int64_t dimension, Computation computation) {
int32_t dimension, Computation computation) {
static_assert(is_tensor<Src>::value, "Expected tensor argument");
static_assert(is_tensor<Dest>::value, "Expected tensor result");

Expand Down Expand Up @@ -688,7 +688,7 @@ inline Dest reduce(Src operand, typename get_element_type<Src>::type initValue,

// ArgMaxOp
template <typename Dest, typename Src>
inline Dest argmax(Src operand, int64_t dimension) {
inline Dest argmax(Src operand, int32_t dimension) {
static_assert(is_tensor<Src>::value, "Expected tensor argument");
static_assert(is_tensor<Dest>::value, "Expected tensor result");

Expand Down Expand Up @@ -732,7 +732,7 @@ inline Dest argmax(Src operand, int64_t dimension) {

// ReduceAllOp
template <typename Dest, typename Src>
inline Dest reduce_all(Src input, int64_t dimension) {
inline Dest reduce_all(Src input, int32_t dimension) {
// ReduceAllOp takes only tensors with datatype bool according to the
// TOSA specifications.
using ET_Src = typename get_element_type<Src>::type;
Expand All @@ -750,7 +750,7 @@ inline Dest reduce_all(Src input, int64_t dimension) {

// ReduceAnyOp
template <typename Dest, typename Src>
inline Dest reduce_any(Src input, int64_t dimension) {
inline Dest reduce_any(Src input, int32_t dimension) {
// ReduceAnyOp takes only tensors with datatype bool according to the
// TOSA specifications.
using ET_Src = typename get_element_type<Src>::type;
Expand All @@ -768,7 +768,7 @@ inline Dest reduce_any(Src input, int64_t dimension) {

// ReduceMaxOp
template <typename Dest, typename Src>
inline Dest reduce_max(Src input, int64_t dimension) {
inline Dest reduce_max(Src input, int32_t dimension) {
using ET_Src = typename get_element_type<Src>::type;

auto f =
Expand All @@ -780,7 +780,7 @@ inline Dest reduce_max(Src input, int64_t dimension) {

// ReduceMinOp
template <typename Dest, typename Src>
inline Dest reduce_min(Src input, int64_t dimension) {
inline Dest reduce_min(Src input, int32_t dimension) {
using ET_Src = typename get_element_type<Src>::type;

auto f =
Expand All @@ -792,7 +792,7 @@ inline Dest reduce_min(Src input, int64_t dimension) {

// ReduceProdOp
template <typename Dest, typename Src>
inline Dest reduce_prod(Src input, int64_t dimension) {
inline Dest reduce_prod(Src input, int32_t dimension) {
using ET_Src = typename get_element_type<Src>::type;

return tosa::reduce<Dest, Src>(input, 1, dimension,
Expand All @@ -801,7 +801,7 @@ inline Dest reduce_prod(Src input, int64_t dimension) {

// ReduceSumOp
template <typename Dest, typename Src>
inline Dest reduce_sum(Src input, int64_t dimension) {
inline Dest reduce_sum(Src input, int32_t dimension) {
using ET_Src = typename get_element_type<Src>::type;

return tosa::reduce<Dest, Src>(input, 0, dimension, std::plus<ET_Src>{});
Expand Down
10 changes: 5 additions & 5 deletions reference-implementation/unittests/tosa.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1136,7 +1136,7 @@ TEST(tosa, reduce_prod) {
TEST(tosa, reduce_sum) {
{
Tensor<int32_t, 2, 3> input{1, 2, 3, 4, 5, 6};
int64_t dimension = 0;
int32_t dimension = 0;
Tensor<int32_t, 3> expected_result{5, 7, 9};
Tensor<int32_t, 3> result =
tosa::reduce_sum<Tensor<int32_t, 3>>(input, dimension);
Expand All @@ -1145,7 +1145,7 @@ TEST(tosa, reduce_sum) {
}
{
Tensor<int32_t, 2, 3> input{1, 2, 3, 4, 5, 6};
int64_t dimension = 1;
int32_t dimension = 1;
Tensor<int32_t, 2> expected_result{6, 15};
Tensor<int32_t, 2> result =
tosa::reduce_sum<Tensor<int32_t, 2>>(input, dimension);
Expand All @@ -1155,7 +1155,7 @@ TEST(tosa, reduce_sum) {
{
Tensor<int32_t, 4, 2, 3> input{1, 2, 3, 4, 5, 6, 1, 2, 3, 4, 5, 6,
1, 2, 3, 4, 5, 6, 1, 2, 3, 4, 5, 6};
int64_t dimension = 0;
int32_t dimension = 0;
Tensor<int32_t, 2, 3> expected_result{4, 8, 12, 16, 20, 24};
Tensor<int32_t, 2, 3> result =
tosa::reduce_sum<Tensor<int32_t, 2, 3>>(input, dimension);
Expand All @@ -1165,7 +1165,7 @@ TEST(tosa, reduce_sum) {
{
Tensor<int32_t, 4, 2, 3> input{1, 2, 3, 4, 5, 6, 1, 2, 3, 4, 5, 6,
1, 2, 3, 4, 5, 6, 1, 2, 3, 4, 5, 6};
int64_t dimension = 1;
int32_t dimension = 1;
Tensor<int32_t, 4, 3> expected_result{5, 7, 9, 5, 7, 9, 5, 7, 9, 5, 7, 9};
Tensor<int32_t, 4, 3> result =
tosa::reduce_sum<Tensor<int32_t, 4, 3>>(input, dimension);
Expand All @@ -1175,7 +1175,7 @@ TEST(tosa, reduce_sum) {
{
Tensor<int32_t, 4, 2, 3> input{1, 2, 3, 4, 5, 6, 1, 2, 3, 4, 5, 6,
1, 2, 3, 4, 5, 6, 1, 2, 3, 4, 5, 6};
int64_t dimension = 2;
int32_t dimension = 2;
Tensor<int32_t, 4, 2> expected_result{6, 15, 6, 15, 6, 15, 6, 15};
Tensor<int32_t, 4, 2> result =
tosa::reduce_sum<Tensor<int32_t, 4, 2>>(input, dimension);
Expand Down

0 comments on commit f896d3f

Please sign in to comment.