From cc701d7b8d14cd04b9f641aa024aa616aa72bdad Mon Sep 17 00:00:00 2001 From: Swopper050 Date: Tue, 5 Sep 2023 07:36:00 +0200 Subject: [PATCH 1/3] Added acos operator --- ops/errors.go | 3 ++ ops/opset13/acos.go | 76 ++++++++++++++++++++++++++++ ops/opset13/acos_test.go | 99 +++++++++++++++++++++++++++++++++++++ ops/opset13/opset13.go | 1 + ops/opset13/opset13_test.go | 5 ++ ops_test.go | 2 + 6 files changed, 186 insertions(+) create mode 100644 ops/opset13/acos.go create mode 100644 ops/opset13/acos_test.go diff --git a/ops/errors.go b/ops/errors.go index ab25387..59f51a0 100644 --- a/ops/errors.go +++ b/ops/errors.go @@ -38,3 +38,6 @@ const AxisOutOfRangeErrTemplate = "axis argument must be in the range -%d <= x < // AxesNotAllInRangeErrTemplate is used to format an error when not all indices // are within a given range. const AxesNotAllInRangeErrTemplate = "all indices entries must be in the range -%d <= x < %d" + +// UnsupportedDTypeError is used when the DType of a tensor is not supported. +const UnsupportedDtypeErrTemplate = "dtype %v is not supported for operator %v" diff --git a/ops/opset13/acos.go b/ops/opset13/acos.go new file mode 100644 index 0000000..5615b2e --- /dev/null +++ b/ops/opset13/acos.go @@ -0,0 +1,76 @@ +package opset13 + +import ( + "fmt" + "math" + + "github.com/advancedclimatesystems/gonnx/onnx" + "github.com/advancedclimatesystems/gonnx/ops" + "gorgonia.org/tensor" +) + +// Acos represents the ONNX acos operator. +type Acos struct{} + +// newAcos creates a new acos operator. +func newAcos() ops.Operator { + return &Acos{} +} + +// Init initializes the acos operator. +func (c *Acos) Init(attributes []*onnx.AttributeProto) error { + return nil +} + +type AcosDType interface { + float32 | float64 +} + +// Apply applies the acos operator. +func (c *Acos) Apply(inputs []tensor.Tensor) ([]tensor.Tensor, error) { + var out tensor.Tensor + var err error + if inputs[0].Dtype() == tensor.Float32 { + out, err = inputs[0].Apply(acos[float32]) + } else if inputs[0].Dtype() == tensor.Float64 { + out, err = inputs[0].Apply(acos[float64]) + } else { + return nil, fmt.Errorf(ops.UnsupportedDtypeErrTemplate, inputs[0].Dtype(), c) + } + + if err != nil { + return nil, err + } + + return []tensor.Tensor{out}, nil +} + +// ValidateInputs validates the inputs that will be given to Apply for this operator. +func (c *Acos) ValidateInputs(inputs []tensor.Tensor) ([]tensor.Tensor, error) { + return ops.ValidateInputs(c, inputs) +} + +// GetMinInputs returns the minimum number of input tensors this operator expects. +func (c *Acos) GetMinInputs() int { + return 1 +} + +// GetMaxInputs returns the maximum number of input tensors this operator expects. +func (c *Acos) GetMaxInputs() int { + return 1 +} + +// GetInputTypeConstraints returns a list. Every element represents a set of allowed tensor dtypes +// for the corresponding input tensor. +func (c *Acos) GetInputTypeConstraints() [][]tensor.Dtype { + return [][]tensor.Dtype{{tensor.Float32, tensor.Float64}} +} + +// String implements the stringer interface, and can be used to format errors or messages. +func (c *Acos) String() string { + return "acos operator" +} + +func acos[T AcosDType](x T) T { + return T(math.Acos(float64(x))) +} diff --git a/ops/opset13/acos_test.go b/ops/opset13/acos_test.go new file mode 100644 index 0000000..c90ed3a --- /dev/null +++ b/ops/opset13/acos_test.go @@ -0,0 +1,99 @@ +package opset13 + +import ( + "fmt" + "testing" + + "github.com/advancedclimatesystems/gonnx/ops" + "github.com/stretchr/testify/assert" + "gorgonia.org/tensor" +) + +func TestAcosInit(t *testing.T) { + c := &Acos{} + + // since 'acos' does not have any attributes we pass in nil. This should not + // fail initializing the acos. + err := c.Init(nil) + assert.Nil(t, err) +} + +func TestAcos(t *testing.T) { + tests := []struct { + acos *Acos + backing []float32 + shape []int + expected []float32 + }{ + { + &Acos{}, + []float32{-1, -1, 0, 1}, + []int{2, 2}, + []float32{3.1415927, 3.1415927, 1.5707964, 0}, + }, + { + &Acos{}, + []float32{1, 0.5, 0.0, -0.5}, + []int{1, 4}, + []float32{0, 1.0471976, 1.5707964, 2.0943952}, + }, + { + &Acos{}, + []float32{-1, -1, -1, -1}, + []int{1, 4}, + []float32{3.1415927, 3.1415927, 3.1415927, 3.1415927}, + }, + } + + for _, test := range tests { + inputs := []tensor.Tensor{ + ops.TensorWithBackingFixture(test.backing, test.shape...), + } + + res, err := test.acos.Apply(inputs) + assert.Nil(t, err) + + assert.Nil(t, err) + assert.Equal(t, test.expected, res[0].Data()) + } +} + +func TestInputValidationAcos(t *testing.T) { + tests := []struct { + inputs []tensor.Tensor + err error + }{ + { + []tensor.Tensor{ + ops.TensorWithBackingFixture([]float32{1, 2}, 2), + }, + nil, + }, + { + []tensor.Tensor{ + ops.TensorWithBackingFixture([]float64{1, 2}, 2), + }, + nil, + }, + { + []tensor.Tensor{}, + fmt.Errorf("acos operator: expected 1 input tensors, got 0"), + }, + { + []tensor.Tensor{ + ops.TensorWithBackingFixture([]int{1, 2}, 2), + }, + fmt.Errorf("acos operator: input 0 does not allow type int"), + }, + } + + for _, test := range tests { + acos := &Acos{} + validated, err := acos.ValidateInputs(test.inputs) + + assert.Equal(t, test.err, err) + if test.err == nil { + assert.Equal(t, test.inputs, validated) + } + } +} diff --git a/ops/opset13/opset13.go b/ops/opset13/opset13.go index 7bbb74c..9bf3160 100644 --- a/ops/opset13/opset13.go +++ b/ops/opset13/opset13.go @@ -8,6 +8,7 @@ import ( var operators13 = map[string]func() ops.Operator{ "Abs": newAbs, + "Acos": newAcos, "Add": newAdd, "Cast": newCast, "Concat": newConcat, diff --git a/ops/opset13/opset13_test.go b/ops/opset13/opset13_test.go index c826e46..6153548 100644 --- a/ops/opset13/opset13_test.go +++ b/ops/opset13/opset13_test.go @@ -19,6 +19,11 @@ func TestGetOperator(t *testing.T) { newAbs(), nil, }, + { + "Acos", + newAcos(), + nil, + }, { "Add", newAdd(), diff --git a/ops_test.go b/ops_test.go index 6fb94fd..71fdd80 100644 --- a/ops_test.go +++ b/ops_test.go @@ -255,6 +255,8 @@ func readTestTensors(basePath, baseFile string, inputs []*onnx.ValueInfoProto) ( // With this we check if we truly run all tests we expected from the integration test. var expectedTests = []string{ "test_abs", + "test_acos", + "test_acos_example", "test_add", "test_add_bcast", "test_cast_DOUBLE_to_FLOAT", From e2dabc2b1af70dfdad3bcb076ce56a3dbaae34d7 Mon Sep 17 00:00:00 2001 From: Swopper050 Date: Mon, 13 Nov 2023 14:44:33 +0100 Subject: [PATCH 2/3] Merge develop --- ops/opset13/acos.go | 6 +----- 1 file changed, 1 insertion(+), 5 deletions(-) diff --git a/ops/opset13/acos.go b/ops/opset13/acos.go index 4781b6c..a03778c 100644 --- a/ops/opset13/acos.go +++ b/ops/opset13/acos.go @@ -21,10 +21,6 @@ func (c *Acos) Init(_ []*onnx.AttributeProto) error { return nil } -type AcosDType interface { - float32 | float64 -} - // Apply applies the acos operator. func (c *Acos) Apply(inputs []tensor.Tensor) ([]tensor.Tensor, error) { var out tensor.Tensor @@ -73,6 +69,6 @@ func (c *Acos) String() string { return "acos operator" } -func acos[T AcosDType](x T) T { +func acos[T ops.FloatType](x T) T { return T(math.Acos(float64(x))) } From 62528375501fb28083a2e4fef4432481e1c6831e Mon Sep 17 00:00:00 2001 From: Swopper050 Date: Mon, 13 Nov 2023 14:45:38 +0100 Subject: [PATCH 3/3] Group declarations --- ops/opset13/acos.go | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/ops/opset13/acos.go b/ops/opset13/acos.go index a03778c..9e766c0 100644 --- a/ops/opset13/acos.go +++ b/ops/opset13/acos.go @@ -23,9 +23,10 @@ func (c *Acos) Init(_ []*onnx.AttributeProto) error { // Apply applies the acos operator. func (c *Acos) Apply(inputs []tensor.Tensor) ([]tensor.Tensor, error) { - var out tensor.Tensor - - var err error + var ( + out tensor.Tensor + err error + ) switch inputs[0].Dtype() { case tensor.Float32: