Skip to content

Commit

Permalink
Added ArgMax operator (#209)
Browse files Browse the repository at this point in the history
* Added ArgMax operator

* Use old linter

* Fix lint

---------

Co-authored-by: Swopper050 <[email protected]>
  • Loading branch information
Swopper050 and Swopper050 authored Oct 21, 2024
1 parent 6e4a050 commit 7ed9d6f
Show file tree
Hide file tree
Showing 9 changed files with 312 additions and 21 deletions.
127 changes: 127 additions & 0 deletions ops/opset13/argmax.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,127 @@
package opset13

import (
"github.com/advancedclimatesystems/gonnx/onnx"
"github.com/advancedclimatesystems/gonnx/ops"
"gorgonia.org/tensor"
)

const (
MinArgMaxInputs = 1
MaxArgMaxInputs = 1
)

// ArgMax represents the ONNX argmax operator.
type ArgMax struct {
axis int
keepDims bool
selectLastIndex bool
}

// newArgMax creates a new argmax operator.
func newArgMax() ops.Operator {
return &ArgMax{
keepDims: true,
selectLastIndex: false,
}
}

type ArgMaxAttribute string

const (
axis = "axis"
keepDims = "keepdims"
selectLastIndex = "select_last_index"
)

// Init initializes the argmax operator.
func (a *ArgMax) Init(n *onnx.NodeProto) error {
attributes := n.GetAttribute()
for _, attr := range attributes {
switch attr.GetName() {
case axis:
a.axis = int(attr.GetI())
case keepDims:
a.keepDims = ops.Int64ToBool(attr.GetI())
case selectLastIndex:
a.selectLastIndex = ops.Int64ToBool(attr.GetI())

// We have no way yet to perform argmax and keeping the
// last index as max in case of duplicates, so if this
// attribute is true, we raise an unsupported error.
if a.selectLastIndex {
return ops.ErrUnsupportedAttribute(attr.GetName(), a)
}
default:
return ops.ErrInvalidAttribute(attr.GetName(), a)
}
}

return nil
}

// Apply applies the argmax operator.
func (a *ArgMax) Apply(inputs []tensor.Tensor) ([]tensor.Tensor, error) {
axis := ops.ConvertNegativeAxis(a.axis, len(inputs[0].Shape()))

reduced, err := tensor.Argmax(inputs[0], axis)
if err != nil {
return nil, err
}

// Keep the reduced dimension, i.e. if the reduced axis was '1', and
// the original shape was (2, 4, 5), the reduced shape would be (2, 5).
// If keepDims is true, that shape should be (2, 1, 5).
if a.keepDims {
newShape := inputs[0].Shape()
newShape[axis] = 1

if err := reduced.Reshape(newShape...); err != nil {
return nil, err
}
}

// The tensor.Argmax function returns data of type int, but according to
// the ONNX standard this operator should return int64.
backing, ok := reduced.Data().([]int)
if !ok {
return nil, ops.ErrTypeAssert("int", reduced.Dtype())
}

backing2 := make([]int64, len(backing))
for i := range backing {
backing2[i] = int64(backing[i])
}

reduced = tensor.New(tensor.WithShape(reduced.Shape()...), tensor.WithBacking(backing2))

return []tensor.Tensor{reduced}, nil
}

// ValidateInputs validates the inputs that will be given to Apply for this operator.
func (a *ArgMax) ValidateInputs(inputs []tensor.Tensor) ([]tensor.Tensor, error) {
return ops.ValidateInputs(a, inputs)
}

// GetMinInputs returns the minimum number of input tensors this operator expects.
func (a *ArgMax) GetMinInputs() int {
return MinArgMaxInputs
}

// GetMaxInputs returns the maximum number of input tensors this operator expects.
func (a *ArgMax) GetMaxInputs() int {
return MaxArgMaxInputs
}

// GetInputTypeConstraints returns a list. Every element represents a set of allowed tensor dtypes
// for the corresponding input tensor.
func (a *ArgMax) GetInputTypeConstraints() [][]tensor.Dtype {
return [][]tensor.Dtype{
{tensor.Uint32, tensor.Uint64, tensor.Int32, tensor.Int64, tensor.Float32, tensor.Float64},
}
}

// String implements the stringer interface, and can be used to format errors or messages.
func (a *ArgMax) String() string {
return "argmax operator"
}
134 changes: 134 additions & 0 deletions ops/opset13/argmax_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,134 @@
package opset13

import (
"testing"

"github.com/advancedclimatesystems/gonnx/onnx"
"github.com/advancedclimatesystems/gonnx/ops"
"github.com/stretchr/testify/assert"
"gorgonia.org/tensor"
)

func TestArgMaxInit(t *testing.T) {
a := &ArgMax{}

err := a.Init(
&onnx.NodeProto{
Attribute: []*onnx.AttributeProto{
{Name: "axis", I: 2},
{Name: "keepdims", I: 0},
{Name: "select_last_index", I: 0},
},
},
)
assert.Nil(t, err)

assert.Equal(t, 2, a.axis)
assert.Equal(t, false, a.keepDims)
assert.Equal(t, false, a.selectLastIndex)
}

func TestArgMax(t *testing.T) {
tests := []struct {
argmax *ArgMax
backing []float32
shape []int
expectedShape tensor.Shape
expectedData []int64
}{
{
&ArgMax{axis: 0, keepDims: true},
[]float32{0, 1, 2, 3},
[]int{2, 2},
[]int{1, 2},
[]int64{1, 1},
},
{
&ArgMax{axis: -1, keepDims: true},
[]float32{0, 1, 2, 3},
[]int{2, 2},
[]int{2, 1},
[]int64{1, 1},
},
}

for _, test := range tests {
inputs := []tensor.Tensor{
ops.TensorWithBackingFixture(test.backing, test.shape...),
}

res, err := test.argmax.Apply(inputs)
assert.Nil(t, err)

assert.Equal(t, test.expectedShape, res[0].Shape())
assert.Equal(t, test.expectedData, res[0].Data())
}
}

func TestInputValidationArgMax(t *testing.T) {
tests := []struct {
inputs []tensor.Tensor
err error
}{
{
[]tensor.Tensor{
ops.TensorWithBackingFixture([]uint32{1, 2}, 2),
},
nil,
},
{
[]tensor.Tensor{
ops.TensorWithBackingFixture([]uint64{1, 2}, 2),
},
nil,
},
{
[]tensor.Tensor{
ops.TensorWithBackingFixture([]int32{1, 2}, 2),
},
nil,
},
{
[]tensor.Tensor{
ops.TensorWithBackingFixture([]int64{1, 2}, 2),
},
nil,
},
{
[]tensor.Tensor{
ops.TensorWithBackingFixture([]float32{1, 2}, 2),
},
nil,
},
{
[]tensor.Tensor{
ops.TensorWithBackingFixture([]float64{1, 2}, 2),
},
nil,
},
{
[]tensor.Tensor{
ops.TensorWithBackingFixture([]float32{1, 2}, 2),
ops.TensorWithBackingFixture([]float32{1, 2}, 2),
},
ops.ErrInvalidInputCount(2, &ArgMax{}),
},
{
[]tensor.Tensor{
ops.TensorWithBackingFixture([]int{1, 2}, 2),
},
ops.ErrInvalidInputType(0, "int", &ArgMax{}),
},
}

for _, test := range tests {
argmax := &ArgMax{}
validated, err := argmax.ValidateInputs(test.inputs)

assert.Equal(t, test.err, err)

if test.err == nil {
assert.Equal(t, test.inputs, validated)
}
}
}
2 changes: 1 addition & 1 deletion ops/opset13/expand.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ func newExpand() ops.Operator {
}

// Init initializes the expand operator.
func (f *Expand) Init(n *onnx.NodeProto) error {
func (f *Expand) Init(*onnx.NodeProto) error {
return nil
}

Expand Down
1 change: 1 addition & 0 deletions ops/opset13/opset13.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ var operators13 = map[string]func() ops.Operator{
"Acosh": newAcosh,
"Add": newAdd,
"And": newAnd,
"ArgMax": newArgMax,
"Asin": newAsin,
"Asinh": newAsinh,
"Atan": newAtan,
Expand Down
20 changes: 13 additions & 7 deletions ops/opset13/opset13_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -39,15 +39,11 @@ func TestGetOperator(t *testing.T) {
nil,
},
{
"Atan",
newAtan(),
nil,
},
{
"Atanh",
newAtanh(),
"ArgMax",
newArgMax(),
nil,
},

{
"Asin",
newAsin(),
Expand All @@ -58,6 +54,16 @@ func TestGetOperator(t *testing.T) {
newAsinh(),
nil,
},
{
"Atan",
newAtan(),
nil,
},
{
"Atanh",
newAtanh(),
nil,
},
{
"Cast",
newCast(),
Expand Down
7 changes: 1 addition & 6 deletions ops/opset13/reduce_max.go
Original file line number Diff line number Diff line change
Expand Up @@ -57,12 +57,7 @@ func (r *ReduceMax) Apply(inputs []tensor.Tensor) ([]tensor.Tensor, error) {

axes := make([]int, len(r.axes))
for i, axis := range r.axes {
// Convert negative dimensions.
if axis < 0 {
axis = len(input.Shape()) + axis
}

axes[i] = axis
axes[i] = ops.ConvertNegativeAxis(axis, len(input.Shape()))
}

out, err := input.Max(axes...)
Expand Down
8 changes: 1 addition & 7 deletions ops/opset13/reduce_min.go
Original file line number Diff line number Diff line change
Expand Up @@ -57,13 +57,7 @@ func (r *ReduceMin) Apply(inputs []tensor.Tensor) ([]tensor.Tensor, error) {

axes := make([]int, len(r.axes))
for i, axis := range r.axes {
// Convert negative dimensions to positive dimensions as Go does not support
// negative dimension indexing like Python does.
if axis < 0 {
axis = len(input.Shape()) + axis
}

axes[i] = axis
axes[i] = ops.ConvertNegativeAxis(axis, len(input.Shape()))
}

out, err := input.Min(axes...)
Expand Down
17 changes: 17 additions & 0 deletions ops/utils.go
Original file line number Diff line number Diff line change
Expand Up @@ -237,3 +237,20 @@ func PairwiseAssign(t1, t2 tensor.Tensor) (err error) {

return nil
}

// Converts a negative axis to the corresponding axis such that it can be used as index.
// For example, if axis is -1, this represents the last dimension. Go does not support
// negative indexing (as opposed to Python, on which ONNX is heavily dependent), so we
// have to convert the negative axis to the positive axis it represents, which is dependent
// on the rank (number of dimensions) of the tensor.
// Example 1: if rank is 3, and axis is -1, the corresponding positive axis is 2.
// Example 2: if rank is 4, and axis is -1, the corresponding positive axis is 3.
// Example 3: if rank is 4, and axis is -3, the corresponding positive axis is 1.
// Example 4: if rank is 3, and axis is 2, the function does nothing.
func ConvertNegativeAxis(axis, rank int) int {
if axis < 0 {
axis = rank + axis
}

return axis
}
17 changes: 17 additions & 0 deletions ops_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -155,6 +155,15 @@ var ignoredTests = []string{
"test_prelu_broadcast_expanded", // Unsupported operator CastLike
"test_prelu_example_expanded", // Unsupported operator CastLike
"test_constant_pad_negative_axes", // Unsupported operator Pad

"test_argmax_keepdims_random_select_last_index", // Unsupported attribute
"test_argmax_keepdims_example_select_last_index", // Unsupported attribute
"test_argmax_no_keepdims_example_select_last_index", // Unsupported attribute
"test_argmax_no_keepdims_random_select_last_index", // Unsupported attribute
"test_argmax_default_axis_example_select_last_index", // Unsupported attribute
"test_argmax_default_axis_random_select_last_index", // Unsupported attribute
"test_argmax_negative_axis_keepdims_example_select_last_index", // Unsupported attribute
"test_argmax_negative_axis_keepdims_random_select_last_index", // Unsupported attribute
}

type ONNXTestCase struct {
Expand Down Expand Up @@ -354,6 +363,14 @@ var expectedTests = []string{
"test_and_bcast4v2d",
"test_and_bcast4v3d",
"test_and_bcast4v4d",
"test_argmax_default_axis_example",
"test_argmax_default_axis_random",
"test_argmax_keepdims_example",
"test_argmax_keepdims_random",
"test_argmax_negative_axis_keepdims_example",
"test_argmax_negative_axis_keepdims_random",
"test_argmax_no_keepdims_example",
"test_argmax_no_keepdims_random",
"test_asin",
"test_asin_example",
"test_asinh",
Expand Down

0 comments on commit 7ed9d6f

Please sign in to comment.