Skip to content

Commit

Permalink
Add Erf operator (#224)
Browse files Browse the repository at this point in the history
  • Loading branch information
Swopper050 authored Dec 22, 2024
1 parent 7d3fe7a commit 81fbbcc
Show file tree
Hide file tree
Showing 6 changed files with 156 additions and 0 deletions.
77 changes: 77 additions & 0 deletions ops/erf/erf.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,77 @@
package erf

import (
"math"

"github.com/advancedclimatesystems/gonnx/onnx"
"github.com/advancedclimatesystems/gonnx/ops"
"gorgonia.org/tensor"
)

var erfTypeConstraints = [][]tensor.Dtype{ops.NumericTypes}

// Erf represents the ONNX erf operator.
type Erf struct {
ops.BaseOperator
}

// newSin creates a new erf operator.
func newErf(version int, typeConstraints [][]tensor.Dtype) ops.Operator {
return &Erf{
BaseOperator: ops.NewBaseOperator(
version,
1,
1,
typeConstraints,
"erf",
),
}
}

// Init initializes the erf operator.
func (e *Erf) Init(*onnx.NodeProto) error {
return nil
}

// Apply applies the erf operator.
func (e *Erf) Apply(inputs []tensor.Tensor) ([]tensor.Tensor, error) {
var (
out tensor.Tensor
err error
)

switch inputs[0].Dtype() {
case tensor.Uint8:
out, err = inputs[0].Apply(erf[uint8])
case tensor.Uint16:
out, err = inputs[0].Apply(erf[uint16])
case tensor.Uint32:
out, err = inputs[0].Apply(erf[uint32])
case tensor.Uint64:
out, err = inputs[0].Apply(erf[uint64])
case tensor.Int8:
out, err = inputs[0].Apply(erf[int8])
case tensor.Int16:
out, err = inputs[0].Apply(erf[int16])
case tensor.Int32:
out, err = inputs[0].Apply(erf[int32])
case tensor.Int64:
out, err = inputs[0].Apply(erf[int64])
case tensor.Float32:
out, err = inputs[0].Apply(erf[float32])
case tensor.Float64:
out, err = inputs[0].Apply(erf[float64])
default:
return nil, ops.ErrInvalidInputType(0, inputs[0].Dtype().String(), e.BaseOperator)
}

if err != nil {
return nil, err
}

return []tensor.Tensor{out}, nil
}

func erf[T ops.NumericType](x T) T {
return T(math.Erf(float64(x)))
}
51 changes: 51 additions & 0 deletions ops/erf/erf_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
package erf

import (
"testing"

"github.com/advancedclimatesystems/gonnx/ops"
"github.com/stretchr/testify/assert"
"gorgonia.org/tensor"
)

func TestErfInit(t *testing.T) {
e := &Erf{}
err := e.Init(nil)
assert.Nil(t, err)
}

func TestErf(t *testing.T) {
tests := []struct {
version int64
backing []float32
shape []int
expected []float32
}{
{
9,
[]float32{-1, -1, 0, 1},
[]int{2, 2},
[]float32{-0.8427008, -0.8427008, 0, 0.8427008},
},
{
13,
[]float32{1, 0.5, 0.0, -0.5},
[]int{1, 4},
[]float32{0.8427008, 0.5204999, 0, -0.5204999},
},
}

for _, test := range tests {
inputs := []tensor.Tensor{
ops.TensorWithBackingFixture(test.backing, test.shape...),
}

erf := erfVersions[test.version]()

res, err := erf.Apply(inputs)
assert.Nil(t, err)

assert.Nil(t, err)
assert.Equal(t, test.expected, res[0].Data())
}
}
14 changes: 14 additions & 0 deletions ops/erf/versions.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
package erf

import (
"github.com/advancedclimatesystems/gonnx/ops"
)

var erfVersions = ops.OperatorVersions{
9: ops.NewOperatorConstructor(newErf, 9, erfTypeConstraints),
13: ops.NewOperatorConstructor(newErf, 13, erfTypeConstraints),
}

func GetVersions() ops.OperatorVersions {
return erfVersions
}
11 changes: 11 additions & 0 deletions ops/types.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,10 @@ type FloatType interface {
float32 | float64
}

type NumericType interface {
uint8 | uint16 | uint32 | uint64 | int8 | int16 | int32 | int64 | FloatType
}

// AllTypes is a type constraint which allows all types.
var AllTypes = []tensor.Dtype{
tensor.Uint8, tensor.Uint16, tensor.Uint32, tensor.Uint64,
Expand All @@ -16,3 +20,10 @@ var AllTypes = []tensor.Dtype{
tensor.String,
tensor.Bool,
}

// NumericTypes is a list with all numeric types.
var NumericTypes = []tensor.Dtype{
tensor.Uint8, tensor.Uint16, tensor.Uint32, tensor.Uint64,
tensor.Int8, tensor.Int16, tensor.Int32, tensor.Int64,
tensor.Float32, tensor.Float64,
}
1 change: 1 addition & 0 deletions ops_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -407,6 +407,7 @@ var expectedTests = []string{
"test_div_example",
"test_equal",
"test_equal_bcast",
"test_erf",
"test_expand_dim_changed",
"test_expand_dim_unchanged",
"test_flatten_axis0",
Expand Down
2 changes: 2 additions & 0 deletions opset.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ import (
"github.com/advancedclimatesystems/gonnx/ops/cosh"
"github.com/advancedclimatesystems/gonnx/ops/div"
"github.com/advancedclimatesystems/gonnx/ops/equal"
"github.com/advancedclimatesystems/gonnx/ops/erf"
"github.com/advancedclimatesystems/gonnx/ops/expand"
"github.com/advancedclimatesystems/gonnx/ops/flatten"
"github.com/advancedclimatesystems/gonnx/ops/gather"
Expand Down Expand Up @@ -88,6 +89,7 @@ var operators = map[string]ops.OperatorVersions{
"Cosh": cosh.GetCoshVersions(),
"Div": div.GetDivVersions(),
"Equal": equal.GetEqualVersions(),
"Erf": erf.GetVersions(),
"Expand": expand.GetExpandVersions(),
"Flatten": flatten.GetFlattenVersions(),
"Gather": gather.GetGatherVersions(),
Expand Down

0 comments on commit 81fbbcc

Please sign in to comment.