Skip to content

Commit

Permalink
Added cos operator (#159)
Browse files Browse the repository at this point in the history
* Added cos operator

* Replace s struct identifier

* Missed characters

* Fixed comment

* Resolved MR comments

* Fix tests

* Fix lint

* remove unused errors

* Fix naming

* Group declarations

---------

Co-authored-by: Swopper050 <[email protected]>
  • Loading branch information
Swopper050 and Swopper050 authored Nov 13, 2023
1 parent 75cee12 commit 2d02c63
Show file tree
Hide file tree
Showing 7 changed files with 199 additions and 10 deletions.
75 changes: 75 additions & 0 deletions ops/opset13/cos.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,75 @@
package opset13

import (
"math"

"github.com/advancedclimatesystems/gonnx/onnx"
"github.com/advancedclimatesystems/gonnx/ops"
"gorgonia.org/tensor"
)

// Cos represents the ONNX cos operator.
type Cos struct{}

// newCos creates a new cos operator.
func newCos() ops.Operator {
return &Cos{}
}

// Init initializes the cos operator.
func (c *Cos) Init(_ []*onnx.AttributeProto) error {
return nil
}

// Apply applies the cos operator.
func (c *Cos) Apply(inputs []tensor.Tensor) ([]tensor.Tensor, error) {
var (
out tensor.Tensor
err error
)

switch inputs[0].Dtype() {
case tensor.Float32:
out, err = inputs[0].Apply(cos[float32])
case tensor.Float64:
out, err = inputs[0].Apply(cos[float64])
default:
return nil, ops.ErrInvalidInputType(0, inputs[0].Dtype().String(), c)
}

if err != nil {
return nil, err
}

return []tensor.Tensor{out}, nil
}

// ValidateInputs validates the inputs that will be given to Apply for this operator.
func (c *Cos) ValidateInputs(inputs []tensor.Tensor) ([]tensor.Tensor, error) {
return ops.ValidateInputs(c, inputs)
}

// GetMinInputs returns the minimum number of input tensors this operator expects.
func (c *Cos) GetMinInputs() int {
return 1
}

// GetMaxInputs returns the maximum number of input tensors this operator expects.
func (c *Cos) GetMaxInputs() int {
return 1
}

// GetInputTypeConstraints returns a list. Every element represents a set of allowed tensor dtypes
// for the corresponding input tensor.
func (c *Cos) GetInputTypeConstraints() [][]tensor.Dtype {
return [][]tensor.Dtype{{tensor.Float32, tensor.Float64}}
}

// String implements the stringer interface, and can be used to format errors or messages.
func (c *Cos) String() string {
return "cos operator"
}

func cos[T ops.FloatType](x T) T {
return T(math.Cos(float64(x)))
}
99 changes: 99 additions & 0 deletions ops/opset13/cos_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,99 @@
package opset13

import (
"testing"

"github.com/advancedclimatesystems/gonnx/ops"
"github.com/stretchr/testify/assert"
"gorgonia.org/tensor"
)

func TestCosInit(t *testing.T) {
c := &Cos{}

// since 'cos' does not have any attributes we pass in nil. This should not
// fail initializing the cos.
err := c.Init(nil)
assert.Nil(t, err)
}

func TestCos(t *testing.T) {
tests := []struct {
cos *Cos
backing []float32
shape []int
expected []float32
}{
{
&Cos{},
[]float32{-2, -1, 0, 1},
[]int{2, 2},
[]float32{-0.41614684, 0.5403023, 1, 0.5403023},
},
{
&Cos{},
[]float32{1, 3, 4, 5},
[]int{1, 4},
[]float32{0.5403023, -0.9899925, -0.6536436, 0.2836622},
},
{
&Cos{},
[]float32{-1, -1, -1, -1},
[]int{1, 4},
[]float32{0.5403023, 0.5403023, 0.5403023, 0.5403023},
},
}

for _, test := range tests {
inputs := []tensor.Tensor{
ops.TensorWithBackingFixture(test.backing, test.shape...),
}

res, err := test.cos.Apply(inputs)
assert.Nil(t, err)

assert.Nil(t, err)
assert.Equal(t, test.expected, res[0].Data())
}
}

func TestInputValidationCos(t *testing.T) {
tests := []struct {
inputs []tensor.Tensor
err error
}{
{
[]tensor.Tensor{
ops.TensorWithBackingFixture([]float32{1, 2}, 2),
},
nil,
},
{
[]tensor.Tensor{
ops.TensorWithBackingFixture([]float64{1, 2}, 2),
},
nil,
},
{
[]tensor.Tensor{},
ops.ErrInvalidInputCount(0, &Cos{}),
},
{
[]tensor.Tensor{
ops.TensorWithBackingFixture([]int{1, 2}, 2),
},
ops.ErrInvalidInputType(0, "int", &Cos{}),
},
}

for _, test := range tests {
cos := &Cos{}
validated, err := cos.ValidateInputs(test.inputs)

assert.Equal(t, test.err, err)

if test.err == nil {
assert.Equal(t, test.inputs, validated)
}
}
}
1 change: 1 addition & 0 deletions ops/opset13/opset13.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ var operators13 = map[string]func() ops.Operator{
"Concat": newConcat,
"Constant": newConstant,
"ConstantOfShape": newConstantOfShape,
"Cos": newCos,
"Div": newDiv,
"Gather": newGather,
"Gemm": newGemm,
Expand Down
5 changes: 5 additions & 0 deletions ops/opset13/opset13_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,11 @@ func TestGetOperator(t *testing.T) {
newConstantOfShape(),
nil,
},
{
"Cos",
newCos(),
nil,
},
{
"Div",
newDiv(),
Expand Down
17 changes: 17 additions & 0 deletions ops/types.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
package ops

import "gorgonia.org/tensor"

type FloatType interface {
float32 | float64
}

// AllTypes is a type constraint which allows all types.
var AllTypes = []tensor.Dtype{
tensor.Uint8, tensor.Uint16, tensor.Uint32, tensor.Uint64,
tensor.Int8, tensor.Int16, tensor.Int32, tensor.Int64,
tensor.Float32, tensor.Float64,
tensor.Complex64, tensor.Complex128,
tensor.String,
tensor.Bool,
}
10 changes: 0 additions & 10 deletions ops/validate_inputs.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,16 +4,6 @@ import (
"gorgonia.org/tensor"
)

// AllTypes is a type constraint which allows all types.
var AllTypes = []tensor.Dtype{
tensor.Uint8, tensor.Uint16, tensor.Uint32, tensor.Uint64,
tensor.Int8, tensor.Int16, tensor.Int32, tensor.Int64,
tensor.Float32, tensor.Float64,
tensor.Complex64, tensor.Complex128,
tensor.String,
tensor.Bool,
}

// ValidateInputs validates if a list of nodes has enough (not too few or too many) nodes.
// When there are fewer input nodes then the given max, the list is padded with nils.
// Expects either 1 requirement ==> the expected number of inputs, or 2 requirements,
Expand Down
2 changes: 2 additions & 0 deletions ops_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -291,6 +291,8 @@ var expectedTests = []string{
"test_constant",
"test_constantofshape_float_ones",
"test_constantofshape_int_zeros",
"test_cos",
"test_cos_example",
"test_div",
"test_div_bcast",
"test_div_example",
Expand Down

0 comments on commit 2d02c63

Please sign in to comment.