Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Added Sin operator #157

Merged
merged 11 commits into from
Nov 26, 2023
1 change: 1 addition & 0 deletions ops/opset13/opset13.go
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ var operators13 = map[string]func() ops.Operator{
"Scaler": newScaler,
"Shape": newShape,
"Sigmoid": newSigmoid,
"Sin": newSin,
"Slice": newSlice,
"Squeeze": newSqueeze,
"Sub": newSub,
Expand Down
5 changes: 5 additions & 0 deletions ops/opset13/opset13_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -118,6 +118,11 @@ func TestGetOperator(t *testing.T) {
newSigmoid(),
nil,
},
{
"Sin",
newSin(),
nil,
},
{
"Slice",
newSlice(),
Expand Down
75 changes: 75 additions & 0 deletions ops/opset13/sin.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,75 @@
package opset13

import (
"math"

"github.com/advancedclimatesystems/gonnx/onnx"
"github.com/advancedclimatesystems/gonnx/ops"
"gorgonia.org/tensor"
)

// Sin represents the ONNX sin operator.
type Sin struct{}

// newSin creates a new sin operator.
func newSin() ops.Operator {
return &Sin{}
}

// Init initializes the sin operator.
func (s *Sin) Init(_ []*onnx.AttributeProto) error {
return nil
}

// Apply applies the sin operator.
func (s *Sin) Apply(inputs []tensor.Tensor) ([]tensor.Tensor, error) {
var (
out tensor.Tensor
err error
)

switch inputs[0].Dtype() {
case tensor.Float32:
out, err = inputs[0].Apply(sin[float32])
case tensor.Float64:
out, err = inputs[0].Apply(sin[float64])
default:
return nil, ops.ErrInvalidInputType(0, inputs[0].Dtype().String(), s)
}

if err != nil {
return nil, err
}

return []tensor.Tensor{out}, nil
}

// ValidateInputs validates the inputs that will be given to Apply for this operator.
func (s *Sin) ValidateInputs(inputs []tensor.Tensor) ([]tensor.Tensor, error) {
return ops.ValidateInputs(s, inputs)
}

// GetMinInputs returns the minimum number of input tensors this operator expects.
func (s *Sin) GetMinInputs() int {
return 1
}

// GetMaxInputs returns the maximum number of input tensors this operator expects.
func (s *Sin) GetMaxInputs() int {
return 1
}

// GetInputTypeConstraints returns a list. Every element represents a set of allowed tensor dtypes
// for the corresponding input tensor.
func (s *Sin) GetInputTypeConstraints() [][]tensor.Dtype {
return [][]tensor.Dtype{{tensor.Float32, tensor.Float64}}
}

// String implements the stringer interface, and can be used to format errors or messages.
func (s *Sin) String() string {
return "sin operator"
}

func sin[T ops.FloatType](x T) T {
return T(math.Sin(float64(x)))
}
99 changes: 99 additions & 0 deletions ops/opset13/sin_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,99 @@
package opset13

import (
"testing"

"github.com/advancedclimatesystems/gonnx/ops"
"github.com/stretchr/testify/assert"
"gorgonia.org/tensor"
)

func TestSinInit(t *testing.T) {
a := &Sin{}

// since 'sin' does not have any attributes we pass in nil. This should not
// fail initializing the sin.
err := a.Init(nil)
assert.Nil(t, err)
}

func TestSin(t *testing.T) {
tests := []struct {
sin *Sin
backing []float32
shape []int
expected []float32
}{
{
&Sin{},
[]float32{-2, -1, 0, 1},
[]int{2, 2},
[]float32{-0.9092974, -0.84147096, 0, 0.84147096},
},
{
&Sin{},
[]float32{1, 3, 4, 5},
[]int{1, 4},
[]float32{0.84147096, 0.14112, -0.7568025, -0.9589243},
},
{
&Sin{},
[]float32{-1, -1, -1, -1},
[]int{1, 4},
[]float32{-0.84147096, -0.84147096, -0.84147096, -0.84147096},
},
}

for _, test := range tests {
inputs := []tensor.Tensor{
ops.TensorWithBackingFixture(test.backing, test.shape...),
}

res, err := test.sin.Apply(inputs)
assert.Nil(t, err)

assert.Nil(t, err)
assert.Equal(t, test.expected, res[0].Data())
}
}

func TestInputValidationSin(t *testing.T) {
tests := []struct {
inputs []tensor.Tensor
err error
}{
{
[]tensor.Tensor{
ops.TensorWithBackingFixture([]float32{1, 2}, 2),
},
nil,
},
{
[]tensor.Tensor{
ops.TensorWithBackingFixture([]float64{1, 2}, 2),
},
nil,
},
{
[]tensor.Tensor{},
ops.ErrInvalidInputCount(0, &Sin{}),
},
{
[]tensor.Tensor{
ops.TensorWithBackingFixture([]int{1, 2}, 2),
},
ops.ErrInvalidInputType(0, "int", &Sin{}),
},
}

for _, test := range tests {
sin := &Sin{}
validated, err := sin.ValidateInputs(test.inputs)

assert.Equal(t, test.err, err)

if test.err == nil {
assert.Equal(t, test.inputs, validated)
}
}
}
2 changes: 2 additions & 0 deletions ops_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -335,6 +335,8 @@ var expectedTests = []string{
"test_reshape_zero_and_negative_dim",
"test_reshape_zero_dim",
"test_shape",
"test_sin",
"test_sin_example",
"test_sigmoid_example",
"test_sigmoid",
"test_slice_negative_axes",
Expand Down