Skip to content

Commit

Permalink
Refactor acosh into base operator
Browse files Browse the repository at this point in the history
  • Loading branch information
wisse committed Dec 3, 2024
1 parent 27c0dfa commit cbd224d
Show file tree
Hide file tree
Showing 7 changed files with 70 additions and 92 deletions.
4 changes: 2 additions & 2 deletions ops/acos/acos.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,10 +14,10 @@ type Acos struct {
}

// newAcos creates a new acos operator.
func newAcos(verion int) ops.Operator {
func newAcos() ops.Operator {
return &Acos{
BaseOperator: ops.NewBaseOperator(
verion,
7,
1,
1,
[][]tensor.Dtype{{tensor.Float32, tensor.Float64}},
Expand Down
8 changes: 4 additions & 4 deletions ops/acos/acos_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -25,19 +25,19 @@ func TestAcos(t *testing.T) {
expected []float32
}{
{
newAcos(7),
newAcos(),
[]float32{-1, -1, 0, 1},
[]int{2, 2},
[]float32{3.1415927, 3.1415927, 1.5707964, 0},
},
{
newAcos(7),
newAcos(),
[]float32{1, 0.5, 0.0, -0.5},
[]int{1, 4},
[]float32{0, 1.0471976, 1.5707964, 2.0943952},
},
{
newAcos(7),
newAcos(),
[]float32{-1, -1, -1, -1},
[]int{1, 4},
[]float32{3.1415927, 3.1415927, 3.1415927, 3.1415927},
Expand Down Expand Up @@ -87,7 +87,7 @@ func TestInputValidationAcos(t *testing.T) {
}

for _, test := range tests {
acos := newAcos(7)
acos := newAcos()
validated, err := acos.ValidateInputs(test.inputs)

assert.Equal(t, test.err, err)
Expand Down
2 changes: 1 addition & 1 deletion ops/acos/versions.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,5 +5,5 @@ import (
)

var AcosVersions = ops.OperatorVersions{
7: ops.NewOperatorConstructor(newAcos(7)),
7: newAcos,
}
53 changes: 53 additions & 0 deletions ops/acosh/acosh.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
package acosh

import (
"math"

"github.com/advancedclimatesystems/gonnx/onnx"
"github.com/advancedclimatesystems/gonnx/ops"
"gorgonia.org/tensor"
)

// Acosh represents the ONNX acosh operator.
type Acosh struct {
ops.BaseOperator
}

// newAcosh creates a new acosh operator.
func newAcosh() ops.Operator {
return &Acosh{
BaseOperator: ops.NewBaseOperator(9, 1, 1, [][]tensor.Dtype{{tensor.Float32, tensor.Float64}}, "acosh"),
}
}

// Init initializes the acosh operator.
func (c *Acosh) Init(*onnx.NodeProto) error {
return nil
}

// Apply applies the acosh operator.
func (c *Acosh) Apply(inputs []tensor.Tensor) ([]tensor.Tensor, error) {
var (
out tensor.Tensor
err error
)

switch inputs[0].Dtype() {
case tensor.Float32:
out, err = inputs[0].Apply(acosh[float32])
case tensor.Float64:
out, err = inputs[0].Apply(acosh[float64])
default:
return nil, ops.ErrInvalidInputType(0, inputs[0].Dtype().String(), c.BaseOperator)
}

if err != nil {
return nil, err
}

return []tensor.Tensor{out}, nil
}

func acosh[T ops.FloatType](x T) T {
return T(math.Acosh(float64(x)))
}
75 changes: 0 additions & 75 deletions ops/acosh/acosh_9.go

This file was deleted.

18 changes: 9 additions & 9 deletions ops/acosh/acosh_9_test.go → ops/acosh/acosh_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ import (
)

func TestAcosh9Init(t *testing.T) {
c := &Acosh9{}
c := &Acosh{}

// since 'acosh' does not have any attributes we pass in nil. This should not
// fail initializing the acosh.
Expand All @@ -19,25 +19,25 @@ func TestAcosh9Init(t *testing.T) {

func TestAcosh9(t *testing.T) {
tests := []struct {
acosh *Acosh9
acosh ops.Operator
backing []float32
shape []int
expected []float32
}{
{
&Acosh9{},
newAcosh(),
[]float32{1, 2, 3, 4},
[]int{2, 2},
[]float32{0, 1.316958, 1.7627472, 2.063437},
},
{
&Acosh9{},
newAcosh(),
[]float32{1, 2, 3, 4},
[]int{1, 4},
[]float32{0, 1.316958, 1.7627472, 2.063437},
},
{
&Acosh9{},
newAcosh(),
[]float32{2, 2, 2, 2},
[]int{1, 4},
[]float32{1.316958, 1.316958, 1.316958, 1.316958},
Expand All @@ -57,7 +57,7 @@ func TestAcosh9(t *testing.T) {
}
}

func TestInputValidationAcosh9(t *testing.T) {
func TestInputValidationAcosh(t *testing.T) {
tests := []struct {
inputs []tensor.Tensor
err error
Expand All @@ -76,18 +76,18 @@ func TestInputValidationAcosh9(t *testing.T) {
},
{
[]tensor.Tensor{},
ops.ErrInvalidInputCount(0, &Acosh9{}),
ops.ErrInvalidInputCount(0, ops.NewBaseOperator(9, 1, 1, [][]tensor.Dtype{{tensor.Float32, tensor.Float64}}, "acosh")),
},
{
[]tensor.Tensor{
ops.TensorWithBackingFixture([]int{1, 2}, 2),
},
ops.ErrInvalidInputType(0, "int", &Acosh9{}),
ops.ErrInvalidInputType(0, "int", ops.NewBaseOperator(9, 1, 1, [][]tensor.Dtype{{tensor.Float32, tensor.Float64}}, "acosh")),
},
}

for _, test := range tests {
acosh := &Acosh9{}
acosh := newAcosh()
validated, err := acosh.ValidateInputs(test.inputs)

assert.Equal(t, test.err, err)
Expand Down
2 changes: 1 addition & 1 deletion ops/acosh/versions.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,5 +5,5 @@ import (
)

var AcoshVersions = ops.OperatorVersions{
9: newAcosh9,
9: newAcosh,
}

0 comments on commit cbd224d

Please sign in to comment.