Skip to content

Commit

Permalink
Added Sin operator (#157)
Browse files Browse the repository at this point in the history
* Added Sin operator

* Added ONNX sin test coverage

* Fix tests

* Remove unused error

* Remove unused error

* Fix lint

* Fix lint

* Use float type

---------

Co-authored-by: Swopper050 <[email protected]>
  • Loading branch information
Swopper050 and Swopper050 authored Nov 26, 2023
1 parent 1e529ec commit dd3763b
Show file tree
Hide file tree
Showing 5 changed files with 182 additions and 0 deletions.
1 change: 1 addition & 0 deletions ops/opset13/opset13.go
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ var operators13 = map[string]func() ops.Operator{
"Scaler": newScaler,
"Shape": newShape,
"Sigmoid": newSigmoid,
"Sin": newSin,
"Slice": newSlice,
"Squeeze": newSqueeze,
"Sub": newSub,
Expand Down
5 changes: 5 additions & 0 deletions ops/opset13/opset13_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -118,6 +118,11 @@ func TestGetOperator(t *testing.T) {
newSigmoid(),
nil,
},
{
"Sin",
newSin(),
nil,
},
{
"Slice",
newSlice(),
Expand Down
75 changes: 75 additions & 0 deletions ops/opset13/sin.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,75 @@
package opset13

import (
"math"

"github.com/advancedclimatesystems/gonnx/onnx"
"github.com/advancedclimatesystems/gonnx/ops"
"gorgonia.org/tensor"
)

// Sin represents the ONNX sin operator.
type Sin struct{}

// newSin creates a new sin operator.
func newSin() ops.Operator {
return &Sin{}
}

// Init initializes the sin operator.
func (s *Sin) Init(_ []*onnx.AttributeProto) error {
return nil
}

// Apply applies the sin operator.
func (s *Sin) Apply(inputs []tensor.Tensor) ([]tensor.Tensor, error) {
var (
out tensor.Tensor
err error
)

switch inputs[0].Dtype() {
case tensor.Float32:
out, err = inputs[0].Apply(sin[float32])
case tensor.Float64:
out, err = inputs[0].Apply(sin[float64])
default:
return nil, ops.ErrInvalidInputType(0, inputs[0].Dtype().String(), s)
}

if err != nil {
return nil, err
}

return []tensor.Tensor{out}, nil
}

// ValidateInputs validates the inputs that will be given to Apply for this operator.
func (s *Sin) ValidateInputs(inputs []tensor.Tensor) ([]tensor.Tensor, error) {
return ops.ValidateInputs(s, inputs)
}

// GetMinInputs returns the minimum number of input tensors this operator expects.
func (s *Sin) GetMinInputs() int {
return 1
}

// GetMaxInputs returns the maximum number of input tensors this operator expects.
func (s *Sin) GetMaxInputs() int {
return 1
}

// GetInputTypeConstraints returns a list. Every element represents a set of allowed tensor dtypes
// for the corresponding input tensor.
func (s *Sin) GetInputTypeConstraints() [][]tensor.Dtype {
return [][]tensor.Dtype{{tensor.Float32, tensor.Float64}}
}

// String implements the stringer interface, and can be used to format errors or messages.
func (s *Sin) String() string {
return "sin operator"
}

func sin[T ops.FloatType](x T) T {
return T(math.Sin(float64(x)))
}
99 changes: 99 additions & 0 deletions ops/opset13/sin_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,99 @@
package opset13

import (
"testing"

"github.com/advancedclimatesystems/gonnx/ops"
"github.com/stretchr/testify/assert"
"gorgonia.org/tensor"
)

func TestSinInit(t *testing.T) {
a := &Sin{}

// since 'sin' does not have any attributes we pass in nil. This should not
// fail initializing the sin.
err := a.Init(nil)
assert.Nil(t, err)
}

func TestSin(t *testing.T) {
tests := []struct {
sin *Sin
backing []float32
shape []int
expected []float32
}{
{
&Sin{},
[]float32{-2, -1, 0, 1},
[]int{2, 2},
[]float32{-0.9092974, -0.84147096, 0, 0.84147096},
},
{
&Sin{},
[]float32{1, 3, 4, 5},
[]int{1, 4},
[]float32{0.84147096, 0.14112, -0.7568025, -0.9589243},
},
{
&Sin{},
[]float32{-1, -1, -1, -1},
[]int{1, 4},
[]float32{-0.84147096, -0.84147096, -0.84147096, -0.84147096},
},
}

for _, test := range tests {
inputs := []tensor.Tensor{
ops.TensorWithBackingFixture(test.backing, test.shape...),
}

res, err := test.sin.Apply(inputs)
assert.Nil(t, err)

assert.Nil(t, err)
assert.Equal(t, test.expected, res[0].Data())
}
}

func TestInputValidationSin(t *testing.T) {
tests := []struct {
inputs []tensor.Tensor
err error
}{
{
[]tensor.Tensor{
ops.TensorWithBackingFixture([]float32{1, 2}, 2),
},
nil,
},
{
[]tensor.Tensor{
ops.TensorWithBackingFixture([]float64{1, 2}, 2),
},
nil,
},
{
[]tensor.Tensor{},
ops.ErrInvalidInputCount(0, &Sin{}),
},
{
[]tensor.Tensor{
ops.TensorWithBackingFixture([]int{1, 2}, 2),
},
ops.ErrInvalidInputType(0, "int", &Sin{}),
},
}

for _, test := range tests {
sin := &Sin{}
validated, err := sin.ValidateInputs(test.inputs)

assert.Equal(t, test.err, err)

if test.err == nil {
assert.Equal(t, test.inputs, validated)
}
}
}
2 changes: 2 additions & 0 deletions ops_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -335,6 +335,8 @@ var expectedTests = []string{
"test_reshape_zero_and_negative_dim",
"test_reshape_zero_dim",
"test_shape",
"test_sin",
"test_sin_example",
"test_sigmoid_example",
"test_sigmoid",
"test_slice_negative_axes",
Expand Down

0 comments on commit dd3763b

Please sign in to comment.