-
Notifications
You must be signed in to change notification settings - Fork 5
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
* Added ArgMax operator * Use old linter * Fix lint --------- Co-authored-by: Swopper050 <[email protected]>
- Loading branch information
1 parent
6e4a050
commit 7ed9d6f
Showing
9 changed files
with
312 additions
and
21 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,127 @@ | ||
package opset13 | ||
|
||
import ( | ||
"github.com/advancedclimatesystems/gonnx/onnx" | ||
"github.com/advancedclimatesystems/gonnx/ops" | ||
"gorgonia.org/tensor" | ||
) | ||
|
||
const ( | ||
MinArgMaxInputs = 1 | ||
MaxArgMaxInputs = 1 | ||
) | ||
|
||
// ArgMax represents the ONNX argmax operator. | ||
type ArgMax struct { | ||
axis int | ||
keepDims bool | ||
selectLastIndex bool | ||
} | ||
|
||
// newArgMax creates a new argmax operator. | ||
func newArgMax() ops.Operator { | ||
return &ArgMax{ | ||
keepDims: true, | ||
selectLastIndex: false, | ||
} | ||
} | ||
|
||
type ArgMaxAttribute string | ||
|
||
const ( | ||
axis = "axis" | ||
keepDims = "keepdims" | ||
selectLastIndex = "select_last_index" | ||
) | ||
|
||
// Init initializes the argmax operator. | ||
func (a *ArgMax) Init(n *onnx.NodeProto) error { | ||
attributes := n.GetAttribute() | ||
for _, attr := range attributes { | ||
switch attr.GetName() { | ||
case axis: | ||
a.axis = int(attr.GetI()) | ||
case keepDims: | ||
a.keepDims = ops.Int64ToBool(attr.GetI()) | ||
case selectLastIndex: | ||
a.selectLastIndex = ops.Int64ToBool(attr.GetI()) | ||
|
||
// We have no way yet to perform argmax and keeping the | ||
// last index as max in case of duplicates, so if this | ||
// attribute is true, we raise an unsupported error. | ||
if a.selectLastIndex { | ||
return ops.ErrUnsupportedAttribute(attr.GetName(), a) | ||
} | ||
default: | ||
return ops.ErrInvalidAttribute(attr.GetName(), a) | ||
} | ||
} | ||
|
||
return nil | ||
} | ||
|
||
// Apply applies the argmax operator. | ||
func (a *ArgMax) Apply(inputs []tensor.Tensor) ([]tensor.Tensor, error) { | ||
axis := ops.ConvertNegativeAxis(a.axis, len(inputs[0].Shape())) | ||
|
||
reduced, err := tensor.Argmax(inputs[0], axis) | ||
if err != nil { | ||
return nil, err | ||
} | ||
|
||
// Keep the reduced dimension, i.e. if the reduced axis was '1', and | ||
// the original shape was (2, 4, 5), the reduced shape would be (2, 5). | ||
// If keepDims is true, that shape should be (2, 1, 5). | ||
if a.keepDims { | ||
newShape := inputs[0].Shape() | ||
newShape[axis] = 1 | ||
|
||
if err := reduced.Reshape(newShape...); err != nil { | ||
return nil, err | ||
} | ||
} | ||
|
||
// The tensor.Argmax function returns data of type int, but according to | ||
// the ONNX standard this operator should return int64. | ||
backing, ok := reduced.Data().([]int) | ||
if !ok { | ||
return nil, ops.ErrTypeAssert("int", reduced.Dtype()) | ||
} | ||
|
||
backing2 := make([]int64, len(backing)) | ||
for i := range backing { | ||
backing2[i] = int64(backing[i]) | ||
} | ||
|
||
reduced = tensor.New(tensor.WithShape(reduced.Shape()...), tensor.WithBacking(backing2)) | ||
|
||
return []tensor.Tensor{reduced}, nil | ||
} | ||
|
||
// ValidateInputs validates the inputs that will be given to Apply for this operator. | ||
func (a *ArgMax) ValidateInputs(inputs []tensor.Tensor) ([]tensor.Tensor, error) { | ||
return ops.ValidateInputs(a, inputs) | ||
} | ||
|
||
// GetMinInputs returns the minimum number of input tensors this operator expects. | ||
func (a *ArgMax) GetMinInputs() int { | ||
return MinArgMaxInputs | ||
} | ||
|
||
// GetMaxInputs returns the maximum number of input tensors this operator expects. | ||
func (a *ArgMax) GetMaxInputs() int { | ||
return MaxArgMaxInputs | ||
} | ||
|
||
// GetInputTypeConstraints returns a list. Every element represents a set of allowed tensor dtypes | ||
// for the corresponding input tensor. | ||
func (a *ArgMax) GetInputTypeConstraints() [][]tensor.Dtype { | ||
return [][]tensor.Dtype{ | ||
{tensor.Uint32, tensor.Uint64, tensor.Int32, tensor.Int64, tensor.Float32, tensor.Float64}, | ||
} | ||
} | ||
|
||
// String implements the stringer interface, and can be used to format errors or messages. | ||
func (a *ArgMax) String() string { | ||
return "argmax operator" | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,134 @@ | ||
package opset13 | ||
|
||
import ( | ||
"testing" | ||
|
||
"github.com/advancedclimatesystems/gonnx/onnx" | ||
"github.com/advancedclimatesystems/gonnx/ops" | ||
"github.com/stretchr/testify/assert" | ||
"gorgonia.org/tensor" | ||
) | ||
|
||
func TestArgMaxInit(t *testing.T) { | ||
a := &ArgMax{} | ||
|
||
err := a.Init( | ||
&onnx.NodeProto{ | ||
Attribute: []*onnx.AttributeProto{ | ||
{Name: "axis", I: 2}, | ||
{Name: "keepdims", I: 0}, | ||
{Name: "select_last_index", I: 0}, | ||
}, | ||
}, | ||
) | ||
assert.Nil(t, err) | ||
|
||
assert.Equal(t, 2, a.axis) | ||
assert.Equal(t, false, a.keepDims) | ||
assert.Equal(t, false, a.selectLastIndex) | ||
} | ||
|
||
func TestArgMax(t *testing.T) { | ||
tests := []struct { | ||
argmax *ArgMax | ||
backing []float32 | ||
shape []int | ||
expectedShape tensor.Shape | ||
expectedData []int64 | ||
}{ | ||
{ | ||
&ArgMax{axis: 0, keepDims: true}, | ||
[]float32{0, 1, 2, 3}, | ||
[]int{2, 2}, | ||
[]int{1, 2}, | ||
[]int64{1, 1}, | ||
}, | ||
{ | ||
&ArgMax{axis: -1, keepDims: true}, | ||
[]float32{0, 1, 2, 3}, | ||
[]int{2, 2}, | ||
[]int{2, 1}, | ||
[]int64{1, 1}, | ||
}, | ||
} | ||
|
||
for _, test := range tests { | ||
inputs := []tensor.Tensor{ | ||
ops.TensorWithBackingFixture(test.backing, test.shape...), | ||
} | ||
|
||
res, err := test.argmax.Apply(inputs) | ||
assert.Nil(t, err) | ||
|
||
assert.Equal(t, test.expectedShape, res[0].Shape()) | ||
assert.Equal(t, test.expectedData, res[0].Data()) | ||
} | ||
} | ||
|
||
func TestInputValidationArgMax(t *testing.T) { | ||
tests := []struct { | ||
inputs []tensor.Tensor | ||
err error | ||
}{ | ||
{ | ||
[]tensor.Tensor{ | ||
ops.TensorWithBackingFixture([]uint32{1, 2}, 2), | ||
}, | ||
nil, | ||
}, | ||
{ | ||
[]tensor.Tensor{ | ||
ops.TensorWithBackingFixture([]uint64{1, 2}, 2), | ||
}, | ||
nil, | ||
}, | ||
{ | ||
[]tensor.Tensor{ | ||
ops.TensorWithBackingFixture([]int32{1, 2}, 2), | ||
}, | ||
nil, | ||
}, | ||
{ | ||
[]tensor.Tensor{ | ||
ops.TensorWithBackingFixture([]int64{1, 2}, 2), | ||
}, | ||
nil, | ||
}, | ||
{ | ||
[]tensor.Tensor{ | ||
ops.TensorWithBackingFixture([]float32{1, 2}, 2), | ||
}, | ||
nil, | ||
}, | ||
{ | ||
[]tensor.Tensor{ | ||
ops.TensorWithBackingFixture([]float64{1, 2}, 2), | ||
}, | ||
nil, | ||
}, | ||
{ | ||
[]tensor.Tensor{ | ||
ops.TensorWithBackingFixture([]float32{1, 2}, 2), | ||
ops.TensorWithBackingFixture([]float32{1, 2}, 2), | ||
}, | ||
ops.ErrInvalidInputCount(2, &ArgMax{}), | ||
}, | ||
{ | ||
[]tensor.Tensor{ | ||
ops.TensorWithBackingFixture([]int{1, 2}, 2), | ||
}, | ||
ops.ErrInvalidInputType(0, "int", &ArgMax{}), | ||
}, | ||
} | ||
|
||
for _, test := range tests { | ||
argmax := &ArgMax{} | ||
validated, err := argmax.ValidateInputs(test.inputs) | ||
|
||
assert.Equal(t, test.err, err) | ||
|
||
if test.err == nil { | ||
assert.Equal(t, test.inputs, validated) | ||
} | ||
} | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters