From c867c16c93c548010df3e8babcbba8d1d863ba6f Mon Sep 17 00:00:00 2001 From: Swopper050 Date: Thu, 7 Sep 2023 07:13:00 +0200 Subject: [PATCH 1/2] Added atan operator --- ops/errors.go | 3 ++ ops/opset13/atan.go | 76 ++++++++++++++++++++++++++++ ops/opset13/atan_test.go | 99 +++++++++++++++++++++++++++++++++++++ ops/opset13/opset13.go | 1 + ops/opset13/opset13_test.go | 5 ++ ops_test.go | 2 + 6 files changed, 186 insertions(+) create mode 100644 ops/opset13/atan.go create mode 100644 ops/opset13/atan_test.go diff --git a/ops/errors.go b/ops/errors.go index ab25387..59f51a0 100644 --- a/ops/errors.go +++ b/ops/errors.go @@ -38,3 +38,6 @@ const AxisOutOfRangeErrTemplate = "axis argument must be in the range -%d <= x < // AxesNotAllInRangeErrTemplate is used to format an error when not all indices // are within a given range. const AxesNotAllInRangeErrTemplate = "all indices entries must be in the range -%d <= x < %d" + +// UnsupportedDTypeError is used when the DType of a tensor is not supported. +const UnsupportedDtypeErrTemplate = "dtype %v is not supported for operator %v" diff --git a/ops/opset13/atan.go b/ops/opset13/atan.go new file mode 100644 index 0000000..3104421 --- /dev/null +++ b/ops/opset13/atan.go @@ -0,0 +1,76 @@ +package opset13 + +import ( + "fmt" + "math" + + "github.com/advancedclimatesystems/gonnx/onnx" + "github.com/advancedclimatesystems/gonnx/ops" + "gorgonia.org/tensor" +) + +// Atan represents the ONNX atan operator. +type Atan struct{} + +// newAtan creates a new atan operator. +func newAtan() ops.Operator { + return &Atan{} +} + +// Init initializes the atan operator. +func (a *Atan) Init(attributes []*onnx.AttributeProto) error { + return nil +} + +type AtanDType interface { + float32 | float64 +} + +// Apply applies the atan operator. +func (a *Atan) Apply(inputs []tensor.Tensor) ([]tensor.Tensor, error) { + var out tensor.Tensor + var err error + if inputs[0].Dtype() == tensor.Float32 { + out, err = inputs[0].Apply(atan[float32]) + } else if inputs[0].Dtype() == tensor.Float64 { + out, err = inputs[0].Apply(atan[float64]) + } else { + return nil, fmt.Errorf(ops.UnsupportedDtypeErrTemplate, inputs[0].Dtype(), a) + } + + if err != nil { + return nil, err + } + + return []tensor.Tensor{out}, nil +} + +// ValidateInputs validates the inputs that will be given to Apply for this operator. +func (a *Atan) ValidateInputs(inputs []tensor.Tensor) ([]tensor.Tensor, error) { + return ops.ValidateInputs(a, inputs) +} + +// GetMinInputs returns the minimum number of input tensors this operator expects. +func (a *Atan) GetMinInputs() int { + return 1 +} + +// GetMaxInputs returns the maximum number of input tensors this operator expects. +func (a *Atan) GetMaxInputs() int { + return 1 +} + +// GetInputTypeConstraints returns a list. Every element represents a set of allowed tensor dtypes +// for the corresponding input tensor. +func (a *Atan) GetInputTypeConstraints() [][]tensor.Dtype { + return [][]tensor.Dtype{{tensor.Float32, tensor.Float64}} +} + +// String implements the stringer interface, and can be used to format errors or messages. +func (a *Atan) String() string { + return "atan operator" +} + +func atan[T AtanDType](x T) T { + return T(math.Atan(float64(x))) +} diff --git a/ops/opset13/atan_test.go b/ops/opset13/atan_test.go new file mode 100644 index 0000000..6c9d07f --- /dev/null +++ b/ops/opset13/atan_test.go @@ -0,0 +1,99 @@ +package opset13 + +import ( + "fmt" + "testing" + + "github.com/advancedclimatesystems/gonnx/ops" + "github.com/stretchr/testify/assert" + "gorgonia.org/tensor" +) + +func TestAtanInit(t *testing.T) { + a := &Atan{} + + // since 'atan' does not have any attributes we pass in nil. This should not + // fail initializing the atan. + err := a.Init(nil) + assert.Nil(t, err) +} + +func TestAtan(t *testing.T) { + tests := []struct { + atan *Atan + backing []float32 + shape []int + expected []float32 + }{ + { + &Atan{}, + []float32{1, 2, 3, 4}, + []int{2, 2}, + []float32{0.7853982, 1.1071488, 1.2490457, 1.3258177}, + }, + { + &Atan{}, + []float32{1, 2, 3, 4}, + []int{1, 4}, + []float32{0.7853982, 1.1071488, 1.2490457, 1.3258177}, + }, + { + &Atan{}, + []float32{2, 2, 2, 2}, + []int{1, 4}, + []float32{1.1071488, 1.1071488, 1.1071488, 1.1071488}, + }, + } + + for _, test := range tests { + inputs := []tensor.Tensor{ + ops.TensorWithBackingFixture(test.backing, test.shape...), + } + + res, err := test.atan.Apply(inputs) + assert.Nil(t, err) + + assert.Nil(t, err) + assert.Equal(t, test.expected, res[0].Data()) + } +} + +func TestInputValidationAtan(t *testing.T) { + tests := []struct { + inputs []tensor.Tensor + err error + }{ + { + []tensor.Tensor{ + ops.TensorWithBackingFixture([]float32{1, 2}, 2), + }, + nil, + }, + { + []tensor.Tensor{ + ops.TensorWithBackingFixture([]float64{1, 2}, 2), + }, + nil, + }, + { + []tensor.Tensor{}, + fmt.Errorf("atan operator: expected 1 input tensors, got 0"), + }, + { + []tensor.Tensor{ + ops.TensorWithBackingFixture([]int{1, 2}, 2), + }, + fmt.Errorf("atan operator: input 0 does not allow type int"), + }, + } + + for _, test := range tests { + atan := &Atan{} + validated, err := atan.ValidateInputs(test.inputs) + + assert.Equal(t, test.err, err) + if test.err == nil { + assert.Equal(t, test.inputs, validated) + } + } +} diff --git a/ops/opset13/opset13.go b/ops/opset13/opset13.go index 7bbb74c..65f5e3b 100644 --- a/ops/opset13/opset13.go +++ b/ops/opset13/opset13.go @@ -9,6 +9,7 @@ import ( var operators13 = map[string]func() ops.Operator{ "Abs": newAbs, "Add": newAdd, + "Atan": newAtan, "Cast": newCast, "Concat": newConcat, "Constant": newConstant, diff --git a/ops/opset13/opset13_test.go b/ops/opset13/opset13_test.go index c826e46..f1db767 100644 --- a/ops/opset13/opset13_test.go +++ b/ops/opset13/opset13_test.go @@ -24,6 +24,11 @@ func TestGetOperator(t *testing.T) { newAdd(), nil, }, + { + "Atan", + newAtan(), + nil, + }, { "Cast", newCast(), diff --git a/ops_test.go b/ops_test.go index 6fb94fd..6e97251 100644 --- a/ops_test.go +++ b/ops_test.go @@ -257,6 +257,8 @@ var expectedTests = []string{ "test_abs", "test_add", "test_add_bcast", + "test_atan", + "test_atan_example", "test_cast_DOUBLE_to_FLOAT", "test_cast_FLOAT_to_DOUBLE", "test_concat_1d_axis_0", From eee56fc4e0d78aaeb5c26ad61afebcce4433eecc Mon Sep 17 00:00:00 2001 From: Swopper050 Date: Sat, 25 Nov 2023 19:53:43 +0100 Subject: [PATCH 2/2] Use FloatType --- ops/opset13/atan.go | 6 +----- 1 file changed, 1 insertion(+), 5 deletions(-) diff --git a/ops/opset13/atan.go b/ops/opset13/atan.go index 17bf1c1..2abdaaa 100644 --- a/ops/opset13/atan.go +++ b/ops/opset13/atan.go @@ -21,10 +21,6 @@ func (a *Atan) Init(_ []*onnx.AttributeProto) error { return nil } -type AtanDType interface { - float32 | float64 -} - // Apply applies the atan operator. func (a *Atan) Apply(inputs []tensor.Tensor) ([]tensor.Tensor, error) { var ( @@ -74,6 +70,6 @@ func (a *Atan) String() string { return "atan operator" } -func atan[T AtanDType](x T) T { +func atan[T ops.FloatType](x T) T { return T(math.Atan(float64(x))) }