Skip to content

Commit

Permalink
Merge develop
Browse files Browse the repository at this point in the history
  • Loading branch information
Swopper050 committed Nov 24, 2023
2 parents 6252837 + 8431865 commit ceb1efa
Show file tree
Hide file tree
Showing 10 changed files with 1,478 additions and 3 deletions.
2 changes: 2 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -1,2 +1,4 @@
test_data/
.coverage.out

sample_models/.env
75 changes: 75 additions & 0 deletions ops/opset13/acosh.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,75 @@
package opset13

import (
"math"

"github.com/advancedclimatesystems/gonnx/onnx"
"github.com/advancedclimatesystems/gonnx/ops"
"gorgonia.org/tensor"
)

// Acosh represents the ONNX acosh operator.
type Acosh struct{}

// newAcosh creates a new acosh operator.
func newAcosh() ops.Operator {
return &Acosh{}
}

// Init initializes the acosh operator.
func (c *Acosh) Init(_ []*onnx.AttributeProto) error {
return nil
}

// Apply applies the acosh operator.
func (c *Acosh) Apply(inputs []tensor.Tensor) ([]tensor.Tensor, error) {
var (
out tensor.Tensor
err error
)

switch inputs[0].Dtype() {
case tensor.Float32:
out, err = inputs[0].Apply(acosh[float32])
case tensor.Float64:
out, err = inputs[0].Apply(acosh[float64])
default:
return nil, ops.ErrInvalidInputType(0, inputs[0].Dtype().String(), c)
}

if err != nil {
return nil, err
}

return []tensor.Tensor{out}, nil
}

// ValidateInputs validates the inputs that will be given to Apply for this operator.
func (c *Acosh) ValidateInputs(inputs []tensor.Tensor) ([]tensor.Tensor, error) {
return ops.ValidateInputs(c, inputs)
}

// GetMinInputs returns the minimum number of input tensors this operator expects.
func (c *Acosh) GetMinInputs() int {
return 1
}

// GetMaxInputs returns the maximum number of input tensors this operator expects.
func (c *Acosh) GetMaxInputs() int {
return 1
}

// GetInputTypeConstraints returns a list. Every element represents a set of allowed tensor dtypes
// for the corresponding input tensor.
func (c *Acosh) GetInputTypeConstraints() [][]tensor.Dtype {
return [][]tensor.Dtype{{tensor.Float32, tensor.Float64}}
}

// String implements the stringer interface, and can be used to format errors or messages.
func (c *Acosh) String() string {
return "acosh operator"
}

func acosh[T ops.FloatType](x T) T {
return T(math.Acosh(float64(x)))
}
99 changes: 99 additions & 0 deletions ops/opset13/acosh_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,99 @@
package opset13

import (
"testing"

"github.com/advancedclimatesystems/gonnx/ops"
"github.com/stretchr/testify/assert"
"gorgonia.org/tensor"
)

func TestAcoshInit(t *testing.T) {
c := &Acosh{}

// since 'acosh' does not have any attributes we pass in nil. This should not
// fail initializing the acosh.
err := c.Init(nil)
assert.Nil(t, err)
}

func TestAcosh(t *testing.T) {
tests := []struct {
acosh *Acosh
backing []float32
shape []int
expected []float32
}{
{
&Acosh{},
[]float32{1, 2, 3, 4},
[]int{2, 2},
[]float32{0, 1.316958, 1.7627472, 2.063437},
},
{
&Acosh{},
[]float32{1, 2, 3, 4},
[]int{1, 4},
[]float32{0, 1.316958, 1.7627472, 2.063437},
},
{
&Acosh{},
[]float32{2, 2, 2, 2},
[]int{1, 4},
[]float32{1.316958, 1.316958, 1.316958, 1.316958},
},
}

for _, test := range tests {
inputs := []tensor.Tensor{
ops.TensorWithBackingFixture(test.backing, test.shape...),
}

res, err := test.acosh.Apply(inputs)
assert.Nil(t, err)

assert.Nil(t, err)
assert.Equal(t, test.expected, res[0].Data())
}
}

func TestInputValidationAcosh(t *testing.T) {
tests := []struct {
inputs []tensor.Tensor
err error
}{
{
[]tensor.Tensor{
ops.TensorWithBackingFixture([]float32{1, 2}, 2),
},
nil,
},
{
[]tensor.Tensor{
ops.TensorWithBackingFixture([]float64{1, 2}, 2),
},
nil,
},
{
[]tensor.Tensor{},
ops.ErrInvalidInputCount(0, &Acosh{}),
},
{
[]tensor.Tensor{
ops.TensorWithBackingFixture([]int{1, 2}, 2),
},
ops.ErrInvalidInputType(0, "int", &Acosh{}),
},
}

for _, test := range tests {
acosh := &Acosh{}
validated, err := acosh.ValidateInputs(test.inputs)

assert.Equal(t, test.err, err)

if test.err == nil {
assert.Equal(t, test.inputs, validated)
}
}
}
Loading

0 comments on commit ceb1efa

Please sign in to comment.