Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Added not operator #170

Merged
merged 5 commits into from
Nov 27, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
31 changes: 31 additions & 0 deletions onnx/graph_proto.go
Original file line number Diff line number Diff line change
Expand Up @@ -185,6 +185,8 @@ func TensorFromProto(tp *TensorProto) (tensor.Tensor, error) {
values, err = getInt64Data(tp)
case typeMap["DOUBLE"]:
values, err = getDoubleData(tp)
case typeMap["BOOL"]:
values = getBoolData(tp)
default:
// At this moment the datatype is either UNDEFINED or some datatype we currently
// do not support.
Expand Down Expand Up @@ -291,8 +293,17 @@ func getDoubleData(tp *TensorProto) ([]float64, error) {
return ReadFloat64ArrayFromBytes(tp.RawData)
}

func getBoolData(tp *TensorProto) []bool {
if len(tp.Int32Data) > 0 {
return Int32ArrayToBoolArray(tp.GetInt32Data())
}

return ReadBoolArrayFromBytes(tp.RawData)
}

const (
float32Size int = 4
boolSize int = 1
uint8Size int = 1
int8Size int = 1
uint16Size int = 2
Expand Down Expand Up @@ -362,6 +373,16 @@ func ReadFloat64ArrayFromBytes(data []byte) ([]float64, error) {
return values, nil
}

// ReadBoolArrayFromBytes reads data and parses it to an array of bool.
func ReadBoolArrayFromBytes(data []byte) []bool {
values := make([]bool, len(data))
for i, b := range data {
values[i] = b > 0
}

return values
}

// ReadUint8ArrayFromBytes reads data and parses it to an array of uint8.
func ReadUint8ArrayFromBytes(data []byte) ([]uint8, error) {
buffer := bytes.NewReader(data)
Expand Down Expand Up @@ -586,6 +607,16 @@ func ReadInt64ArrayFromBytes(data []byte) ([]int64, error) {
return values, nil
}

// Int32ArrayToBoolArray converts an int32 array to a bool array.
func Int32ArrayToBoolArray(arr []int32) []bool {
newArr := make([]bool, len(arr))
for i, value := range arr {
newArr[i] = value == 1.0
}

return newArr
}

// Int32ArrayToInt8Array converts an int32 array to an int8 array.
func Int32ArrayToInt8Array(arr []int32) []int8 {
newArr := make([]int8, len(arr))
Expand Down
60 changes: 60 additions & 0 deletions ops/opset13/not.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
package opset13

import (
"github.com/advancedclimatesystems/gonnx/onnx"
"github.com/advancedclimatesystems/gonnx/ops"
"gorgonia.org/tensor"
)

// Not represents the ONNX not operator.
type Not struct{}

// newNot creates a new not operator.
func newNot() ops.Operator {
return &Not{}
}

// Init initializes the not operator.
func (n *Not) Init(_ []*onnx.AttributeProto) error {
return nil
}

// Apply applies the not operator.
func (n *Not) Apply(inputs []tensor.Tensor) ([]tensor.Tensor, error) {
out, err := inputs[0].Apply(not)
if err != nil {
return nil, err
}

return []tensor.Tensor{out}, nil
}

// ValidateInputs validates the inputs that will be given to Apply for this operator.
func (n *Not) ValidateInputs(inputs []tensor.Tensor) ([]tensor.Tensor, error) {
return ops.ValidateInputs(n, inputs)
}

// GetMinInputs returns the minimum number of input tensors this operator expects.
func (n *Not) GetMinInputs() int {
return 1
}

// GetMaxInputs returns the maximum number of input tensors this operator expects.
func (n *Not) GetMaxInputs() int {
return 1
}

// GetInputTypeConstraints returns a list. Every element represents a set of allowed tensor dtypes
// for the corresponding input tensor.
func (n *Not) GetInputTypeConstraints() [][]tensor.Dtype {
return [][]tensor.Dtype{{tensor.Bool}}
}

// String implements the stringer interface, and can be used to format errors or messages.
func (n *Not) String() string {
return "not operator"
}

func not(x bool) bool {
return !x
}
93 changes: 93 additions & 0 deletions ops/opset13/not_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,93 @@
package opset13

import (
"testing"

"github.com/advancedclimatesystems/gonnx/ops"
"github.com/stretchr/testify/assert"
"gorgonia.org/tensor"
)

func TestNotInit(t *testing.T) {
n := &Not{}

// since 'not' does not have any attributes we pass in nil. This should not
// fail initializing the not.
err := n.Init(nil)
assert.Nil(t, err)
}

func TestNot(t *testing.T) {
tests := []struct {
not *Not
backing []bool
shape []int
expected []bool
}{
{
&Not{},
[]bool{true, false, true, false},
[]int{2, 2},
[]bool{false, true, false, true},
},
{
&Not{},
[]bool{true, true, false, false},
[]int{1, 4},
[]bool{false, false, true, true},
},
{
&Not{},
[]bool{false, false, false, false},
[]int{4, 1},
[]bool{true, true, true, true},
},
}

for _, test := range tests {
inputs := []tensor.Tensor{
ops.TensorWithBackingFixture(test.backing, test.shape...),
}

res, err := test.not.Apply(inputs)
assert.Nil(t, err)

assert.Nil(t, err)
assert.Equal(t, test.expected, res[0].Data())
}
}

func TestInputValidationNot(t *testing.T) {
tests := []struct {
inputs []tensor.Tensor
err error
}{
{
[]tensor.Tensor{
ops.TensorWithBackingFixture([]bool{false, false}, 2),
},
nil,
},
{
[]tensor.Tensor{},
ops.ErrInvalidInputCount(0, &Not{}),
},
{
[]tensor.Tensor{
ops.TensorWithBackingFixture([]int{1, 2}, 2),
},
ops.ErrInvalidInputType(0, "int", &Not{}),
},
}

for _, test := range tests {
not := &Not{}
validated, err := not.ValidateInputs(test.inputs)

assert.Equal(t, test.err, err)

if test.err == nil {
assert.Equal(t, test.inputs, validated)
}
}
}
1 change: 1 addition & 0 deletions ops/opset13/opset13.go
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ var operators13 = map[string]func() ops.Operator{
"GRU": newGRU,
"MatMul": newMatMul,
"Mul": newMul,
"Not": newNot,
"PRelu": newPRelu,
"Relu": newRelu,
"Reshape": newReshape,
Expand Down
5 changes: 5 additions & 0 deletions ops/opset13/opset13_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -113,6 +113,11 @@ func TestGetOperator(t *testing.T) {
newMul(),
nil,
},
{
"Not",
newNot(),
nil,
},
{
"Relu",
newRelu(),
Expand Down
1 change: 1 addition & 0 deletions ops/validate_inputs.go
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ func ValidateInputs(op Operator, inputs []tensor.Tensor) ([]tensor.Tensor, error
func checkNInputs(op Operator, inputs []tensor.Tensor) (int, error) {
nInputs := len(inputs)
padLength := 0

min := op.GetMinInputs()
max := op.GetMaxInputs()

Expand Down
11 changes: 10 additions & 1 deletion ops_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ import (
"github.com/advancedclimatesystems/gonnx/ops/opset13"
"github.com/stretchr/testify/assert"
"google.golang.org/protobuf/proto"
"gorgonia.org/tensor"
)

// Currently we ignore some of tests provided by ONNX. This has to do with the
Expand Down Expand Up @@ -114,14 +115,19 @@ func TestOps(t *testing.T) {
assert.Nil(t, err)

for _, test := range tests {
fmt.Println(test.inputs)
t.Run(test.name, func(t *testing.T) {
outputs, err := test.model.Run(test.inputs)
assert.Nil(t, err)

for outputName := range test.outputs {
expectedTensor := test.outputs[outputName]
actualTensor := outputs[outputName]
assert.InDeltaSlice(t, expectedTensor.Data(), actualTensor.Data(), 0.00001)
if expectedTensor.Dtype() == tensor.Bool {
assert.ElementsMatch(t, expectedTensor.Data(), actualTensor.Data())
} else {
assert.InDeltaSlice(t, expectedTensor.Data(), actualTensor.Data(), 0.00001)
}
}
})

Expand Down Expand Up @@ -330,6 +336,9 @@ var expectedTests = []string{
"test_mul",
"test_mul_bcast",
"test_mul_example",
"test_not_2d",
"test_not_3d",
"test_not_4d",
"test_prelu_broadcast",
"test_prelu_example",
"test_relu",
Expand Down