diff --git a/ops/opset13/opset13.go b/ops/opset13/opset13.go index 7384a05..ead0e71 100644 --- a/ops/opset13/opset13.go +++ b/ops/opset13/opset13.go @@ -27,6 +27,7 @@ var operators13 = map[string]func() ops.Operator{ "Scaler": newScaler, "Shape": newShape, "Sigmoid": newSigmoid, + "Sin": newSin, "Slice": newSlice, "Squeeze": newSqueeze, "Sub": newSub, diff --git a/ops/opset13/opset13_test.go b/ops/opset13/opset13_test.go index 5a2fc35..3331de6 100644 --- a/ops/opset13/opset13_test.go +++ b/ops/opset13/opset13_test.go @@ -118,6 +118,11 @@ func TestGetOperator(t *testing.T) { newSigmoid(), nil, }, + { + "Sin", + newSin(), + nil, + }, { "Slice", newSlice(), diff --git a/ops/opset13/sin.go b/ops/opset13/sin.go new file mode 100644 index 0000000..50b371d --- /dev/null +++ b/ops/opset13/sin.go @@ -0,0 +1,75 @@ +package opset13 + +import ( + "math" + + "github.com/advancedclimatesystems/gonnx/onnx" + "github.com/advancedclimatesystems/gonnx/ops" + "gorgonia.org/tensor" +) + +// Sin represents the ONNX sin operator. +type Sin struct{} + +// newSin creates a new sin operator. +func newSin() ops.Operator { + return &Sin{} +} + +// Init initializes the sin operator. +func (s *Sin) Init(_ []*onnx.AttributeProto) error { + return nil +} + +// Apply applies the sin operator. +func (s *Sin) Apply(inputs []tensor.Tensor) ([]tensor.Tensor, error) { + var ( + out tensor.Tensor + err error + ) + + switch inputs[0].Dtype() { + case tensor.Float32: + out, err = inputs[0].Apply(sin[float32]) + case tensor.Float64: + out, err = inputs[0].Apply(sin[float64]) + default: + return nil, ops.ErrInvalidInputType(0, inputs[0].Dtype().String(), s) + } + + if err != nil { + return nil, err + } + + return []tensor.Tensor{out}, nil +} + +// ValidateInputs validates the inputs that will be given to Apply for this operator. +func (s *Sin) ValidateInputs(inputs []tensor.Tensor) ([]tensor.Tensor, error) { + return ops.ValidateInputs(s, inputs) +} + +// GetMinInputs returns the minimum number of input tensors this operator expects. +func (s *Sin) GetMinInputs() int { + return 1 +} + +// GetMaxInputs returns the maximum number of input tensors this operator expects. +func (s *Sin) GetMaxInputs() int { + return 1 +} + +// GetInputTypeConstraints returns a list. Every element represents a set of allowed tensor dtypes +// for the corresponding input tensor. +func (s *Sin) GetInputTypeConstraints() [][]tensor.Dtype { + return [][]tensor.Dtype{{tensor.Float32, tensor.Float64}} +} + +// String implements the stringer interface, and can be used to format errors or messages. +func (s *Sin) String() string { + return "sin operator" +} + +func sin[T ops.FloatType](x T) T { + return T(math.Sin(float64(x))) +} diff --git a/ops/opset13/sin_test.go b/ops/opset13/sin_test.go new file mode 100644 index 0000000..1ec4483 --- /dev/null +++ b/ops/opset13/sin_test.go @@ -0,0 +1,99 @@ +package opset13 + +import ( + "testing" + + "github.com/advancedclimatesystems/gonnx/ops" + "github.com/stretchr/testify/assert" + "gorgonia.org/tensor" +) + +func TestSinInit(t *testing.T) { + a := &Sin{} + + // since 'sin' does not have any attributes we pass in nil. This should not + // fail initializing the sin. + err := a.Init(nil) + assert.Nil(t, err) +} + +func TestSin(t *testing.T) { + tests := []struct { + sin *Sin + backing []float32 + shape []int + expected []float32 + }{ + { + &Sin{}, + []float32{-2, -1, 0, 1}, + []int{2, 2}, + []float32{-0.9092974, -0.84147096, 0, 0.84147096}, + }, + { + &Sin{}, + []float32{1, 3, 4, 5}, + []int{1, 4}, + []float32{0.84147096, 0.14112, -0.7568025, -0.9589243}, + }, + { + &Sin{}, + []float32{-1, -1, -1, -1}, + []int{1, 4}, + []float32{-0.84147096, -0.84147096, -0.84147096, -0.84147096}, + }, + } + + for _, test := range tests { + inputs := []tensor.Tensor{ + ops.TensorWithBackingFixture(test.backing, test.shape...), + } + + res, err := test.sin.Apply(inputs) + assert.Nil(t, err) + + assert.Nil(t, err) + assert.Equal(t, test.expected, res[0].Data()) + } +} + +func TestInputValidationSin(t *testing.T) { + tests := []struct { + inputs []tensor.Tensor + err error + }{ + { + []tensor.Tensor{ + ops.TensorWithBackingFixture([]float32{1, 2}, 2), + }, + nil, + }, + { + []tensor.Tensor{ + ops.TensorWithBackingFixture([]float64{1, 2}, 2), + }, + nil, + }, + { + []tensor.Tensor{}, + ops.ErrInvalidInputCount(0, &Sin{}), + }, + { + []tensor.Tensor{ + ops.TensorWithBackingFixture([]int{1, 2}, 2), + }, + ops.ErrInvalidInputType(0, "int", &Sin{}), + }, + } + + for _, test := range tests { + sin := &Sin{} + validated, err := sin.ValidateInputs(test.inputs) + + assert.Equal(t, test.err, err) + + if test.err == nil { + assert.Equal(t, test.inputs, validated) + } + } +} diff --git a/ops_test.go b/ops_test.go index df7ea8e..2cb8218 100644 --- a/ops_test.go +++ b/ops_test.go @@ -335,6 +335,8 @@ var expectedTests = []string{ "test_reshape_zero_and_negative_dim", "test_reshape_zero_dim", "test_shape", + "test_sin", + "test_sin_example", "test_sigmoid_example", "test_sigmoid", "test_slice_negative_axes",