Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Added atan operator #165

Merged
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
75 changes: 75 additions & 0 deletions ops/opset13/atan.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,75 @@
package opset13

import (
"math"

"github.com/advancedclimatesystems/gonnx/onnx"
"github.com/advancedclimatesystems/gonnx/ops"
"gorgonia.org/tensor"
)

// Atan represents the ONNX atan operator.
type Atan struct{}

// newAtan creates a new atan operator.
func newAtan() ops.Operator {
return &Atan{}
}

// Init initializes the atan operator.
func (a *Atan) Init(_ []*onnx.AttributeProto) error {
return nil
}

// Apply applies the atan operator.
func (a *Atan) Apply(inputs []tensor.Tensor) ([]tensor.Tensor, error) {
var (
out tensor.Tensor
err error
)

switch inputs[0].Dtype() {
case tensor.Float32:
out, err = inputs[0].Apply(atan[float32])
case tensor.Float64:
out, err = inputs[0].Apply(atan[float64])
default:
return nil, ops.ErrInvalidInputType(0, inputs[0].Dtype().String(), a)
}

if err != nil {
return nil, err
}

return []tensor.Tensor{out}, nil
}

// ValidateInputs validates the inputs that will be given to Apply for this operator.
func (a *Atan) ValidateInputs(inputs []tensor.Tensor) ([]tensor.Tensor, error) {
return ops.ValidateInputs(a, inputs)
}

// GetMinInputs returns the minimum number of input tensors this operator expects.
func (a *Atan) GetMinInputs() int {
return 1
}

// GetMaxInputs returns the maximum number of input tensors this operator expects.
func (a *Atan) GetMaxInputs() int {
return 1
}

// GetInputTypeConstraints returns a list. Every element represents a set of allowed tensor dtypes
// for the corresponding input tensor.
func (a *Atan) GetInputTypeConstraints() [][]tensor.Dtype {
return [][]tensor.Dtype{{tensor.Float32, tensor.Float64}}
}

// String implements the stringer interface, and can be used to format errors or messages.
func (a *Atan) String() string {
return "atan operator"
}

func atan[T ops.FloatType](x T) T {
return T(math.Atan(float64(x)))
}
99 changes: 99 additions & 0 deletions ops/opset13/atan_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,99 @@
package opset13

import (
"testing"

"github.com/advancedclimatesystems/gonnx/ops"
"github.com/stretchr/testify/assert"
"gorgonia.org/tensor"
)

func TestAtanInit(t *testing.T) {
a := &Atan{}

// since 'atan' does not have any attributes we pass in nil. This should not
// fail initializing the atan.
err := a.Init(nil)
assert.Nil(t, err)
}

func TestAtan(t *testing.T) {
tests := []struct {
atan *Atan
backing []float32
shape []int
expected []float32
}{
{
&Atan{},
[]float32{1, 2, 3, 4},
[]int{2, 2},
[]float32{0.7853982, 1.1071488, 1.2490457, 1.3258177},
},
{
&Atan{},
[]float32{1, 2, 3, 4},
[]int{1, 4},
[]float32{0.7853982, 1.1071488, 1.2490457, 1.3258177},
},
{
&Atan{},
[]float32{2, 2, 2, 2},
[]int{1, 4},
[]float32{1.1071488, 1.1071488, 1.1071488, 1.1071488},
},
}

for _, test := range tests {
inputs := []tensor.Tensor{
ops.TensorWithBackingFixture(test.backing, test.shape...),
}

res, err := test.atan.Apply(inputs)
assert.Nil(t, err)

assert.Nil(t, err)
assert.Equal(t, test.expected, res[0].Data())
}
}

func TestInputValidationAtan(t *testing.T) {
tests := []struct {
inputs []tensor.Tensor
err error
}{
{
[]tensor.Tensor{
ops.TensorWithBackingFixture([]float32{1, 2}, 2),
},
nil,
},
{
[]tensor.Tensor{
ops.TensorWithBackingFixture([]float64{1, 2}, 2),
},
nil,
},
{
[]tensor.Tensor{},
ops.ErrInvalidInputCount(0, &Atan{}),
},
{
[]tensor.Tensor{
ops.TensorWithBackingFixture([]int{1, 2}, 2),
},
ops.ErrInvalidInputType(0, "int", &Atan{}),
},
}

for _, test := range tests {
atan := &Atan{}
validated, err := atan.ValidateInputs(test.inputs)

assert.Equal(t, test.err, err)

if test.err == nil {
assert.Equal(t, test.inputs, validated)
}
}
}
1 change: 1 addition & 0 deletions ops/opset13/opset13.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ var operators13 = map[string]func() ops.Operator{
"Acos": newAcos,
"Acosh": newAcosh,
"Add": newAdd,
"Atan": newAtan,
Swopper050 marked this conversation as resolved.
Show resolved Hide resolved
"Cast": newCast,
"Concat": newConcat,
"Constant": newConstant,
Expand Down
5 changes: 5 additions & 0 deletions ops/opset13/opset13_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,11 @@ func TestGetOperator(t *testing.T) {
newAdd(),
nil,
},
{
"Atan",
newAtan(),
nil,
},
{
"Cast",
newCast(),
Expand Down
2 changes: 2 additions & 0 deletions ops_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -278,6 +278,8 @@ var expectedTests = []string{
"test_acosh_example",
"test_add",
"test_add_bcast",
"test_atan",
"test_atan_example",
"test_cast_DOUBLE_to_FLOAT",
"test_cast_FLOAT_to_DOUBLE",
"test_concat_1d_axis_0",
Expand Down