Skip to content

Commit

Permalink
Finished all operator refactors
Browse files Browse the repository at this point in the history
  • Loading branch information
Swopper050 committed Dec 2, 2024
1 parent 095692a commit 5a6800d
Show file tree
Hide file tree
Showing 129 changed files with 4,354 additions and 2,016 deletions.
22 changes: 11 additions & 11 deletions model.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,9 +16,9 @@ type Tensors map[string]tensor.Tensor

// Model defines a model that can be used for inference.
type Model struct {
mp *onnx.ModelProto
parameters Tensors
GetOperator OpGetter
mp *onnx.ModelProto
parameters Tensors
Opset Opset
}

// NewModelFromFile creates a new model from a path to a file.
Expand Down Expand Up @@ -74,15 +74,15 @@ func NewModel(mp *onnx.ModelProto) (*Model, error) {
}
}

GetOperator, err := ResolveOperatorGetter(opsetID)
opset, err := ResolveOpset(opsetID)
if err != nil {
return nil, err
}

return &Model{
mp: mp,
parameters: params,
GetOperator: GetOperator,
mp: mp,
parameters: params,
Opset: opset,
}, nil
}

Expand Down Expand Up @@ -167,12 +167,12 @@ func (m *Model) Run(inputs Tensors) (Tensors, error) {
}

