Skip to content

Commit

Permalink
Added atan operator (#165)
Browse files Browse the repository at this point in the history
* Added atan operator

* Use FloatType

---------

Co-authored-by: Swopper050 <[email protected]>
  • Loading branch information
Swopper050 and Swopper050 authored Nov 27, 2023
1 parent 62247cb commit ddea0c3
Show file tree
Hide file tree
Showing 5 changed files with 182 additions and 0 deletions.
75 changes: 75 additions & 0 deletions ops/opset13/atan.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,75 @@
package opset13

import (
"math"

"github.com/advancedclimatesystems/gonnx/onnx"
"github.com/advancedclimatesystems/gonnx/ops"
"gorgonia.org/tensor"
)

// Atan represents the ONNX atan operator.
type Atan struct{}

// newAtan creates a new atan operator.
func newAtan() ops.Operator {
return &Atan{}
}

// Init initializes the atan operator.
func (a *Atan) Init(_ []*onnx.AttributeProto) error {
return nil
}

// Apply applies the atan operator.
func (a *Atan) Apply(inputs []tensor.Tensor) ([]tensor.Tensor, error) {
var (
out tensor.Tensor
err error
)

switch inputs[0].Dtype() {
case tensor.Float32:
out, err = inputs[0].Apply(atan[float32])
case tensor.Float64:
out, err = inputs[0].Apply(atan[float64])
default:
return nil, ops.ErrInvalidInputType(0, inputs[0].Dtype().String(), a)
}

if err != nil {
return nil, err
}

return []tensor.Tensor{out}, nil
}

// ValidateInputs validates the inputs that will be given to Apply for this operator.
func (a *Atan) ValidateInputs(inputs []tensor.Tensor) ([]tensor.Tensor, error) {
return ops.ValidateInputs(a, inputs)
}

// GetMinInputs returns the minimum number of input tensors this operator expects.
func (a *Atan) GetMinInputs() int {
return 1
}

// GetMaxInputs returns the maximum number of input tensors this operator expects.
func (a *Atan) GetMaxInputs() int {
return 1
}

// GetInputTypeConstraints returns a list. Every element represents a set of allowed tensor dtypes
// for the corresponding input tensor.
func (a *Atan) GetInputTypeConstraints() [][]tensor.Dtype {
return [][]tensor.Dtype{{tensor.Float32, tensor.Float64}}
}

// String implements the stringer interface, and can be used to format errors or messages.
func (a *Atan) String() string {
return "atan operator"
}

func atan[T ops.FloatType](x T) T {
return T(math.Atan(float64(x)))
}
99 changes: 99 additions & 0 deletions ops/opset13/atan_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,99 @@
package opset13

import (
"testing"

"github.com/advancedclimatesystems/gonnx/ops"
"github.com/stretchr/testify/assert"
"gorgonia.org/tensor"
)

func TestAtanInit(t *testing.T) {
a := &Atan{}

// since 'atan' does not have any attributes we pass in nil. This should not
// fail initializing the atan.
err := a.Init(nil)
assert.Nil(t, err)
}

func TestAtan(t *testing.T) {
tests := []struct {
atan *Atan
backing []float32
shape []int
expected []float32
}{
{
&Atan{},
[]float32{1, 2, 3, 4},
[]int{2, 2},
[]float32{0.7853982, 1.1071488, 1.2490457, 1.3258177},
},
{
&Atan{},
[]float32{1, 2, 3, 4},
[]int{1, 4},
[]float32{0.7853982, 1.1071488, 1.2490457, 1.3258177},
},
{
&Atan{},
[]float32{2, 2, 2, 2},
[]int{1, 4},
[]float32{1.1071488, 1.1071488, 1.1071488, 1.1071488},
},
}

for _, test := range tests {
inputs := []tensor.Tensor{
ops.TensorWithBackingFixture(test.backing, test.shape...),
}

res, err := test.atan.Apply(inputs)
assert.Nil(t, err)

assert.Nil(t, err)
assert.Equal(t, test.expected, res[0].Data())
}
}

func TestInputValidationAtan(t *testing.T) {
tests := []struct {
inputs []tensor.Tensor
err error
}{
{
[]tensor.Tensor{
ops.TensorWithBackingFixture([]float32{1, 2}, 2),
},
nil,
},
{
[]tensor.Tensor{
ops.TensorWithBackingFixture([]float64{1, 2}, 2),
},
nil,
},
{
[]tensor.Tensor{},
ops.ErrInvalidInputCount(0, &Atan{}),
},
{
[]tensor.Tensor{
ops.TensorWithBackingFixture([]int{1, 2}, 2),
},
ops.ErrInvalidInputType(0, "int", &Atan{}),
},
}

for _, test := range tests {
atan := &Atan{}
validated, err := atan.ValidateInputs(test.inputs)

assert.Equal(t, test.err, err)

if test.err == nil {
assert.Equal(t, test.inputs, validated)
}
}
}
1 change: 1 addition & 0 deletions ops/opset13/opset13.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ var operators13 = map[string]func() ops.Operator{
"Acosh": newAcosh,
"Add": newAdd,
"Asin": newAsin,
"Atan": newAtan,
"Cast": newCast,
"Concat": newConcat,
"Constant": newConstant,
Expand Down
5 changes: 5 additions & 0 deletions ops/opset13/opset13_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,11 @@ func TestGetOperator(t *testing.T) {
newAdd(),
nil,
},
{
"Atan",
newAtan(),
nil,
},
{
"Asin",
newAsin(),
Expand Down
2 changes: 2 additions & 0 deletions ops_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -280,6 +280,8 @@ var expectedTests = []string{
"test_add_bcast",
"test_asin",
"test_asin_example",
"test_atan",
"test_atan_example",
"test_cast_DOUBLE_to_FLOAT",
"test_cast_FLOAT_to_DOUBLE",
"test_concat_1d_axis_0",
Expand Down

0 comments on commit ddea0c3

Please sign in to comment.