Skip to content

Commit

Permalink
Refactored Tan operator
Browse files Browse the repository at this point in the history
  • Loading branch information
Swopper050 committed Dec 4, 2024
1 parent d375d1f commit 555c3ac
Show file tree
Hide file tree
Showing 4 changed files with 84 additions and 89 deletions.
61 changes: 61 additions & 0 deletions ops/tan/tan.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,61 @@
package tan

import (
"math"

"github.com/advancedclimatesystems/gonnx/onnx"
"github.com/advancedclimatesystems/gonnx/ops"
"gorgonia.org/tensor"
)

var tanTypeConstraints = [][]tensor.Dtype{{tensor.Float32, tensor.Float64}}

// Tan represents the ONNX tan operator.
type Tan struct {
ops.BaseOperator
}

// newTan creates a new tan operator.
func newTan(version int, typeConstraints [][]tensor.Dtype) *Tan {
return &Tan{
BaseOperator: ops.NewBaseOperator(
version,
1,
1,
typeConstraints,
"tan",
),
}
}

// Init initializes the tan operator.
func (t *Tan) Init(*onnx.NodeProto) error {
return nil
}

// Apply applies the tan operator.
func (t *Tan) Apply(inputs []tensor.Tensor) ([]tensor.Tensor, error) {
var (
out tensor.Tensor
err error
)

switch inputs[0].Dtype() {
case tensor.Float32:
out, err = inputs[0].Apply(tan[float32])
case tensor.Float64:
out, err = inputs[0].Apply(tan[float64])
default:
return nil, ops.ErrInvalidInputType(0, inputs[0].Dtype().String(), t.BaseOperator)
}

if err != nil {
return nil, err
}

return []tensor.Tensor{out}, nil
}

func tan[T ops.FloatType](x T) T {
return T(math.Tan(float64(x)))
}
75 changes: 0 additions & 75 deletions ops/tan/tan_7.go

This file was deleted.

35 changes: 22 additions & 13 deletions ops/tan/tan_7_test.go → ops/tan/tan_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,36 +8,36 @@ import (
"gorgonia.org/tensor"
)

func TestTan7Init(t *testing.T) {
a := &Tan7{}
func TestTanInit(t *testing.T) {
a := &Tan{}

// since 'tan' does not have any attributes we pass in nil. This should not
// fail initializing the tan.
err := a.Init(nil)
assert.Nil(t, err)
}

func TestTan7(t *testing.T) {
func TestTan(t *testing.T) {
tests := []struct {
tan *Tan7
tan *Tan
backing []float32
shape []int
expected []float32
}{
{
&Tan7{},
&Tan{},
[]float32{1, 2, 3, 4},
[]int{2, 2},
[]float32{1.5574077, -2.1850398, -0.14254655, 1.1578213},
},
{
&Tan7{},
&Tan{},
[]float32{1, 2, 3, 4},
[]int{1, 4},
[]float32{1.5574077, -2.1850398, -0.14254655, 1.1578213},
},
{
&Tan7{},
&Tan{},
[]float32{2, 2, 2, 2},
[]int{1, 4},
[]float32{-2.1850398, -2.1850398, -2.1850398, -2.1850398},
Expand All @@ -57,37 +57,42 @@ func TestTan7(t *testing.T) {
}
}

func TestInputValidationTan7(t *testing.T) {
func TestInputValidationTan(t *testing.T) {
tests := []struct {
inputs []tensor.Tensor
err error
version int64
inputs []tensor.Tensor
err error
}{
{
7,
[]tensor.Tensor{
ops.TensorWithBackingFixture([]float32{1, 2}, 2),
},
nil,
},
{
7,
[]tensor.Tensor{
ops.TensorWithBackingFixture([]float64{1, 2}, 2),
},
nil,
},
{
7,
[]tensor.Tensor{},
ops.ErrInvalidInputCount(0, &Tan7{}),
ops.ErrInvalidInputCount(0, tan7BaseOpFixture()),
},
{
7,
[]tensor.Tensor{
ops.TensorWithBackingFixture([]int{1, 2}, 2),
},
ops.ErrInvalidInputType(0, "int", &Tan7{}),
ops.ErrInvalidInputType(0, "int", tan7BaseOpFixture()),
},
}

for _, test := range tests {
tan := &Tan7{}
tan := TanVersions[test.version]()
validated, err := tan.ValidateInputs(test.inputs)

assert.Equal(t, test.err, err)
Expand All @@ -97,3 +102,7 @@ func TestInputValidationTan7(t *testing.T) {
}
}
}

func tan7BaseOpFixture() ops.BaseOperator {
return ops.NewBaseOperator(7, 1, 1, tanTypeConstraints, "tan")
}
2 changes: 1 addition & 1 deletion ops/tan/versions.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,5 +3,5 @@ package tan
import "github.com/advancedclimatesystems/gonnx/ops"

var TanVersions = ops.OperatorVersions{
7: newTan7,
7: ops.NewOperatorConstructor(newTan(7, tanTypeConstraints)),
}

0 comments on commit 555c3ac

Please sign in to comment.