From cbd224d6fd9ed4be8db7705af57e4fe8669562dd Mon Sep 17 00:00:00 2001 From: wisse Date: Tue, 3 Dec 2024 13:14:58 +0100 Subject: [PATCH] Refactor acosh into base operator --- ops/acos/acos.go | 4 +- ops/acos/acos_test.go | 8 +-- ops/acos/versions.go | 2 +- ops/acosh/acosh.go | 53 ++++++++++++++ ops/acosh/acosh_9.go | 75 -------------------- ops/acosh/{acosh_9_test.go => acosh_test.go} | 18 ++--- ops/acosh/versions.go | 2 +- 7 files changed, 70 insertions(+), 92 deletions(-) create mode 100644 ops/acosh/acosh.go delete mode 100644 ops/acosh/acosh_9.go rename ops/acosh/{acosh_9_test.go => acosh_test.go} (79%) diff --git a/ops/acos/acos.go b/ops/acos/acos.go index f1b2686..a3de114 100644 --- a/ops/acos/acos.go +++ b/ops/acos/acos.go @@ -14,10 +14,10 @@ type Acos struct { } // newAcos creates a new acos operator. -func newAcos(verion int) ops.Operator { +func newAcos() ops.Operator { return &Acos{ BaseOperator: ops.NewBaseOperator( - verion, + 7, 1, 1, [][]tensor.Dtype{{tensor.Float32, tensor.Float64}}, diff --git a/ops/acos/acos_test.go b/ops/acos/acos_test.go index 35ac3bf..7eb3ed9 100644 --- a/ops/acos/acos_test.go +++ b/ops/acos/acos_test.go @@ -25,19 +25,19 @@ func TestAcos(t *testing.T) { expected []float32 }{ { - newAcos(7), + newAcos(), []float32{-1, -1, 0, 1}, []int{2, 2}, []float32{3.1415927, 3.1415927, 1.5707964, 0}, }, { - newAcos(7), + newAcos(), []float32{1, 0.5, 0.0, -0.5}, []int{1, 4}, []float32{0, 1.0471976, 1.5707964, 2.0943952}, }, { - newAcos(7), + newAcos(), []float32{-1, -1, -1, -1}, []int{1, 4}, []float32{3.1415927, 3.1415927, 3.1415927, 3.1415927}, @@ -87,7 +87,7 @@ func TestInputValidationAcos(t *testing.T) { } for _, test := range tests { - acos := newAcos(7) + acos := newAcos() validated, err := acos.ValidateInputs(test.inputs) assert.Equal(t, test.err, err) diff --git a/ops/acos/versions.go b/ops/acos/versions.go index 814bee2..0892e8f 100644 --- a/ops/acos/versions.go +++ b/ops/acos/versions.go @@ -5,5 +5,5 @@ import ( ) var AcosVersions = ops.OperatorVersions{ - 7: ops.NewOperatorConstructor(newAcos(7)), + 7: newAcos, } diff --git a/ops/acosh/acosh.go b/ops/acosh/acosh.go new file mode 100644 index 0000000..8e91430 --- /dev/null +++ b/ops/acosh/acosh.go @@ -0,0 +1,53 @@ +package acosh + +import ( + "math" + + "github.com/advancedclimatesystems/gonnx/onnx" + "github.com/advancedclimatesystems/gonnx/ops" + "gorgonia.org/tensor" +) + +// Acosh represents the ONNX acosh operator. +type Acosh struct { + ops.BaseOperator +} + +// newAcosh creates a new acosh operator. +func newAcosh() ops.Operator { + return &Acosh{ + BaseOperator: ops.NewBaseOperator(9, 1, 1, [][]tensor.Dtype{{tensor.Float32, tensor.Float64}}, "acosh"), + } +} + +// Init initializes the acosh operator. +func (c *Acosh) Init(*onnx.NodeProto) error { + return nil +} + +// Apply applies the acosh operator. +func (c *Acosh) Apply(inputs []tensor.Tensor) ([]tensor.Tensor, error) { + var ( + out tensor.Tensor + err error + ) + + switch inputs[0].Dtype() { + case tensor.Float32: + out, err = inputs[0].Apply(acosh[float32]) + case tensor.Float64: + out, err = inputs[0].Apply(acosh[float64]) + default: + return nil, ops.ErrInvalidInputType(0, inputs[0].Dtype().String(), c.BaseOperator) + } + + if err != nil { + return nil, err + } + + return []tensor.Tensor{out}, nil +} + +func acosh[T ops.FloatType](x T) T { + return T(math.Acosh(float64(x))) +} diff --git a/ops/acosh/acosh_9.go b/ops/acosh/acosh_9.go deleted file mode 100644 index f376795..0000000 --- a/ops/acosh/acosh_9.go +++ /dev/null @@ -1,75 +0,0 @@ -package acosh - -import ( - "math" - - "github.com/advancedclimatesystems/gonnx/onnx" - "github.com/advancedclimatesystems/gonnx/ops" - "gorgonia.org/tensor" -) - -// Acosh9 represents the ONNX acosh operator. -type Acosh9 struct{} - -// newAcosh9 creates a new acosh operator. -func newAcosh9() ops.Operator { - return &Acosh9{} -} - -// Init initializes the acosh operator. -func (c *Acosh9) Init(*onnx.NodeProto) error { - return nil -} - -// Apply applies the acosh operator. -func (c *Acosh9) Apply(inputs []tensor.Tensor) ([]tensor.Tensor, error) { - var ( - out tensor.Tensor - err error - ) - - switch inputs[0].Dtype() { - case tensor.Float32: - out, err = inputs[0].Apply(acosh[float32]) - case tensor.Float64: - out, err = inputs[0].Apply(acosh[float64]) - default: - return nil, ops.ErrInvalidInputType(0, inputs[0].Dtype().String(), c) - } - - if err != nil { - return nil, err - } - - return []tensor.Tensor{out}, nil -} - -// ValidateInputs validates the inputs that will be given to Apply for this operator. -func (c *Acosh9) ValidateInputs(inputs []tensor.Tensor) ([]tensor.Tensor, error) { - return ops.ValidateInputs(c, inputs) -} - -// GetMinInputs returns the minimum number of input tensors this operator expects. -func (c *Acosh9) GetMinInputs() int { - return 1 -} - -// GetMaxInputs returns the maximum number of input tensors this operator expects. -func (c *Acosh9) GetMaxInputs() int { - return 1 -} - -// GetInputTypeConstraints returns a list. Every element represents a set of allowed tensor dtypes -// for the corresponding input tensor. -func (c *Acosh9) GetInputTypeConstraints() [][]tensor.Dtype { - return [][]tensor.Dtype{{tensor.Float32, tensor.Float64}} -} - -// String implements the stringer interface, and can be used to format errors or messages. -func (c *Acosh9) String() string { - return "acosh9 operator" -} - -func acosh[T ops.FloatType](x T) T { - return T(math.Acosh(float64(x))) -} diff --git a/ops/acosh/acosh_9_test.go b/ops/acosh/acosh_test.go similarity index 79% rename from ops/acosh/acosh_9_test.go rename to ops/acosh/acosh_test.go index 7172c8b..ae01374 100644 --- a/ops/acosh/acosh_9_test.go +++ b/ops/acosh/acosh_test.go @@ -9,7 +9,7 @@ import ( ) func TestAcosh9Init(t *testing.T) { - c := &Acosh9{} + c := &Acosh{} // since 'acosh' does not have any attributes we pass in nil. This should not // fail initializing the acosh. @@ -19,25 +19,25 @@ func TestAcosh9Init(t *testing.T) { func TestAcosh9(t *testing.T) { tests := []struct { - acosh *Acosh9 + acosh ops.Operator backing []float32 shape []int expected []float32 }{ { - &Acosh9{}, + newAcosh(), []float32{1, 2, 3, 4}, []int{2, 2}, []float32{0, 1.316958, 1.7627472, 2.063437}, }, { - &Acosh9{}, + newAcosh(), []float32{1, 2, 3, 4}, []int{1, 4}, []float32{0, 1.316958, 1.7627472, 2.063437}, }, { - &Acosh9{}, + newAcosh(), []float32{2, 2, 2, 2}, []int{1, 4}, []float32{1.316958, 1.316958, 1.316958, 1.316958}, @@ -57,7 +57,7 @@ func TestAcosh9(t *testing.T) { } } -func TestInputValidationAcosh9(t *testing.T) { +func TestInputValidationAcosh(t *testing.T) { tests := []struct { inputs []tensor.Tensor err error @@ -76,18 +76,18 @@ func TestInputValidationAcosh9(t *testing.T) { }, { []tensor.Tensor{}, - ops.ErrInvalidInputCount(0, &Acosh9{}), + ops.ErrInvalidInputCount(0, ops.NewBaseOperator(9, 1, 1, [][]tensor.Dtype{{tensor.Float32, tensor.Float64}}, "acosh")), }, { []tensor.Tensor{ ops.TensorWithBackingFixture([]int{1, 2}, 2), }, - ops.ErrInvalidInputType(0, "int", &Acosh9{}), + ops.ErrInvalidInputType(0, "int", ops.NewBaseOperator(9, 1, 1, [][]tensor.Dtype{{tensor.Float32, tensor.Float64}}, "acosh")), }, } for _, test := range tests { - acosh := &Acosh9{} + acosh := newAcosh() validated, err := acosh.ValidateInputs(test.inputs) assert.Equal(t, test.err, err) diff --git a/ops/acosh/versions.go b/ops/acosh/versions.go index 3fff494..0f32b9f 100644 --- a/ops/acosh/versions.go +++ b/ops/acosh/versions.go @@ -5,5 +5,5 @@ import ( ) var AcoshVersions = ops.OperatorVersions{ - 9: newAcosh9, + 9: newAcosh, }