From 555c3acb26bbee7b626ee72946036ee91222543f Mon Sep 17 00:00:00 2001 From: Swopper050 Date: Wed, 4 Dec 2024 08:45:30 +0100 Subject: [PATCH] Refactored Tan operator --- ops/tan/tan.go | 61 +++++++++++++++++++++ ops/tan/tan_7.go | 75 -------------------------- ops/tan/{tan_7_test.go => tan_test.go} | 35 +++++++----- ops/tan/versions.go | 2 +- 4 files changed, 84 insertions(+), 89 deletions(-) create mode 100644 ops/tan/tan.go delete mode 100644 ops/tan/tan_7.go rename ops/tan/{tan_7_test.go => tan_test.go} (74%) diff --git a/ops/tan/tan.go b/ops/tan/tan.go new file mode 100644 index 0000000..996f15e --- /dev/null +++ b/ops/tan/tan.go @@ -0,0 +1,61 @@ +package tan + +import ( + "math" + + "github.com/advancedclimatesystems/gonnx/onnx" + "github.com/advancedclimatesystems/gonnx/ops" + "gorgonia.org/tensor" +) + +var tanTypeConstraints = [][]tensor.Dtype{{tensor.Float32, tensor.Float64}} + +// Tan represents the ONNX tan operator. +type Tan struct { + ops.BaseOperator +} + +// newTan creates a new tan operator. +func newTan(version int, typeConstraints [][]tensor.Dtype) *Tan { + return &Tan{ + BaseOperator: ops.NewBaseOperator( + version, + 1, + 1, + typeConstraints, + "tan", + ), + } +} + +// Init initializes the tan operator. +func (t *Tan) Init(*onnx.NodeProto) error { + return nil +} + +// Apply applies the tan operator. +func (t *Tan) Apply(inputs []tensor.Tensor) ([]tensor.Tensor, error) { + var ( + out tensor.Tensor + err error + ) + + switch inputs[0].Dtype() { + case tensor.Float32: + out, err = inputs[0].Apply(tan[float32]) + case tensor.Float64: + out, err = inputs[0].Apply(tan[float64]) + default: + return nil, ops.ErrInvalidInputType(0, inputs[0].Dtype().String(), t.BaseOperator) + } + + if err != nil { + return nil, err + } + + return []tensor.Tensor{out}, nil +} + +func tan[T ops.FloatType](x T) T { + return T(math.Tan(float64(x))) +} diff --git a/ops/tan/tan_7.go b/ops/tan/tan_7.go deleted file mode 100644 index 42afee5..0000000 --- a/ops/tan/tan_7.go +++ /dev/null @@ -1,75 +0,0 @@ -package tan - -import ( - "math" - - "github.com/advancedclimatesystems/gonnx/onnx" - "github.com/advancedclimatesystems/gonnx/ops" - "gorgonia.org/tensor" -) - -// Tan7 represents the ONNX tan operator. -type Tan7 struct{} - -// newTan7 creates a new tan operator. -func newTan7() ops.Operator { - return &Tan7{} -} - -// Init initializes the tan operator. -func (t *Tan7) Init(*onnx.NodeProto) error { - return nil -} - -// Apply applies the tan operator. -func (t *Tan7) Apply(inputs []tensor.Tensor) ([]tensor.Tensor, error) { - var ( - out tensor.Tensor - err error - ) - - switch inputs[0].Dtype() { - case tensor.Float32: - out, err = inputs[0].Apply(tan[float32]) - case tensor.Float64: - out, err = inputs[0].Apply(tan[float64]) - default: - return nil, ops.ErrInvalidInputType(0, inputs[0].Dtype().String(), t) - } - - if err != nil { - return nil, err - } - - return []tensor.Tensor{out}, nil -} - -// ValidateInputs validates the inputs that will be given to Apply for this operator. -func (t *Tan7) ValidateInputs(inputs []tensor.Tensor) ([]tensor.Tensor, error) { - return ops.ValidateInputs(t, inputs) -} - -// GetMinInputs returns the minimum number of input tensors this operator expects. -func (t *Tan7) GetMinInputs() int { - return 1 -} - -// GetMaxInputs returns the maximum number of input tensors this operator expects. -func (t *Tan7) GetMaxInputs() int { - return 1 -} - -// GetInputTypeConstraints returns a list. Every element represents a set of allowed tensor dtypes -// for the corresponding input tensor. -func (t *Tan7) GetInputTypeConstraints() [][]tensor.Dtype { - return [][]tensor.Dtype{{tensor.Float32, tensor.Float64}} -} - -// String implements the stringer interface, and can be used to format errors or messages. -func (t *Tan7) String() string { - return "tan7 operator" -} - -func tan[T ops.FloatType](x T) T { - return T(math.Tan(float64(x))) -} diff --git a/ops/tan/tan_7_test.go b/ops/tan/tan_test.go similarity index 74% rename from ops/tan/tan_7_test.go rename to ops/tan/tan_test.go index dcdd5d6..8b888b4 100644 --- a/ops/tan/tan_7_test.go +++ b/ops/tan/tan_test.go @@ -8,8 +8,8 @@ import ( "gorgonia.org/tensor" ) -func TestTan7Init(t *testing.T) { - a := &Tan7{} +func TestTanInit(t *testing.T) { + a := &Tan{} // since 'tan' does not have any attributes we pass in nil. This should not // fail initializing the tan. @@ -17,27 +17,27 @@ func TestTan7Init(t *testing.T) { assert.Nil(t, err) } -func TestTan7(t *testing.T) { +func TestTan(t *testing.T) { tests := []struct { - tan *Tan7 + tan *Tan backing []float32 shape []int expected []float32 }{ { - &Tan7{}, + &Tan{}, []float32{1, 2, 3, 4}, []int{2, 2}, []float32{1.5574077, -2.1850398, -0.14254655, 1.1578213}, }, { - &Tan7{}, + &Tan{}, []float32{1, 2, 3, 4}, []int{1, 4}, []float32{1.5574077, -2.1850398, -0.14254655, 1.1578213}, }, { - &Tan7{}, + &Tan{}, []float32{2, 2, 2, 2}, []int{1, 4}, []float32{-2.1850398, -2.1850398, -2.1850398, -2.1850398}, @@ -57,37 +57,42 @@ func TestTan7(t *testing.T) { } } -func TestInputValidationTan7(t *testing.T) { +func TestInputValidationTan(t *testing.T) { tests := []struct { - inputs []tensor.Tensor - err error + version int64 + inputs []tensor.Tensor + err error }{ { + 7, []tensor.Tensor{ ops.TensorWithBackingFixture([]float32{1, 2}, 2), }, nil, }, { + 7, []tensor.Tensor{ ops.TensorWithBackingFixture([]float64{1, 2}, 2), }, nil, }, { + 7, []tensor.Tensor{}, - ops.ErrInvalidInputCount(0, &Tan7{}), + ops.ErrInvalidInputCount(0, tan7BaseOpFixture()), }, { + 7, []tensor.Tensor{ ops.TensorWithBackingFixture([]int{1, 2}, 2), }, - ops.ErrInvalidInputType(0, "int", &Tan7{}), + ops.ErrInvalidInputType(0, "int", tan7BaseOpFixture()), }, } for _, test := range tests { - tan := &Tan7{} + tan := TanVersions[test.version]() validated, err := tan.ValidateInputs(test.inputs) assert.Equal(t, test.err, err) @@ -97,3 +102,7 @@ func TestInputValidationTan7(t *testing.T) { } } } + +func tan7BaseOpFixture() ops.BaseOperator { + return ops.NewBaseOperator(7, 1, 1, tanTypeConstraints, "tan") +} diff --git a/ops/tan/versions.go b/ops/tan/versions.go index 014a0b2..7e5953b 100644 --- a/ops/tan/versions.go +++ b/ops/tan/versions.go @@ -3,5 +3,5 @@ package tan import "github.com/advancedclimatesystems/gonnx/ops" var TanVersions = ops.OperatorVersions{ - 7: newTan7, + 7: ops.NewOperatorConstructor(newTan(7, tanTypeConstraints)), }