for _, n := range m.mp.Graph.GetNode() {
op, err := m.GetOperator(n.GetOpType())
if err != nil {
return nil, err
op, ok := m.Opset[n.GetOpType()]
if !ok {
return nil, ops.ErrUnknownOperatorType(n.GetOpType())
}

if err := m.applyOp(op, n, tensors); err != nil {
if err := m.applyOp(op(), n, tensors); err != nil {
return nil, err
}
}
Expand Down
6 changes: 3 additions & 3 deletions ops/constantofshape/constant_of_shape_9_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -137,7 +137,7 @@ func TestIncorrectInput(t *testing.T) {
assert.NotNil(t, err)
assert.Equal(
t,
"constant of shape operator invalid tensor found, reason: expected tensor to have one element",
"constantofshape9 operator invalid tensor found, reason: expected tensor to have one element",
err.Error(),
)
}
Expand All @@ -154,7 +154,7 @@ func TestNegativeShapeNotAllowed(t *testing.T) {

assert.Equal(
t,
"constant of shape operator invalid tensor found, reason: empty dimensions are not allowed",
"constantofshape9 operator invalid tensor found, reason: empty dimensions are not allowed",
err.Error())
}

Expand All @@ -170,7 +170,7 @@ func TestEmptyTensorNotAllowed(t *testing.T) {

assert.Equal(
t,
"constant of shape operator invalid tensor found, reason: empty dimensions are not allowed",
"constantofshape9 operator invalid tensor found, reason: empty dimensions are not allowed",
err.Error())
}

Expand Down
6 changes: 6 additions & 0 deletions ops/errors.go
Original file line number Diff line number Diff line change
Expand Up @@ -221,6 +221,12 @@ func ErrUnknownOperatorType(operatorType string) error {
return fmt.Errorf("%w: %s", ErrUnsupportedOperator, operatorType)
}

var ErrUnsupportedOperatorVersion = errors.New("unsupported opset operator version")

func ErrUnsupportedOperatorVersionType(opsetID int64, operatorType string) error {
return fmt.Errorf("%w: opset %d for operator %s", ErrUnsupportedOperator, opsetID, operatorType)
}

var ErrAxisNotInRange = errors.New("axis out of range")

func ErrNotAllAxesInRange(min, max int) error {
Expand Down
4 changes: 2 additions & 2 deletions ops/greater/versions.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ package greater
import "github.com/advancedclimatesystems/gonnx/ops"

var GreaterVersions = ops.OperatorVersions{
7: newGreater7,
9: newGreater9,
7: newGreater7, // Only float types
9: newGreater9, // bfloat16 added
13: newGreater13,
}
2 changes: 1 addition & 1 deletion ops/greaterorequal/greater_or_equal_12.go
Original file line number Diff line number Diff line change
Expand Up @@ -57,5 +57,5 @@ func (g *GreaterOrEqual12) GetInputTypeConstraints() [][]tensor.Dtype {

// String implements the stringer interface, and can be used to format errors or messages.
func (g *GreaterOrEqual12) String() string {
return "greaterOrEqual12 operator"
return "greaterorequal12 operator"
}
61 changes: 61 additions & 0 deletions ops/less/less_13.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,61 @@
package less

import (
"github.com/advancedclimatesystems/gonnx/onnx"
"github.com/advancedclimatesystems/gonnx/ops"
"gorgonia.org/tensor"
)

var (
MinLess13Inputs = 2
MaxLess13Inputs = 2
)

// Less13 represents the ONNX less operator.
type Less13 struct{}

// newLess13 creates a new less operator.
func newLess13() ops.Operator {
return &Less13{}
}

// Init initializes the less operator.
func (l *Less13) Init(*onnx.NodeProto) error {
return nil
}

// Apply applies the less operator.
func (l *Less13) Apply(inputs []tensor.Tensor) ([]tensor.Tensor, error) {
return ops.ApplyBinaryOperation(
inputs[0],
inputs[1],
ops.Lt,
ops.MultidirectionalBroadcasting,
)
}

// ValidateInputs validates the inputs that will be given to Apply for this operator.
func (l *Less13) ValidateInputs(inputs []tensor.Tensor) ([]tensor.Tensor, error) {
return ops.ValidateInputs(l, inputs)
}

// GetMinInputs returns the minimum number of input tensors this operator expects.
func (l *Less13) GetMinInputs() int {
return MinLess13Inputs
}

// GetMaxInputs returns the maximum number of input tensors this operator expects.
func (l *Less13) GetMaxInputs() int {
return MaxLess13Inputs
}

// GetInputTypeConstraints returns a list. Every element represents a set of allowed tensor dtypes
// for the corresponding input tensor.
func (l *Less13) GetInputTypeConstraints() [][]tensor.Dtype {
return [][]tensor.Dtype{ops.AllTypes, ops.AllTypes}
}

// String implements the stringer interface, and can be used to format errors or messages.
func (l *Less13) String() string {
return "less13 operator"
}
24 changes: 12 additions & 12 deletions ops/opset13/less_test.go → ops/less/less_13_test.go
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
package opset13
package less

import (
"testing"
Expand All @@ -8,36 +8,36 @@ import (
"gorgonia.org/tensor"
)

func TestLessInit(t *testing.T) {
l := &Less{}
func TestLess13Init(t *testing.T) {
l := &Less13{}

// since 'less' does not have any attributes we pass in nil. This should not
// fail initializing the less.
err := l.Init(ops.EmptyNodeProto())
assert.Nil(t, err)
}

func TestLess(t *testing.T) {
func TestLess13(t *testing.T) {
tests := []struct {
less *Less
less *Less13
backings [][]float32
shapes [][]int
expected []bool
}{
{
&Less{},
&Less13{},
[][]float32{{0, 1, 2, 3}, {1, 1, 1, 1}},
[][]int{{2, 2}, {2, 2}},
[]bool{true, false, false, false},
},
{
&Less{},
&Less13{},
[][]float32{{0, 1, 2, 3, 4, 5}, {2, 2, 2, 2, 2, 2}},
[][]int{{3, 2}, {3, 2}},
[]bool{true, true, false, false, false, false},
},
{
&Less{},
&Less13{},
[][]float32{{0, 1}, {0, 1, 2, 3}},
[][]int{{2}, {2, 2}},
[]bool{false, false, true, true},
Expand All @@ -58,7 +58,7 @@ func TestLess(t *testing.T) {
}
}

func TestInputValidationLess(t *testing.T) {
func TestInputValidationLess13(t *testing.T) {
tests := []struct {
inputs []tensor.Tensor
err error
Expand Down Expand Up @@ -109,19 +109,19 @@ func TestInputValidationLess(t *testing.T) {
[]tensor.Tensor{
ops.TensorWithBackingFixture([]int{1, 2}, 2),
},
ops.ErrInvalidInputCount(1, &Less{}),
ops.ErrInvalidInputCount(1, &Less13{}),
},
{
[]tensor.Tensor{
ops.TensorWithBackingFixture([]int{1, 2}, 2),
ops.TensorWithBackingFixture([]int{3, 4}, 2),
},
ops.ErrInvalidInputType(0, "int", &Less{}),
ops.ErrInvalidInputType(0, "int", &Less13{}),
},
}

for _, test := range tests {
less := &Less{}
less := &Less13{}
validated, err := less.ValidateInputs(test.inputs)

assert.Equal(t, test.err, err)
Expand Down
61 changes: 61 additions & 0 deletions ops/less/less_7.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,61 @@
package less

import (
"github.com/advancedclimatesystems/gonnx/onnx"
"github.com/advancedclimatesystems/gonnx/ops"
"gorgonia.org/tensor"
)

var (
MinLess7Inputs = 2
MaxLess7Inputs = 2
)

// Less7 represents the ONNX less operator.
type Less7 struct{}

// newLess7 creates a new less operator.
func newLess7() ops.Operator {
return &Less7{}
}

// Init initializes the less operator.
func (l *Less7) Init(*onnx.NodeProto) error {
return nil
}

// Apply applies the less operator.
func (l *Less7) Apply(inputs []tensor.Tensor) ([]tensor.Tensor, error) {
return ops.ApplyBinaryOperation(
inputs[0],
inputs[1],
ops.Lt,
ops.MultidirectionalBroadcasting,
)
}

// ValidateInputs validates the inputs that will be given to Apply for this operator.
func (l *Less7) ValidateInputs(inputs []tensor.Tensor) ([]tensor.Tensor, error) {
return ops.ValidateInputs(l, inputs)
}

// GetMinInputs returns the minimum number of input tensors this operator expects.
func (l *Less7) GetMinInputs() int {
return MinLess7Inputs
}

// GetMaxInputs returns the maximum number of input tensors this operator expects.
func (l *Less7) GetMaxInputs() int {
return MaxLess7Inputs
}

// GetInputTypeConstraints returns a list. Every element represents a set of allowed tensor dtypes
// for the corresponding input tensor.
func (l *Less7) GetInputTypeConstraints() [][]tensor.Dtype {
return [][]tensor.Dtype{{tensor.Float32, tensor.Float64}, {tensor.Float32, tensor.Float64}}
}

// String implements the stringer interface, and can be used to format errors or messages.
func (l *Less7) String() string {
return "less7 operator"
}
36 changes: 18 additions & 18 deletions ops/opset13/less.go → ops/less/less_9.go
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
package opset13
package less

import (
"github.com/advancedclimatesystems/gonnx/onnx"
Expand All @@ -7,25 +7,25 @@ import (
)

var (
MinLessInputs = 2
MaxLessInputs = 2
MinLess9Inputs = 2
MaxLess9Inputs = 2
)

// Less represents the ONNX less operator.
type Less struct{}
// Less9 represents the ONNX less operator.
type Less9 struct{}

// newLess creates a new less operator.
func newLess() ops.Operator {
return &Less{}
// newLess9 creates a new less operator.
func newLess9() ops.Operator {
return &Less9{}
}

// Init initializes the less operator.
func (l *Less) Init(*onnx.NodeProto) error {
func (l *Less9) Init(*onnx.NodeProto) error {
return nil
}

// Apply applies the less operator.
func (l *Less) Apply(inputs []tensor.Tensor) ([]tensor.Tensor, error) {
func (l *Less9) Apply(inputs []tensor.Tensor) ([]tensor.Tensor, error) {
return ops.ApplyBinaryOperation(
inputs[0],
inputs[1],
Expand All @@ -35,27 +35,27 @@ func (l *Less) Apply(inputs []tensor.Tensor) ([]tensor.Tensor, error) {
}

// ValidateInputs validates the inputs that will be given to Apply for this operator.
func (l *Less) ValidateInputs(inputs []tensor.Tensor) ([]tensor.Tensor, error) {
func (l *Less9) ValidateInputs(inputs []tensor.Tensor) ([]tensor.Tensor, error) {
return ops.ValidateInputs(l, inputs)
}

// GetMinInputs returns the minimum number of input tensors this operator expects.
func (l *Less) GetMinInputs() int {
return MinLessInputs
func (l *Less9) GetMinInputs() int {
return MinLess9Inputs
}

// GetMaxInputs returns the maximum number of input tensors this operator expects.
func (l *Less) GetMaxInputs() int {
return MaxLessInputs
func (l *Less9) GetMaxInputs() int {
return MaxLess9Inputs
}

// GetInputTypeConstraints returns a list. Every element represents a set of allowed tensor dtypes
// for the corresponding input tensor.
func (l *Less) GetInputTypeConstraints() [][]tensor.Dtype {
func (l *Less9) GetInputTypeConstraints() [][]tensor.Dtype {
return [][]tensor.Dtype{ops.AllTypes, ops.AllTypes}
}

// String implements the stringer interface, and can be used to format errors or messages.
func (l *Less) String() string {
return "less operator"
func (l *Less9) String() string {
return "less9 operator"
}
9 changes: 9 additions & 0 deletions ops/less/versions.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
package less

import "github.com/advancedclimatesystems/gonnx/ops"

var LessVersions = ops.OperatorVersions{
7: newLess7, // Only float types
9: newLess9, // bfloat16 type
13: newLess13,
}
Loading

0 comments on commit 5a6800d

Please sign in to comment.