From 81fbbcc552961d026105cc6d4c6bb8f102c4ac17 Mon Sep 17 00:00:00 2001 From: Bram Date: Sun, 22 Dec 2024 09:47:08 +0100 Subject: [PATCH] Add Erf operator (#224) --- ops/erf/erf.go | 77 +++++++++++++++++++++++++++++++++++++++++++++ ops/erf/erf_test.go | 51 ++++++++++++++++++++++++++++++ ops/erf/versions.go | 14 +++++++++ ops/types.go | 11 +++++++ ops_test.go | 1 + opset.go | 2 ++ 6 files changed, 156 insertions(+) create mode 100644 ops/erf/erf.go create mode 100644 ops/erf/erf_test.go create mode 100644 ops/erf/versions.go diff --git a/ops/erf/erf.go b/ops/erf/erf.go new file mode 100644 index 0000000..ca2708e --- /dev/null +++ b/ops/erf/erf.go @@ -0,0 +1,77 @@ +package erf + +import ( + "math" + + "github.com/advancedclimatesystems/gonnx/onnx" + "github.com/advancedclimatesystems/gonnx/ops" + "gorgonia.org/tensor" +) + +var erfTypeConstraints = [][]tensor.Dtype{ops.NumericTypes} + +// Erf represents the ONNX erf operator. +type Erf struct { + ops.BaseOperator +} + +// newSin creates a new erf operator. +func newErf(version int, typeConstraints [][]tensor.Dtype) ops.Operator { + return &Erf{ + BaseOperator: ops.NewBaseOperator( + version, + 1, + 1, + typeConstraints, + "erf", + ), + } +} + +// Init initializes the erf operator. +func (e *Erf) Init(*onnx.NodeProto) error { + return nil +} + +// Apply applies the erf operator. +func (e *Erf) Apply(inputs []tensor.Tensor) ([]tensor.Tensor, error) { + var ( + out tensor.Tensor + err error + ) + + switch inputs[0].Dtype() { + case tensor.Uint8: + out, err = inputs[0].Apply(erf[uint8]) + case tensor.Uint16: + out, err = inputs[0].Apply(erf[uint16]) + case tensor.Uint32: + out, err = inputs[0].Apply(erf[uint32]) + case tensor.Uint64: + out, err = inputs[0].Apply(erf[uint64]) + case tensor.Int8: + out, err = inputs[0].Apply(erf[int8]) + case tensor.Int16: + out, err = inputs[0].Apply(erf[int16]) + case tensor.Int32: + out, err = inputs[0].Apply(erf[int32]) + case tensor.Int64: + out, err = inputs[0].Apply(erf[int64]) + case tensor.Float32: + out, err = inputs[0].Apply(erf[float32]) + case tensor.Float64: + out, err = inputs[0].Apply(erf[float64]) + default: + return nil, ops.ErrInvalidInputType(0, inputs[0].Dtype().String(), e.BaseOperator) + } + + if err != nil { + return nil, err + } + + return []tensor.Tensor{out}, nil +} + +func erf[T ops.NumericType](x T) T { + return T(math.Erf(float64(x))) +} diff --git a/ops/erf/erf_test.go b/ops/erf/erf_test.go new file mode 100644 index 0000000..07b8bb8 --- /dev/null +++ b/ops/erf/erf_test.go @@ -0,0 +1,51 @@ +package erf + +import ( + "testing" + + "github.com/advancedclimatesystems/gonnx/ops" + "github.com/stretchr/testify/assert" + "gorgonia.org/tensor" +) + +func TestErfInit(t *testing.T) { + e := &Erf{} + err := e.Init(nil) + assert.Nil(t, err) +} + +func TestErf(t *testing.T) { + tests := []struct { + version int64 + backing []float32 + shape []int + expected []float32 + }{ + { + 9, + []float32{-1, -1, 0, 1}, + []int{2, 2}, + []float32{-0.8427008, -0.8427008, 0, 0.8427008}, + }, + { + 13, + []float32{1, 0.5, 0.0, -0.5}, + []int{1, 4}, + []float32{0.8427008, 0.5204999, 0, -0.5204999}, + }, + } + + for _, test := range tests { + inputs := []tensor.Tensor{ + ops.TensorWithBackingFixture(test.backing, test.shape...), + } + + erf := erfVersions[test.version]() + + res, err := erf.Apply(inputs) + assert.Nil(t, err) + + assert.Nil(t, err) + assert.Equal(t, test.expected, res[0].Data()) + } +} diff --git a/ops/erf/versions.go b/ops/erf/versions.go new file mode 100644 index 0000000..bb7abd4 --- /dev/null +++ b/ops/erf/versions.go @@ -0,0 +1,14 @@ +package erf + +import ( + "github.com/advancedclimatesystems/gonnx/ops" +) + +var erfVersions = ops.OperatorVersions{ + 9: ops.NewOperatorConstructor(newErf, 9, erfTypeConstraints), + 13: ops.NewOperatorConstructor(newErf, 13, erfTypeConstraints), +} + +func GetVersions() ops.OperatorVersions { + return erfVersions +} diff --git a/ops/types.go b/ops/types.go index edea5fb..fdc0f81 100644 --- a/ops/types.go +++ b/ops/types.go @@ -7,6 +7,10 @@ type FloatType interface { float32 | float64 } +type NumericType interface { + uint8 | uint16 | uint32 | uint64 | int8 | int16 | int32 | int64 | FloatType +} + // AllTypes is a type constraint which allows all types. var AllTypes = []tensor.Dtype{ tensor.Uint8, tensor.Uint16, tensor.Uint32, tensor.Uint64, @@ -16,3 +20,10 @@ var AllTypes = []tensor.Dtype{ tensor.String, tensor.Bool, } + +// NumericTypes is a list with all numeric types. +var NumericTypes = []tensor.Dtype{ + tensor.Uint8, tensor.Uint16, tensor.Uint32, tensor.Uint64, + tensor.Int8, tensor.Int16, tensor.Int32, tensor.Int64, + tensor.Float32, tensor.Float64, +} diff --git a/ops_test.go b/ops_test.go index 8d1ece4..8965ff2 100644 --- a/ops_test.go +++ b/ops_test.go @@ -407,6 +407,7 @@ var expectedTests = []string{ "test_div_example", "test_equal", "test_equal_bcast", + "test_erf", "test_expand_dim_changed", "test_expand_dim_unchanged", "test_flatten_axis0", diff --git a/opset.go b/opset.go index 2d35935..79437ad 100644 --- a/opset.go +++ b/opset.go @@ -21,6 +21,7 @@ import ( "github.com/advancedclimatesystems/gonnx/ops/cosh" "github.com/advancedclimatesystems/gonnx/ops/div" "github.com/advancedclimatesystems/gonnx/ops/equal" + "github.com/advancedclimatesystems/gonnx/ops/erf" "github.com/advancedclimatesystems/gonnx/ops/expand" "github.com/advancedclimatesystems/gonnx/ops/flatten" "github.com/advancedclimatesystems/gonnx/ops/gather" @@ -88,6 +89,7 @@ var operators = map[string]ops.OperatorVersions{ "Cosh": cosh.GetCoshVersions(), "Div": div.GetDivVersions(), "Equal": equal.GetEqualVersions(), + "Erf": erf.GetVersions(), "Expand": expand.GetExpandVersions(), "Flatten": flatten.GetFlattenVersions(), "Gather": gather.GetGatherVersions(),