Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Added acos operator #162

Merged
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
75 changes: 75 additions & 0 deletions ops/opset13/acos.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,75 @@
package opset13

import (
"math"

"github.com/advancedclimatesystems/gonnx/onnx"
"github.com/advancedclimatesystems/gonnx/ops"
"gorgonia.org/tensor"
)

// Acos represents the ONNX acos operator.
type Acos struct{}

// newAcos creates a new acos operator.
func newAcos() ops.Operator {
return &Acos{}
}

// Init initializes the acos operator.
func (c *Acos) Init(_ []*onnx.AttributeProto) error {
return nil
}

// Apply applies the acos operator.
func (c *Acos) Apply(inputs []tensor.Tensor) ([]tensor.Tensor, error) {
var (
out tensor.Tensor
err error
)

switch inputs[0].Dtype() {
case tensor.Float32:
out, err = inputs[0].Apply(acos[float32])
case tensor.Float64:
out, err = inputs[0].Apply(acos[float64])
default:
return nil, ops.ErrInvalidInputType(0, inputs[0].Dtype().String(), c)
}

if err != nil {
return nil, err
}

return []tensor.Tensor{out}, nil
}

// ValidateInputs validates the inputs that will be given to Apply for this operator.
func (c *Acos) ValidateInputs(inputs []tensor.Tensor) ([]tensor.Tensor, error) {
return ops.ValidateInputs(c, inputs)
}

// GetMinInputs returns the minimum number of input tensors this operator expects.
func (c *Acos) GetMinInputs() int {
return 1
}

// GetMaxInputs returns the maximum number of input tensors this operator expects.
func (c *Acos) GetMaxInputs() int {
return 1
}

// GetInputTypeConstraints returns a list. Every element represents a set of allowed tensor dtypes
// for the corresponding input tensor.
func (c *Acos) GetInputTypeConstraints() [][]tensor.Dtype {
return [][]tensor.Dtype{{tensor.Float32, tensor.Float64}}
}

// String implements the stringer interface, and can be used to format errors or messages.
func (c *Acos) String() string {
return "acos operator"
}

func acos[T ops.FloatType](x T) T {
return T(math.Acos(float64(x)))
}
99 changes: 99 additions & 0 deletions ops/opset13/acos_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,99 @@
package opset13

import (
"testing"

"github.com/advancedclimatesystems/gonnx/ops"
"github.com/stretchr/testify/assert"
"gorgonia.org/tensor"
)

func TestAcosInit(t *testing.T) {
c := &Acos{}

// since 'acos' does not have any attributes we pass in nil. This should not
// fail initializing the acos.
err := c.Init(nil)
assert.Nil(t, err)
}

func TestAcos(t *testing.T) {
tests := []struct {
acos *Acos
backing []float32
shape []int
expected []float32
}{
{
&Acos{},
[]float32{-1, -1, 0, 1},
[]int{2, 2},
[]float32{3.1415927, 3.1415927, 1.5707964, 0},
},
{
&Acos{},
[]float32{1, 0.5, 0.0, -0.5},
[]int{1, 4},
[]float32{0, 1.0471976, 1.5707964, 2.0943952},
},
{
&Acos{},
[]float32{-1, -1, -1, -1},
[]int{1, 4},
[]float32{3.1415927, 3.1415927, 3.1415927, 3.1415927},
},
}

for _, test := range tests {
inputs := []tensor.Tensor{
ops.TensorWithBackingFixture(test.backing, test.shape...),
}

res, err := test.acos.Apply(inputs)
assert.Nil(t, err)

assert.Nil(t, err)
assert.Equal(t, test.expected, res[0].Data())
}
}

func TestInputValidationAcos(t *testing.T) {
tests := []struct {
inputs []tensor.Tensor
err error
}{
{
[]tensor.Tensor{
ops.TensorWithBackingFixture([]float32{1, 2}, 2),
},
nil,
},
{
[]tensor.Tensor{
ops.TensorWithBackingFixture([]float64{1, 2}, 2),
},
nil,
},
{
[]tensor.Tensor{},
ops.ErrInvalidInputCount(0, &Acos{}),
},
{
[]tensor.Tensor{
ops.TensorWithBackingFixture([]int{1, 2}, 2),
},
ops.ErrInvalidInputType(0, "int", &Acos{}),
},
}

for _, test := range tests {
acos := &Acos{}
validated, err := acos.ValidateInputs(test.inputs)

assert.Equal(t, test.err, err)

if test.err == nil {
assert.Equal(t, test.inputs, validated)
}
}
}
1 change: 1 addition & 0 deletions ops/opset13/opset13.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ import (

var operators13 = map[string]func() ops.Operator{
"Abs": newAbs,
"Acos": newAcos,
"Add": newAdd,
"Cast": newCast,
"Concat": newConcat,
Expand Down
5 changes: 5 additions & 0 deletions ops/opset13/opset13_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,11 @@ func TestGetOperator(t *testing.T) {
newAbs(),
nil,
},
{
"Acos",
newAcos(),
nil,
},
{
"Add",
newAdd(),
Expand Down
2 changes: 2 additions & 0 deletions ops_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -272,6 +272,8 @@ func readTestTensors(basePath, baseFile string, inputs []*onnx.ValueInfoProto) (
// With this we check if we truly run all tests we expected from the integration test.
var expectedTests = []string{
"test_abs",
"test_acos",
"test_acos_example",
"test_add",
"test_add_bcast",
"test_cast_DOUBLE_to_FLOAT",
Expand Down