diff --git a/ops/opset13/acos.go b/ops/opset13/acos.go new file mode 100644 index 0000000..9e766c0 --- /dev/null +++ b/ops/opset13/acos.go @@ -0,0 +1,75 @@ +package opset13 + +import ( + "math" + + "github.com/advancedclimatesystems/gonnx/onnx" + "github.com/advancedclimatesystems/gonnx/ops" + "gorgonia.org/tensor" +) + +// Acos represents the ONNX acos operator. +type Acos struct{} + +// newAcos creates a new acos operator. +func newAcos() ops.Operator { + return &Acos{} +} + +// Init initializes the acos operator. +func (c *Acos) Init(_ []*onnx.AttributeProto) error { + return nil +} + +// Apply applies the acos operator. +func (c *Acos) Apply(inputs []tensor.Tensor) ([]tensor.Tensor, error) { + var ( + out tensor.Tensor + err error + ) + + switch inputs[0].Dtype() { + case tensor.Float32: + out, err = inputs[0].Apply(acos[float32]) + case tensor.Float64: + out, err = inputs[0].Apply(acos[float64]) + default: + return nil, ops.ErrInvalidInputType(0, inputs[0].Dtype().String(), c) + } + + if err != nil { + return nil, err + } + + return []tensor.Tensor{out}, nil +} + +// ValidateInputs validates the inputs that will be given to Apply for this operator. +func (c *Acos) ValidateInputs(inputs []tensor.Tensor) ([]tensor.Tensor, error) { + return ops.ValidateInputs(c, inputs) +} + +// GetMinInputs returns the minimum number of input tensors this operator expects. +func (c *Acos) GetMinInputs() int { + return 1 +} + +// GetMaxInputs returns the maximum number of input tensors this operator expects. +func (c *Acos) GetMaxInputs() int { + return 1 +} + +// GetInputTypeConstraints returns a list. Every element represents a set of allowed tensor dtypes +// for the corresponding input tensor. +func (c *Acos) GetInputTypeConstraints() [][]tensor.Dtype { + return [][]tensor.Dtype{{tensor.Float32, tensor.Float64}} +} + +// String implements the stringer interface, and can be used to format errors or messages. +func (c *Acos) String() string { + return "acos operator" +} + +func acos[T ops.FloatType](x T) T { + return T(math.Acos(float64(x))) +} diff --git a/ops/opset13/acos_test.go b/ops/opset13/acos_test.go new file mode 100644 index 0000000..e2c755a --- /dev/null +++ b/ops/opset13/acos_test.go @@ -0,0 +1,99 @@ +package opset13 + +import ( + "testing" + + "github.com/advancedclimatesystems/gonnx/ops" + "github.com/stretchr/testify/assert" + "gorgonia.org/tensor" +) + +func TestAcosInit(t *testing.T) { + c := &Acos{} + + // since 'acos' does not have any attributes we pass in nil. This should not + // fail initializing the acos. + err := c.Init(nil) + assert.Nil(t, err) +} + +func TestAcos(t *testing.T) { + tests := []struct { + acos *Acos + backing []float32 + shape []int + expected []float32 + }{ + { + &Acos{}, + []float32{-1, -1, 0, 1}, + []int{2, 2}, + []float32{3.1415927, 3.1415927, 1.5707964, 0}, + }, + { + &Acos{}, + []float32{1, 0.5, 0.0, -0.5}, + []int{1, 4}, + []float32{0, 1.0471976, 1.5707964, 2.0943952}, + }, + { + &Acos{}, + []float32{-1, -1, -1, -1}, + []int{1, 4}, + []float32{3.1415927, 3.1415927, 3.1415927, 3.1415927}, + }, + } + + for _, test := range tests { + inputs := []tensor.Tensor{ + ops.TensorWithBackingFixture(test.backing, test.shape...), + } + + res, err := test.acos.Apply(inputs) + assert.Nil(t, err) + + assert.Nil(t, err) + assert.Equal(t, test.expected, res[0].Data()) + } +} + +func TestInputValidationAcos(t *testing.T) { + tests := []struct { + inputs []tensor.Tensor + err error + }{ + { + []tensor.Tensor{ + ops.TensorWithBackingFixture([]float32{1, 2}, 2), + }, + nil, + }, + { + []tensor.Tensor{ + ops.TensorWithBackingFixture([]float64{1, 2}, 2), + }, + nil, + }, + { + []tensor.Tensor{}, + ops.ErrInvalidInputCount(0, &Acos{}), + }, + { + []tensor.Tensor{ + ops.TensorWithBackingFixture([]int{1, 2}, 2), + }, + ops.ErrInvalidInputType(0, "int", &Acos{}), + }, + } + + for _, test := range tests { + acos := &Acos{} + validated, err := acos.ValidateInputs(test.inputs) + + assert.Equal(t, test.err, err) + + if test.err == nil { + assert.Equal(t, test.inputs, validated) + } + } +} diff --git a/ops/opset13/opset13.go b/ops/opset13/opset13.go index 953ca27..7384a05 100644 --- a/ops/opset13/opset13.go +++ b/ops/opset13/opset13.go @@ -6,6 +6,7 @@ import ( var operators13 = map[string]func() ops.Operator{ "Abs": newAbs, + "Acos": newAcos, "Acosh": newAcosh, "Add": newAdd, "Cast": newCast, diff --git a/ops/opset13/opset13_test.go b/ops/opset13/opset13_test.go index cb27818..5a2fc35 100644 --- a/ops/opset13/opset13_test.go +++ b/ops/opset13/opset13_test.go @@ -18,6 +18,11 @@ func TestGetOperator(t *testing.T) { newAbs(), nil, }, + { + "Acos", + newAcos(), + nil, + }, { "Acosh", newAcosh(), diff --git a/ops_test.go b/ops_test.go index f54a2e6..df7ea8e 100644 --- a/ops_test.go +++ b/ops_test.go @@ -272,6 +272,8 @@ func readTestTensors(basePath, baseFile string, inputs []*onnx.ValueInfoProto) ( // With this we check if we truly run all tests we expected from the integration test. var expectedTests = []string{ "test_abs", + "test_acos", + "test_acos_example", "test_acosh", "test_acosh_example", "test_add",