Skip to content

Commit

Permalink
Refactored tanh operator
Browse files Browse the repository at this point in the history
  • Loading branch information
Swopper050 committed Dec 3, 2024
1 parent 67a7796 commit d375d1f
Show file tree
Hide file tree
Showing 6 changed files with 81 additions and 122 deletions.
41 changes: 41 additions & 0 deletions ops/tanh/tanh.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
package tanh

import (
"github.com/advancedclimatesystems/gonnx/onnx"
"github.com/advancedclimatesystems/gonnx/ops"
"gorgonia.org/tensor"
)

var tanhTypeConstraint = [][]tensor.Dtype{
{tensor.Float32, tensor.Float64},
}

// Tanh represents the tanh operator.
type Tanh struct {
ops.BaseOperator
}

// newTanh returns a new tanh operator.
func newTanh(version int, typeConstraint [][]tensor.Dtype) *Tanh {
return &Tanh{
BaseOperator: ops.NewBaseOperator(
version,
1,
1,
typeConstraint,
"tanh",
),
}
}

// Init initializes the sigmoid operator.
func (t *Tanh) Init(*onnx.NodeProto) error {
return nil
}

// Apply the sigmoid operator to the input node.
func (t *Tanh) Apply(inputs []tensor.Tensor) ([]tensor.Tensor, error) {
out, err := ops.Tanh(inputs[0])

return []tensor.Tensor{out}, err
}
54 changes: 0 additions & 54 deletions ops/tanh/tanh_13.go

This file was deleted.

54 changes: 0 additions & 54 deletions ops/tanh/tanh_6.go

This file was deleted.

48 changes: 38 additions & 10 deletions ops/tanh/tanh_13_test.go → ops/tanh/tanh_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,16 +8,16 @@ import (
"gorgonia.org/tensor"
)

func TestTanh13Init(t *testing.T) {
tanh := newTanh13()
func TestTanhInit(t *testing.T) {
tanh := &Tanh{}
// Since the tanh does not have any attributes we expect it to initialize even
// when nil is passed.
err := tanh.Init(nil)

assert.Nil(t, err)
}

func TestTanh13(t *testing.T) {
func TestTanh(t *testing.T) {
tests := []struct {
backing []float32
shape []int
Expand All @@ -44,7 +44,7 @@ func TestTanh13(t *testing.T) {
}

for _, test := range tests {
tanh := &Tanh13{}
tanh := &Tanh{}
inputs := []tensor.Tensor{
ops.TensorWithBackingFixture(test.backing, test.shape...),
}
Expand All @@ -55,31 +55,51 @@ func TestTanh13(t *testing.T) {
}
}

func TestInputValidationTanh13(t *testing.T) {
func TestInputValidationTanh(t *testing.T) {
tests := []struct {
inputs []tensor.Tensor
err error
version int64
inputs []tensor.Tensor
err error
}{
{
6,
[]tensor.Tensor{ops.TensorWithBackingFixture([]float32{1, 2}, 2)},
nil,
},
{
13,
[]tensor.Tensor{ops.TensorWithBackingFixture([]float32{1, 2}, 2)},
nil,
},
{
13,
[]tensor.Tensor{ops.TensorWithBackingFixture([]float64{1, 2}, 2)},
nil,
},
{
6,
[]tensor.Tensor{},
ops.ErrInvalidInputCount(0, tanh6BaseOpFixture()),
},
{
13,
[]tensor.Tensor{},
ops.ErrInvalidInputCount(0, &Tanh13{}),
ops.ErrInvalidInputCount(0, tanh13BaseOpFixture()),
},
{
6,
[]tensor.Tensor{ops.TensorWithBackingFixture([]int{1, 2}, 2)},
ops.ErrInvalidInputType(0, "int", &Tanh13{}),
ops.ErrInvalidInputType(0, "int", tanh6BaseOpFixture()),
},
{
13,
[]tensor.Tensor{ops.TensorWithBackingFixture([]int{1, 2}, 2)},
ops.ErrInvalidInputType(0, "int", tanh13BaseOpFixture()),
},
}

for _, test := range tests {
tanh := &Tanh13{}
tanh := TanhVersions[test.version]()
validated, err := tanh.ValidateInputs(test.inputs)

assert.Equal(t, test.err, err)
Expand All @@ -89,3 +109,11 @@ func TestInputValidationTanh13(t *testing.T) {
}
}
}

func tanh6BaseOpFixture() ops.BaseOperator {
return ops.NewBaseOperator(6, 1, 1, tanhTypeConstraint, "tanh")
}

func tanh13BaseOpFixture() ops.BaseOperator {
return ops.NewBaseOperator(13, 1, 1, tanhTypeConstraint, "tanh")
}
4 changes: 2 additions & 2 deletions ops/tanh/versions.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,6 @@ package tanh
import "github.com/advancedclimatesystems/gonnx/ops"

var TanhVersions = ops.OperatorVersions{
6: newTanh6, // Only bfloat16 type differs
13: newTanh13,
6: ops.NewOperatorConstructor(newTanh(6, tanhTypeConstraint)),
13: ops.NewOperatorConstructor(newTanh(13, tanhTypeConstraint)),
}
2 changes: 0 additions & 2 deletions ops/transpose/transpose_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -127,10 +127,8 @@ func TransposeOnnxNodeProtoFixture() *onnx.NodeProto {

func transpose1BaseOpFixture() ops.BaseOperator {
return ops.NewBaseOperator(1, 1, 1, transposeTypeConstraint, "transpose")

}

func transpose13BaseOpFixture() ops.BaseOperator {
return ops.NewBaseOperator(13, 1, 1, transposeTypeConstraint, "transpose")

}

0 comments on commit d375d1f

Please sign in to comment.