Skip to content

Commit

Permalink
WIP on migrating operators
Browse files Browse the repository at this point in the history
  • Loading branch information
Swopper050 committed Dec 1, 2024
1 parent 8ccb457 commit de51332
Show file tree
Hide file tree
Showing 11 changed files with 449 additions and 90 deletions.
28 changes: 14 additions & 14 deletions ops/opset13/atanh.go → ops/atanh/atanh_9.go
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
package opset13
package atanh

import (
"math"
Expand All @@ -8,21 +8,21 @@ import (
"gorgonia.org/tensor"
)

// Atanh represents the ONNX atanh operator.
type Atanh struct{}
// Atanh9 represents the ONNX atanh operator.
type Atanh9 struct{}

// newAtanh creates a new atanh operator.
func newAtanh() ops.Operator {
return &Atanh{}
// newAtanh9 creates a new atanh operator.
func NewAtanh9() ops.Operator {
return &Atanh9{}
}

// Init initializes the atanh operator.
func (a *Atanh) Init(*onnx.NodeProto) error {
func (a *Atanh9) Init(*onnx.NodeProto) error {
return nil
}

// Apply applies the atanh operator.
func (a *Atanh) Apply(inputs []tensor.Tensor) ([]tensor.Tensor, error) {
func (a *Atanh9) Apply(inputs []tensor.Tensor) ([]tensor.Tensor, error) {
var (
out tensor.Tensor
err error
Expand All @@ -45,29 +45,29 @@ func (a *Atanh) Apply(inputs []tensor.Tensor) ([]tensor.Tensor, error) {
}

// ValidateInputs validates the inputs that will be given to Apply for this operator.
func (a *Atanh) ValidateInputs(inputs []tensor.Tensor) ([]tensor.Tensor, error) {
func (a *Atanh9) ValidateInputs(inputs []tensor.Tensor) ([]tensor.Tensor, error) {
return ops.ValidateInputs(a, inputs)
}

// GetMinInputs returns the minimum number of input tensors this operator expects.
func (a *Atanh) GetMinInputs() int {
func (a *Atanh9) GetMinInputs() int {
return 1
}

// GetMaxInputs returns the maximum number of input tensors this operator expects.
func (a *Atanh) GetMaxInputs() int {
func (a *Atanh9) GetMaxInputs() int {
return 1
}

// GetInputTypeConstraints returns a list. Every element represents a set of allowed tensor dtypes
// for the corresponding input tensor.
func (a *Atanh) GetInputTypeConstraints() [][]tensor.Dtype {
func (a *Atanh9) GetInputTypeConstraints() [][]tensor.Dtype {
return [][]tensor.Dtype{{tensor.Float32, tensor.Float64}}
}

// String implements the stringer interface, and can be used to format errors or messages.
func (a *Atanh) String() string {
return "atanh operator"
func (a *Atanh9) String() string {
return "atanh9 operator"
}

func atanh[T ops.FloatType](x T) T {
Expand Down
24 changes: 12 additions & 12 deletions ops/opset13/atanh_test.go → ops/atanh/atanh_9_test.go
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
package opset13
package atanh

import (
"testing"
Expand All @@ -8,36 +8,36 @@ import (
"gorgonia.org/tensor"
)

func TestAtanhInit(t *testing.T) {
a := &Atanh{}
func TestAtanh9Init(t *testing.T) {
a := &Atanh9{}

// since 'atanh' does not have any attributes we pass in nil. This should not
// fail initializing the atanh.
err := a.Init(nil)
assert.Nil(t, err)
}

func TestAtanh(t *testing.T) {
func TestAtanh9(t *testing.T) {
tests := []struct {
atanh *Atanh
atanh *Atanh9
backing []float32
shape []int
expected []float32
}{
{
&Atanh{},
&Atanh9{},
[]float32{-0.9, -0.5, 0, 0.5},
[]int{2, 2},
[]float32{-1.4722193, -0.54930615, 0, 0.54930615},
},
{
&Atanh{},
&Atanh9{},
[]float32{-0.9, -0.5, 0, 0.5},
[]int{1, 4},
[]float32{-1.4722193, -0.54930615, 0, 0.54930615},
},
{
&Atanh{},
&Atanh9{},
[]float32{0.5, 0.5, 0.5, 0.5},
[]int{1, 4},
[]float32{0.54930615, 0.54930615, 0.54930615, 0.54930615},
Expand All @@ -57,7 +57,7 @@ func TestAtanh(t *testing.T) {
}
}

func TestInputValidationAtanh(t *testing.T) {
func TestInputValidationAtanh9(t *testing.T) {
tests := []struct {
inputs []tensor.Tensor
err error
Expand All @@ -76,18 +76,18 @@ func TestInputValidationAtanh(t *testing.T) {
},
{
[]tensor.Tensor{},
ops.ErrInvalidInputCount(0, &Atanh{}),
ops.ErrInvalidInputCount(0, &Atanh9{}),
},
{
[]tensor.Tensor{
ops.TensorWithBackingFixture([]int{1, 2}, 2),
},
ops.ErrInvalidInputType(0, "int", &Atanh{}),
ops.ErrInvalidInputType(0, "int", &Atanh9{}),
},
}

for _, test := range tests {
atanh := &Atanh{}
atanh := &Atanh9{}
validated, err := atanh.ValidateInputs(test.inputs)

assert.Equal(t, test.err, err)
Expand Down
82 changes: 82 additions & 0 deletions ops/cast/cast_13.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,82 @@
package cast

import (
"github.com/advancedclimatesystems/gonnx/onnx"
"github.com/advancedclimatesystems/gonnx/ops"
"gorgonia.org/tensor"
)

const (
MinCast13Inputs = 1
MaxCast13Inputs = 1
)

// Cast13 represents the ONNX cast operator.
type Cast13 struct {
to int32 // DataType to cast to, as defined by TensorProto
}

// newCast13 creates a new cast operator.
func NewCast13() ops.Operator {
return &Cast13{}
}

// Init initializes the cast operator.
func (c *Cast13) Init(n *onnx.NodeProto) error {
attributes := n.GetAttribute()

if len(attributes) != 1 {
return ops.ErrInvalidAttributeCount(1, len(attributes), c)
}

attr := attributes[0]
if attr.GetName() == "to" {
c.to = int32(attr.GetI())
} else {
return ops.ErrInvalidAttribute(attr.GetName(), c)
}

return nil
}

// Apply applies the cast operator.
func (c *Cast13) Apply(inputs []tensor.Tensor) ([]tensor.Tensor, error) {
out, err := ops.ConvertTensorDtype(inputs[0], c.to)
if err != nil {
return nil, err
}

return []tensor.Tensor{out}, nil
}

// ValidateInputs validates the inputs that will be given to Apply for this operator.
func (c *Cast13) ValidateInputs(inputs []tensor.Tensor) ([]tensor.Tensor, error) {
return ops.ValidateInputs(c, inputs)
}

// GetMinInputs returns the minimum number of input tensors this operator expects.
func (c *Cast13) GetMinInputs() int {
return MinCast13Inputs
}

// GetMaxInputs returns the maximum number of input tensors this operator expects.
func (c *Cast13) GetMaxInputs() int {
return MaxCast13Inputs
}

// GetInputTypeConstraints returns a list. Every element represents a set of allowed tensor dtypes
// for the corresponding input tensor.
func (c *Cast13) GetInputTypeConstraints() [][]tensor.Dtype {
// tensor.String is specified by ONNX but not supported here yet.
return [][]tensor.Dtype{
{
tensor.Int16, tensor.Uint16, tensor.Int32, tensor.Uint32, tensor.Int64, tensor.Uint64,
tensor.Float32, tensor.Float64,
},
}
}

// String implements the stringer interface, and can be used to format errors or messages.
func (c *Cast13) String() string {
return "cast13 operator"
}
28 changes: 14 additions & 14 deletions ops/opset13/cast_test.go → ops/cast/cast_13_test.go
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
package opset13
package cast

import (
"testing"
Expand All @@ -9,52 +9,52 @@ import (
"gorgonia.org/tensor"
)

func TestCastInit(t *testing.T) {
c := &Cast{}
func TestCast13Init(t *testing.T) {
c := &Cast13{}

err := c.Init(&onnx.NodeProto{Attribute: []*onnx.AttributeProto{{Name: "to", I: 1}}})
assert.Nil(t, err)
assert.Equal(t, int32(1), c.to)
}

func TestCast(t *testing.T) {
func TestCast13(t *testing.T) {
tests := []struct {
cast *Cast
cast *Cast13
backing interface{}
shape []int
to int64
expected interface{}
}{
{
&Cast{},
&Cast13{},
[]float32{1.0, 1.0},
[]int{2},
11,
[]float64{1.0, 1.0},
},
{
&Cast{},
&Cast13{},
[]float32{1.3, 1.8},
[]int{2},
4,
[]uint16{1, 1},
},
{
&Cast{},
&Cast13{},
[]int8{1, 1},
[]int{2},
1,
[]float32{1.0, 1.0},
},
{
&Cast{},
&Cast13{},
[]int64{1, 1},
[]int{2},
11,
[]float64{1.0, 1.0},
},
{
&Cast{},
&Cast13{},
[]float64{1.4, 1.5},
[]int{2},
3,
Expand All @@ -74,7 +74,7 @@ func TestCast(t *testing.T) {
}
}

func TestInputValidationCast(t *testing.T) {
func TestInputValidationCast13(t *testing.T) {
tests := []struct {
inputs []tensor.Tensor
err error
Expand All @@ -92,18 +92,18 @@ func TestInputValidationCast(t *testing.T) {
ops.TensorWithBackingFixture([]float64{1, 2}, 2),
ops.TensorWithBackingFixture([]float64{3, 4}, 2),
},
ops.ErrInvalidInputCount(2, &Cast{}),
ops.ErrInvalidInputCount(2, &Cast13{}),
},
{
[]tensor.Tensor{
ops.TensorWithBackingFixture([]bool{true, false}, 2),
},
ops.ErrInvalidInputType(0, "bool", &Cast{}),
ops.ErrInvalidInputType(0, "bool", &Cast13{}),
},
}

for _, test := range tests {
cast := &Cast{}
cast := &Cast13{}
validated, err := cast.ValidateInputs(test.inputs)

assert.Equal(t, test.err, err)
Expand Down
Loading

0 comments on commit de51332

Please sign in to comment.