diff --git a/ops/abs/abs.go b/ops/abs/abs.go new file mode 100644 index 0000000..d0c05cf --- /dev/null +++ b/ops/abs/abs.go @@ -0,0 +1,44 @@ +package abs + +import ( + "github.com/advancedclimatesystems/gonnx/onnx" + "github.com/advancedclimatesystems/gonnx/ops" + "gorgonia.org/tensor" +) + +var absTypeConstraint = [][]tensor.Dtype{ + {tensor.Uint8, tensor.Uint16, tensor.Uint32, tensor.Uint64, tensor.Int8, tensor.Int16, tensor.Int32, tensor.Int64, tensor.Float32, tensor.Float64}, +} + +// Abs represents the ONNX abs operator. +type Abs struct { + ops.BaseOperator +} + +// newAbs creates a new abs operator. +func newAbs(version int, typeConstraint [][]tensor.Dtype) *Abs { + return &Abs{ + BaseOperator: ops.NewBaseOperator( + version, + 1, + 1, + typeConstraint, + "abs", + ), + } +} + +// Init initializes the abs operator. +func (a *Abs) Init(*onnx.NodeProto) error { + return nil +} + +// Apply applies the abs operator. +func (a *Abs) Apply(inputs []tensor.Tensor) ([]tensor.Tensor, error) { + out, err := tensor.Abs(inputs[0]) + if err != nil { + return nil, err + } + + return []tensor.Tensor{out}, nil +} diff --git a/ops/abs/abs_13.go b/ops/abs/abs_13.go deleted file mode 100644 index 0ad2ad5..0000000 --- a/ops/abs/abs_13.go +++ /dev/null @@ -1,63 +0,0 @@ -package abs - -import ( - "github.com/advancedclimatesystems/gonnx/onnx" - "github.com/advancedclimatesystems/gonnx/ops" - "gorgonia.org/tensor" -) - -const ( - MinAbs13Inputs = 1 - MaxAbs13Inputs = 1 -) - -// Abs13 represents the ONNX abs operator. -type Abs13 struct{} - -// newAbs13 creates a new abs operator. -func newAbs13() ops.Operator { - return &Abs13{} -} - -// Init initializes the abs operator. -func (a *Abs13) Init(*onnx.NodeProto) error { - return nil -} - -// Apply applies the abs operator. -func (a *Abs13) Apply(inputs []tensor.Tensor) ([]tensor.Tensor, error) { - out, err := tensor.Abs(inputs[0]) - if err != nil { - return nil, err - } - - return []tensor.Tensor{out}, nil -} - -// ValidateInputs validates the inputs that will be given to Apply for this operator. -func (a *Abs13) ValidateInputs(inputs []tensor.Tensor) ([]tensor.Tensor, error) { - return ops.ValidateInputs(a, inputs) -} - -// GetMinInputs returns the minimum number of input tensors this operator expects. -func (a *Abs13) GetMinInputs() int { - return MinAbs13Inputs -} - -// GetMaxInputs returns the maximum number of input tensors this operator expects. -func (a *Abs13) GetMaxInputs() int { - return MaxAbs13Inputs -} - -// GetInputTypeConstraints returns a list. Every element represents a set of allowed tensor dtypes -// for the corresponding input tensor. -func (a *Abs13) GetInputTypeConstraints() [][]tensor.Dtype { - return [][]tensor.Dtype{ - {tensor.Uint8, tensor.Uint16, tensor.Uint32, tensor.Uint64, tensor.Int8, tensor.Int16, tensor.Int32, tensor.Int64, tensor.Float32, tensor.Float64}, - } -} - -// String implements the stringer interface, and can be used to format errors or messages. -func (a *Abs13) String() string { - return "abs13 operator" -} diff --git a/ops/abs/abs_6.go b/ops/abs/abs_6.go deleted file mode 100644 index e4c4f6e..0000000 --- a/ops/abs/abs_6.go +++ /dev/null @@ -1,63 +0,0 @@ -package abs - -import ( - "github.com/advancedclimatesystems/gonnx/onnx" - "github.com/advancedclimatesystems/gonnx/ops" - "gorgonia.org/tensor" -) - -const ( - MinAbs6Inputs = 1 - MaxAbs6Inputs = 1 -) - -// Abs6 represents the ONNX abs operator. -type Abs6 struct{} - -// newAbs6 creates a new abs operator. -func newAbs6() ops.Operator { - return &Abs6{} -} - -// Init initializes the abs operator. -func (a *Abs6) Init(*onnx.NodeProto) error { - return nil -} - -// Apply applies the abs operator. -func (a *Abs6) Apply(inputs []tensor.Tensor) ([]tensor.Tensor, error) { - out, err := tensor.Abs(inputs[0]) - if err != nil { - return nil, err - } - - return []tensor.Tensor{out}, nil -} - -// ValidateInputs validates the inputs that will be given to Apply for this operator. -func (a *Abs6) ValidateInputs(inputs []tensor.Tensor) ([]tensor.Tensor, error) { - return ops.ValidateInputs(a, inputs) -} - -// GetMinInputs returns the minimum number of input tensors this operator expects. -func (a *Abs6) GetMinInputs() int { - return MinAbs6Inputs -} - -// GetMaxInputs returns the maximum number of input tensors this operator expects. -func (a *Abs6) GetMaxInputs() int { - return MaxAbs6Inputs -} - -// GetInputTypeConstraints returns a list. Every element represents a set of allowed tensor dtypes -// for the corresponding input tensor. -func (a *Abs6) GetInputTypeConstraints() [][]tensor.Dtype { - return [][]tensor.Dtype{ - {tensor.Uint8, tensor.Uint16, tensor.Uint32, tensor.Uint64, tensor.Int8, tensor.Int16, tensor.Int32, tensor.Int64, tensor.Float32, tensor.Float64}, - } -} - -// String implements the stringer interface, and can be used to format errors or messages. -func (a *Abs6) String() string { - return "abs6 operator" -} diff --git a/ops/abs/abs_13_test.go b/ops/abs/abs_test.go similarity index 53% rename from ops/abs/abs_13_test.go rename to ops/abs/abs_test.go index 1930879..845fc39 100644 --- a/ops/abs/abs_13_test.go +++ b/ops/abs/abs_test.go @@ -8,8 +8,8 @@ import ( "gorgonia.org/tensor" ) -func TestAbs13Init(t *testing.T) { - a := &Abs13{} +func TestAbsInit(t *testing.T) { + a := &Abs{} // since 'abs' does not have any attributes we pass in nil. This should not // fail initializing the abs. @@ -17,27 +17,27 @@ func TestAbs13Init(t *testing.T) { assert.Nil(t, err) } -func TestAbs13(t *testing.T) { +func TestAbs(t *testing.T) { tests := []struct { - abs *Abs13 + abs *Abs backing []float32 shape []int expected []float32 }{ { - &Abs13{}, + &Abs{}, []float32{-2, -1, 0, 1}, []int{2, 2}, []float32{2, 1, 0, 1}, }, { - &Abs13{}, + &Abs{}, []float32{1, 3, 4, 5}, []int{1, 4}, []float32{1, 3, 4, 5}, }, { - &Abs13{}, + &Abs{}, []float32{-1, -1, -1, -1}, []int{1, 4}, []float32{1, 1, 1, 1}, @@ -57,85 +57,180 @@ func TestAbs13(t *testing.T) { } } -func TestInputValidationAbs13(t *testing.T) { +func TestInputValidationAbs(t *testing.T) { tests := []struct { - inputs []tensor.Tensor - err error + inputs []tensor.Tensor + err error + version int64 }{ { []tensor.Tensor{ ops.TensorWithBackingFixture([]uint8{1, 2}, 2), }, nil, + 6, }, { []tensor.Tensor{ ops.TensorWithBackingFixture([]uint16{1, 2}, 2), }, nil, + 6, }, { []tensor.Tensor{ ops.TensorWithBackingFixture([]uint32{1, 2}, 2), }, nil, + 6, }, { []tensor.Tensor{ ops.TensorWithBackingFixture([]uint64{1, 2}, 2), }, nil, + 6, }, { []tensor.Tensor{ ops.TensorWithBackingFixture([]int8{1, 2}, 2), }, nil, + 6, }, { []tensor.Tensor{ ops.TensorWithBackingFixture([]int16{1, 2}, 2), }, nil, + 6, }, { []tensor.Tensor{ ops.TensorWithBackingFixture([]int32{1, 2}, 2), }, nil, + 6, }, { []tensor.Tensor{ ops.TensorWithBackingFixture([]int64{1, 2}, 2), }, nil, + 6, }, { []tensor.Tensor{ ops.TensorWithBackingFixture([]float32{1, 2}, 2), }, nil, + 6, }, { []tensor.Tensor{ ops.TensorWithBackingFixture([]float64{1, 2}, 2), }, nil, + 6, }, { []tensor.Tensor{}, - ops.ErrInvalidInputCount(0, &Abs13{}), + ops.ErrInvalidInputCount(0, ops.NewBaseOperator(6, 1, 1, absTypeConstraint, "abs")), + 6, }, { []tensor.Tensor{ ops.TensorWithBackingFixture([]int{1, 2}, 2), }, - ops.ErrInvalidInputType(0, "int", &Abs13{}), + ops.ErrInvalidInputType(0, "int", ops.NewBaseOperator(6, 1, 1, absTypeConstraint, "abs")), + 6, + }, + { + []tensor.Tensor{ + ops.TensorWithBackingFixture([]uint8{1, 2}, 2), + }, + nil, + 13, + }, + { + []tensor.Tensor{ + ops.TensorWithBackingFixture([]uint16{1, 2}, 2), + }, + nil, + 13, + }, + { + []tensor.Tensor{ + ops.TensorWithBackingFixture([]uint32{1, 2}, 2), + }, + nil, + 13, + }, + { + []tensor.Tensor{ + ops.TensorWithBackingFixture([]uint64{1, 2}, 2), + }, + nil, + 13, + }, + { + []tensor.Tensor{ + ops.TensorWithBackingFixture([]int8{1, 2}, 2), + }, + nil, + 13, + }, + { + []tensor.Tensor{ + ops.TensorWithBackingFixture([]int16{1, 2}, 2), + }, + nil, + 13, + }, + { + []tensor.Tensor{ + ops.TensorWithBackingFixture([]int32{1, 2}, 2), + }, + nil, + 13, + }, + { + []tensor.Tensor{ + ops.TensorWithBackingFixture([]int64{1, 2}, 2), + }, + nil, + 13, + }, + { + []tensor.Tensor{ + ops.TensorWithBackingFixture([]float32{1, 2}, 2), + }, + nil, + 13, + }, + { + []tensor.Tensor{ + ops.TensorWithBackingFixture([]float64{1, 2}, 2), + }, + nil, + 13, + }, + { + []tensor.Tensor{}, + ops.ErrInvalidInputCount(0, ops.NewBaseOperator(13, 1, 1, absTypeConstraint, "abs")), + 13, + }, + { + []tensor.Tensor{ + ops.TensorWithBackingFixture([]int{1, 2}, 2), + }, + ops.ErrInvalidInputType(0, "int", ops.NewBaseOperator(13, 1, 1, absTypeConstraint, "abs")), + 13, }, } for _, test := range tests { - abs := &Abs13{} + abs := AbsVersions[test.version]() validated, err := abs.ValidateInputs(test.inputs) assert.Equal(t, test.err, err) diff --git a/ops/abs/versions.go b/ops/abs/versions.go index 975523e..11c2325 100644 --- a/ops/abs/versions.go +++ b/ops/abs/versions.go @@ -5,6 +5,12 @@ import ( ) var AbsVersions = ops.OperatorVersions{ - 6: newAbs6, // Same, but bfloat16 type is added - 13: newAbs13, + 6: newConstructor(newAbs(6, absTypeConstraint)), // Same, but bfloat16 type is added + 13: newConstructor(newAbs(13, absTypeConstraint)), // Same, but bfloat16 type is added +} + +func newConstructor(base *Abs) func() ops.Operator { + return func() ops.Operator { + return base + } }