From 1d6aa996335ff5a46b3348d52a02fec2ea25000b Mon Sep 17 00:00:00 2001 From: wisse Date: Thu, 4 May 2023 21:17:14 +0200 Subject: [PATCH 01/15] Fix a lot of linter issues --- errors.go | 24 ++++++++++++++++++---- model.go | 35 ++++++++++++++++---------------- model_test.go | 5 ++++- onnx/graph_proto.go | 26 +++++++++++++++++------- ops/fixtures.go | 2 +- ops/opset13/gather.go | 2 +- ops/opset13/unsqueeze.go | 2 +- ops/unidir_broadcast.go | 1 - ops/utils.go | 43 +++++++++++++++++++++++++++++++--------- ops_test.go | 1 - opset.go | 3 ++- 11 files changed, 99 insertions(+), 45 deletions(-) diff --git a/errors.go b/errors.go index 82523e8..74cb976 100644 --- a/errors.go +++ b/errors.go @@ -1,7 +1,23 @@ package gonnx -// InvalidShapeError is used when the shape of an input tensor does not match the expectation. -const InvalidShapeError = "input shape does not match for %v: expected %v but got %v" +import ( + "errors" + "fmt" +) -// SetOutputTensorsError is used when the output of an operation could not be set. -const SetOutputTensorsError = "could not set output tensors, expected %v tensors but got %v" +var ( + errInvalidShape = errors.New("input shape does not match") + errSetOutputTensor = errors.New("could not set output tensor") + errModel = errors.New("gonnx model error") +) + +// TODO weird description. +func ErrInvalidShape(format string, a ...any) error { + return fmt.Errorf("%w: %s", errInvalidShape, fmt.Sprintf(format, a...)) +} + +// ErrModel is used for when an error ocured during setup of running onnx models. +// The user can specify a formatted message using the standard formatting rules. +func ErrModel(format string, a ...any) error { + return fmt.Errorf("%w: %s", errModel, fmt.Sprintf(format, a...)) +} diff --git a/model.go b/model.go index d1d8537..aa5779b 100644 --- a/model.go +++ b/model.go @@ -2,8 +2,8 @@ package gonnx import ( "archive/zip" - "fmt" - "io/ioutil" + "io" + "os" "github.com/advancedclimatesystems/gonnx/onnx" "github.com/advancedclimatesystems/gonnx/ops" @@ -11,7 +11,7 @@ import ( "gorgonia.org/tensor" ) -// Tensors is a map with tensors +// Tensors is a map with tensors. type Tensors map[string]tensor.Tensor // Model defines a model that can be used for inference. @@ -23,7 +23,7 @@ type Model struct { // NewModelFromFile creates a new model from a path to a file. func NewModelFromFile(path string) (*Model, error) { - bytesModel, err := ioutil.ReadFile(path) + bytesModel, err := os.ReadFile(path) if err != nil { return nil, err } @@ -38,7 +38,7 @@ func NewModelFromZipFile(file *zip.File) (*Model, error) { return nil, err } - bytesModel, err := ioutil.ReadAll(fc) + bytesModel, err := io.ReadAll(fc) if err != nil { return nil, err } @@ -66,6 +66,7 @@ func NewModel(mp *onnx.ModelProto) (*Model, error) { opsetImports := mp.GetOpsetImport() var opsetID int64 + for i := 0; i < len(opsetImports); i++ { version := opsetImports[i].GetVersion() if version > opsetID { @@ -78,12 +79,11 @@ func NewModel(mp *onnx.ModelProto) (*Model, error) { return nil, err } - model := &Model{ + return &Model{ mp: mp, parameters: params, GetOperator: GetOperator, - } - return model, nil + }, nil } // ModelProtoFromBytes creates an onnx.ModelProto based on a list of bytes. @@ -92,6 +92,7 @@ func ModelProtoFromBytes(bytesModel []byte) (*onnx.ModelProto, error) { if err := proto.Unmarshal(bytesModel, mp); err != nil { return nil, err } + return mp, nil } @@ -108,16 +109,13 @@ func (m *Model) InputShapes() onnx.Shapes { // InputDimSize returns the size of the input dimension given an input tensor. func (m *Model) InputDimSize(input string, i int) (int, error) { if !m.hasInput(input) { - return 0, fmt.Errorf("input %v does not exist", input) + return 0, ErrModel("input %v does not exist", input) } inputShape := m.mp.Graph.InputShapes()[input] if i >= len(inputShape) { - err := fmt.Errorf( - "input %v only has %d dimensions, but index %d was required", input, len(inputShape), i, - ) - return 0, err + return 0, ErrModel("input %v only has %d dimensions, but index %d was required", input, len(inputShape), i) } return int(inputShape[i].Size), nil @@ -222,13 +220,13 @@ func (m *Model) validateShapes(inputTensors Tensors) error { tensor, ok := inputTensors[name] if !ok { - return fmt.Errorf("tensor: %v not found", name) + return ErrModel("tensor: %v not found", name) } shapeReceived := tensor.Shape() if len(shapeReceived) != len(shapeExpected) { - return fmt.Errorf(InvalidShapeError, name, shapeExpected, shapeReceived) + return ErrInvalidShape("shape does not match for %v: expected %v but got %v", name, shapeExpected, shapeReceived) } for i, dim := range shapeExpected { @@ -239,7 +237,7 @@ func (m *Model) validateShapes(inputTensors Tensors) error { } if dim.Size != int64(shapeReceived[i]) { - return fmt.Errorf(InvalidShapeError, name, shapeExpected, shapeReceived) + return ErrInvalidShape("shape does not match for %v: expected %v but got %v", name, shapeExpected, shapeReceived) } } } @@ -249,6 +247,7 @@ func (m *Model) validateShapes(inputTensors Tensors) error { func getInputTensorsForNode(names []string, tensors Tensors) ([]tensor.Tensor, error) { var inputTensors []tensor.Tensor + for _, tensorName := range names { // An empty name can happen in between optional inputs, like: // [, , nil, ] @@ -259,7 +258,7 @@ func getInputTensorsForNode(names []string, tensors Tensors) ([]tensor.Tensor, e } else if tensor, ok := tensors[tensorName]; ok { inputTensors = append(inputTensors, tensor) } else { - return nil, fmt.Errorf("no tensor yet for name %v", tensorName) + return nil, ErrModel("no tensor yet for name %v", tensorName) } } @@ -270,7 +269,7 @@ func setOutputTensorsOfNode( names []string, outputTensors []tensor.Tensor, tensors Tensors, ) error { if len(names) != len(outputTensors) { - return fmt.Errorf(SetOutputTensorsError, len(names), len(outputTensors)) + return ErrModel("could not set output tensor") } for i, tensor := range outputTensors { diff --git a/model_test.go b/model_test.go index 87259ac..e0d15b0 100644 --- a/model_test.go +++ b/model_test.go @@ -169,7 +169,7 @@ func TestInputDimSizeInvalidInput(t *testing.T) { } // tensorsFixture creates Tensors with the given names shapes and backings. This is useful for -// providing a model with inputs and checking it's outputs +// providing a model with inputs and checking it's outputs. func tensorsFixture(names []string, shapes [][]int, backing [][]float32) Tensors { res := make(Tensors, len(names)) for i, name := range names { @@ -178,6 +178,7 @@ func tensorsFixture(names []string, shapes [][]int, backing [][]float32) Tensors tensor.WithBacking(backing[i]), ) } + return res } @@ -186,6 +187,7 @@ func rangeFloat(size int) []float32 { for i := 0; i < size; i++ { res[i] = float32(i) } + return res } @@ -194,6 +196,7 @@ func rangeZeros(size int) []float32 { for i := range res { res[i] = 0.0 } + return res } diff --git a/onnx/graph_proto.go b/onnx/graph_proto.go index 87decfa..44180c2 100644 --- a/onnx/graph_proto.go +++ b/onnx/graph_proto.go @@ -13,7 +13,7 @@ import ( "gorgonia.org/tensor" ) -// InputNames returns the input names for a GraphProto +// InputNames returns the input names for a GraphProto. func (g *GraphProto) InputNames() []string { return getNamesFromValueProto(g.GetInput()) } @@ -51,6 +51,7 @@ func (g *GraphProto) Params() (map[string]tensor.Tensor, error) { res[i.Name] = t } + return res, nil } @@ -62,10 +63,12 @@ type Shape []Dim // String prints a shape in a human-friendly matter. func (s Shape) String() string { - var dimSizes []int64 + dimSizes := make([]int64, 0, len(s)) + for _, dim := range s { dimSizes = append(dimSizes, dim.Size) } + return fmt.Sprintf("%d", dimSizes) } @@ -95,6 +98,7 @@ func getShapesFromValueProto(protos []*ValueInfoProto) Shapes { if protos == nil { return map[string]Shape{} } + shapes := make(map[string]Shape, len(protos)) for _, p := range protos { @@ -119,6 +123,7 @@ func getShapesFromValueProto(protos []*ValueInfoProto) Shapes { } shape := make([]Dim, len(dims)) + for i, dim := range dims { param := dim.GetDimParam() v := dim.GetDimValue() @@ -130,6 +135,7 @@ func getShapesFromValueProto(protos []*ValueInfoProto) Shapes { shape[i] = Dim{IsDynamic: isDynamic, Name: param, Size: v} } + shapes[p.GetName()] = shape } @@ -146,7 +152,7 @@ func getNamesFromTensorProto(protos []*TensorProto) []string { return res } -// TensorFromProto returns a tensor.Tensor from an onnx.TensorProto +// TensorFromProto returns a tensor.Tensor from an onnx.TensorProto. func TensorFromProto(tp *TensorProto) (tensor.Tensor, error) { var values interface{} var err error @@ -297,11 +303,14 @@ func ReadFloat32ArrayFromBytes(data []byte) ([]float32, error) { buffer := bytes.NewReader(data) element := make([]byte, float32Size) - var err error - var values []float32 + var ( + err error + values []float32 + ) for { var n int + n, err = buffer.Read(element) if n != float32Size || err != nil { break @@ -323,11 +332,14 @@ func ReadFloat64ArrayFromBytes(data []byte) ([]float64, error) { buffer := bytes.NewReader(data) element := make([]byte, float64Size) - var err error - var values []float64 + var ( + err error + values []float64 + ) for { var n int + n, err = buffer.Read(element) if n != float64Size || err != nil { break diff --git a/ops/fixtures.go b/ops/fixtures.go index 12d552a..ba0d215 100644 --- a/ops/fixtures.go +++ b/ops/fixtures.go @@ -4,7 +4,7 @@ import ( "gorgonia.org/tensor" ) -// InputFixture is a function that generates inputs for ops. Useful in testing +// InputFixture is a function that generates inputs for ops. Useful in testing. type InputFixture func() []tensor.Tensor // Float32TensorFixture returns a float32 backed gorgonia node. It initializes all its values diff --git a/ops/opset13/gather.go b/ops/opset13/gather.go index ecea439..19f5199 100644 --- a/ops/opset13/gather.go +++ b/ops/opset13/gather.go @@ -192,7 +192,7 @@ func gather(out, data, indices tensor.Tensor, axis int) error { // Example: // > a = [-1, -2, -3] // > x = [1, 2, 3, 4, 5, 6, 7] -// insertWithReplace(a, x, 3) -> [1, 2, 3, -1, -2, -3, 5, 6, 7] +// insertWithReplace(a, x, 3) -> [1, 2, 3, -1, -2, -3, 5, 6, 7]. func insertWithReplace(a, x []int, axis int) []int { y := append([]int{}, x[:axis]...) y = append(y, a...) diff --git a/ops/opset13/unsqueeze.go b/ops/opset13/unsqueeze.go index 6648de3..51d3cd1 100644 --- a/ops/opset13/unsqueeze.go +++ b/ops/opset13/unsqueeze.go @@ -82,7 +82,7 @@ func (u *Unsqueeze) String() string { // Creates a new array, which is `original` with ones added at the indices specified by `indices` // `indices` may not contain duplicates, the elements are assumed to be in the range 0 <= x < N // and should be sorted in increasing order. -// Is done in a single pass through the new array with length: len(original) + len(indices) +// Is done in a single pass through the new array with length: len(original) + len(indices). func insertOnes(original, indices []int) []int { N := len(indices) + len(original) diff --git a/ops/unidir_broadcast.go b/ops/unidir_broadcast.go index 41b35d5..702404c 100644 --- a/ops/unidir_broadcast.go +++ b/ops/unidir_broadcast.go @@ -8,7 +8,6 @@ import ( // UnidirectionalBroadcast tries to broadcast tensor B to tensor A according to the ONNX standards. func UnidirectionalBroadcast(A, B tensor.Tensor) (tensor.Tensor, tensor.Tensor, error) { - reshapedB, err := reshapeTensorsForUnidirBroadcast(A, B) if err != nil { return nil, nil, fmt.Errorf(UnidirBroadcastErrTemplate, A.Shape(), B.Shape()) diff --git a/ops/utils.go b/ops/utils.go index 369aedf..096bc99 100644 --- a/ops/utils.go +++ b/ops/utils.go @@ -11,6 +11,7 @@ func Abs(x int) int { if x < 0 { x *= -1 } + return x } @@ -26,21 +27,26 @@ func AllInRange(arr []int, min, max int) bool { return false } } + return true } -// HasDuplicates checks if there are duplicates in the sorted array `arr` +// HasDuplicates checks if there are duplicates in the sorted array `arr`. func HasDuplicates(arr []int) bool { if len(arr) < 1 { return false } + prev := arr[0] + for _, x := range arr[1:] { if prev == x { return true } + prev = x } + return false } @@ -51,47 +57,58 @@ func OffsetArrayIfNegative(arr []int, offset int) { if ax < 0 { ax += offset } + arr[i] = ax } } // OffsetTensorIfNegative adds an offset to every negative element in tensor t. // Works only for tensors with Dtype int (same as offset). -func OffsetTensorIfNegative(t tensor.Tensor, offset int) { +func OffsetTensorIfNegative(t tensor.Tensor, offset int) error { f := func(n int) int { if n < 0 { return n + offset } + return n } - t.Apply(f, tensor.WithReuse(t)) + + if _, err := t.Apply(f, tensor.WithReuse(t)); err != nil { + return err + } + + return nil } // AnyToIntSlice casts the data of a node to an int list. This will only // be done if the data is of some sort of int type. -func AnyToIntSlice(any interface{}) ([]int, error) { +func AnyToIntSlice(value interface{}) ([]int, error) { var res []int - switch data := any.(type) { + switch data := value.(type) { case []int8: for _, value := range data { res = append(res, int(value)) } + return res, nil case []int16: for _, value := range data { res = append(res, int(value)) } + return res, nil case []int32: for _, value := range data { res = append(res, int(value)) } + return res, nil case []int64: for _, value := range data { res = append(res, int(value)) } + return res, nil default: return nil, fmt.Errorf("could not cast %v to int list", data) @@ -122,8 +139,8 @@ func GetValueAsTensorType(value float64, dtype tensor.Dtype) (interface{}, error // IfScalarToSlice will wrap the value in a slice if it is a scalar in a slice with that value, // otherwise will return itself. -func IfScalarToSlice(any interface{}) interface{} { - switch data := any.(type) { +func IfScalarToSlice(value any) any { + switch data := value.(type) { case int8: return []int8{data} case int16: @@ -143,7 +160,7 @@ func IfScalarToSlice(any interface{}) interface{} { case complex128: return []complex128{data} default: - return any + return value } } @@ -153,6 +170,7 @@ func Zeros(size int) []float32 { for i := range res { res[i] = 0.0 } + return res } @@ -162,6 +180,7 @@ func Full(size int, value float32) []float32 { for i := range res { res[i] = value } + return res } @@ -171,6 +190,7 @@ func Ones(size int) []float32 { for i := range res { res[i] = 1.0 } + return res } @@ -180,6 +200,7 @@ func Arange(size int, step float32) []float32 { for i := range res { res[i] = float32(i) * step } + return res } @@ -193,12 +214,15 @@ func NElements(shp ...int) int { return nElem } -// PairwiseAssign essentially does pairwise t1 = t2 in place! +// PairwiseAssign essentially does pairwise t1 = t2 in place!. func PairwiseAssign(t1, t2 tensor.Tensor) (err error) { if !t1.Shape().Eq(t2.Shape()) { return fmt.Errorf("Shapes of tensors must be equal, were %v and %v", t1.Shape(), t2.Shape()) } + it := t1.Iterator() + // We cannot check the error here since it is a post statement so ignore the nolint errcheck here. + // nolint errcheck for it.Reset(); !it.Done(); it.Next() { coord := it.Coord() @@ -212,5 +236,6 @@ func PairwiseAssign(t1, t2 tensor.Tensor) (err error) { return err } } + return nil } diff --git a/ops_test.go b/ops_test.go index 6fb94fd..2037d20 100644 --- a/ops_test.go +++ b/ops_test.go @@ -229,7 +229,6 @@ func readTestTensors(basePath, baseFile string, inputs []*onnx.ValueInfoProto) ( tensors := make(Tensors) for i := 0; i < len(inputs); i++ { - filePath := fmt.Sprintf("%v/%v_%d.pb", basePath, baseFile, i) bytesInput, err := ioutil.ReadFile(filePath) if err != nil { diff --git a/opset.go b/opset.go index 908e8bd..0471864 100644 --- a/opset.go +++ b/opset.go @@ -20,9 +20,10 @@ func ResolveOperatorGetter(opsetID int64) (OpGetter, error) { return GetOperator, nil } - var opsets []int64 + opsets := make([]int64, len(operatorGetters)) for version := range operatorGetters { opsets = append(opsets, version) } + return nil, fmt.Errorf("expected opset to be in %d, got operator set %d", opsets, opsetID) } From cb919a7727e1d232e6c956e7c51a1f241390ac04 Mon Sep 17 00:00:00 2001 From: wisse Date: Thu, 4 May 2023 21:22:01 +0200 Subject: [PATCH 02/15] Fix more lint --- ops/opset13/gru.go | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/ops/opset13/gru.go b/ops/opset13/gru.go index a72107c..a642f5b 100644 --- a/ops/opset13/gru.go +++ b/ops/opset13/gru.go @@ -142,6 +142,7 @@ func (g *GRU) Apply(inputs []tensor.Tensor) ([]tensor.Tensor, error) { if err != nil { return nil, err } + return []tensor.Tensor{Y, Yh}, nil } @@ -266,6 +267,7 @@ func (g *GRU) getForwardWeights(W tensor.Tensor) (Wz, Wr, Wh tensor.Tensor, err if err != nil { return nil, nil, nil, err } + return n[0], n[1], n[2], nil } @@ -275,6 +277,7 @@ func (g *GRU) getRecurrentWeights(R tensor.Tensor) (Rz, Rr, Rh tensor.Tensor, er if err != nil { return nil, nil, nil, err } + return recurrentWeights[0], recurrentWeights[1], recurrentWeights[2], nil } @@ -284,6 +287,7 @@ func (g *GRU) getBiases(B tensor.Tensor) (Wbz, Wbr, Wbh, Rbz, Rbr, Rbh tensor.Te if err != nil { return nil, nil, nil, nil, nil, nil, err } + return biases[0], biases[1], biases[2], biases[3], biases[4], biases[5], nil } @@ -308,6 +312,7 @@ func (g *GRU) extractWeights(W tensor.Tensor) ([]tensor.Tensor, error) { weights[i] = w } + return weights, nil } @@ -331,6 +336,7 @@ func (g *GRU) extractBiases(B tensor.Tensor) ([]tensor.Tensor, error) { biases[i] = w } + return biases, nil } From 2d920c62795c7e6430ad1ca05684d4b318e3fac1 Mon Sep 17 00:00:00 2001 From: wisse Date: Sun, 10 Sep 2023 11:20:38 +0200 Subject: [PATCH 03/15] Got to cast operator and disable `captLocal` --- .golangci.yml | 15 +++-- errors.go | 17 ++++- ops/errors.go | 135 ++++++++++++++++++++++++++++++++++++++ ops/multidir_broadcast.go | 31 ++++++--- ops/opset13/abs.go | 14 +++- ops/opset13/abs_test.go | 9 +-- ops/opset13/add.go | 14 +++- ops/opset13/add_test.go | 17 ++--- ops/opset13/cast.go | 6 +- ops/opset13/cast_test.go | 2 +- ops/opset13/mul.go | 2 +- ops/opset13/squeeze.go | 37 +++++++---- ops/validate_inputs.go | 18 +++-- 13 files changed, 251 insertions(+), 66 deletions(-) diff --git a/.golangci.yml b/.golangci.yml index 6fb336b..e71a88c 100644 --- a/.golangci.yml +++ b/.golangci.yml @@ -43,10 +43,17 @@ linters: linters-settings: gomnd: ignored-functions: - - 'strconv.ParseInt' - - 'strconv.ParseFloat' - - 'strconv.FormatInt' - - 'strconv.FormatFloat' + - "strconv.ParseInt" + - "strconv.ParseFloat" + - "strconv.FormatInt" + - "strconv.FormatFloat" + gocritic: + disabled-checks: + # In the world of AI tensor's are often denoted with a capital letter. + # We want to adopt the go style guide as much as possible but we also want + # to be able to easily show when a variable is a Tensor. So we chose to + # disable captLocal. Note that any other parameter should use a lower case letters. + - "captLocal" issues: max-issues-per-linter: 0 max-same-issues: 0 diff --git a/errors.go b/errors.go index 74cb976..a032b6c 100644 --- a/errors.go +++ b/errors.go @@ -11,9 +11,20 @@ var ( errModel = errors.New("gonnx model error") ) -// TODO weird description. -func ErrInvalidShape(format string, a ...any) error { - return fmt.Errorf("%w: %s", errInvalidShape, fmt.Sprintf(format, a...)) +type InvalidShapeError struct { + expected []int + actual []int +} + +func (i InvalidShapeError) Error() string { + return fmt.Sprintf("invalid shape error expected: %v actual %v. mehtod %s", i.expected, i.actual) +} + +func ErrInvalidShape(expected, actual []int) error { + return InvalidShapeError{ + expected: expected, + actual: actual, + } } // ErrModel is used for when an error ocured during setup of running onnx models. diff --git a/ops/errors.go b/ops/errors.go index ab25387..b91d187 100644 --- a/ops/errors.go +++ b/ops/errors.go @@ -1,5 +1,53 @@ package ops +import ( + "errors" + "fmt" + "reflect" + + "gorgonia.org/tensor" +) + +type AttributeError struct { + kind string + attributeCount int + expectedCount int + attributeName string + operator Operator +} + +func (t *AttributeError) Error() string { + switch t.kind { + case "count": + return fmt.Sprintf("operator %s attribute error: invalid count %d expected %d", t.operator.String(), t.attributeCount, t.expectedCount) + case "name": + return fmt.Sprintf("operator %s attribute error: invalid attribute %s", t.operator.String(), t.attributeName) + } + + return fmt.Sprintf("attribute error") +} + +func ErrInvalidAttribute(attributeName string, operator Operator) error { + return &AttributeError{attributeName: attributeName, kind: "count", operator: operator} +} + +func ErrInvalidAttributeCount(expected, actual int, operator Operator) error { + return &AttributeError{attributeCount: actual, expectedCount: expected, kind: "count", operator: operator} +} + +type TypeAssertError struct { + expectedType string + actualType any +} + +func (t *TypeAssertError) Error() string { + return fmt.Sprintf("type assert error: expected %v, got %v", t.expectedType, reflect.TypeOf(t.actualType)) +} + +func ErrTypeAssert(expected string, actual any) error { + return &TypeAssertError{expectedType: expected, actualType: actual} +} + // UnknownAttributeErrTemplate is used to format an error // when an operator finds an unknown attribute during its initialization. const UnknownAttributeErrTemplate = "%v: unknown attribute: %v" @@ -20,6 +68,83 @@ const InvalidInputCountErrTemplate = "%v: expected %d input tensors, got %d" // the wrong amount of input tensors when optional inputs are present. const InvalidOptionalInputCountErrTemplate = "%v: expected %d-%d input tensors, got %d" +type InvalidInputTypeError struct { + inputNumber int + actualType string + operator Operator +} + +func (i *InvalidInputTypeError) Error() string { + return fmt.Sprintf("input %d for op %v does not allow dtype %v", i.inputNumber, i.operator, i.actualType) +} + +func ErrInvalidInputType(operator Operator, inputNumber int, dType string) error { + return &InvalidInputTypeError{ + operator: operator, + inputNumber: inputNumber, + actualType: dType, + } +} + +type InvalidInputCountError struct { + hasOptionalInputs bool + actualCount int + operator Operator +} + +func (i *InvalidInputCountError) Error() string { + if i.hasOptionalInputs { + return fmt.Sprintf(InvalidOptionalInputCountErrTemplate, i.operator, i.operator.GetMinInputs(), i.operator.GetMaxInputs(), i.actualCount) + } + + return fmt.Sprintf(InvalidInputCountErrTemplate, i.operator, i.operator.GetMinInputs(), i.actualCount) +} + +func ErrInvalidInputCount(operator Operator, actual int) error { + return &InvalidInputCountError{ + actualCount: actual, + operator: operator, + } +} + +func ErrInvalidOptionalInputCount(operator Operator, actual int) error { + return &InvalidInputCountError{ + hasOptionalInputs: true, + actualCount: actual, + operator: operator, + } +} + +type BroadcastError struct { + broadcastType string + shapeA tensor.Shape + shapeB tensor.Shape + err error +} + +func (b *BroadcastError) Error() string { + return fmt.Sprintf("%v: could not perform %v, inputs with shape %d and %d.", b.err, b.broadcastType, b.shapeA, b.shapeB) +} + +func ErrMultidirBroadcast(shapeA, shapeB tensor.Shape, err error) error { + return &BroadcastError{ + broadcastType: "multidirectional broadcast", + shapeA: shapeA, + shapeB: shapeB, + err: err, + } +} + +func ErrUnidirBroadcast(shapeA, shapeB tensor.Shape) error { + return &BroadcastError{ + broadcastType: "Unidirectional broadcast", + shapeA: shapeA, + shapeB: shapeB, + } +} + +var ErrIncompatibleDimension = errors.New("incompatible dimensions") + // UnknowOpTypeErrTemplate is used to format an error when the operator type is unknown. const UnknowOpTypeErrTemplate = "unknown operator type: %v" @@ -38,3 +163,13 @@ const AxisOutOfRangeErrTemplate = "axis argument must be in the range -%d <= x < // AxesNotAllInRangeErrTemplate is used to format an error when not all indices // are within a given range. const AxesNotAllInRangeErrTemplate = "all indices entries must be in the range -%d <= x < %d" + +var ErrAxisNotInRange = errors.New("axis out of range") + +func ErrNotAllAxisInRange(min, max int) error { + return fmt.Errorf("%w: all indices entries must be in the range -%d <= x < %d", ErrAxisNotInRange, min, max) +} + +func ErrAxisOutOfRange(min, max, actual int) error { + return fmt.Errorf("%w: axis argument must be in the range -%d <= x < %d, was %d", ErrAxisNotInRange, min, max, actual) +} diff --git a/ops/multidir_broadcast.go b/ops/multidir_broadcast.go index 9b81961..a55a361 100644 --- a/ops/multidir_broadcast.go +++ b/ops/multidir_broadcast.go @@ -1,8 +1,6 @@ package ops import ( - "fmt" - "gorgonia.org/tensor" ) @@ -11,12 +9,12 @@ import ( func MultidirectionalBroadcast(A, B tensor.Tensor) (tensor.Tensor, tensor.Tensor, error) { newA, newB, err := ReshapeTensorsForMultidirBroadcast(A, B) if err != nil { - return nil, nil, fmt.Errorf(MultidirBroadcastErrTemplate, A.Shape(), B.Shape(), err) + return nil, nil, ErrMultidirBroadcast(A.Shape(), B.Shape(), err) } newA, newB, err = repeatTensorsForMutltidirBroadcast(newA, newB) if err != nil { - return nil, nil, fmt.Errorf(MultidirBroadcastErrTemplate, A.Shape(), B.Shape(), err) + return nil, nil, ErrMultidirBroadcast(A.Shape(), B.Shape(), err) } return newA, newB, nil @@ -38,12 +36,14 @@ func ReshapeTensorsForMultidirBroadcast(A, B tensor.Tensor) (tensor.Tensor, tens if err != nil { return nil, nil, err } + return A, newB, nil case nDimsB > nDimsA: newA, err := addExtraDimsToTensor(A, nDimsB-nDimsA) if err != nil { return nil, nil, err } + return newA, B, nil default: return A, B, nil @@ -55,9 +55,11 @@ func ReshapeTensorsForMultidirBroadcast(A, B tensor.Tensor) (tensor.Tensor, tens // the dimension of the other. If both sizes are not 1, the tensors cannot be broadcasted to // each other. It is assumed that both tensors are reshaped accordingly first. // Example: -// shapeA=(1, 3, 4) and shapeB=(2, 3, 1) yields shapeNewA=(2, 3, 4) and shapeNewB=(2, 3, 4). +// +// shapeA=(1, 3, 4) and shapeB=(2, 3, 1) yields shapeNewA=(2, 3, 4) and shapeNewB=(2, 3, 4). func repeatTensorsForMutltidirBroadcast(A, B tensor.Tensor) (tensor.Tensor, tensor.Tensor, error) { var err error + shapeA := A.Shape() shapeB := B.Shape() nDims := len(shapeA) @@ -73,13 +75,15 @@ func repeatTensorsForMutltidirBroadcast(A, B tensor.Tensor) (tensor.Tensor, tens if err != nil { return nil, nil, err } + case sizeDimB == 1: B, err = tensor.Repeat(B, axis, sizeDimA) if err != nil { return nil, nil, err } + default: - return nil, nil, fmt.Errorf("incompatible dimensions") + return nil, nil, ErrIncompatibleDimension } } } @@ -92,14 +96,21 @@ func repeatTensorsForMutltidirBroadcast(A, B tensor.Tensor) (tensor.Tensor, tens // The given tensor is cloned such that the tensor is not modified in place. // Example: if we add 2 extra dimensions to shape (2, 3) we get shape (1, 1, 2, 3). func addExtraDimsToTensor(t tensor.Tensor, nExtraDims int) (tensor.Tensor, error) { - t = t.Clone().(tensor.Tensor) + t, ok := t.Clone().(tensor.Tensor) + if !ok { + return nil, ErrTypeAssert("tensor.Tensor", t.Clone()) + } - var newShape []int + newShape := []int{} for i := 0; i < nExtraDims; i++ { newShape = append(newShape, 1) } + newShape = append(newShape, t.Shape()...) - err := t.Reshape(newShape...) - return t, err + if err := t.Reshape(newShape...); err != nil { + return nil, err + } + + return t, nil } diff --git a/ops/opset13/abs.go b/ops/opset13/abs.go index 5c27994..15f8afc 100644 --- a/ops/opset13/abs.go +++ b/ops/opset13/abs.go @@ -6,6 +6,14 @@ import ( "gorgonia.org/tensor" ) +const ( + // MinAbsInput is the minimimum amount of inputs the abs operator expects. + MinAbsInput = 1 + + // MaxAbsInput is the maximum amount of inputs the abs operator accepts. + MaxAbsInput = 1 +) + // Abs represents the ONNX abs operator. type Abs struct{} @@ -15,7 +23,7 @@ func newAbs() ops.Operator { } // Init initializes the abs operator. -func (a *Abs) Init(attributes []*onnx.AttributeProto) error { +func (a *Abs) Init([]*onnx.AttributeProto) error { return nil } @@ -36,12 +44,12 @@ func (a *Abs) ValidateInputs(inputs []tensor.Tensor) ([]tensor.Tensor, error) { // GetMinInputs returns the minimum number of input tensors this operator expects. func (a *Abs) GetMinInputs() int { - return 1 + return MinAbsInput } // GetMaxInputs returns the maximum number of input tensors this operator expects. func (a *Abs) GetMaxInputs() int { - return 1 + return MaxAbsInput } // GetInputTypeConstraints returns a list. Every element represents a set of allowed tensor dtypes diff --git a/ops/opset13/abs_test.go b/ops/opset13/abs_test.go index f31e675..6589c37 100644 --- a/ops/opset13/abs_test.go +++ b/ops/opset13/abs_test.go @@ -1,7 +1,6 @@ package opset13 import ( - "fmt" "testing" "github.com/advancedclimatesystems/gonnx/ops" @@ -125,13 +124,13 @@ func TestInputValidationAbs(t *testing.T) { }, { []tensor.Tensor{}, - fmt.Errorf("abs operator: expected 1 input tensors, got 0"), + ops.ErrInvalidInputCount(&Abs{}, 0), }, { []tensor.Tensor{ ops.TensorWithBackingFixture([]int{1, 2}, 2), }, - fmt.Errorf("abs operator: input 0 does not allow type int"), + ops.ErrInvalidInputType(&Abs{}, 0, "int"), }, } @@ -140,8 +139,6 @@ func TestInputValidationAbs(t *testing.T) { validated, err := abs.ValidateInputs(test.inputs) assert.Equal(t, test.err, err) - if test.err == nil { - assert.Equal(t, test.inputs, validated) - } + assert.Equal(t, test.inputs, validated) } } diff --git a/ops/opset13/add.go b/ops/opset13/add.go index a80422f..32a5ded 100644 --- a/ops/opset13/add.go +++ b/ops/opset13/add.go @@ -6,6 +6,14 @@ import ( "gorgonia.org/tensor" ) +const ( + // MinAddInput is the minimimum amount of inputs the add operator expects. + MinAddInput = 1 + + // MaxAddInput is the maximum amount of inputs the add operator accepts. + MaxAddInput = 1 +) + // Add represents the ONNX add operator. type Add struct{} @@ -15,7 +23,7 @@ func newAdd() ops.Operator { } // Init initializes the add operator. -func (a *Add) Init(attributes []*onnx.AttributeProto) error { +func (a *Add) Init(_ []*onnx.AttributeProto) error { return nil } @@ -41,12 +49,12 @@ func (a *Add) ValidateInputs(inputs []tensor.Tensor) ([]tensor.Tensor, error) { // GetMinInputs returns the minimum number of input tensors this operator expects. func (a *Add) GetMinInputs() int { - return 2 + return MinAddInput } // GetMaxInputs returns the maximum number of input tensors this operator expects. func (a *Add) GetMaxInputs() int { - return 2 + return MaxAddInput } // GetInputTypeConstraints returns a list. Every element represents a set of allowed tensor dtypes diff --git a/ops/opset13/add_test.go b/ops/opset13/add_test.go index aa83c48..d69df55 100644 --- a/ops/opset13/add_test.go +++ b/ops/opset13/add_test.go @@ -1,7 +1,6 @@ package opset13 import ( - "fmt" "testing" "github.com/advancedclimatesystems/gonnx/ops" @@ -67,16 +66,7 @@ func TestAddFail(t *testing.T) { add := &Add{} _, err := add.Apply(inputs) - assert.Equal( - t, - err, - fmt.Errorf( - ops.MultidirBroadcastErrTemplate, - []int{2, 2}, - []int{3}, - "incompatible dimensions", - ), - ) + assert.Equal(t, err, ops.ErrMultidirBroadcast(inputs[0].Shape(), inputs[1].Shape(), ops.ErrIncompatibleDimension)) } func TestInputValidationAdd(t *testing.T) { @@ -130,14 +120,14 @@ func TestInputValidationAdd(t *testing.T) { []tensor.Tensor{ ops.TensorWithBackingFixture([]int{1, 2}, 2), }, - fmt.Errorf("add operator: expected 2 input tensors, got 1"), + ops.ErrInvalidInputCount(&Add{}, 1), }, { []tensor.Tensor{ ops.TensorWithBackingFixture([]int{1, 2}, 2), ops.TensorWithBackingFixture([]int{3, 4}, 2), }, - fmt.Errorf("add operator: input 0 does not allow type int"), + ops.ErrInvalidInputType(&Add{}, 0, "int"), }, } @@ -146,6 +136,7 @@ func TestInputValidationAdd(t *testing.T) { validated, err := add.ValidateInputs(test.inputs) assert.Equal(t, test.err, err) + if test.err == nil { assert.Equal(t, test.inputs, validated) } diff --git a/ops/opset13/cast.go b/ops/opset13/cast.go index 9f6e843..f039710 100644 --- a/ops/opset13/cast.go +++ b/ops/opset13/cast.go @@ -1,8 +1,6 @@ package opset13 import ( - "fmt" - "github.com/advancedclimatesystems/gonnx/onnx" "github.com/advancedclimatesystems/gonnx/ops" "gorgonia.org/tensor" @@ -21,14 +19,14 @@ func newCast() ops.Operator { // Init initializes the cast operator. func (c *Cast) Init(attributes []*onnx.AttributeProto) error { if len(attributes) != 1 { - return fmt.Errorf(ops.InvalidAttrCountErrTemplate, c, 1, len(attributes)) + return ops.ErrInvalidAttributeCount(1, len(attributes), c) } attr := attributes[0] if attr.GetName() == "to" { c.to = int32(attr.GetI()) } else { - return fmt.Errorf(ops.UnknownAttributeErrTemplate, c, attr.GetName()) + return ops.ErrInvalidAttribute(attr.GetName(), c) } return nil diff --git a/ops/opset13/cast_test.go b/ops/opset13/cast_test.go index 51d8942..593efe4 100644 --- a/ops/opset13/cast_test.go +++ b/ops/opset13/cast_test.go @@ -64,7 +64,7 @@ func TestCast(t *testing.T) { } for _, test := range tests { - test.cast.Init([]*onnx.AttributeProto{{Name: "to", I: test.to}}) + _ = test.cast.Init([]*onnx.AttributeProto{{Name: "to", I: test.to}}) inputs := []tensor.Tensor{ops.TensorWithBackingFixture(test.backing, test.shape...)} res, err := test.cast.Apply(inputs) diff --git a/ops/opset13/mul.go b/ops/opset13/mul.go index ed5a240..08b3e42 100644 --- a/ops/opset13/mul.go +++ b/ops/opset13/mul.go @@ -15,7 +15,7 @@ func newMul() ops.Operator { } // Init initializes the mul operator. -func (m *Mul) Init(attributes []*onnx.AttributeProto) error { +func (m *Mul) Init(_ []*onnx.AttributeProto) error { return nil } diff --git a/ops/opset13/squeeze.go b/ops/opset13/squeeze.go index 5cdf64c..9ee81cc 100644 --- a/ops/opset13/squeeze.go +++ b/ops/opset13/squeeze.go @@ -1,13 +1,16 @@ package opset13 import ( - "fmt" - "github.com/advancedclimatesystems/gonnx/onnx" "github.com/advancedclimatesystems/gonnx/ops" "gorgonia.org/tensor" ) +const ( + SqueezeMinInput = 2 + SqueezeMaxInput = 2 +) + // Squeeze represents the ONNX squeeze operator. type Squeeze struct{} @@ -17,19 +20,20 @@ func newSqueeze() ops.Operator { } // Init initializes the squeeze operator. -func (s *Squeeze) Init(attributes []*onnx.AttributeProto) error { +func (s *Squeeze) Init(_ []*onnx.AttributeProto) error { return nil } // Apply applies the squeeze operator. func (s *Squeeze) Apply(inputs []tensor.Tensor) ([]tensor.Tensor, error) { var err error + currentShape := inputs[0].Shape() nDims := len(currentShape) dimsToSqueeze := getDimsToSqueezeFromShape(currentShape) if !ops.AllInRange(dimsToSqueeze, -nDims, nDims-1) { - return nil, fmt.Errorf(ops.AxesNotAllInRangeErrTemplate, nDims, nDims) + return nil, ops.ErrNotAllAxisInRange(nDims, nDims) } // negative entries should be offset by the rank of the output tensor @@ -45,8 +49,13 @@ func (s *Squeeze) Apply(inputs []tensor.Tensor) ([]tensor.Tensor, error) { newShape := getNewShape(currentShape, dimsToSqueeze) - out := inputs[0].Clone().(tensor.Tensor) + out, ok := inputs[0].Clone().(tensor.Tensor) + if !ok { + return nil, nil + } + err = out.Reshape(newShape...) + return []tensor.Tensor{out}, err } @@ -57,12 +66,12 @@ func (s *Squeeze) ValidateInputs(inputs []tensor.Tensor) ([]tensor.Tensor, error // GetMinInputs returns the minimum number of input tensors this operator expects. func (s *Squeeze) GetMinInputs() int { - return 1 + return SqueezeMinInput } // GetMaxInputs returns the maximum number of input tensors this operator expects. func (s *Squeeze) GetMaxInputs() int { - return 2 + return SqueezeMaxInput } // GetInputTypeConstraints returns a list. Every element represents a set of allowed tensor dtypes @@ -91,29 +100,34 @@ func getDimsToSqueezeFromTensor(t tensor.Tensor, nDims int) ([]int, error) { dimsToSqueeze[i] = nDims + val } } + return dimsToSqueeze, nil } // getDimsToSqueezeFromShape creates a list with ints representing the dimensions/axes to squeeze // based on the current shape. All dimensions with only 1 value will be squeezed. func getDimsToSqueezeFromShape(shape []int) []int { - var res []int + result := []int{} + for i, size := range shape { if size == 1 { - res = append(res, i) + result = append(result, i) } } - return res + + return result } // getNewShape returns a new shape based on the current shape and a list of dims to squeeze. func getNewShape(currentShape tensor.Shape, dimsToSqueeze []int) []int { - var newShape []int + newShape := []int{} + for i, dimSize := range currentShape { if keepDim(i, dimsToSqueeze) { newShape = append(newShape, dimSize) } } + return newShape } @@ -124,5 +138,6 @@ func keepDim(dim int, dimsToSqueeze []int) bool { return false } } + return true } diff --git a/ops/validate_inputs.go b/ops/validate_inputs.go index 8b954cb..6995e81 100644 --- a/ops/validate_inputs.go +++ b/ops/validate_inputs.go @@ -1,8 +1,6 @@ package ops import ( - "fmt" - "gorgonia.org/tensor" ) @@ -38,19 +36,21 @@ func ValidateInputs(op Operator, inputs []tensor.Tensor) ([]tensor.Tensor, error func checkNInputs(op Operator, inputs []tensor.Tensor) (int, error) { nInputs := len(inputs) - var padLength int - + padLength := 0 min := op.GetMinInputs() max := op.GetMaxInputs() + if min == max { if nInputs != min { - return 0, fmt.Errorf(InvalidInputCountErrTemplate, op, min, nInputs) + return 0, ErrInvalidInputCount(op, nInputs) } + padLength = min } else { if nInputs < min || nInputs > max { - return 0, fmt.Errorf(InvalidOptionalInputCountErrTemplate, op, min, max, nInputs) + return 0, ErrInvalidOptionalInputCount(op, nInputs) } + padLength = max } @@ -62,11 +62,13 @@ func padInputs(inputs []tensor.Tensor, length int) []tensor.Tensor { for len(inputs) < length { inputs = append(inputs, nil) } + return inputs } func checkInputTypes(op Operator, inputs []tensor.Tensor) error { typeConstraints := op.GetInputTypeConstraints() + for i, input := range inputs { // Optional inputs can be nil, we can not check for type constraints then. if input == nil { @@ -76,9 +78,10 @@ func checkInputTypes(op Operator, inputs []tensor.Tensor) error { typeConstraint := newTypeConstraint(typeConstraints[i]) if _, ok := typeConstraint[input.Dtype()]; !ok { - return fmt.Errorf("%v: input %d does not allow type %v", op, i, input.Dtype()) + return ErrInvalidInputType(op, i, input.Dtype().Name()) } } + return nil } @@ -89,5 +92,6 @@ func newTypeConstraint(allowedTypes []tensor.Dtype) map[tensor.Dtype]bool { for _, allowedType := range allowedTypes { typeConstraint[allowedType] = true } + return typeConstraint } From ba704a68819830bebafa83bdc93ce6a86a86cff2 Mon Sep 17 00:00:00 2001 From: wisse Date: Wed, 13 Sep 2023 15:06:05 +0200 Subject: [PATCH 04/15] WIP: more lint fixes --- ops/errors.go | 51 +++++++++++++++++------ ops/opset13/add_test.go | 4 +- ops/opset13/cast.go | 12 +++++- ops/opset13/cast_test.go | 6 +-- ops/opset13/concat.go | 13 ++++-- ops/opset13/concat_test.go | 3 +- ops/opset13/constant.go | 20 ++++++--- ops/opset13/constant_of_shape.go | 60 ++++++++++++++------------- ops/opset13/constant_of_shape_test.go | 13 +++--- ops/opset13/constant_test.go | 1 + 10 files changed, 118 insertions(+), 65 deletions(-) diff --git a/ops/errors.go b/ops/errors.go index b91d187..6259740 100644 --- a/ops/errors.go +++ b/ops/errors.go @@ -8,8 +8,16 @@ import ( "gorgonia.org/tensor" ) +type AttributeErrorKind string + +const ( + AttributeErrorCount AttributeErrorKind = "count" + AttributeErrorInvalid AttributeErrorKind = "invalid" + AttributeErrorUnsupported AttributeErrorKind = "unsupported" +) + type AttributeError struct { - kind string + kind AttributeErrorKind attributeCount int expectedCount int attributeName string @@ -18,23 +26,29 @@ type AttributeError struct { func (t *AttributeError) Error() string { switch t.kind { - case "count": + case AttributeErrorCount: return fmt.Sprintf("operator %s attribute error: invalid count %d expected %d", t.operator.String(), t.attributeCount, t.expectedCount) - case "name": + case AttributeErrorInvalid: return fmt.Sprintf("operator %s attribute error: invalid attribute %s", t.operator.String(), t.attributeName) + case AttributeErrorUnsupported: + return fmt.Sprintf("operator %s attribute error: unsupported attribute %s", t.operator.String(), t.attributeName) + default: + return fmt.Sprintf("operator %s unknown error attribute error kind %s", t.operator.String(), t.kind) } - - return fmt.Sprintf("attribute error") } -func ErrInvalidAttribute(attributeName string, operator Operator) error { - return &AttributeError{attributeName: attributeName, kind: "count", operator: operator} +func ErrInvalidAttribute(attributeName string, operator Operator) *AttributeError { + return &AttributeError{attributeName: attributeName, kind: "invalid", operator: operator} } func ErrInvalidAttributeCount(expected, actual int, operator Operator) error { return &AttributeError{attributeCount: actual, expectedCount: expected, kind: "count", operator: operator} } +func ErrUnsupportedAttribute(attributeName string, operator Operator) error { + return &AttributeError{attributeName: attributeName, kind: "unsupported", operator: operator} +} + type TypeAssertError struct { expectedType string actualType any @@ -68,18 +82,18 @@ const InvalidInputCountErrTemplate = "%v: expected %d input tensors, got %d" // the wrong amount of input tensors when optional inputs are present. const InvalidOptionalInputCountErrTemplate = "%v: expected %d-%d input tensors, got %d" -type InvalidInputTypeError struct { +type InvalidInputError struct { inputNumber int actualType string operator Operator } -func (i *InvalidInputTypeError) Error() string { +func (i *InvalidInputError) Error() string { return fmt.Sprintf("input %d for op %v does not allow dtype %v", i.inputNumber, i.operator, i.actualType) } -func ErrInvalidInputType(operator Operator, inputNumber int, dType string) error { - return &InvalidInputTypeError{ +func ErrInvalidInputType(inputNumber int, dType string, operator Operator) error { + return &InvalidInputError{ operator: operator, inputNumber: inputNumber, actualType: dType, @@ -100,7 +114,7 @@ func (i *InvalidInputCountError) Error() string { return fmt.Sprintf(InvalidInputCountErrTemplate, i.operator, i.operator.GetMinInputs(), i.actualCount) } -func ErrInvalidInputCount(operator Operator, actual int) error { +func ErrInvalidInputCount(actual int, operator Operator) error { return &InvalidInputCountError{ actualCount: actual, operator: operator, @@ -143,6 +157,19 @@ func ErrUnidirBroadcast(shapeA, shapeB tensor.Shape) error { } } +type InvalidTensorError struct { + reason string + operator Operator +} + +func (i *InvalidTensorError) Error() string { + return fmt.Sprintf("%v invalid tensor found, reason: %s", i.operator.String(), i.reason) +} + +func ErrInvalidTensor(reason string, operator Operator) error { + return &InvalidTensorError{reason: reason, operator: operator} +} + var ErrIncompatibleDimension = errors.New("incompatible dimensions") // UnknowOpTypeErrTemplate is used to format an error when the operator type is unknown. diff --git a/ops/opset13/add_test.go b/ops/opset13/add_test.go index d69df55..2b48621 100644 --- a/ops/opset13/add_test.go +++ b/ops/opset13/add_test.go @@ -120,14 +120,14 @@ func TestInputValidationAdd(t *testing.T) { []tensor.Tensor{ ops.TensorWithBackingFixture([]int{1, 2}, 2), }, - ops.ErrInvalidInputCount(&Add{}, 1), + ops.ErrInvalidInputCount(1, &Add{}), }, { []tensor.Tensor{ ops.TensorWithBackingFixture([]int{1, 2}, 2), ops.TensorWithBackingFixture([]int{3, 4}, 2), }, - ops.ErrInvalidInputType(&Add{}, 0, "int"), + ops.ErrInvalidInputType(0, "int", &Add{}), }, } diff --git a/ops/opset13/cast.go b/ops/opset13/cast.go index f039710..d465fe7 100644 --- a/ops/opset13/cast.go +++ b/ops/opset13/cast.go @@ -6,6 +6,14 @@ import ( "gorgonia.org/tensor" ) +const ( + // MinCastInput is the minimimum amount of inputs the add operator expects. + MinCastInput = 1 + + // MaxCastInput is the maximum amount of inputs the add operator accepts. + MaxCastInput = 1 +) + // Cast represents the ONNX cast operator. type Cast struct { to int32 // DataType to cast to, as defined by TensorProto @@ -49,12 +57,12 @@ func (c *Cast) ValidateInputs(inputs []tensor.Tensor) ([]tensor.Tensor, error) { // GetMinInputs returns the minimum number of input tensors this operator expects. func (c *Cast) GetMinInputs() int { - return 1 + return MinCastInput } // GetMaxInputs returns the maximum number of input tensors this operator expects. func (c *Cast) GetMaxInputs() int { - return 1 + return MaxCastInput } // GetInputTypeConstraints returns a list. Every element represents a set of allowed tensor dtypes diff --git a/ops/opset13/cast_test.go b/ops/opset13/cast_test.go index 593efe4..5ab663b 100644 --- a/ops/opset13/cast_test.go +++ b/ops/opset13/cast_test.go @@ -1,7 +1,6 @@ package opset13 import ( - "fmt" "testing" "github.com/advancedclimatesystems/gonnx/onnx" @@ -93,13 +92,13 @@ func TestInputValidationCast(t *testing.T) { ops.TensorWithBackingFixture([]float64{1, 2}, 2), ops.TensorWithBackingFixture([]float64{3, 4}, 2), }, - fmt.Errorf("cast operator: expected 1 input tensors, got 2"), + ops.ErrInvalidInputCount(2, &Cast{}), }, { []tensor.Tensor{ ops.TensorWithBackingFixture([]bool{true, false}, 2), }, - fmt.Errorf("cast operator: input 0 does not allow type bool"), + ops.ErrInvalidInputType(1, "bool", &Cast{}), }, } @@ -108,6 +107,7 @@ func TestInputValidationCast(t *testing.T) { validated, err := cast.ValidateInputs(test.inputs) assert.Equal(t, test.err, err) + if test.err == nil { assert.Equal(t, test.inputs, validated) } diff --git a/ops/opset13/concat.go b/ops/opset13/concat.go index 9795127..1327885 100644 --- a/ops/opset13/concat.go +++ b/ops/opset13/concat.go @@ -1,13 +1,16 @@ package opset13 import ( - "fmt" - "github.com/advancedclimatesystems/gonnx/onnx" "github.com/advancedclimatesystems/gonnx/ops" "gorgonia.org/tensor" ) +const ( + // MinConcatInput is the minimimum amount of inputs the add operator expects. + MinConcatInput = 1 +) + // Concat represents the ONNX concat operator. type Concat struct { axis int @@ -23,10 +26,11 @@ func newConcat() ops.Operator { // Init initializes the concat operator. func (c *Concat) Init(attributes []*onnx.AttributeProto) error { if len(attributes) != 1 { - return fmt.Errorf(ops.InvalidAttrCountErrTemplate, c, 1, len(attributes)) + return ops.ErrInvalidAttributeCount(1, len(attributes), c) } c.axis = int(attributes[0].GetI()) + return nil } @@ -56,6 +60,7 @@ func (c *Concat) ValidateInputs(inputs []tensor.Tensor) ([]tensor.Tensor, error) // of inputs dynamically, based on our inputs. Every input can have any type. c.maxInputs = len(inputs) c.inputTypeConstraints = make([][]tensor.Dtype, len(inputs)) + for i := 0; i < len(inputs); i++ { c.inputTypeConstraints[i] = ops.AllTypes } @@ -65,7 +70,7 @@ func (c *Concat) ValidateInputs(inputs []tensor.Tensor) ([]tensor.Tensor, error) // GetMinInputs returns the minimum number of input tensors this operator expects. func (c *Concat) GetMinInputs() int { - return 1 + return MinConcatInput } // GetMaxInputs returns the maximum number of input tensors this operator expects. diff --git a/ops/opset13/concat_test.go b/ops/opset13/concat_test.go index 0ad0d46..3f83843 100644 --- a/ops/opset13/concat_test.go +++ b/ops/opset13/concat_test.go @@ -1,7 +1,6 @@ package opset13 import ( - "fmt" "testing" "github.com/advancedclimatesystems/gonnx/onnx" @@ -22,7 +21,7 @@ func TestConcatInitFail(t *testing.T) { concat := &Concat{} err := concat.Init([]*onnx.AttributeProto{}) - expected := fmt.Errorf(ops.InvalidAttrCountErrTemplate, concat, 1, 0) + expected := ops.ErrInvalidAttributeCount(1, 0, concat) assert.Equal(t, expected, err) } diff --git a/ops/opset13/constant.go b/ops/opset13/constant.go index 130d4dd..43f2988 100644 --- a/ops/opset13/constant.go +++ b/ops/opset13/constant.go @@ -1,13 +1,19 @@ package opset13 import ( - "fmt" - "github.com/advancedclimatesystems/gonnx/onnx" "github.com/advancedclimatesystems/gonnx/ops" "gorgonia.org/tensor" ) +const ( + // MinConstInput is the minimimum amount of inputs the add operator expects. + MinConstInput = 1 + + // MaxConstInput is the maximum amount of inputs the add operator accepts. + MaxConstInput = 1 +) + // Constant represents the ONNX constant operator. type Constant struct { value tensor.Tensor @@ -22,18 +28,20 @@ func newConstant() ops.Operator { // `sparse_value`, `value_string`, and `value_strings`. func (c *Constant) Init(attributes []*onnx.AttributeProto) error { if len(attributes) != 1 { - return fmt.Errorf(ops.InvalidAttrCountErrTemplate, c, 1, len(attributes)) + return ops.ErrInvalidAttributeCount(1, len(attributes), c) } + attr := attributes[0] switch attr.GetName() { case "sparse_value", "value_string", "value_strings": - return fmt.Errorf(ops.UnsupportedAttrErrTemplate, c, attr.GetName()) + return ops.ErrInvalidAttribute(attr.GetName(), c) case "value": t, err := onnx.TensorFromProto(attr.GetT()) if err != nil { return err } + c.value = t case "value_float": c.value = tensor.New(tensor.FromScalar(attr.GetF())) @@ -46,14 +54,14 @@ func (c *Constant) Init(attributes []*onnx.AttributeProto) error { ints := attr.GetInts() c.value = tensor.New(tensor.WithShape(len(ints)), tensor.WithBacking(ints)) default: - return fmt.Errorf(ops.UnknownAttributeErrTemplate, c, attr.GetName()) + return ops.ErrInvalidAttribute(attr.GetName(), c) } return nil } // Apply applies the constant operator. -func (c *Constant) Apply(inputs []tensor.Tensor) ([]tensor.Tensor, error) { +func (c *Constant) Apply(_ []tensor.Tensor) ([]tensor.Tensor, error) { return []tensor.Tensor{c.value}, nil } diff --git a/ops/opset13/constant_of_shape.go b/ops/opset13/constant_of_shape.go index 07d919d..e861cac 100644 --- a/ops/opset13/constant_of_shape.go +++ b/ops/opset13/constant_of_shape.go @@ -1,13 +1,19 @@ package opset13 import ( - "fmt" - "github.com/advancedclimatesystems/gonnx/onnx" "github.com/advancedclimatesystems/gonnx/ops" "gorgonia.org/tensor" ) +const ( + // MinConstanShapeOfInput is the minimimum amount of inputs the add operator expects. + MinConstanShapeOfInput = 1 + + // MaxConstanShapeOfInput is the maximum amount of inputs the add operator accepts. + MaxConstanShapeOfInput = 1 +) + // ConstantOfShape represents the ONNX constant of shape operator. type ConstantOfShape struct { // One element tensor, giving the value and type of the output tensor @@ -21,9 +27,9 @@ func newConstantOfShape() ops.Operator { } // Init initializes the constant of shape operator. -func (op *ConstantOfShape) Init(attributes []*onnx.AttributeProto) error { +func (c *ConstantOfShape) Init(attributes []*onnx.AttributeProto) error { if len(attributes) > 1 { - return fmt.Errorf(ops.InvalidAttrCountErrTemplate, op, "0 or 1", len(attributes)) + return ops.ErrInvalidAttributeCount(1, len(attributes), c) } if len(attributes) == 1 { @@ -34,26 +40,22 @@ func (op *ConstantOfShape) Init(attributes []*onnx.AttributeProto) error { return err } - op.value = tensor.New(tensor.WithBacking(t.Data())) - if op.value.Len() != 1 { - return fmt.Errorf( - "Value input tensor should be a single element tensor, but was %v", - op.value, - ) + c.value = tensor.New(tensor.WithBacking(t.Data())) + if c.value.Len() != 1 { + return ops.ErrInvalidTensor("expected tensor to have one element", c) } } else { - return fmt.Errorf(ops.UnknownAttributeErrTemplate, op, attr.GetName()) + return ops.ErrInvalidAttribute(attr.GetName(), c) } } else { - // Default - op.value = tensor.New(tensor.FromScalar(float32(0.0))) + c.value = tensor.New(tensor.FromScalar(float32(0.0))) } return nil } // Apply applies the constant of shape operator. -func (op *ConstantOfShape) Apply(inputs []tensor.Tensor) ([]tensor.Tensor, error) { +func (c *ConstantOfShape) Apply(inputs []tensor.Tensor) ([]tensor.Tensor, error) { shape, err := ops.AnyToIntSlice(ops.IfScalarToSlice(inputs[0].Data())) if err != nil { return nil, err @@ -62,42 +64,44 @@ func (op *ConstantOfShape) Apply(inputs []tensor.Tensor) ([]tensor.Tensor, error // Empty dimensions in a tensor are not supported for i := range shape { if shape[i] <= 0 { - return nil, fmt.Errorf( - "Non positive dimensions are not allowed (must be > 0). Given: %v", - shape, - ) + return nil, ops.ErrInvalidTensor("no empty dimensions are allowed", c) } } - t := tensor.New(tensor.WithShape(shape...), tensor.Of(op.value.Dtype())) - t, err = t.AddScalar(op.value, true) + + t := tensor.New(tensor.WithShape(shape...), tensor.Of(c.value.Dtype())) + + t, err = t.AddScalar(c.value, true) + if err != nil { + return nil, err + } return []tensor.Tensor{t}, err } // ValidateInputs validates the inputs that will be given to Apply for this operator. -func (op *ConstantOfShape) ValidateInputs(inputs []tensor.Tensor) ([]tensor.Tensor, error) { - return ops.ValidateInputs(op, inputs) +func (c *ConstantOfShape) ValidateInputs(inputs []tensor.Tensor) ([]tensor.Tensor, error) { + return ops.ValidateInputs(c, inputs) } // GetMinInputs returns the minimum number of input tensors this operator expects. -func (op *ConstantOfShape) GetMinInputs() int { - return 1 +func (c *ConstantOfShape) GetMinInputs() int { + return MinConstanShapeOfInput } // GetMaxInputs returns the maximum number of input tensors this operator expects. -func (op *ConstantOfShape) GetMaxInputs() int { - return 1 +func (c *ConstantOfShape) GetMaxInputs() int { + return MaxConstanShapeOfInput } // GetInputTypeConstraints returns a list. Every element represents a set of allowed tensor dtypes // for the corresponding input tensor. -func (op *ConstantOfShape) GetInputTypeConstraints() [][]tensor.Dtype { +func (c *ConstantOfShape) GetInputTypeConstraints() [][]tensor.Dtype { return [][]tensor.Dtype{ {tensor.Int64}, } } // String implements the stringer interface, and can be used to format errors or messages. -func (op *ConstantOfShape) String() string { +func (c *ConstantOfShape) String() string { return "constant of shape operator" } diff --git a/ops/opset13/constant_of_shape_test.go b/ops/opset13/constant_of_shape_test.go index 658c89c..97cd59c 100644 --- a/ops/opset13/constant_of_shape_test.go +++ b/ops/opset13/constant_of_shape_test.go @@ -2,7 +2,6 @@ package opset13 import ( "encoding/binary" - "fmt" "testing" "github.com/advancedclimatesystems/gonnx/onnx" @@ -19,6 +18,7 @@ func TensorProtoFromNumber(n interface{}) *onnx.TensorProto { size := 1 rawData := make([]byte, size) rawData[0] = uint8(x) + return &onnx.TensorProto{ DataType: onnx.TensorProto_DataType_value["INT8"], Dims: []int64{1}, @@ -29,6 +29,7 @@ func TensorProtoFromNumber(n interface{}) *onnx.TensorProto { size := 2 rawData := make([]byte, size) binary.LittleEndian.PutUint16(rawData, uint16(x)) + return &onnx.TensorProto{ DataType: onnx.TensorProto_DataType_value["INT16"], Dims: []int64{1}, @@ -142,7 +143,7 @@ func TestIncorrectInput(t *testing.T) { func TestNegativeShapeNotAllowed(t *testing.T) { op := &ConstantOfShape{} - op.Init([]*onnx.AttributeProto{}) + _ = op.Init([]*onnx.AttributeProto{}) shape := []int64{1, -1} @@ -158,7 +159,7 @@ func TestNegativeShapeNotAllowed(t *testing.T) { func TestEmptyTensorNotAllowed(t *testing.T) { op := &ConstantOfShape{} - op.Init([]*onnx.AttributeProto{}) + _ = op.Init([]*onnx.AttributeProto{}) shape := []int64{0} @@ -174,7 +175,7 @@ func TestEmptyTensorNotAllowed(t *testing.T) { func TestScalarShapeInput(t *testing.T) { op := &ConstantOfShape{} - op.Init([]*onnx.AttributeProto{}) + _ = op.Init([]*onnx.AttributeProto{}) shape := []int64{6} input := tensor.New(tensor.WithBacking(shape)) @@ -198,11 +199,11 @@ func TestInputValidationConstantOfShape(t *testing.T) { }, { []tensor.Tensor{}, - fmt.Errorf("constant of shape operator: expected 1 input tensors, got 0"), + ops.ErrInvalidInputCount(1, &ConstantOfShape{}), }, { []tensor.Tensor{ops.TensorWithBackingFixture([]int{1, 2}, 2)}, - fmt.Errorf("constant of shape operator: input 0 does not allow type int"), + ops.ErrInvalidInputType(0, "int", &ConstantOfShape{}), }, } diff --git a/ops/opset13/constant_test.go b/ops/opset13/constant_test.go index 739041a..2817bf5 100644 --- a/ops/opset13/constant_test.go +++ b/ops/opset13/constant_test.go @@ -46,6 +46,7 @@ func TestConstantInit(t *testing.T) { []*onnx.AttributeProto{{Name: "sparse_value"}}, nil, fmt.Errorf(ops.UnsupportedAttrErrTemplate, &Constant{}, "sparse_value"), + ops.ErrUnsupportedAttribute("sparse_value", &Constant{}) }, { []*onnx.AttributeProto{{Name: "unknownAttribute"}}, From db8d50c83c9c4916603af38fbeaff2058ed97d90 Mon Sep 17 00:00:00 2001 From: wisse Date: Thu, 14 Sep 2023 15:08:08 +0200 Subject: [PATCH 05/15] WIP: More lint fixes --- ops/errors.go | 7 ++++-- ops/opset13/div.go | 12 ++++++++-- ops/opset13/div_test.go | 5 ++--- ops/opset13/gather.go | 17 +++++++++++--- ops/opset13/gemm.go | 16 +++++++++----- ops/opset13/gemm_test.go | 6 ++--- ops/opset13/gru.go | 10 ++++++++- ops/opset13/gru_test.go | 29 ++++++++++++------------ ops/opset13/matmul.go | 44 +++++++++++++++++++++++++++---------- ops/opset13/matmul_test.go | 5 ++--- ops/opset13/mul.go | 12 ++++++++-- ops/opset13/mul_test.go | 4 ++-- ops/opset13/opset13.go | 11 +++++----- ops/opset13/opset13_test.go | 3 +-- ops/opset13/prelu.go | 2 +- 15 files changed, 122 insertions(+), 61 deletions(-) diff --git a/ops/errors.go b/ops/errors.go index 6259740..9c35a2c 100644 --- a/ops/errors.go +++ b/ops/errors.go @@ -172,8 +172,11 @@ func ErrInvalidTensor(reason string, operator Operator) error { var ErrIncompatibleDimension = errors.New("incompatible dimensions") -// UnknowOpTypeErrTemplate is used to format an error when the operator type is unknown. -const UnknowOpTypeErrTemplate = "unknown operator type: %v" +var UnknownOperatorTypeError = errors.New("unknown operator type") + +func ErrUnknownOperatorType(operatorType string) error { + return fmt.Errorf("%w: %s", UnknownOperatorTypeError, operatorType) +} // MultidirBroadcastErrTemplate is used to format an error when two inputs cannot be // broadcasted together with Multidirectional broadcasting. diff --git a/ops/opset13/div.go b/ops/opset13/div.go index 87b5cd3..a2df772 100644 --- a/ops/opset13/div.go +++ b/ops/opset13/div.go @@ -6,6 +6,14 @@ import ( "gorgonia.org/tensor" ) +const ( + // MinDivInput is the minimimum amount of inputs the add operator expects. + MinDivInput = 2 + + // MaxDivInput is the maximum amount of inputs the add operator accepts. + MaxDivInput = 2 +) + // Div represents the ONNX div operator. type Div struct{} @@ -41,12 +49,12 @@ func (d *Div) ValidateInputs(inputs []tensor.Tensor) ([]tensor.Tensor, error) { // GetMinInputs returns the minimum number of input tensors this operator expects. func (d *Div) GetMinInputs() int { - return 2 + return MinDivInput } // GetMaxInputs returns the maximum number of input tensors this operator expects. func (d *Div) GetMaxInputs() int { - return 2 + return MaxDivInput } // GetInputTypeConstraints returns a list. Every element represents a set of allowed tensor dtypes diff --git a/ops/opset13/div_test.go b/ops/opset13/div_test.go index 420cab4..b0c7ae0 100644 --- a/ops/opset13/div_test.go +++ b/ops/opset13/div_test.go @@ -1,7 +1,6 @@ package opset13 import ( - "fmt" "testing" "github.com/advancedclimatesystems/gonnx/ops" @@ -108,14 +107,14 @@ func TestInputValidationDiv(t *testing.T) { []tensor.Tensor{ ops.TensorWithBackingFixture([]int{1, 2}, 2), }, - fmt.Errorf("div operator: expected 2 input tensors, got 1"), + ops.ErrInvalidInputCount(1.0, &Div{}), }, { []tensor.Tensor{ ops.TensorWithBackingFixture([]int{1, 2}, 2), ops.TensorWithBackingFixture([]int{3, 4}, 2), }, - fmt.Errorf("div operator: input 0 does not allow type int"), + ops.ErrInvalidInputType(0, "int", &Div{}), }, } diff --git a/ops/opset13/gather.go b/ops/opset13/gather.go index 19f5199..1449e08 100644 --- a/ops/opset13/gather.go +++ b/ops/opset13/gather.go @@ -8,6 +8,14 @@ import ( "gorgonia.org/tensor" ) +const ( + // MinGatherInput is the minimimum amount of inputs the add operator expects. + MinGatherInput = 2 + + // MaxGatherInput is the maximum amount of inputs the add operator accepts. + MaxGatherInput = 2 +) + // Gather represents the ONNX gather operator. type Gather struct { axis int // axis to gather on, default is 0 @@ -23,16 +31,19 @@ func (g *Gather) Init(attributes []*onnx.AttributeProto) error { switch length := len(attributes); { case length > 1: return fmt.Errorf(ops.InvalidAttrCountErrTemplate, g, "0 or 1", len(attributes)) + case length == 1: attr := attributes[0] + if attr.GetName() == "axis" { g.axis = int(attr.GetI()) } else { - return fmt.Errorf(ops.UnknownAttributeErrTemplate, g, attr.GetName()) + return ops.ErrInvalidAttribute(attr.GetName(), g) } default: g.axis = 0 } + return nil } @@ -85,12 +96,12 @@ func (g *Gather) ValidateInputs(inputs []tensor.Tensor) ([]tensor.Tensor, error) // GetMinInputs returns the minimum number of input tensors this operator expects. func (g *Gather) GetMinInputs() int { - return 2 + return MinGatherInput } // GetMaxInputs returns the maximum number of input tensors this operator expects. func (g *Gather) GetMaxInputs() int { - return 2 + return MaxGatherInput } // GetInputTypeConstraints returns a list. Every element represents a set of allowed tensor dtypes diff --git a/ops/opset13/gemm.go b/ops/opset13/gemm.go index dca1d18..09202cd 100644 --- a/ops/opset13/gemm.go +++ b/ops/opset13/gemm.go @@ -1,13 +1,19 @@ package opset13 import ( - "fmt" - "github.com/advancedclimatesystems/gonnx/onnx" "github.com/advancedclimatesystems/gonnx/ops" "gorgonia.org/tensor" ) +const ( + // MinGemmInput is the minimimum amount of inputs the add operator expects. + MinGemmInput = 2 + + // MaxGemmInput is the maximum amount of inputs the add operator accepts. + MaxGemmInput = 3 +) + // Gemm represents the ONNX gemm operator. type Gemm struct { alpha float32 @@ -39,7 +45,7 @@ func (g *Gemm) Init(attributes []*onnx.AttributeProto) error { case "transB": g.transB = ops.Int64ToBool(attr.GetI()) default: - return fmt.Errorf(ops.UnknownAttributeErrTemplate, g, attr.GetName()) + return ops.ErrInvalidAttribute(attr.GetName(), g) } } @@ -107,12 +113,12 @@ func (g *Gemm) ValidateInputs(inputs []tensor.Tensor) ([]tensor.Tensor, error) { // GetMinInputs returns the minimum number of input tensors this operator expects. func (g *Gemm) GetMinInputs() int { - return 2 + return MinGemmInput } // GetMaxInputs returns the maximum number of input tensors this operator expects. func (g *Gemm) GetMaxInputs() int { - return 3 + return MaxGemmInput } // GetInputTypeConstraints returns a list. Every element represents a set of allowed tensor dtypes diff --git a/ops/opset13/gemm_test.go b/ops/opset13/gemm_test.go index 1bdf09f..fde5001 100644 --- a/ops/opset13/gemm_test.go +++ b/ops/opset13/gemm_test.go @@ -134,7 +134,7 @@ func TestInputValidationGemm(t *testing.T) { { []tensor.Tensor{ops.TensorWithBackingFixture([]int{1, 2}, 2)}, nil, - fmt.Errorf("gemm operator: expected 2-3 input tensors, got 1"), + ops.ErrInvalidInputCount(1, &Gemm{}), }, { []tensor.Tensor{ @@ -144,7 +144,7 @@ func TestInputValidationGemm(t *testing.T) { ops.TensorWithBackingFixture([]uint32{1, 2}, 2), }, nil, - fmt.Errorf("gemm operator: expected 2-3 input tensors, got 4"), + ops.ErrInvalidInputCount(4, &Gemm{}), }, { []tensor.Tensor{ @@ -152,7 +152,7 @@ func TestInputValidationGemm(t *testing.T) { ops.TensorWithBackingFixture([]int{3, 4}, 2), }, nil, - fmt.Errorf("gemm operator: input 0 does not allow type int"), + ops.ErrInvalidInputType(0, "int", &Gemm{}), }, } diff --git a/ops/opset13/gru.go b/ops/opset13/gru.go index a642f5b..b4a1405 100644 --- a/ops/opset13/gru.go +++ b/ops/opset13/gru.go @@ -8,6 +8,14 @@ import ( "gorgonia.org/tensor" ) +const ( + // MinGRUInput is the minimimum amount of inputs the add operator expects. + MinGRUInput = 3 + + // MaxGRUInput is the maximum amount of inputs the add operator accepts. + MaxGRUInput = 6 +) + // GRU represents the ONNX gru operator. It only supports a simple forward gru // operation with default activations. type GRU struct { @@ -35,7 +43,7 @@ func (g *GRU) Init(attributes []*onnx.AttributeProto) error { case "linear_before_reset": g.linearBeforeReset = ops.Int64ToBool(attr.GetI()) default: - return fmt.Errorf(ops.UnsupportedAttrErrTemplate, g, attr.GetName()) + return ops.ErrInvalidAttribute(attr.GetName(), g) } } diff --git a/ops/opset13/gru_test.go b/ops/opset13/gru_test.go index 71b5a82..683056b 100644 --- a/ops/opset13/gru_test.go +++ b/ops/opset13/gru_test.go @@ -1,7 +1,6 @@ package opset13 import ( - "fmt" "testing" "github.com/advancedclimatesystems/gonnx/onnx" @@ -27,27 +26,27 @@ func TestGruInitUnkownAttr(t *testing.T) { }{ { []*onnx.AttributeProto{{Name: "activation_alpha"}}, - fmt.Errorf(ops.UnsupportedAttrErrTemplate, &gru, "activation_alpha"), + ops.ErrInvalidAttribute("activation_alpha", &gru), }, { []*onnx.AttributeProto{{Name: "activation_beta"}}, - fmt.Errorf(ops.UnsupportedAttrErrTemplate, &gru, "activation_beta"), + ops.ErrInvalidAttribute("activation_beta", &gru), }, { []*onnx.AttributeProto{{Name: "direction"}}, - fmt.Errorf(ops.UnsupportedAttrErrTemplate, &gru, "direction"), + ops.ErrInvalidAttribute("direction", &gru), }, { []*onnx.AttributeProto{{Name: "clip"}}, - fmt.Errorf(ops.UnsupportedAttrErrTemplate, &gru, "clip"), + ops.ErrInvalidAttribute("clip", &gru), }, { []*onnx.AttributeProto{{Name: "activation"}}, - fmt.Errorf(ops.UnsupportedAttrErrTemplate, &gru, "activation"), + ops.ErrInvalidAttribute("activation", &gru), }, { []*onnx.AttributeProto{{Name: "unknown"}}, - fmt.Errorf(ops.UnsupportedAttrErrTemplate, &gru, "unknown"), + ops.ErrInvalidAttribute("unknown", &gru), }, } @@ -138,7 +137,7 @@ func TestInputValidationGRU(t *testing.T) { { []tensor.Tensor{ops.TensorWithBackingFixture([]float32{1, 2}, 2)}, nil, - fmt.Errorf("gru operator: expected 3-6 input tensors, got 1"), + ops.ErrInvalidInputCount(1, &GRU{}), }, { []tensor.Tensor{ @@ -147,7 +146,7 @@ func TestInputValidationGRU(t *testing.T) { ops.TensorWithBackingFixture([]float32{1, 2}, 2), }, nil, - fmt.Errorf("gru operator: input 1 does not allow type int"), + ops.ErrInvalidInputType(1, "int", &GRU{}), }, { []tensor.Tensor{ @@ -156,7 +155,7 @@ func TestInputValidationGRU(t *testing.T) { ops.TensorWithBackingFixture([]float32{1, 2}, 2), }, nil, - fmt.Errorf("gru operator: input 0 does not allow type int"), + ops.ErrInvalidInputType(0, "int", &GRU{}), }, { []tensor.Tensor{ @@ -165,7 +164,7 @@ func TestInputValidationGRU(t *testing.T) { ops.TensorWithBackingFixture([]float32{1, 2}, 2), }, nil, - fmt.Errorf("gru operator: input 1 does not allow type int"), + ops.ErrInvalidInputType(1, "int", &GRU{}), }, { []tensor.Tensor{ @@ -174,7 +173,7 @@ func TestInputValidationGRU(t *testing.T) { ops.TensorWithBackingFixture([]int{1, 2}, 2), }, nil, - fmt.Errorf("gru operator: input 2 does not allow type int"), + ops.ErrInvalidInputType(2, "int", &GRU{}), }, { []tensor.Tensor{ @@ -184,7 +183,7 @@ func TestInputValidationGRU(t *testing.T) { ops.TensorWithBackingFixture([]int{1, 2}, 2), }, nil, - fmt.Errorf("gru operator: input 3 does not allow type int"), + ops.ErrInvalidInputType(3, "int", &GRU{}), }, { []tensor.Tensor{ @@ -195,7 +194,7 @@ func TestInputValidationGRU(t *testing.T) { ops.TensorWithBackingFixture([]float32{1, 2}, 2), }, nil, - fmt.Errorf("gru operator: input 4 does not allow type float32"), + ops.ErrInvalidInputType(4, "float32", &GRU{}), }, { []tensor.Tensor{ @@ -207,7 +206,7 @@ func TestInputValidationGRU(t *testing.T) { ops.TensorWithBackingFixture([]int{1, 2}, 2), }, nil, - fmt.Errorf("gru operator: input 5 does not allow type int"), + ops.ErrInvalidInputType(4, "int", &GRU{}), }, } diff --git a/ops/opset13/matmul.go b/ops/opset13/matmul.go index 6f11ed9..5df4fb6 100644 --- a/ops/opset13/matmul.go +++ b/ops/opset13/matmul.go @@ -8,6 +8,14 @@ import ( "gorgonia.org/tensor" ) +const ( + // MinMatMulInput is the minimimum amount of inputs the add operator expects. + MinMatMulInput = 2 + + // MaxMatMulInput is the maximum amount of inputs the add operator accepts. + MaxMatMulInput = 2 +) + // MatMul represents the ONNX matmul operator. type MatMul struct{} @@ -29,6 +37,10 @@ func (m *MatMul) Apply(inputs []tensor.Tensor) ([]tensor.Tensor, error) { // If both are normal matrices, apply normal matrix multiplication. if len(A.Shape()) == 2 && len(B.Shape()) == 2 { out, err := tensor.MatMul(A, B) + if err != nil { + return nil, err + } + return []tensor.Tensor{out}, err } @@ -36,9 +48,13 @@ func (m *MatMul) Apply(inputs []tensor.Tensor) ([]tensor.Tensor, error) { prependedDimension := false if len(A.Shape()) == 1 { prependedDimension = true - A = A.Clone().(tensor.Tensor) - err := A.Reshape(1, A.Shape()[0]) - if err != nil { + + A, ok := A.Clone().(tensor.Tensor) + if !ok { + return nil, ops.ErrTypeAssert("tensor.Tensor", A) + } + + if err := A.Reshape(1, A.Shape()[0]); err != nil { return nil, err } } @@ -47,9 +63,13 @@ func (m *MatMul) Apply(inputs []tensor.Tensor) ([]tensor.Tensor, error) { appendedDimension := false if len(B.Shape()) == 1 { appendedDimension = true - B = B.Clone().(tensor.Tensor) - err := B.Reshape(B.Shape()[0], 1) - if err != nil { + + B, ok := B.Clone().(tensor.Tensor) + if !ok { + return nil, ops.ErrTypeAssert("tensor.Tensor", A) + } + + if err := B.Reshape(B.Shape()[0], 1); err != nil { return nil, err } } @@ -71,8 +91,8 @@ func (m *MatMul) Apply(inputs []tensor.Tensor) ([]tensor.Tensor, error) { currentShape := out.Shape().Clone() newShape := currentShape[:len(currentShape)-2] newShape = append(newShape, currentShape[len(currentShape)-1]) - err = out.Reshape(newShape...) - if err != nil { + + if err := out.Reshape(newShape...); err != nil { return nil, err } } @@ -80,8 +100,8 @@ func (m *MatMul) Apply(inputs []tensor.Tensor) ([]tensor.Tensor, error) { if appendedDimension { currentShape := out.Shape().Clone() newShape := currentShape[:len(currentShape)-1] - err = out.Reshape(newShape...) - if err != nil { + + if err = out.Reshape(newShape...); err != nil { return nil, err } } @@ -96,12 +116,12 @@ func (m *MatMul) ValidateInputs(inputs []tensor.Tensor) ([]tensor.Tensor, error) // GetMinInputs returns the minimum number of input tensors this operator expects. func (m *MatMul) GetMinInputs() int { - return 2 + return MinMatMulInput } // GetMaxInputs returns the maximum number of input tensors this operator expects. func (m *MatMul) GetMaxInputs() int { - return 2 + return MaxMatMulInput } // GetInputTypeConstraints returns a list. Every element represents a set of allowed tensor dtypes diff --git a/ops/opset13/matmul_test.go b/ops/opset13/matmul_test.go index b7aa000..18c85a1 100644 --- a/ops/opset13/matmul_test.go +++ b/ops/opset13/matmul_test.go @@ -1,7 +1,6 @@ package opset13 import ( - "fmt" "testing" "github.com/advancedclimatesystems/gonnx/ops" @@ -179,14 +178,14 @@ func TestInputValidationMatMul(t *testing.T) { []tensor.Tensor{ ops.TensorWithBackingFixture([]int{1, 2}, 2), }, - fmt.Errorf("matmul operator: expected 2 input tensors, got 1"), + ops.ErrInvalidInputCount(1, &MatMul{}), }, { []tensor.Tensor{ ops.TensorWithBackingFixture([]int{1, 2}, 2), ops.TensorWithBackingFixture([]int{3, 4}, 2), }, - fmt.Errorf("matmul operator: input 0 does not allow type int"), + ops.ErrInvalidInputType(0, "int", &MatMul{}), }, } diff --git a/ops/opset13/mul.go b/ops/opset13/mul.go index 08b3e42..6395572 100644 --- a/ops/opset13/mul.go +++ b/ops/opset13/mul.go @@ -6,6 +6,14 @@ import ( "gorgonia.org/tensor" ) +const ( + // MinMulInput is the minimimum amount of inputs the mul operator expects. + MinMulInput = 2 + + // MaxMulInput is the maximum amount of inputs the mul operator accepts. + MaxMulInput = 2 +) + // Mul represents the ONNX mul operator. type Mul struct{} @@ -41,12 +49,12 @@ func (m *Mul) ValidateInputs(inputs []tensor.Tensor) ([]tensor.Tensor, error) { // GetMinInputs returns the minimum number of input tensors this operator expects. func (m *Mul) GetMinInputs() int { - return 2 + return MinMulInput } // GetMaxInputs returns the maximum number of input tensors this operator expects. func (m *Mul) GetMaxInputs() int { - return 2 + return MaxMulInput } // GetInputTypeConstraints returns a list. Every element represents a set of allowed tensor dtypes diff --git a/ops/opset13/mul_test.go b/ops/opset13/mul_test.go index f0a85de..543b0b9 100644 --- a/ops/opset13/mul_test.go +++ b/ops/opset13/mul_test.go @@ -130,14 +130,14 @@ func TestInputValidationMul(t *testing.T) { []tensor.Tensor{ ops.TensorWithBackingFixture([]int{1, 2}, 2), }, - fmt.Errorf("mul operator: expected 2 input tensors, got 1"), + ops.ErrInvalidInputCount(1, &Mul{}), }, { []tensor.Tensor{ ops.TensorWithBackingFixture([]int{1, 2}, 2), ops.TensorWithBackingFixture([]int{3, 4}, 2), }, - fmt.Errorf("mul operator: input 0 does not allow type int"), + ops.ErrInvalidInputType(0, "int", &Mul{}), }, } diff --git a/ops/opset13/opset13.go b/ops/opset13/opset13.go index 7bbb74c..4e8bd2c 100644 --- a/ops/opset13/opset13.go +++ b/ops/opset13/opset13.go @@ -1,8 +1,6 @@ package opset13 import ( - "fmt" - "github.com/advancedclimatesystems/gonnx/ops" ) @@ -34,18 +32,21 @@ var operators13 = map[string]func() ops.Operator{ } // GetOperator maps strings as found in the ModelProto to Operators from opset 13. -func GetOperator(opType string) (ops.Operator, error) { - if opInit, ok := operators13[opType]; ok { +func GetOperator(operatorType string) (ops.Operator, error) { + if opInit, ok := operators13[operatorType]; ok { return opInit(), nil } - return nil, fmt.Errorf(ops.UnknowOpTypeErrTemplate, opType) + + return nil, ops.ErrUnknownOperatorType(operatorType) } // GetOpNames returns a list with operator names for opset 13. func GetOpNames() []string { var opList []string + for opName := range operators13 { opList = append(opList, opName) } + return opList } diff --git a/ops/opset13/opset13_test.go b/ops/opset13/opset13_test.go index c826e46..1419e91 100644 --- a/ops/opset13/opset13_test.go +++ b/ops/opset13/opset13_test.go @@ -1,7 +1,6 @@ package opset13 import ( - "fmt" "testing" "github.com/advancedclimatesystems/gonnx/ops" @@ -132,7 +131,7 @@ func TestGetOperator(t *testing.T) { { "NotYetImplemented", nil, - fmt.Errorf(ops.UnknowOpTypeErrTemplate, "NotYetImplemented"), + ops.ErrUnknownOperatorType("NotYetImplemented"), }, } diff --git a/ops/opset13/prelu.go b/ops/opset13/prelu.go index ea38ce0..b019960 100644 --- a/ops/opset13/prelu.go +++ b/ops/opset13/prelu.go @@ -47,7 +47,7 @@ func (op *PRelu) Apply(inputs []tensor.Tensor) ([]tensor.Tensor, error) { case tensor.Int64: calcPRelu(y.Data().([]int64), x.Data().([]int64), slope.Data().([]int64)) default: - return nil, fmt.Errorf("%v: unsupported type %v", op, x.Dtype()) + return nil, ops.ErrInvalidInputType(0, x.Dtype().String(), op) } return []tensor.Tensor{y}, nil From c07046a4bc3603b76a8cbd9c1d4e6366104a62dd Mon Sep 17 00:00:00 2001 From: wisse Date: Sun, 22 Oct 2023 11:46:52 +0200 Subject: [PATCH 06/15] More wip on lint errors --- errors.go | 14 +++---- model.go | 4 +- model_test.go | 14 +++---- onnx/graph_proto.go | 71 ++++++++++++++++++++++++++---------- ops/convert.go | 24 ++++++++++-- ops/convert_test.go | 5 +-- ops/errors.go | 2 +- ops/opset13/abs_test.go | 4 +- ops/opset13/constant_test.go | 15 ++++---- ops/unidir_broadcast.go | 11 +++--- ops/unidir_broadcast_test.go | 20 ++-------- ops/utils.go | 13 +++++-- ops/utils_test.go | 13 ++++--- ops/validate_inputs.go | 6 +-- ops/validate_inputs_test.go | 14 ++++--- ops_test.go | 26 +++++++++++-- opset.go | 13 ++++--- 17 files changed, 166 insertions(+), 103 deletions(-) diff --git a/errors.go b/errors.go index a032b6c..1f8f48c 100644 --- a/errors.go +++ b/errors.go @@ -3,24 +3,22 @@ package gonnx import ( "errors" "fmt" -) -var ( - errInvalidShape = errors.New("input shape does not match") - errSetOutputTensor = errors.New("could not set output tensor") - errModel = errors.New("gonnx model error") + "github.com/advancedclimatesystems/gonnx/onnx" ) +var errModel = errors.New("gonnx model error") + type InvalidShapeError struct { - expected []int + expected onnx.Shape actual []int } func (i InvalidShapeError) Error() string { - return fmt.Sprintf("invalid shape error expected: %v actual %v. mehtod %s", i.expected, i.actual) + return fmt.Sprintf("invalid shape error expected: %v actual %v", i.expected, i.actual) } -func ErrInvalidShape(expected, actual []int) error { +func ErrInvalidShape(expected onnx.Shape, actual []int) error { return InvalidShapeError{ expected: expected, actual: actual, diff --git a/model.go b/model.go index aa5779b..383eee7 100644 --- a/model.go +++ b/model.go @@ -226,7 +226,7 @@ func (m *Model) validateShapes(inputTensors Tensors) error { shapeReceived := tensor.Shape() if len(shapeReceived) != len(shapeExpected) { - return ErrInvalidShape("shape does not match for %v: expected %v but got %v", name, shapeExpected, shapeReceived) + return ErrInvalidShape(shapeExpected, shapeReceived) } for i, dim := range shapeExpected { @@ -237,7 +237,7 @@ func (m *Model) validateShapes(inputTensors Tensors) error { } if dim.Size != int64(shapeReceived[i]) { - return ErrInvalidShape("shape does not match for %v: expected %v but got %v", name, shapeExpected, shapeReceived) + return ErrInvalidShape(shapeExpected, shapeReceived) } } } diff --git a/model_test.go b/model_test.go index e0d15b0..1f86c41 100644 --- a/model_test.go +++ b/model_test.go @@ -1,8 +1,6 @@ package gonnx import ( - "errors" - "fmt" "testing" "github.com/advancedclimatesystems/gonnx/onnx" @@ -39,9 +37,7 @@ func TestModel(t *testing.T) { [][]float32{rangeFloat(16)}, ), nil, - errors.New( - "input shape does not match for data_input: expected [0 3] but got (2, 4, 2)", - ), + ErrModel("input %v only has %d dimensions, but index %d was required", "data_input", 3, 0), }, { "./sample_models/onnx_models/mlp.onnx", @@ -51,7 +47,7 @@ func TestModel(t *testing.T) { [][]float32{rangeFloat(6)}, ), nil, - errors.New("tensor: data_input not found"), + ErrModel("input %v does not exist", "tensor_data"), }, { "./sample_models/onnx_models/gru.onnx", @@ -106,6 +102,7 @@ func TestModel(t *testing.T) { outputs, err := model.Run(test.input) assert.Equal(t, test.err, err) + if test.expected == nil { assert.Nil(t, outputs) } else { @@ -128,6 +125,7 @@ func TestModelIOUtil(t *testing.T) { {IsDynamic: false, Name: "", Size: 3}, }, } + assert.Equal(t, []string{"data_input"}, model.InputNames()) assert.Equal(t, expectedInputShapes, model.InputShapes()) @@ -137,6 +135,7 @@ func TestModelIOUtil(t *testing.T) { {IsDynamic: false, Name: "", Size: 2}, }, } + assert.Equal(t, []string{"preds"}, model.OutputNames()) assert.Equal(t, expectedOutputShapes, model.OutputShapes()) assert.Equal(t, expectedOutputShapes["preds"], model.OutputShape("preds")) @@ -165,7 +164,8 @@ func TestInputDimSizeInvalidInput(t *testing.T) { assert.Nil(t, err) _, err = model.InputDimSize("swagger", 0) - assert.Equal(t, fmt.Errorf("input swagger does not exist"), err) + + assert.Equal(t, ErrModel("input %v does not exist", "swagger"), err) } // tensorsFixture creates Tensors with the given names shapes and backings. This is useful for diff --git a/onnx/graph_proto.go b/onnx/graph_proto.go index 44180c2..cd217e9 100644 --- a/onnx/graph_proto.go +++ b/onnx/graph_proto.go @@ -6,6 +6,7 @@ package onnx import ( "bytes" "encoding/binary" + "errors" "fmt" "io" "math" @@ -72,6 +73,8 @@ func (s Shape) String() string { return fmt.Sprintf("%d", dimSizes) } +var ErrInvalidType = errors.New("invalid type") + // Dim is a dimension. type Dim struct { IsDynamic bool @@ -154,10 +157,13 @@ func getNamesFromTensorProto(protos []*TensorProto) []string { // TensorFromProto returns a tensor.Tensor from an onnx.TensorProto. func TensorFromProto(tp *TensorProto) (tensor.Tensor, error) { - var values interface{} - var err error + var ( + values interface{} + err error + ) typeMap := TensorProto_DataType_value + switch tp.DataType { case typeMap["FLOAT"]: values, err = getFloatData(tp) @@ -194,7 +200,7 @@ func TensorFromProto(tp *TensorProto) (tensor.Tensor, error) { case len(tp.Uint64Data) > 0: values, err = getUint64Data(tp) default: - return nil, fmt.Errorf("unsupported datatype for Tensor: %v", tp.DataType) + return nil, ErrInvalidType } } @@ -361,16 +367,20 @@ func ReadUint8ArrayFromBytes(data []byte) ([]uint8, error) { buffer := bytes.NewReader(data) element := make([]byte, uint8Size) - var err error - var values []uint8 + var ( + err error + values []uint8 + ) for { var n int + n, err = buffer.Read(element) if n != uint8Size || err != nil { break } - values = append(values, uint8(element[0])) + + values = append(values, element[0]) } if err != io.EOF { @@ -385,11 +395,14 @@ func ReadInt8ArrayFromBytes(data []byte) ([]int8, error) { buffer := bytes.NewReader(data) element := make([]byte, int8Size) - var err error - var values []int8 + var ( + err error + values []int8 + ) for { var n int + n, err = buffer.Read(element) if n != int8Size || err != nil { break @@ -410,11 +423,14 @@ func ReadUint16ArrayFromBytes(data []byte) ([]uint16, error) { buffer := bytes.NewReader(data) element := make([]byte, uint16Size) - var err error - var values []uint16 + var ( + err error + values []uint16 + ) for { var n int + n, err = buffer.Read(element) if n != uint16Size || err != nil { break @@ -435,11 +451,14 @@ func ReadInt16ArrayFromBytes(data []byte) ([]int16, error) { buffer := bytes.NewReader(data) element := make([]byte, uint16Size) - var err error - var values []int16 + var ( + err error + values []int16 + ) for { var n int + n, err = buffer.Read(element) if n != int16Size || err != nil { break @@ -460,11 +479,14 @@ func ReadUint32ArrayFromBytes(data []byte) ([]uint32, error) { buffer := bytes.NewReader(data) element := make([]byte, int32Size) - var err error - var values []uint32 + var ( + err error + values []uint32 + ) for { var n int + n, err = buffer.Read(element) if n != uint32Size || err != nil { break @@ -485,11 +507,14 @@ func ReadInt32ArrayFromBytes(data []byte) ([]int32, error) { buffer := bytes.NewReader(data) element := make([]byte, int32Size) - var err error - var values []int32 + var ( + err error + values []int32 + ) for { var n int + n, err = buffer.Read(element) if n != int32Size || err != nil { break @@ -510,11 +535,14 @@ func ReadUint64ArrayFromBytes(data []byte) ([]uint64, error) { buffer := bytes.NewReader(data) element := make([]byte, int32Size) - var err error - var values []uint64 + var ( + err error + values []uint64 + ) for { var n int + n, err = buffer.Read(element) if n != uint64Size || err != nil { break @@ -535,11 +563,14 @@ func ReadInt64ArrayFromBytes(data []byte) ([]int64, error) { buffer := bytes.NewReader(data) element := make([]byte, int64Size) - var err error - var values []int64 + var ( + err error + values []int64 + ) for { var n int + n, err = buffer.Read(element) if n != int64Size || err != nil { break diff --git a/ops/convert.go b/ops/convert.go index 6556df9..f32971c 100644 --- a/ops/convert.go +++ b/ops/convert.go @@ -1,12 +1,23 @@ package ops import ( + "errors" "fmt" "github.com/advancedclimatesystems/gonnx/onnx" "gorgonia.org/tensor" ) +var ErrConversion = errors.New("unable to convert") + +func ErrConversionInvalidType(dType tensor.Dtype, newType int32) error { + return fmt.Errorf("%w: type %v, to %v is invalid", ErrConversion, dType, newType) +} + +func ErrConversionNotSupported(dType int32) error { + return fmt.Errorf("%w: to %v is not supported yet", ErrConversion, dType) +} + // Number is a type which represents a number. type Number interface { float32 | float64 | int8 | int16 | int32 | int64 | uint8 | uint16 | uint32 | uint64 @@ -14,8 +25,11 @@ type Number interface { // ConvertTensorDtype converts an interface of a specific dtype to a new dtype. func ConvertTensorDtype(t tensor.Tensor, newType int32) (tensor.Tensor, error) { - var err error - var newBacking any + var ( + err error + newBacking any + ) + backing := IfScalarToSlice(t.Data()) switch t.Dtype() { @@ -40,7 +54,7 @@ func ConvertTensorDtype(t tensor.Tensor, newType int32) (tensor.Tensor, error) { case tensor.Uint64: newBacking, err = convertBacking(backing.([]uint64), newType) default: - return nil, fmt.Errorf("unable to convert tensor of type %v to type %v", t.Dtype(), newType) + return nil, ErrConversionInvalidType(t.Dtype(), newType) } if err != nil { @@ -72,8 +86,10 @@ func convertBacking[B Number](backing []B, dataType int32) (any, error) { return createNewBacking[B, uint32](backing), nil case onnx.TensorProto_UINT64: return createNewBacking[B, uint64](backing), nil + case onnx.TensorProto_BFLOAT16, onnx.TensorProto_BOOL, onnx.TensorProto_COMPLEX64, onnx.TensorProto_FLOAT16, onnx.TensorProto_UNDEFINED, onnx.TensorProto_STRING: + return nil, ErrConversionNotSupported(dataType) default: - return nil, fmt.Errorf("converting to onnx datatype %d is not supported yet", dataType) + return nil, ErrConversionNotSupported(dataType) } } diff --git a/ops/convert_test.go b/ops/convert_test.go index 8154e85..399de51 100644 --- a/ops/convert_test.go +++ b/ops/convert_test.go @@ -1,7 +1,6 @@ package ops import ( - "fmt" "testing" "github.com/stretchr/testify/assert" @@ -91,13 +90,13 @@ func TestConvertTensorDtype(t *testing.T) { tensor.New(tensor.WithShape(2), tensor.WithBacking([]bool{true, false})), tensor.New(tensor.WithShape(2), tensor.WithBacking([]float32{1.0, 2.0})), 1, - fmt.Errorf("unable to convert tensor of type bool to type 1"), + ErrConversionInvalidType(tensor.Bool, 1), }, { tensor.New(tensor.WithShape(2), tensor.WithBacking([]float32{1.0, 2.1})), nil, 8, - fmt.Errorf("converting to onnx datatype 8 is not supported yet"), + ErrConversionNotSupported(8), }, } diff --git a/ops/errors.go b/ops/errors.go index 9c35a2c..a59b7c5 100644 --- a/ops/errors.go +++ b/ops/errors.go @@ -121,7 +121,7 @@ func ErrInvalidInputCount(actual int, operator Operator) error { } } -func ErrInvalidOptionalInputCount(operator Operator, actual int) error { +func ErrInvalidOptionalInputCount(actual int, operator Operator) error { return &InvalidInputCountError{ hasOptionalInputs: true, actualCount: actual, diff --git a/ops/opset13/abs_test.go b/ops/opset13/abs_test.go index 6589c37..e9e0791 100644 --- a/ops/opset13/abs_test.go +++ b/ops/opset13/abs_test.go @@ -124,13 +124,13 @@ func TestInputValidationAbs(t *testing.T) { }, { []tensor.Tensor{}, - ops.ErrInvalidInputCount(&Abs{}, 0), + ops.ErrInvalidInputCount(0, &Abs{}), }, { []tensor.Tensor{ ops.TensorWithBackingFixture([]int{1, 2}, 2), }, - ops.ErrInvalidInputType(&Abs{}, 0, "int"), + ops.ErrInvalidInputType(0, "int", &Abs{}), }, } diff --git a/ops/opset13/constant_test.go b/ops/opset13/constant_test.go index 2817bf5..c546627 100644 --- a/ops/opset13/constant_test.go +++ b/ops/opset13/constant_test.go @@ -2,7 +2,6 @@ package opset13 import ( "encoding/binary" - "fmt" "testing" "github.com/advancedclimatesystems/gonnx/onnx" @@ -45,18 +44,17 @@ func TestConstantInit(t *testing.T) { { []*onnx.AttributeProto{{Name: "sparse_value"}}, nil, - fmt.Errorf(ops.UnsupportedAttrErrTemplate, &Constant{}, "sparse_value"), - ops.ErrUnsupportedAttribute("sparse_value", &Constant{}) + ops.ErrUnsupportedAttribute("sparse_value", &Constant{}), }, { []*onnx.AttributeProto{{Name: "unknownAttribute"}}, nil, - fmt.Errorf(ops.UnknownAttributeErrTemplate, &Constant{}, "unknownAttribute"), + ops.ErrUnsupportedAttribute("unknownAttribute", &Constant{}), }, { []*onnx.AttributeProto{}, nil, - fmt.Errorf(ops.InvalidAttrCountErrTemplate, &Constant{}, 1, 0), + ops.ErrInvalidAttributeCount(1, 0, &Constant{}), }, } @@ -65,6 +63,7 @@ func TestConstantInit(t *testing.T) { err := constant.Init(test.initAttr) assert.Equal(t, test.err, err) + if err != nil { assert.Equal(t, test.expected, constant.value) } @@ -105,7 +104,7 @@ func TestConstant(t *testing.T) { } for _, test := range tests { - test.constant.Init(test.initAttr) + _ = test.constant.Init(test.initAttr) res, err := test.constant.Apply([]tensor.Tensor{}) assert.Nil(t, err) @@ -134,7 +133,7 @@ func TestInputValidationConstant(t *testing.T) { []tensor.Tensor{ ops.TensorWithBackingFixture([]int{1, 2}, 2), }, - fmt.Errorf("constant operator: expected 0 input tensors, got 1"), + ops.ErrInvalidInputCount(1, &Constant{}), }, } @@ -143,6 +142,7 @@ func TestInputValidationConstant(t *testing.T) { validated, err := constant.ValidateInputs(test.inputs) assert.Equal(t, test.err, err) + if test.err == nil { assert.Equal(t, test.inputs, validated) } @@ -158,6 +158,7 @@ func ConstantValueAttrProtoFixture() []*onnx.AttributeProto { binary.LittleEndian.PutUint64(bValues[16:24], uint64(values[2])) tp := &onnx.TensorProto{DataType: int32(7), Dims: []int64{3}, RawData: bValues} + return []*onnx.AttributeProto{{Name: "value", T: tp}} } diff --git a/ops/unidir_broadcast.go b/ops/unidir_broadcast.go index 702404c..5bdcbdf 100644 --- a/ops/unidir_broadcast.go +++ b/ops/unidir_broadcast.go @@ -1,8 +1,6 @@ package ops import ( - "fmt" - "gorgonia.org/tensor" ) @@ -10,12 +8,12 @@ import ( func UnidirectionalBroadcast(A, B tensor.Tensor) (tensor.Tensor, tensor.Tensor, error) { reshapedB, err := reshapeTensorsForUnidirBroadcast(A, B) if err != nil { - return nil, nil, fmt.Errorf(UnidirBroadcastErrTemplate, A.Shape(), B.Shape()) + return nil, nil, ErrUnidirBroadcast(A.Shape(), B.Shape()) } newB, err := repeatTensorsForUnidirBroadcast(A, reshapedB) if err != nil { - return nil, nil, fmt.Errorf(UnidirBroadcastErrTemplate, A.Shape(), B.Shape()) + return nil, nil, ErrUnidirBroadcast(A.Shape(), B.Shape()) } return A, newB, nil @@ -34,7 +32,7 @@ func reshapeTensorsForUnidirBroadcast(A, B tensor.Tensor) (tensor.Tensor, error) case nDimsA == nDimsB: return B, nil default: - return nil, fmt.Errorf("tensor B may not have more dimensions than tensor A") + return nil, ErrUnidirBroadcast(A.Shape(), B.Shape()) } } @@ -44,6 +42,7 @@ func reshapeTensorsForUnidirBroadcast(A, B tensor.Tensor) (tensor.Tensor, error) // Example: shapeA=(2, 3, 4) and shapeB=(1, 3, 4) yields shapeNewB=(2, 3, 4). func repeatTensorsForUnidirBroadcast(A, B tensor.Tensor) (tensor.Tensor, error) { var err error + shapeA := A.Shape() shapeB := B.Shape() @@ -54,7 +53,7 @@ func repeatTensorsForUnidirBroadcast(A, B tensor.Tensor) (tensor.Tensor, error) if sizeDimA != sizeDimB { if sizeDimB != 1 { - return nil, fmt.Errorf("incompatible dimensions") + return nil, ErrUnidirBroadcast(shapeA, shapeB) } B, err = tensor.Repeat(B, axis, sizeDimA) diff --git a/ops/unidir_broadcast_test.go b/ops/unidir_broadcast_test.go index 9890e02..98d2ea5 100644 --- a/ops/unidir_broadcast_test.go +++ b/ops/unidir_broadcast_test.go @@ -1,7 +1,6 @@ package ops import ( - "fmt" "testing" "github.com/stretchr/testify/assert" @@ -42,29 +41,17 @@ func TestUnidirectionalBroadcast(t *testing.T) { { [][]int{{1, 3, 1}, {3, 2}}, nil, - fmt.Errorf( - UnidirBroadcastErrTemplate, - []int{1, 3, 1}, - []int{3, 2}, - ), + ErrUnidirBroadcast([]int{1, 3, 1}, []int{3, 2}), }, { [][]int{{5}, {2, 3, 4}}, nil, - fmt.Errorf( - UnidirBroadcastErrTemplate, - []int{5}, - []int{2, 3, 4}, - ), + ErrUnidirBroadcast([]int{5}, []int{2, 3, 4}), }, { [][]int{{1, 4, 5}, {1, 1, 3}}, nil, - fmt.Errorf( - UnidirBroadcastErrTemplate, - []int{1, 4, 5}, - []int{1, 1, 3}, - ), + ErrUnidirBroadcast([]int{1, 4, 5}, []int{1, 1, 3}), }, } @@ -75,6 +62,7 @@ func TestUnidirectionalBroadcast(t *testing.T) { newA, newB, err := UnidirectionalBroadcast(A, B) assert.Equal(t, test.err, err) + if err == nil { assert.Equal(t, test.expectedShape, newA.Shape()) assert.Equal(t, test.expectedShape, newB.Shape()) diff --git a/ops/utils.go b/ops/utils.go index 096bc99..2a81881 100644 --- a/ops/utils.go +++ b/ops/utils.go @@ -1,7 +1,7 @@ package ops import ( - "fmt" + "errors" "gorgonia.org/tensor" ) @@ -80,6 +80,11 @@ func OffsetTensorIfNegative(t tensor.Tensor, offset int) error { return nil } +var ( + ErrCast = errors.New("cast error") + ErrInvalidShape = errors.New("invalid shape error") +) + // AnyToIntSlice casts the data of a node to an int list. This will only // be done if the data is of some sort of int type. func AnyToIntSlice(value interface{}) ([]int, error) { @@ -111,7 +116,7 @@ func AnyToIntSlice(value interface{}) ([]int, error) { return res, nil default: - return nil, fmt.Errorf("could not cast %v to int list", data) + return nil, ErrCast } } @@ -133,7 +138,7 @@ func GetValueAsTensorType(value float64, dtype tensor.Dtype) (interface{}, error case tensor.Float64: return value, nil default: - return nil, fmt.Errorf("unknown type %v, cannot cast constant to this type", dtype) + return nil, ErrCast } } @@ -217,7 +222,7 @@ func NElements(shp ...int) int { // PairwiseAssign essentially does pairwise t1 = t2 in place!. func PairwiseAssign(t1, t2 tensor.Tensor) (err error) { if !t1.Shape().Eq(t2.Shape()) { - return fmt.Errorf("Shapes of tensors must be equal, were %v and %v", t1.Shape(), t2.Shape()) + return } it := t1.Iterator() diff --git a/ops/utils_test.go b/ops/utils_test.go index a43d633..9ea0d87 100644 --- a/ops/utils_test.go +++ b/ops/utils_test.go @@ -1,7 +1,6 @@ package ops import ( - "fmt" "testing" "github.com/stretchr/testify/assert" @@ -106,7 +105,9 @@ func TestOffsetTensorIfNegative(t *testing.T) { } for _, test := range tests { tIn := tensor.New(tensor.WithShape(len(test.in)), tensor.WithBacking(test.in)) - OffsetTensorIfNegative(tIn, test.offset) + err := OffsetTensorIfNegative(tIn, test.offset) + + assert.Nil(t, err) assert.Equal(t, test.expected, tIn.Data()) } } @@ -140,12 +141,13 @@ func TestAnyToIntSlice(t *testing.T) { { "some string", nil, - fmt.Errorf("could not cast some string to int list"), + ErrCast, }, } for _, test := range tests { res, err := AnyToIntSlice(test.in) + assert.Equal(t, test.expected, res) assert.Equal(t, test.err, err) } @@ -204,7 +206,7 @@ func TestGetValueAsTensorType(t *testing.T) { 1.0, tensor.Complex64, nil, - fmt.Errorf("unknown type complex64, cannot cast constant to this type"), + ErrCast, }, } @@ -313,7 +315,7 @@ func TestPairwiseAssign(t *testing.T) { { tensor.New(tensor.WithShape(2, 2), tensor.WithBacking([]float32{1, 2, 3, 4})), tensor.New(tensor.WithShape(1, 2), tensor.WithBacking([]float32{1, 1})), - fmt.Errorf("Shapes of tensors must be equal, were (2, 2) and (1, 2)"), + ErrInvalidShape, }, } @@ -321,6 +323,7 @@ func TestPairwiseAssign(t *testing.T) { err := PairwiseAssign(test.t1, test.t2) assert.Equal(t, err, test.err) + if err == nil { assert.Equal(t, test.t2.Data(), test.t1.Data()) } diff --git a/ops/validate_inputs.go b/ops/validate_inputs.go index 6995e81..0689c07 100644 --- a/ops/validate_inputs.go +++ b/ops/validate_inputs.go @@ -42,13 +42,13 @@ func checkNInputs(op Operator, inputs []tensor.Tensor) (int, error) { if min == max { if nInputs != min { - return 0, ErrInvalidInputCount(op, nInputs) + return 0, ErrInvalidInputCount(nInputs, op) } padLength = min } else { if nInputs < min || nInputs > max { - return 0, ErrInvalidOptionalInputCount(op, nInputs) + return 0, ErrInvalidOptionalInputCount(nInputs, op) } padLength = max @@ -78,7 +78,7 @@ func checkInputTypes(op Operator, inputs []tensor.Tensor) error { typeConstraint := newTypeConstraint(typeConstraints[i]) if _, ok := typeConstraint[input.Dtype()]; !ok { - return ErrInvalidInputType(op, i, input.Dtype().Name()) + return ErrInvalidInputType(i, input.Dtype().Name(), op) } } diff --git a/ops/validate_inputs_test.go b/ops/validate_inputs_test.go index 4926237..deb1409 100644 --- a/ops/validate_inputs_test.go +++ b/ops/validate_inputs_test.go @@ -1,7 +1,6 @@ package ops import ( - "fmt" "testing" "github.com/advancedclimatesystems/gonnx/onnx" @@ -76,7 +75,7 @@ func TestValidateInputs(t *testing.T) { }, PaddedInputsFixture(1, 0), 0, - fmt.Errorf(InvalidInputCountErrTemplate, &MockOp{}, 2, 1), + ErrInvalidInputCount(2, &MockOp{}), }, { &MockOp{ @@ -92,7 +91,7 @@ func TestValidateInputs(t *testing.T) { }, PaddedInputsFixture(7, 0), 0, - fmt.Errorf(InvalidOptionalInputCountErrTemplate, &MockOp{}, 3, 5, 7), + ErrInvalidOptionalInputCount(7, &MockOp{}), }, { &MockOp{ @@ -102,7 +101,7 @@ func TestValidateInputs(t *testing.T) { }, PaddedInputsFixture(2, 0), 0, - fmt.Errorf("%v: input %d does not allow type %v", &MockOp{}, 1, tensor.Float32), + ErrInvalidInputType(1, "float32", &MockOp{}), }, } @@ -136,12 +135,15 @@ func TestPadInputs(t *testing.T) { func PaddedInputsFixture(nTensors, nNil int) []tensor.Tensor { result := make([]tensor.Tensor, nTensors+nNil) i := 0 + for ; i < nTensors; i++ { result[i] = tensor.New(tensor.WithBacking([]float32{0.0})) } + for ; i < nTensors+nNil; i++ { result[i] = nil } + return result } @@ -151,11 +153,11 @@ type MockOp struct { inputTypeConstraints [][]tensor.Dtype } -func (m *MockOp) Init(attr []*onnx.AttributeProto) error { +func (m *MockOp) Init(_ []*onnx.AttributeProto) error { return nil } -func (m *MockOp) Apply(inputs []tensor.Tensor) ([]tensor.Tensor, error) { +func (m *MockOp) Apply(_ []tensor.Tensor) ([]tensor.Tensor, error) { return nil, nil } diff --git a/ops_test.go b/ops_test.go index 2037d20..20f6c65 100644 --- a/ops_test.go +++ b/ops_test.go @@ -2,7 +2,7 @@ package gonnx import ( "fmt" - "io/ioutil" + "io" "os" "sort" "strings" @@ -106,8 +106,9 @@ type ONNXTestCase struct { } func TestOps(t *testing.T) { - var runnedTests []string + runnedTests := []string{} opNames := opset13.GetOpNames() + for _, opName := range opNames { tests, err := getTestCasesForOp(opName) assert.Nil(t, err) @@ -127,8 +128,10 @@ func TestOps(t *testing.T) { runnedTests = append(runnedTests, test.name) } } + sort.Strings(expectedTests) sort.Strings(runnedTests) + assert.Equal(t, expectedTests, runnedTests) } @@ -146,6 +149,7 @@ func getTestCasesForOp(opName string) ([]*ONNXTestCase, error) { } var tests []*ONNXTestCase + for _, testFolder := range testFolders { if shouldRunTest(testFolder, opFilter) { testcase, err := getTestCase(fmt.Sprintf("./test_data/%v", testFolder)) @@ -174,6 +178,7 @@ func shouldRunTest(folder, opFilter string) bool { return true } } + return false } @@ -186,6 +191,7 @@ func getTestCase(folder string) (*ONNXTestCase, error) { } basePath := fmt.Sprintf("%v/test_data_set_0", folder) + inputs, err := readTestTensors(basePath, "input", model.mp.Graph.GetInput()) if err != nil { return nil, err @@ -199,11 +205,17 @@ func getTestCase(folder string) (*ONNXTestCase, error) { testcase.model = model testcase.inputs = inputs testcase.outputs = outputs + return testcase, nil } func readTestModel(folder string) (*Model, error) { - bytesModel, err := ioutil.ReadFile(folder + "/model.onnx") + file, err := os.Open(folder + "/model.onnx") + if err != nil { + return nil, err + } + + bytesModel, err := io.ReadAll(file) if err != nil { return nil, err } @@ -230,7 +242,13 @@ func readTestTensors(basePath, baseFile string, inputs []*onnx.ValueInfoProto) ( for i := 0; i < len(inputs); i++ { filePath := fmt.Sprintf("%v/%v_%d.pb", basePath, baseFile, i) - bytesInput, err := ioutil.ReadFile(filePath) + + file, err := os.Open(filePath) + if err != nil { + return nil, err + } + + bytesInput, err := io.ReadAll(file) if err != nil { return nil, err } diff --git a/opset.go b/opset.go index 0471864..64f1151 100644 --- a/opset.go +++ b/opset.go @@ -1,12 +1,14 @@ package gonnx import ( - "fmt" + "errors" "github.com/advancedclimatesystems/gonnx/ops" "github.com/advancedclimatesystems/gonnx/ops/opset13" ) +var ErrInvalidOperator = errors.New("invalid operator getter") + // OpGetter is a function that gets an operator based on a string. type OpGetter func(string) (ops.Operator, error) @@ -16,14 +18,15 @@ var operatorGetters = map[int64]OpGetter{ // ResolveOperatorGetter resolves the getter for operators based on the opset version. func ResolveOperatorGetter(opsetID int64) (OpGetter, error) { - if GetOperator, ok := operatorGetters[opsetID]; ok { - return GetOperator, nil + if getOperator, ok := operatorGetters[opsetID]; ok { + return getOperator, nil } - opsets := make([]int64, len(operatorGetters)) + opsets := make([]int64, 0, len(operatorGetters)) for version := range operatorGetters { + // TODO what does this do again? opsets = append(opsets, version) } - return nil, fmt.Errorf("expected opset to be in %d, got operator set %d", opsets, opsetID) + return nil, ErrInvalidOperator } From f4ae78c3764013f703df64953a1474fd1a42a9ec Mon Sep 17 00:00:00 2001 From: Swopper050 Date: Mon, 23 Oct 2023 08:26:45 +0200 Subject: [PATCH 07/15] WIP on lint --- ops/opset13/matmul.go | 11 +++++++++-- ops/opset13/sigmoid.go | 3 ++- ops/opset13/slice.go | 18 +++++++++++++++--- ops/opset13/sub.go | 14 +++++++++++--- ops/opset13/tanh.go | 3 ++- opset.go | 6 ------ 6 files changed, 39 insertions(+), 16 deletions(-) diff --git a/ops/opset13/matmul.go b/ops/opset13/matmul.go index 5df4fb6..81da79a 100644 --- a/ops/opset13/matmul.go +++ b/ops/opset13/matmul.go @@ -25,7 +25,7 @@ func newMatMul() ops.Operator { } // Init initializes the matmul operator. -func (m *MatMul) Init(attributes []*onnx.AttributeProto) error { +func (m *MatMul) Init(_ []*onnx.AttributeProto) error { return nil } @@ -152,7 +152,10 @@ func (m *MatMul) broadcastTensors(A, B tensor.Tensor) (tensor.Tensor, tensor.Ten // want to broadcast those. All leading dimensions we do want to broadcast. shapeA := A.Shape() shapeB := B.Shape() - for axis := len(shapeA) - 3; axis >= 0; axis-- { + + nMatrixDims := 3 + + for axis := len(shapeA) - nMatrixDims; axis >= 0; axis-- { sizeDimA := shapeA[axis] sizeDimB := shapeB[axis] @@ -183,6 +186,7 @@ func (m *MatMul) broadcastTensors(A, B tensor.Tensor) (tensor.Tensor, tensor.Ten func (m *MatMul) batchedMatMul(A, B tensor.Tensor) (tensor.Tensor, error) { shapeA := A.Shape() shapeB := B.Shape() + outerShape := append([]int{}, shapeA[:len(shapeA)-2]...) // This will be the shape of the output tensor. @@ -197,7 +201,9 @@ func (m *MatMul) batchedMatMul(A, B tensor.Tensor) (tensor.Tensor, error) { } var err error + var matrixA, matrixB, matrixOut tensor.Tensor + for { matrixA, err = A.Slice(slices...) if err != nil { @@ -244,6 +250,7 @@ func incrementSlices(slices []tensor.Slice, shape []int) bool { slices[i] = ops.NewSlicer(0) // Else we start again for this dimension. } else { slices[i] = ops.NewSlicer(dimSliceStart + 1) + return true } } diff --git a/ops/opset13/sigmoid.go b/ops/opset13/sigmoid.go index e6c4749..6171e99 100644 --- a/ops/opset13/sigmoid.go +++ b/ops/opset13/sigmoid.go @@ -15,13 +15,14 @@ func newSigmoid() ops.Operator { } // Init initializes the sigmoid operator. -func (s *Sigmoid) Init(attributes []*onnx.AttributeProto) error { +func (s *Sigmoid) Init(_ []*onnx.AttributeProto) error { return nil } // Apply the sigmoid operator to the input node. func (s *Sigmoid) Apply(inputs []tensor.Tensor) ([]tensor.Tensor, error) { out, err := ops.Sigmoid(inputs[0]) + return []tensor.Tensor{out}, err } diff --git a/ops/opset13/slice.go b/ops/opset13/slice.go index 7279c6a..c31715f 100644 --- a/ops/opset13/slice.go +++ b/ops/opset13/slice.go @@ -6,6 +6,14 @@ import ( "gorgonia.org/tensor" ) +const ( + // MinSliceInput is the minimimum amount of inputs the slice operator expects. + MinSliceInput = 3 + + // MaxSliceInput is the maximum amount of inputs the slice operator accepts. + MaxSliceInput = 5 +) + // Slice represents the ONNX slice operator. type Slice struct{} @@ -15,13 +23,14 @@ func newSlice() ops.Operator { } // Init initializes the slice operator. -func (s *Slice) Init(attributes []*onnx.AttributeProto) error { +func (s *Slice) Init(_ []*onnx.AttributeProto) error { return nil } // Apply applies the slice operator. func (s *Slice) Apply(inputs []tensor.Tensor) ([]tensor.Tensor, error) { data := inputs[0] + starts, err := ops.AnyToIntSlice(ops.IfScalarToSlice(inputs[1].Data())) if err != nil { return nil, err @@ -65,12 +74,12 @@ func (s *Slice) ValidateInputs(inputs []tensor.Tensor) ([]tensor.Tensor, error) // GetMinInputs returns the minimum number of input tensors this operator expects. func (s *Slice) GetMinInputs() int { - return 3 + return MinSliceInput } // GetMaxInputs returns the maximum number of input tensors this operator expects. func (s *Slice) GetMaxInputs() int { - return 5 + return MaxSliceInput } // GetInputTypeConstraints returns a list. Every element represents a set of allowed tensor dtypes @@ -102,6 +111,7 @@ func (s *Slice) constructSlices(starts, ends, steps, axes []int, nTotalSlices in if ax < 0 { ax = nTotalSlices + ax } + slices[ax] = ops.NewSlicer(starts[i], ends[i], steps[i]) } @@ -114,6 +124,7 @@ func (s *Slice) getDefaultAxes(nSlices int) []int { for i := 0; i < nSlices; i++ { axes[i] = i } + return axes } @@ -123,5 +134,6 @@ func (s *Slice) getDefaultSteps(nSlices int) []int { for i := 0; i < nSlices; i++ { steps[i] = 1 } + return steps } diff --git a/ops/opset13/sub.go b/ops/opset13/sub.go index 0f6ca20..8e49337 100644 --- a/ops/opset13/sub.go +++ b/ops/opset13/sub.go @@ -6,6 +6,14 @@ import ( "gorgonia.org/tensor" ) +const ( + // MinSubInput is the minimimum amount of inputs the sub operator expects. + MinSubInput = 2 + + // MaxSubInput is the maximum amount of inputs the sub operator accepts. + MaxSubInput = 2 +) + // Sub represents the ONNX sub operator. type Sub struct{} @@ -15,7 +23,7 @@ func newSub() ops.Operator { } // Init initializes the sub operator. -func (s *Sub) Init(attributes []*onnx.AttributeProto) error { +func (s *Sub) Init(_ []*onnx.AttributeProto) error { return nil } @@ -41,12 +49,12 @@ func (s *Sub) ValidateInputs(inputs []tensor.Tensor) ([]tensor.Tensor, error) { // GetMinInputs returns the minimum number of input tensors this operator expects. func (s *Sub) GetMinInputs() int { - return 2 + return MinSubInput } // GetMaxInputs returns the maximum number of input tensors this operator expects. func (s *Sub) GetMaxInputs() int { - return 2 + return MaxSubInput } // GetInputTypeConstraints returns a list. Every element represents a set of allowed tensor dtypes diff --git a/ops/opset13/tanh.go b/ops/opset13/tanh.go index fab9a02..4939941 100644 --- a/ops/opset13/tanh.go +++ b/ops/opset13/tanh.go @@ -15,13 +15,14 @@ func newTanh() ops.Operator { } // Init initializes the sigmoid operator. -func (t *Tanh) Init(attributes []*onnx.AttributeProto) error { +func (t *Tanh) Init(_ []*onnx.AttributeProto) error { return nil } // Apply the sigmoid operator to the input node. func (t *Tanh) Apply(inputs []tensor.Tensor) ([]tensor.Tensor, error) { out, err := ops.Tanh(inputs[0]) + return []tensor.Tensor{out}, err } diff --git a/opset.go b/opset.go index 64f1151..2ee297f 100644 --- a/opset.go +++ b/opset.go @@ -22,11 +22,5 @@ func ResolveOperatorGetter(opsetID int64) (OpGetter, error) { return getOperator, nil } - opsets := make([]int64, 0, len(operatorGetters)) - for version := range operatorGetters { - // TODO what does this do again? - opsets = append(opsets, version) - } - return nil, ErrInvalidOperator } From b31978e503af7cc73c19aa1e9d33a38a0bd2843d Mon Sep 17 00:00:00 2001 From: Swopper050 Date: Mon, 23 Oct 2023 08:47:08 +0200 Subject: [PATCH 08/15] More WIP --- ops/convert.go | 2 +- ops/multidir_broadcast_test.go | 1 + ops/opset13/gru.go | 36 ++++++++++++++++++++++++++-------- ops/opset13/unsqueeze.go | 19 ++++++++++++++---- ops/slicer.go | 4 +++- 5 files changed, 48 insertions(+), 14 deletions(-) diff --git a/ops/convert.go b/ops/convert.go index f32971c..591d445 100644 --- a/ops/convert.go +++ b/ops/convert.go @@ -86,7 +86,7 @@ func convertBacking[B Number](backing []B, dataType int32) (any, error) { return createNewBacking[B, uint32](backing), nil case onnx.TensorProto_UINT64: return createNewBacking[B, uint64](backing), nil - case onnx.TensorProto_BFLOAT16, onnx.TensorProto_BOOL, onnx.TensorProto_COMPLEX64, onnx.TensorProto_FLOAT16, onnx.TensorProto_UNDEFINED, onnx.TensorProto_STRING: + case onnx.TensorProto_BFLOAT16, onnx.TensorProto_BOOL, onnx.TensorProto_COMPLEX64, onnx.TensorProto_COMPLEX128, onnx.TensorProto_FLOAT16, onnx.TensorProto_UNDEFINED, onnx.TensorProto_STRING: return nil, ErrConversionNotSupported(dataType) default: return nil, ErrConversionNotSupported(dataType) diff --git a/ops/multidir_broadcast_test.go b/ops/multidir_broadcast_test.go index c0f454e..3433a37 100644 --- a/ops/multidir_broadcast_test.go +++ b/ops/multidir_broadcast_test.go @@ -73,6 +73,7 @@ func TestMultidirectionalBroadcast(t *testing.T) { newA, newB, err := MultidirectionalBroadcast(A, B) assert.Equal(t, test.err, err) + if err == nil { assert.Equal(t, test.expectedShape, newA.Shape()) assert.Equal(t, test.expectedShape, newB.Shape()) diff --git a/ops/opset13/gru.go b/ops/opset13/gru.go index b4a1405..96915c2 100644 --- a/ops/opset13/gru.go +++ b/ops/opset13/gru.go @@ -56,6 +56,7 @@ func (g *GRU) Apply(inputs []tensor.Tensor) ([]tensor.Tensor, error) { W := inputs[1] R := inputs[2] B := inputs[3] + if inputs[4] != nil { return nil, fmt.Errorf("%v: sequence lens not yet supported as input", g) } @@ -87,18 +88,24 @@ func (g *GRU) Apply(inputs []tensor.Tensor) ([]tensor.Tensor, error) { if initialH == nil { prevH = g.initialH(batchSize) } else { - prevH = initialH.Clone().(tensor.Tensor) + var ok bool + prevH, ok = initialH.Clone().(tensor.Tensor) + if !ok { + return nil, fmt.Errorf("could not clone the initial hidden state tensor") + } } // Extract the shape of the hidden dimensions without the bidirectional dimension, as // we do not support bidirectional GRU yet. shapeWithoutBidir := prevH.Shape().Clone()[1:] + err = prevH.Reshape(shapeWithoutBidir...) if err != nil { return nil, err } outputs := []tensor.Tensor{} + for i := 0; i < seqLength; i++ { Xt, err := g.extractXt(X, i) if err != nil { @@ -144,7 +151,11 @@ func (g *GRU) Apply(inputs []tensor.Tensor) ([]tensor.Tensor, error) { return nil, err } - Yh := prevH.Clone().(tensor.Tensor) + Yh, ok := prevH.Clone().(tensor.Tensor) + if !ok { + return nil, fmt.Errorf("could not clone the hidden tensor") + } + // Reshape the output so it adds the num_directions as specified by onnx. err = Yh.Reshape([]int{1, batchSize, g.hiddenSize}...) if err != nil { @@ -161,12 +172,12 @@ func (g *GRU) ValidateInputs(inputs []tensor.Tensor) ([]tensor.Tensor, error) { // GetMinInputs returns the minimum number of input tensors this operator expects. func (g *GRU) GetMinInputs() int { - return 3 + return MinGRUInput } // GetMaxInputs returns the maximum number of input tensors this operator expects. func (g *GRU) GetMaxInputs() int { - return 6 + return MaxGRUInput } // GetInputTypeConstraints returns a list. Every element represents a set of allowed tensor dtypes @@ -196,6 +207,7 @@ func (g *GRU) gateCalculation( Xt, H, W, R, Wb, Rb tensor.Tensor, activation ops.Activation, ) (tensor.Tensor, error) { gemm := &Gemm{transA: false, transB: true, alpha: 1.0, beta: 1.0} + inputCalc, err := gemm.Apply([]tensor.Tensor{Xt, W, Wb}) if err != nil { return nil, err @@ -222,6 +234,7 @@ func (g *GRU) htCalculation( if err != nil { return nil, err } + return g.gateCalculation(Xt, temp1, W, R, Wb, Rb, activation) } @@ -313,6 +326,7 @@ func (g *GRU) extractWeights(W tensor.Tensor) ([]tensor.Tensor, error) { for i := 0; i < 3; i++ { slice := ops.NewSlicer(i*g.hiddenSize, (i+1)*g.hiddenSize) + w, err := W.Slice(dirSlice, slice, nil) if err != nil { return nil, err @@ -332,11 +346,14 @@ func (g *GRU) extractWeights(W tensor.Tensor) ([]tensor.Tensor, error) { // B has a shape of (num_directions, 6 * hidden_size) and every individual bias tensor should have // shape (hidden_size). We extract the biases by slicing over the '6 * hidden_size' dimension. func (g *GRU) extractBiases(B tensor.Tensor) ([]tensor.Tensor, error) { + const nWeightMatrices = 6 + dirSlice := ops.NewSlicer(0) - biases := make([]tensor.Tensor, 7) + biases := make([]tensor.Tensor, nWeightMatrices) - for i := 0; i < 6; i++ { + for i := 0; i < nWeightMatrices; i++ { slice := ops.NewSlicer(i*g.hiddenSize, (i+1)*g.hiddenSize) + w, err := B.Slice(dirSlice, slice) if err != nil { return nil, err @@ -352,9 +369,11 @@ func (g *GRU) extractBiases(B tensor.Tensor) ([]tensor.Tensor, error) { // of the gru operator this tensor can be used as the biases tensor. By default biases // are all 0. func (g *GRU) initialB() tensor.Tensor { + const nWeightMatrices = 6 + return tensor.New( - tensor.WithShape(1, 6*g.hiddenSize), - tensor.WithBacking(ops.Zeros(6*g.hiddenSize)), + tensor.WithShape(1, nWeightMatrices*g.hiddenSize), + tensor.WithBacking(ops.Zeros(nWeightMatrices*g.hiddenSize)), ) } @@ -363,6 +382,7 @@ func (g *GRU) initialB() tensor.Tensor { // (num_directions, batch_size, hidden_size). func (g *GRU) initialH(batchSize int) tensor.Tensor { hiddenFloats := ops.Zeros(batchSize * g.hiddenSize) + return tensor.New( tensor.WithShape(1, batchSize, g.hiddenSize), tensor.WithBacking(hiddenFloats), diff --git a/ops/opset13/unsqueeze.go b/ops/opset13/unsqueeze.go index 51d3cd1..cd7d0ec 100644 --- a/ops/opset13/unsqueeze.go +++ b/ops/opset13/unsqueeze.go @@ -9,6 +9,9 @@ import ( "gorgonia.org/tensor" ) +// NUnsqueezeInputs is the exact number of inputs the unsqueeze operator expects. +const NUnsqueezeInputs = 2 + // Unsqueeze represents the ONNX unsqueeze operator. type Unsqueeze struct{} @@ -18,13 +21,14 @@ func newUnsqueeze() ops.Operator { } // Init initializes the unsqueeze operator. -func (u *Unsqueeze) Init(attributes []*onnx.AttributeProto) error { +func (u *Unsqueeze) Init(_ []*onnx.AttributeProto) error { return nil } // Apply applies the unsqueeze operator. func (u *Unsqueeze) Apply(inputs []tensor.Tensor) ([]tensor.Tensor, error) { dataShape := inputs[0].Shape() + axes, err := ops.AnyToIntSlice(inputs[1].Data()) if err != nil { return nil, err @@ -48,8 +52,13 @@ func (u *Unsqueeze) Apply(inputs []tensor.Tensor) ([]tensor.Tensor, error) { newShape := insertOnes(dataShape, axes) - out := inputs[0].Clone().(tensor.Tensor) + out, ok := inputs[0].Clone().(tensor.Tensor) + if !ok { + return nil, fmt.Errorf("could not copy the input tensor") + } + err = out.Reshape(newShape...) + return []tensor.Tensor{out}, err } @@ -60,12 +69,12 @@ func (u *Unsqueeze) ValidateInputs(inputs []tensor.Tensor) ([]tensor.Tensor, err // GetMinInputs returns the minimum number of input tensors this operator expects. func (u *Unsqueeze) GetMinInputs() int { - return 2 + return NUnsqueezeInputs } // GetMaxInputs returns the maximum number of input tensors this operator expects. func (u *Unsqueeze) GetMaxInputs() int { - return 2 + return NUnsqueezeInputs } // GetInputTypeConstraints returns a list. Every element represents a set of allowed tensor dtypes @@ -91,6 +100,7 @@ func insertOnes(original, indices []int) []int { originalIdx := 0 indicesIdx := 0 + for i := 0; i < N; i++ { if indicesIdx < len(indices) && indices[indicesIdx] == i { newShape[i] = 1 @@ -100,5 +110,6 @@ func insertOnes(original, indices []int) []int { originalIdx++ } } + return newShape } diff --git a/ops/slicer.go b/ops/slicer.go index 2dc1a00..8d82430 100644 --- a/ops/slicer.go +++ b/ops/slicer.go @@ -13,6 +13,8 @@ type Slicer struct { // will be set to 1. If options are given, it is assumed that the first element will be the value // for the end index and the second element the value for the step of the Slicer. func NewSlicer(start int, options ...int) tensor.Slice { + const maxOptionLength = 2 + end := start + 1 step := 1 @@ -20,7 +22,7 @@ func NewSlicer(start int, options ...int) tensor.Slice { end = options[0] } - if len(options) >= 2 { + if len(options) >= maxOptionLength { step = options[1] } From ca1503e416f15ab52be22fd3624ada1744db6afb Mon Sep 17 00:00:00 2001 From: Swopper050 Date: Sun, 29 Oct 2023 19:40:46 +0100 Subject: [PATCH 09/15] Fix all lints except errors --- ops/convert_test.go | 1 + ops/opset13/constant_of_shape_test.go | 1 + ops/opset13/div.go | 2 +- ops/opset13/div_test.go | 1 + ops/opset13/gather.go | 27 ++++++++++++++++++++++++--- ops/opset13/gather_test.go | 1 + ops/opset13/gemm.go | 1 + ops/opset13/gemm_test.go | 2 ++ ops/opset13/gru.go | 4 +++- ops/opset13/gru_test.go | 1 + ops/opset13/matmul_test.go | 1 + ops/opset13/mul_test.go | 1 + ops/opset13/prelu.go | 15 +++++++++++---- ops/opset13/prelu_test.go | 2 ++ ops/opset13/relu.go | 2 +- ops/opset13/relu_test.go | 1 + ops/opset13/reshape.go | 24 +++++++++++++++++++----- ops/opset13/reshape_test.go | 1 + ops/opset13/scaler.go | 11 ++++++++--- ops/opset13/scaler_test.go | 1 + ops/opset13/shape.go | 10 +++++++--- ops/opset13/shape_test.go | 1 + ops/opset13/sigmoid_test.go | 1 + ops/opset13/slice_test.go | 2 ++ ops/opset13/squeeze_test.go | 1 + ops/opset13/sub_test.go | 1 + ops/opset13/tanh_test.go | 1 + ops/opset13/transpose.go | 7 +++++-- ops/opset13/transpose_test.go | 1 + ops/opset13/unsqueeze_test.go | 19 ++++++++++++------- 30 files changed, 114 insertions(+), 30 deletions(-) diff --git a/ops/convert_test.go b/ops/convert_test.go index 399de51..400f67e 100644 --- a/ops/convert_test.go +++ b/ops/convert_test.go @@ -104,6 +104,7 @@ func TestConvertTensorDtype(t *testing.T) { out, err := ConvertTensorDtype(test.tensorIn, test.newType) assert.Equal(t, test.err, err) + if test.err != nil { continue } diff --git a/ops/opset13/constant_of_shape_test.go b/ops/opset13/constant_of_shape_test.go index 97cd59c..d1cc3e9 100644 --- a/ops/opset13/constant_of_shape_test.go +++ b/ops/opset13/constant_of_shape_test.go @@ -212,6 +212,7 @@ func TestInputValidationConstantOfShape(t *testing.T) { validated, err := constantOfShape.ValidateInputs(test.inputs) assert.Equal(t, test.err, err) + if test.err == nil { assert.Equal(t, test.inputs, validated) } diff --git a/ops/opset13/div.go b/ops/opset13/div.go index a2df772..530aa89 100644 --- a/ops/opset13/div.go +++ b/ops/opset13/div.go @@ -23,7 +23,7 @@ func newDiv() ops.Operator { } // Init initializes the div operator. -func (d *Div) Init(attributes []*onnx.AttributeProto) error { +func (d *Div) Init(_ []*onnx.AttributeProto) error { return nil } diff --git a/ops/opset13/div_test.go b/ops/opset13/div_test.go index b0c7ae0..6af8f60 100644 --- a/ops/opset13/div_test.go +++ b/ops/opset13/div_test.go @@ -123,6 +123,7 @@ func TestInputValidationDiv(t *testing.T) { validated, err := div.ValidateInputs(test.inputs) assert.Equal(t, test.err, err) + if test.err == nil { assert.Equal(t, test.inputs, validated) } diff --git a/ops/opset13/gather.go b/ops/opset13/gather.go index 1449e08..c0406e2 100644 --- a/ops/opset13/gather.go +++ b/ops/opset13/gather.go @@ -54,6 +54,7 @@ func (g *Gather) Apply(inputs []tensor.Tensor) ([]tensor.Tensor, error) { if err != nil { return nil, err } + indices := tensor.New(tensor.WithBacking(indicesData), tensor.WithShape(inputs[1].Shape()...)) data := inputs[0] @@ -61,6 +62,7 @@ func (g *Gather) Apply(inputs []tensor.Tensor) ([]tensor.Tensor, error) { // Make sure axis is in the correct range (according to the size of the data tensor) rank := len(data.Shape()) dataAxis := g.axis + if dataAxis < -rank || dataAxis > rank-1 { return nil, fmt.Errorf(ops.AxisOutOfRangeErrTemplate, rank, rank, dataAxis) } @@ -75,7 +77,11 @@ func (g *Gather) Apply(inputs []tensor.Tensor) ([]tensor.Tensor, error) { if !ops.AllInRange(indicesData, -axisDimSize, axisDimSize-1) { return nil, fmt.Errorf(ops.AxesNotAllInRangeErrTemplate, axisDimSize, axisDimSize) } - ops.OffsetTensorIfNegative(indices, axisDimSize) + + err = ops.OffsetTensorIfNegative(indices, axisDimSize) + if err != nil { + return nil, err + } // Make the shape of the output tensor os := insertWithReplace(indices.Shape(), data.Shape(), dataAxis) @@ -86,6 +92,7 @@ func (g *Gather) Apply(inputs []tensor.Tensor) ([]tensor.Tensor, error) { if err != nil { return nil, err } + return []tensor.Tensor{output}, nil } @@ -163,13 +170,20 @@ func (g *Gather) String() string { // slicing to extract the blocks that we need to assign, and then pairwise assign them. func gather(out, data, indices tensor.Tensor, axis int) error { it := indices.Iterator() - for it.Reset(); !it.Done(); it.Next() { + it.Reset() + + for !it.Done() { coords := it.Coord() + at, err := indices.At(coords...) if err != nil { return err } - k := at.(int) + + k, ok := at.(int) + if !ok { + return fmt.Errorf("could not cast to int") + } // Slice that selects `k` on the given axis. // Equivalent to: data[:, ... , :, k, :, ..., :], where `k` is on the index `axis` @@ -186,12 +200,18 @@ func gather(out, data, indices tensor.Tensor, axis int) error { for i, s := range coords { oslices[i+axis] = ops.NewSlicer(s) } + outputSlice, _ := out.Slice(oslices...) err = ops.PairwiseAssign(outputSlice, dataSlice) if err != nil { return err } + + _, err = it.Next() + if err != nil { + return err + } } return nil @@ -207,6 +227,7 @@ func gather(out, data, indices tensor.Tensor, axis int) error { func insertWithReplace(a, x []int, axis int) []int { y := append([]int{}, x[:axis]...) y = append(y, a...) + if axis+1 < len(x) { y = append(y, x[axis+1:]...) } diff --git a/ops/opset13/gather_test.go b/ops/opset13/gather_test.go index 685e537..e6c4da1 100644 --- a/ops/opset13/gather_test.go +++ b/ops/opset13/gather_test.go @@ -293,6 +293,7 @@ func TestInputValidationGather(t *testing.T) { validated, err := gather.ValidateInputs(test.inputs) assert.Equal(t, test.err, err) + if test.err == nil { assert.Equal(t, test.inputs, validated) } diff --git a/ops/opset13/gemm.go b/ops/opset13/gemm.go index 09202cd..c0cbec9 100644 --- a/ops/opset13/gemm.go +++ b/ops/opset13/gemm.go @@ -55,6 +55,7 @@ func (g *Gemm) Init(attributes []*onnx.AttributeProto) error { // Apply applies the gemm operator on the given graph. func (g *Gemm) Apply(inputs []tensor.Tensor) ([]tensor.Tensor, error) { var err error + a := inputs[0] b := inputs[1] c := inputs[2] diff --git a/ops/opset13/gemm_test.go b/ops/opset13/gemm_test.go index fde5001..3d66ae6 100644 --- a/ops/opset13/gemm_test.go +++ b/ops/opset13/gemm_test.go @@ -96,6 +96,7 @@ func TestGemm(t *testing.T) { } else { inputs = append(inputs, nil) } + res, err := test.gemm.Apply(inputs) assert.Nil(t, err) @@ -161,6 +162,7 @@ func TestInputValidationGemm(t *testing.T) { validated, err := gemm.ValidateInputs(test.inputs) assert.Equal(t, test.err, err) + if test.err == nil { if test.expected != nil { assert.Equal(t, test.expected, validated) diff --git a/ops/opset13/gru.go b/ops/opset13/gru.go index 96915c2..84c95ec 100644 --- a/ops/opset13/gru.go +++ b/ops/opset13/gru.go @@ -321,8 +321,10 @@ func (g *GRU) getBiases(B tensor.Tensor) (Wbz, Wbr, Wbh, Rbz, Rbr, Rbh tensor.Te // W will have a shape of (num_directions, 3 * hidden_size, ...) and we extract the // by slicing over the '3 * hidden_size' dimension. func (g *GRU) extractWeights(W tensor.Tensor) ([]tensor.Tensor, error) { + const nWeightMatrices = 3 + dirSlice := ops.NewSlicer(0) - weights := make([]tensor.Tensor, 3) + weights := make([]tensor.Tensor, nWeightMatrices) for i := 0; i < 3; i++ { slice := ops.NewSlicer(i*g.hiddenSize, (i+1)*g.hiddenSize) diff --git a/ops/opset13/gru_test.go b/ops/opset13/gru_test.go index 683056b..cc278d3 100644 --- a/ops/opset13/gru_test.go +++ b/ops/opset13/gru_test.go @@ -215,6 +215,7 @@ func TestInputValidationGRU(t *testing.T) { validated, err := gru.ValidateInputs(test.inputs) assert.Equal(t, test.err, err) + if test.err == nil { if test.expected != nil { assert.Equal(t, test.expected, validated) diff --git a/ops/opset13/matmul_test.go b/ops/opset13/matmul_test.go index 18c85a1..71a7ac0 100644 --- a/ops/opset13/matmul_test.go +++ b/ops/opset13/matmul_test.go @@ -194,6 +194,7 @@ func TestInputValidationMatMul(t *testing.T) { validated, err := matmul.ValidateInputs(test.inputs) assert.Equal(t, test.err, err) + if test.err == nil { assert.Equal(t, test.inputs, validated) } diff --git a/ops/opset13/mul_test.go b/ops/opset13/mul_test.go index 543b0b9..b4be910 100644 --- a/ops/opset13/mul_test.go +++ b/ops/opset13/mul_test.go @@ -146,6 +146,7 @@ func TestInputValidationMul(t *testing.T) { validated, err := mul.ValidateInputs(test.inputs) assert.Equal(t, test.err, err) + if test.err == nil { assert.Equal(t, test.inputs, validated) } diff --git a/ops/opset13/prelu.go b/ops/opset13/prelu.go index b019960..754cb18 100644 --- a/ops/opset13/prelu.go +++ b/ops/opset13/prelu.go @@ -8,6 +8,11 @@ import ( "gorgonia.org/tensor" ) +const ( + PReluMinInputs = 2 + PReluMaxInputs = 2 +) + // PRelu represents the ONNX prelu operator. type PRelu struct{} @@ -17,15 +22,16 @@ func newPRelu() ops.Operator { } // Init initializes the prelu operator. -func (op *PRelu) Init(attributes []*onnx.AttributeProto) error { +func (op *PRelu) Init(_ []*onnx.AttributeProto) error { return nil } // Apply applies the prelu operator. func (op *PRelu) Apply(inputs []tensor.Tensor) ([]tensor.Tensor, error) { + var err error + x, slope := inputs[0], inputs[1] - var err error x, slope, err = ops.UnidirectionalBroadcast(x, slope) if err != nil { return nil, err @@ -70,12 +76,12 @@ func (op *PRelu) ValidateInputs(inputs []tensor.Tensor) ([]tensor.Tensor, error) // GetMinInputs returns the minimum number of input tensors this operator expects. func (op *PRelu) GetMinInputs() int { - return 2 + return PReluMinInputs } // GetMaxInputs returns the maximum number of input tensors this operator expects. func (op *PRelu) GetMaxInputs() int { - return 2 + return PReluMaxInputs } // GetInputTypeConstraints returns a list. Every element represents a set of allowed tensor dtypes @@ -97,6 +103,7 @@ func calcPRelu[T float32 | float64 | uint32 | uint64 | int32 | int64](result []T if v < 0 { v = slope[i] * v } + result[i] = v } } diff --git a/ops/opset13/prelu_test.go b/ops/opset13/prelu_test.go index 2a1c8cf..38d2de9 100644 --- a/ops/opset13/prelu_test.go +++ b/ops/opset13/prelu_test.go @@ -83,6 +83,7 @@ func TestInputValidationPRelu(t *testing.T) { validated, err := prelu.ValidateInputs(test.inputs) assert.Equal(t, test.err, err) + if test.err == nil { assert.Equal(t, test.inputs, validated) } @@ -102,6 +103,7 @@ func BenchmarkPRelu_Apply(b *testing.B) { if err != nil { b.Fatal(err) } + _ = y } } diff --git a/ops/opset13/relu.go b/ops/opset13/relu.go index 31e5330..8a169c2 100644 --- a/ops/opset13/relu.go +++ b/ops/opset13/relu.go @@ -15,7 +15,7 @@ func newRelu() ops.Operator { } // Init initializes the relu operator. -func (r *Relu) Init(attributes []*onnx.AttributeProto) error { +func (r *Relu) Init(_ []*onnx.AttributeProto) error { return nil } diff --git a/ops/opset13/relu_test.go b/ops/opset13/relu_test.go index 524779d..ba417a6 100644 --- a/ops/opset13/relu_test.go +++ b/ops/opset13/relu_test.go @@ -81,6 +81,7 @@ func TestInputValidationRelu(t *testing.T) { validated, err := relu.ValidateInputs(test.inputs) assert.Equal(t, test.err, err) + if test.err == nil { assert.Equal(t, test.inputs, validated) } diff --git a/ops/opset13/reshape.go b/ops/opset13/reshape.go index 45687a1..a255cc5 100644 --- a/ops/opset13/reshape.go +++ b/ops/opset13/reshape.go @@ -8,6 +8,11 @@ import ( "gorgonia.org/tensor" ) +const ( + ReshapeMinInputs = 2 + ReshapeMaxInputs = 2 +) + // Reshape represents the ONNX reshape operator. type Reshape struct{} @@ -17,13 +22,14 @@ func newReshape() ops.Operator { } // Init initializes the reshape operator. -func (r *Reshape) Init(attributes []*onnx.AttributeProto) error { +func (r *Reshape) Init(_ []*onnx.AttributeProto) error { return nil } // Apply applies the reshape operator. func (r *Reshape) Apply(inputs []tensor.Tensor) ([]tensor.Tensor, error) { t := inputs[0] + newShape, err := ops.AnyToIntSlice(ops.IfScalarToSlice(inputs[1].Data().([]int64))) if err != nil { return nil, err @@ -34,8 +40,13 @@ func (r *Reshape) Apply(inputs []tensor.Tensor) ([]tensor.Tensor, error) { return nil, err } - out := t.Clone().(tensor.Tensor) + out, ok := t.Clone().(tensor.Tensor) + if !ok { + return nil, fmt.Errorf("could not cast to tensor") + } + err = out.Reshape(newShape...) + return []tensor.Tensor{out}, err } @@ -46,12 +57,12 @@ func (r *Reshape) ValidateInputs(inputs []tensor.Tensor) ([]tensor.Tensor, error // GetMinInputs returns the minimum number of input tensors this operator expects. func (r *Reshape) GetMinInputs() int { - return 2 + return ReshapeMinInputs } // GetMaxInputs returns the maximum number of input tensors this operator expects. func (r *Reshape) GetMaxInputs() int { - return 2 + return ReshapeMaxInputs } // GetInputTypeConstraints returns a list. Every element represents a set of allowed tensor dtypes @@ -71,6 +82,7 @@ func processShape(newShape, currentShape []int) error { if i >= len(currentShape) { return fmt.Errorf("could not infer dim size") } + newShape[i] = currentShape[i] } } @@ -82,19 +94,21 @@ func processShape(newShape, currentShape []int) error { // When encountering a -1 dim size, calculate which size this should be. if newShape[i] == -1 { remainingSize := totalSize + for j := 0; j < len(newShape); j++ { if j == i { continue } if newShape[j] == -1 { - return fmt.Errorf("At most one -1 dim size is allowed") + return fmt.Errorf("at most one -1 dim size is allowed") } remainingSize /= newShape[j] } newShape[i] = remainingSize + break } } diff --git a/ops/opset13/reshape_test.go b/ops/opset13/reshape_test.go index 1df7205..de642b7 100644 --- a/ops/opset13/reshape_test.go +++ b/ops/opset13/reshape_test.go @@ -100,6 +100,7 @@ func TestInputValidationReshape(t *testing.T) { validated, err := reshape.ValidateInputs(test.inputs) assert.Equal(t, test.err, err) + if test.err == nil { assert.Equal(t, test.inputs, validated) } diff --git a/ops/opset13/scaler.go b/ops/opset13/scaler.go index c72d702..f642979 100644 --- a/ops/opset13/scaler.go +++ b/ops/opset13/scaler.go @@ -8,6 +8,11 @@ import ( "gorgonia.org/tensor" ) +const ( + ScalerExpectedAttributes = 2 + ScalerInputs = 1 +) + // Scaler represents the ONNX-ml scaler operator. type Scaler struct { offset tensor.Tensor @@ -21,7 +26,7 @@ func newScaler() ops.Operator { // Init initializes the scaler operator. func (s *Scaler) Init(attributes []*onnx.AttributeProto) error { - if len(attributes) != 2 { + if len(attributes) != ScalerExpectedAttributes { return fmt.Errorf(ops.InvalidAttrCountErrTemplate, s, 2, len(attributes)) } @@ -73,12 +78,12 @@ func (s *Scaler) ValidateInputs(inputs []tensor.Tensor) ([]tensor.Tensor, error) // GetMinInputs returns the minimum number of input tensors this operator expects. func (s *Scaler) GetMinInputs() int { - return 1 + return ScalerInputs } // GetMaxInputs returns the maximum number of input tensors this operator expects. func (s *Scaler) GetMaxInputs() int { - return 1 + return ScalerInputs } // GetInputTypeConstraints returns a list. Every element represents a set of allowed tensor dtypes diff --git a/ops/opset13/scaler_test.go b/ops/opset13/scaler_test.go index e0b1345..1e86462 100644 --- a/ops/opset13/scaler_test.go +++ b/ops/opset13/scaler_test.go @@ -122,6 +122,7 @@ func TestInputValidationScaler(t *testing.T) { validated, err := scaler.ValidateInputs(test.inputs) assert.Equal(t, test.err, err) + if test.err == nil { assert.Equal(t, test.inputs, validated) } diff --git a/ops/opset13/shape.go b/ops/opset13/shape.go index e507640..16b0978 100644 --- a/ops/opset13/shape.go +++ b/ops/opset13/shape.go @@ -6,6 +6,8 @@ import ( "gorgonia.org/tensor" ) +const NShapeInputs = 1 + // Shape represents the ONNX shape operator. type Shape struct{} @@ -15,7 +17,7 @@ func newShape() ops.Operator { } // Init initializes the shape operator. -func (s *Shape) Init(attributes []*onnx.AttributeProto) error { +func (s *Shape) Init(_ []*onnx.AttributeProto) error { return nil } @@ -24,11 +26,13 @@ func (s *Shape) Init(attributes []*onnx.AttributeProto) error { func (s *Shape) Apply(inputs []tensor.Tensor) ([]tensor.Tensor, error) { nodeShape := inputs[0].Shape() shape := make([]int64, len(nodeShape)) + for i, dimSize := range nodeShape { shape[i] = int64(dimSize) } out := tensor.New(tensor.WithShape(len(nodeShape)), tensor.WithBacking(shape)) + return []tensor.Tensor{out}, nil } @@ -39,12 +43,12 @@ func (s *Shape) ValidateInputs(inputs []tensor.Tensor) ([]tensor.Tensor, error) // GetMinInputs returns the minimum number of input tensors this operator expects. func (s *Shape) GetMinInputs() int { - return 1 + return NShapeInputs } // GetMaxInputs returns the maximum number of input tensors this operator expects. func (s *Shape) GetMaxInputs() int { - return 1 + return NShapeInputs } // GetInputTypeConstraints returns a list. Every element represents a set of allowed tensor dtypes diff --git a/ops/opset13/shape_test.go b/ops/opset13/shape_test.go index 4da1aa4..7cc964b 100644 --- a/ops/opset13/shape_test.go +++ b/ops/opset13/shape_test.go @@ -73,6 +73,7 @@ func TestInputValidationShape(t *testing.T) { validated, err := shape.ValidateInputs(test.inputs) assert.Equal(t, test.err, err) + if test.err == nil { assert.Equal(t, test.inputs, validated) } diff --git a/ops/opset13/sigmoid_test.go b/ops/opset13/sigmoid_test.go index 7dddd20..3e64bd3 100644 --- a/ops/opset13/sigmoid_test.go +++ b/ops/opset13/sigmoid_test.go @@ -92,6 +92,7 @@ func TestInputValidationSigmoid(t *testing.T) { validated, err := sigmoid.ValidateInputs(test.inputs) assert.Equal(t, test.err, err) + if test.err == nil { assert.Equal(t, test.inputs, validated) } diff --git a/ops/opset13/slice_test.go b/ops/opset13/slice_test.go index 1741ed8..7165c96 100644 --- a/ops/opset13/slice_test.go +++ b/ops/opset13/slice_test.go @@ -140,6 +140,7 @@ func TestConstructSlices(t *testing.T) { ) assert.Equal(t, test.nSlices, len(slices)) + for i := 0; i < test.nSlices; i++ { if test.expectedSlices[i] == nil { assert.Nil(t, slices[i]) @@ -217,6 +218,7 @@ func TestInputValidationSlice(t *testing.T) { validated, err := slice.ValidateInputs(test.inputs) assert.Equal(t, test.err, err) + if test.err == nil { if test.expected != nil { assert.Equal(t, test.expected, validated) diff --git a/ops/opset13/squeeze_test.go b/ops/opset13/squeeze_test.go index 420df27..38279af 100644 --- a/ops/opset13/squeeze_test.go +++ b/ops/opset13/squeeze_test.go @@ -175,6 +175,7 @@ func TestInputValidationSqueeze(t *testing.T) { validated, err := squeeze.ValidateInputs(test.inputs) assert.Equal(t, test.err, err) + if test.err == nil { if test.expected != nil { assert.Equal(t, test.expected, validated) diff --git a/ops/opset13/sub_test.go b/ops/opset13/sub_test.go index 725113c..d9830af 100644 --- a/ops/opset13/sub_test.go +++ b/ops/opset13/sub_test.go @@ -120,6 +120,7 @@ func TestInputValidationSub(t *testing.T) { validated, err := sub.ValidateInputs(test.inputs) assert.Equal(t, test.err, err) + if test.err == nil { assert.Equal(t, test.inputs, validated) } diff --git a/ops/opset13/tanh_test.go b/ops/opset13/tanh_test.go index 1b05d14..86a7110 100644 --- a/ops/opset13/tanh_test.go +++ b/ops/opset13/tanh_test.go @@ -84,6 +84,7 @@ func TestInputValidationTanh(t *testing.T) { validated, err := tanh.ValidateInputs(test.inputs) assert.Equal(t, test.err, err) + if test.err == nil { assert.Equal(t, test.inputs, validated) } diff --git a/ops/opset13/transpose.go b/ops/opset13/transpose.go index 6acb9f7..4f7f4f2 100644 --- a/ops/opset13/transpose.go +++ b/ops/opset13/transpose.go @@ -8,6 +8,8 @@ import ( "gorgonia.org/tensor" ) +const TransposeInputs = 1 + // Transpose represents the ONNX transpose operator. type Transpose struct { perm []int @@ -34,6 +36,7 @@ func (t *Transpose) Init(attributes []*onnx.AttributeProto) error { for _, val := range attrPerm { t.perm = append(t.perm, int(val)) } + return nil } @@ -54,12 +57,12 @@ func (t *Transpose) ValidateInputs(inputs []tensor.Tensor) ([]tensor.Tensor, err // GetMinInputs returns the minimum number of input tensors this operator expects. func (t *Transpose) GetMinInputs() int { - return 1 + return TransposeInputs } // GetMaxInputs returns the maximum number of input tensors this operator expects. func (t *Transpose) GetMaxInputs() int { - return 1 + return TransposeInputs } // GetInputTypeConstraints returns a list. Every element represents a set of allowed tensor dtypes diff --git a/ops/opset13/transpose_test.go b/ops/opset13/transpose_test.go index f80b77f..37d00fe 100644 --- a/ops/opset13/transpose_test.go +++ b/ops/opset13/transpose_test.go @@ -98,6 +98,7 @@ func TestInputValidationTranspose(t *testing.T) { validated, err := transpose.ValidateInputs(test.inputs) assert.Equal(t, test.err, err) + if test.err == nil { assert.Equal(t, test.inputs, validated) } diff --git a/ops/opset13/unsqueeze_test.go b/ops/opset13/unsqueeze_test.go index ea1b151..c4140c0 100644 --- a/ops/opset13/unsqueeze_test.go +++ b/ops/opset13/unsqueeze_test.go @@ -20,21 +20,23 @@ func TestUnsqueezeInit(t *testing.T) { func TestAxesOutRangeError(t *testing.T) { op := &Unsqueeze{} - op.Init(nil) + err := op.Init(nil) + assert.Nil(t, err) axes := []int64{4} data := ops.Arange(9, 1) // 3 x 3 tensor dataIn := ops.TensorWithBackingFixture(data, 3, 3) axesIn := ops.TensorWithBackingFixture(axes, len(axes)) - _, err := op.Apply([]tensor.Tensor{dataIn, axesIn}) + _, err = op.Apply([]tensor.Tensor{dataIn, axesIn}) expected := fmt.Errorf(ops.AxesNotAllInRangeErrTemplate, 3, 3) assert.Equal(t, err, expected) } func TestDuplicateEntriesAfterOffsetNotAllowed(t *testing.T) { op := &Unsqueeze{} - op.Init(nil) + err := op.Init(nil) + assert.Nil(t, err) // -1 will be offset to 3 (since outputrank = 4) axes := []int64{3, -1} @@ -42,20 +44,21 @@ func TestDuplicateEntriesAfterOffsetNotAllowed(t *testing.T) { dataIn := ops.TensorWithBackingFixture(data, 3, 3) axesIn := ops.TensorWithBackingFixture(axes, len(axes)) - _, err := op.Apply([]tensor.Tensor{dataIn, axesIn}) + _, err = op.Apply([]tensor.Tensor{dataIn, axesIn}) assert.EqualError(t, err, "Axes cannot have duplicate entries after offset, axes: [3 3]") } func TestDuplicateEntriesNotAllowed(t *testing.T) { op := &Unsqueeze{} - op.Init(nil) + err := op.Init(nil) + assert.Nil(t, err) axes := []int64{0, 0} data := ops.Arange(9, 1) // 3 x 3 tensor dataIn := ops.TensorWithBackingFixture(data, 3, 3) axesIn := ops.TensorWithBackingFixture(axes, len(axes)) - _, err := op.Apply([]tensor.Tensor{dataIn, axesIn}) + _, err = op.Apply([]tensor.Tensor{dataIn, axesIn}) assert.EqualError(t, err, "Axes cannot have duplicate entries after offset, axes: [0 0]") } @@ -108,7 +111,8 @@ func TestUnsqueeze(t *testing.T) { } for _, test := range tests { op := &Unsqueeze{} - op.Init(nil) + err := op.Init(nil) + assert.Nil(t, err) axes := test.axes data := test.data @@ -168,6 +172,7 @@ func TestInputValidationUnsqueeze(t *testing.T) { validated, err := unsqueeze.ValidateInputs(test.inputs) assert.Equal(t, test.err, err) + if test.err == nil { assert.Equal(t, test.inputs, validated) } From 5dad4475d213436843d10f6445f6aa46e495c3e9 Mon Sep 17 00:00:00 2001 From: Swopper050 Date: Mon, 6 Nov 2023 14:28:54 +0100 Subject: [PATCH 10/15] Fix all lints --- ops/errors.go | 156 +++++++++++++++++--------- ops/multidir_broadcast.go | 2 +- ops/multidir_broadcast_test.go | 15 +-- ops/opset13/add_test.go | 2 +- ops/opset13/constant_of_shape_test.go | 1 + ops/opset13/gather.go | 22 ++-- ops/opset13/gather_test.go | 5 +- ops/opset13/gemm_test.go | 3 +- ops/opset13/gru.go | 8 +- ops/opset13/matmul.go | 8 +- ops/opset13/mul_test.go | 8 +- ops/opset13/opset13.go | 2 +- ops/opset13/prelu.go | 51 ++++++--- ops/opset13/prelu_test.go | 5 +- ops/opset13/relu_test.go | 5 +- ops/opset13/reshape.go | 8 +- ops/opset13/reshape_test.go | 5 +- ops/opset13/scaler.go | 6 +- ops/opset13/scaler_test.go | 9 +- ops/opset13/shape_test.go | 5 +- ops/opset13/sigmoid_test.go | 5 +- ops/opset13/slice_test.go | 5 +- ops/opset13/squeeze.go | 2 +- ops/opset13/squeeze_test.go | 7 +- ops/opset13/sub_test.go | 5 +- ops/opset13/tanh_test.go | 5 +- ops/opset13/transpose.go | 6 +- ops/opset13/transpose_test.go | 9 +- ops/opset13/unsqueeze.go | 7 +- ops/opset13/unsqueeze_test.go | 9 +- opset.go | 6 +- opset_test.go | 4 +- 32 files changed, 210 insertions(+), 186 deletions(-) diff --git a/ops/errors.go b/ops/errors.go index a59b7c5..d2d069c 100644 --- a/ops/errors.go +++ b/ops/errors.go @@ -62,18 +62,6 @@ func ErrTypeAssert(expected string, actual any) error { return &TypeAssertError{expectedType: expected, actualType: actual} } -// UnknownAttributeErrTemplate is used to format an error -// when an operator finds an unknown attribute during its initialization. -const UnknownAttributeErrTemplate = "%v: unknown attribute: %v" - -// UnsupportedAttrErrTemplate is used to format an error when an operator receives -// an attribute that is not supported yet. -const UnsupportedAttrErrTemplate = "%v: %v attribute not supported yet" - -// InvalidAttrCountErrTemplate is used to format an error when an operator -// got the wrong amount of attributes. -const InvalidAttrCountErrTemplate = "%v: expected %v attributes, got %d" - // InvalidInputCountErrTemplate is used to format an error when an operator got // the wrong amount of input tensors. const InvalidInputCountErrTemplate = "%v: expected %d input tensors, got %d" @@ -82,53 +70,101 @@ const InvalidInputCountErrTemplate = "%v: expected %d input tensors, got %d" // the wrong amount of input tensors when optional inputs are present. const InvalidOptionalInputCountErrTemplate = "%v: expected %d-%d input tensors, got %d" -type InvalidInputError struct { +// UnsupportedInputErrTemplate is used to format an error when an operator got +// the wrong amount of input tensors when optional inputs are present. +const UnsupportedInputErrTemplate = "unsupported input for %v: %v" + +// InvalidInputErrTemplate is used to format an error when an operator got +// an invalid input tensor as input. +const InvalidInputErrTemplate = "invalid input tensor for %v: %v" + +type InputErrorKind string + +const ( + InputErrorType InputErrorKind = "type" + InputErrorCount InputErrorKind = "count" + InputErrorUnsupported InputErrorKind = "unsupported" + InputErrorInvalid InputErrorKind = "invalid" +) + +type InputError struct { + kind InputErrorKind + operator Operator + reason string + + // Attributes for input type error. inputNumber int actualType string - operator Operator + + // Attributes for input count error. + hasOptionalInputs bool + actualCount int + + // Attributes for unsupported input error. + inputName string } -func (i *InvalidInputError) Error() string { - return fmt.Sprintf("input %d for op %v does not allow dtype %v", i.inputNumber, i.operator, i.actualType) +func (i *InputError) Error() string { + switch i.kind { + case InputErrorType: + return fmt.Sprintf("input %d for op %v does not allow dtype %v", i.inputNumber, i.operator, i.actualType) + case InputErrorCount: + if i.hasOptionalInputs { + return fmt.Sprintf(InvalidOptionalInputCountErrTemplate, i.operator, i.operator.GetMinInputs(), i.operator.GetMaxInputs(), i.actualCount) + } + + return fmt.Sprintf(InvalidInputCountErrTemplate, i.operator, i.operator.GetMinInputs(), i.actualCount) + case InputErrorUnsupported: + return fmt.Sprintf(UnsupportedInputErrTemplate, i.operator, i.inputName) + case InputErrorInvalid: + return fmt.Sprintf(InvalidInputErrTemplate, i.operator, i.reason) + default: + return fmt.Sprintf("operator %s unknown error input error kind %s", i.operator.String(), i.kind) + } } func ErrInvalidInputType(inputNumber int, dType string, operator Operator) error { - return &InvalidInputError{ + return &InputError{ + kind: InputErrorType, operator: operator, inputNumber: inputNumber, actualType: dType, } } -type InvalidInputCountError struct { - hasOptionalInputs bool - actualCount int - operator Operator -} - -func (i *InvalidInputCountError) Error() string { - if i.hasOptionalInputs { - return fmt.Sprintf(InvalidOptionalInputCountErrTemplate, i.operator, i.operator.GetMinInputs(), i.operator.GetMaxInputs(), i.actualCount) - } - - return fmt.Sprintf(InvalidInputCountErrTemplate, i.operator, i.operator.GetMinInputs(), i.actualCount) -} - func ErrInvalidInputCount(actual int, operator Operator) error { - return &InvalidInputCountError{ + return &InputError{ + kind: InputErrorCount, actualCount: actual, operator: operator, } } func ErrInvalidOptionalInputCount(actual int, operator Operator) error { - return &InvalidInputCountError{ + return &InputError{ + kind: InputErrorCount, hasOptionalInputs: true, actualCount: actual, operator: operator, } } +func ErrUnsupportedInput(inputName string, operator Operator) error { + return &InputError{ + kind: InputErrorUnsupported, + inputName: inputName, + operator: operator, + } +} + +func ErrInvalidInput(reason string, operator Operator) error { + return &InputError{ + kind: InputErrorInvalid, + reason: reason, + operator: operator, + } +} + type BroadcastError struct { broadcastType string shapeA tensor.Shape @@ -170,36 +206,48 @@ func ErrInvalidTensor(reason string, operator Operator) error { return &InvalidTensorError{reason: reason, operator: operator} } -var ErrIncompatibleDimension = errors.New("incompatible dimensions") - -var UnknownOperatorTypeError = errors.New("unknown operator type") +var ErrUnsupportedOperator = errors.New("unsupported operator") func ErrUnknownOperatorType(operatorType string) error { - return fmt.Errorf("%w: %s", UnknownOperatorTypeError, operatorType) + return fmt.Errorf("%w: %s", ErrUnsupportedOperator, operatorType) } -// MultidirBroadcastErrTemplate is used to format an error when two inputs cannot be -// broadcasted together with Multidirectional broadcasting. -const MultidirBroadcastErrTemplate = "could not multidir broadcast inputs with shape %d and %d: %v" - -// UnidirBroadcastErrTemplate is used to format an error when two inputs cannot be -// broadcasted together with Unidirectional broadcasting. -const UnidirBroadcastErrTemplate = "could not unidir broadcast inputs with shape %d and %d" - -// AxisOutOfRangeErrTemplate is used to format an error when an given axis is out of range -// given a certain rank. -const AxisOutOfRangeErrTemplate = "axis argument must be in the range -%d <= x < %d, was %d" - -// AxesNotAllInRangeErrTemplate is used to format an error when not all indices -// are within a given range. -const AxesNotAllInRangeErrTemplate = "all indices entries must be in the range -%d <= x < %d" - var ErrAxisNotInRange = errors.New("axis out of range") -func ErrNotAllAxisInRange(min, max int) error { +func ErrNotAllAxesInRange(min, max int) error { return fmt.Errorf("%w: all indices entries must be in the range -%d <= x < %d", ErrAxisNotInRange, min, max) } func ErrAxisOutOfRange(min, max, actual int) error { return fmt.Errorf("%w: axis argument must be in the range -%d <= x < %d, was %d", ErrAxisNotInRange, min, max, actual) } + +var ErrUnsupportedOpsetVersion = errors.New("unsupported opset version") + +type DimensionErrorKind string + +const ( + DimensionErrorIncompatible DimensionErrorKind = "incompatible" +) + +type DimensionError struct { + kind DimensionErrorKind + reason string +} + +func (d *DimensionError) Error() string { + switch d.kind { + case DimensionErrorIncompatible: + return fmt.Sprintf("dimensions error: incompatible dimensions") + default: + return fmt.Sprintf("dimension error: %s", d.reason) + } +} + +func ErrIncompatibleDimensions() error { + return &DimensionError{kind: DimensionErrorIncompatible, reason: ""} +} + +func ErrDimension(reason string) error { + return &DimensionError{reason: reason} +} diff --git a/ops/multidir_broadcast.go b/ops/multidir_broadcast.go index a55a361..3b016dd 100644 --- a/ops/multidir_broadcast.go +++ b/ops/multidir_broadcast.go @@ -83,7 +83,7 @@ func repeatTensorsForMutltidirBroadcast(A, B tensor.Tensor) (tensor.Tensor, tens } default: - return nil, nil, ErrIncompatibleDimension + return nil, nil, ErrIncompatibleDimensions() } } } diff --git a/ops/multidir_broadcast_test.go b/ops/multidir_broadcast_test.go index 3433a37..02e5cac 100644 --- a/ops/multidir_broadcast_test.go +++ b/ops/multidir_broadcast_test.go @@ -1,7 +1,6 @@ package ops import ( - "fmt" "testing" "github.com/stretchr/testify/assert" @@ -47,22 +46,12 @@ func TestMultidirectionalBroadcast(t *testing.T) { { [][]int{{1, 4, 5}, {2, 1, 1, 3}}, nil, - fmt.Errorf( - MultidirBroadcastErrTemplate, - []int{1, 4, 5}, - []int{2, 1, 1, 3}, - "incompatible dimensions", - ), + ErrMultidirBroadcast([]int{1, 4, 5}, []int{2, 1, 1, 3}, ErrIncompatibleDimensions()), }, { [][]int{{5}, {2, 3, 4}}, nil, - fmt.Errorf( - MultidirBroadcastErrTemplate, - []int{5}, - []int{2, 3, 4}, - "incompatible dimensions", - ), + ErrMultidirBroadcast([]int{5}, []int{2, 3, 4}, ErrIncompatibleDimensions()), }, } diff --git a/ops/opset13/add_test.go b/ops/opset13/add_test.go index 2b48621..4944bf2 100644 --- a/ops/opset13/add_test.go +++ b/ops/opset13/add_test.go @@ -66,7 +66,7 @@ func TestAddFail(t *testing.T) { add := &Add{} _, err := add.Apply(inputs) - assert.Equal(t, err, ops.ErrMultidirBroadcast(inputs[0].Shape(), inputs[1].Shape(), ops.ErrIncompatibleDimension)) + assert.Equal(t, err, ops.ErrMultidirBroadcast(inputs[0].Shape(), inputs[1].Shape(), ops.ErrIncompatibleDimensions())) } func TestInputValidationAdd(t *testing.T) { diff --git a/ops/opset13/constant_of_shape_test.go b/ops/opset13/constant_of_shape_test.go index d1cc3e9..17e6053 100644 --- a/ops/opset13/constant_of_shape_test.go +++ b/ops/opset13/constant_of_shape_test.go @@ -86,6 +86,7 @@ func TestConstantOfShape(t *testing.T) { // Make the input tensor tp := TensorProtoFromNumber(test.input) assert.NotNil(t, tp) + attr := []*onnx.AttributeProto{{Name: "value", T: tp}} // Create operator diff --git a/ops/opset13/gather.go b/ops/opset13/gather.go index c0406e2..07d9c45 100644 --- a/ops/opset13/gather.go +++ b/ops/opset13/gather.go @@ -1,8 +1,6 @@ package opset13 import ( - "fmt" - "github.com/advancedclimatesystems/gonnx/onnx" "github.com/advancedclimatesystems/gonnx/ops" "gorgonia.org/tensor" @@ -23,16 +21,14 @@ type Gather struct { // newGather creates a new gather operator. func newGather() ops.Operator { - return &Gather{} + return &Gather{ + axis: 0, + } } // Init initializes the gather operator. func (g *Gather) Init(attributes []*onnx.AttributeProto) error { - switch length := len(attributes); { - case length > 1: - return fmt.Errorf(ops.InvalidAttrCountErrTemplate, g, "0 or 1", len(attributes)) - - case length == 1: + if len(attributes) == 1 { attr := attributes[0] if attr.GetName() == "axis" { @@ -40,8 +36,8 @@ func (g *Gather) Init(attributes []*onnx.AttributeProto) error { } else { return ops.ErrInvalidAttribute(attr.GetName(), g) } - default: - g.axis = 0 + } else if len(attributes) > 1 { + return ops.ErrInvalidAttributeCount(1, len(attributes), g) } return nil @@ -64,7 +60,7 @@ func (g *Gather) Apply(inputs []tensor.Tensor) ([]tensor.Tensor, error) { dataAxis := g.axis if dataAxis < -rank || dataAxis > rank-1 { - return nil, fmt.Errorf(ops.AxisOutOfRangeErrTemplate, rank, rank, dataAxis) + return nil, ops.ErrAxisOutOfRange(rank, rank, dataAxis) } // Offset axis if a negative index is given. if dataAxis < 0 { @@ -75,7 +71,7 @@ func (g *Gather) Apply(inputs []tensor.Tensor) ([]tensor.Tensor, error) { // dimension which is selected by `axis`) axisDimSize := data.Shape()[dataAxis] if !ops.AllInRange(indicesData, -axisDimSize, axisDimSize-1) { - return nil, fmt.Errorf(ops.AxesNotAllInRangeErrTemplate, axisDimSize, axisDimSize) + return nil, ops.ErrNotAllAxesInRange(axisDimSize, axisDimSize) } err = ops.OffsetTensorIfNegative(indices, axisDimSize) @@ -182,7 +178,7 @@ func gather(out, data, indices tensor.Tensor, axis int) error { k, ok := at.(int) if !ok { - return fmt.Errorf("could not cast to int") + return ops.ErrTypeAssert("int", at) } // Slice that selects `k` on the given axis. diff --git a/ops/opset13/gather_test.go b/ops/opset13/gather_test.go index e6c4da1..a9cd9ae 100644 --- a/ops/opset13/gather_test.go +++ b/ops/opset13/gather_test.go @@ -1,7 +1,6 @@ package opset13 import ( - "fmt" "testing" "github.com/advancedclimatesystems/gonnx/onnx" @@ -277,14 +276,14 @@ func TestInputValidationGather(t *testing.T) { }, { []tensor.Tensor{ops.TensorWithBackingFixture([]int{1, 2}, 2)}, - fmt.Errorf("gather operator: expected 2 input tensors, got 1"), + ops.ErrInvalidInputCount(1, &Gather{}), }, { []tensor.Tensor{ ops.TensorWithBackingFixture([]float64{1, 2}, 2), ops.TensorWithBackingFixture([]float32{3, 4}, 2), }, - fmt.Errorf("gather operator: input 1 does not allow type float32"), + ops.ErrInvalidInputType(1, "float32", &Gather{}), }, } diff --git a/ops/opset13/gemm_test.go b/ops/opset13/gemm_test.go index 3d66ae6..d882b16 100644 --- a/ops/opset13/gemm_test.go +++ b/ops/opset13/gemm_test.go @@ -1,7 +1,6 @@ package opset13 import ( - "fmt" "testing" "github.com/advancedclimatesystems/gonnx/onnx" @@ -25,7 +24,7 @@ func TestGemmInitFail(t *testing.T) { gemm := &Gemm{} err := gemm.Init([]*onnx.AttributeProto{{Name: "unknownAttribute"}}) - expected := fmt.Errorf(ops.UnknownAttributeErrTemplate, gemm, "unknownAttribute") + expected := ops.ErrInvalidAttribute("unknownAttribute", gemm) assert.Equal(t, expected, err) } diff --git a/ops/opset13/gru.go b/ops/opset13/gru.go index 84c95ec..1cb1258 100644 --- a/ops/opset13/gru.go +++ b/ops/opset13/gru.go @@ -1,8 +1,6 @@ package opset13 import ( - "fmt" - "github.com/advancedclimatesystems/gonnx/onnx" "github.com/advancedclimatesystems/gonnx/ops" "gorgonia.org/tensor" @@ -58,7 +56,7 @@ func (g *GRU) Apply(inputs []tensor.Tensor) ([]tensor.Tensor, error) { B := inputs[3] if inputs[4] != nil { - return nil, fmt.Errorf("%v: sequence lens not yet supported as input", g) + return nil, ops.ErrUnsupportedInput("sequence lens", g) } initialH := inputs[5] @@ -91,7 +89,7 @@ func (g *GRU) Apply(inputs []tensor.Tensor) ([]tensor.Tensor, error) { var ok bool prevH, ok = initialH.Clone().(tensor.Tensor) if !ok { - return nil, fmt.Errorf("could not clone the initial hidden state tensor") + return nil, ops.ErrTypeAssert("tensor.Tensor", initialH.Clone()) } } @@ -153,7 +151,7 @@ func (g *GRU) Apply(inputs []tensor.Tensor) ([]tensor.Tensor, error) { Yh, ok := prevH.Clone().(tensor.Tensor) if !ok { - return nil, fmt.Errorf("could not clone the hidden tensor") + return nil, ops.ErrTypeAssert("tensor.Tensor", prevH.Clone()) } // Reshape the output so it adds the num_directions as specified by onnx. diff --git a/ops/opset13/matmul.go b/ops/opset13/matmul.go index 81da79a..fc6365e 100644 --- a/ops/opset13/matmul.go +++ b/ops/opset13/matmul.go @@ -1,8 +1,6 @@ package opset13 import ( - "fmt" - "github.com/advancedclimatesystems/gonnx/onnx" "github.com/advancedclimatesystems/gonnx/ops" "gorgonia.org/tensor" @@ -51,7 +49,7 @@ func (m *MatMul) Apply(inputs []tensor.Tensor) ([]tensor.Tensor, error) { A, ok := A.Clone().(tensor.Tensor) if !ok { - return nil, ops.ErrTypeAssert("tensor.Tensor", A) + return nil, ops.ErrTypeAssert("tensor.Tensor", A.Clone()) } if err := A.Reshape(1, A.Shape()[0]); err != nil { @@ -66,7 +64,7 @@ func (m *MatMul) Apply(inputs []tensor.Tensor) ([]tensor.Tensor, error) { B, ok := B.Clone().(tensor.Tensor) if !ok { - return nil, ops.ErrTypeAssert("tensor.Tensor", A) + return nil, ops.ErrTypeAssert("tensor.Tensor", B.Clone()) } if err := B.Reshape(B.Shape()[0], 1); err != nil { @@ -172,7 +170,7 @@ func (m *MatMul) broadcastTensors(A, B tensor.Tensor) (tensor.Tensor, tensor.Ten return nil, nil, err } default: - return nil, nil, fmt.Errorf("incompatible dimensions") + return nil, nil, ops.ErrIncompatibleDimensions() } } } diff --git a/ops/opset13/mul_test.go b/ops/opset13/mul_test.go index b4be910..e6d00e4 100644 --- a/ops/opset13/mul_test.go +++ b/ops/opset13/mul_test.go @@ -1,7 +1,6 @@ package opset13 import ( - "fmt" "testing" "github.com/advancedclimatesystems/gonnx/ops" @@ -70,12 +69,7 @@ func TestMulFail(t *testing.T) { assert.Equal( t, err, - fmt.Errorf( - ops.MultidirBroadcastErrTemplate, - []int{2, 2}, - []int{3}, - "incompatible dimensions", - ), + ops.ErrMultidirBroadcast([]int{2, 2}, []int{3}, ops.ErrIncompatibleDimensions()), ) } diff --git a/ops/opset13/opset13.go b/ops/opset13/opset13.go index 4e8bd2c..122e8ec 100644 --- a/ops/opset13/opset13.go +++ b/ops/opset13/opset13.go @@ -42,7 +42,7 @@ func GetOperator(operatorType string) (ops.Operator, error) { // GetOpNames returns a list with operator names for opset 13. func GetOpNames() []string { - var opList []string + opList := make([]string, 0, len(operators13)) for opName := range operators13 { opList = append(opList, opName) diff --git a/ops/opset13/prelu.go b/ops/opset13/prelu.go index 754cb18..5131dc8 100644 --- a/ops/opset13/prelu.go +++ b/ops/opset13/prelu.go @@ -1,8 +1,6 @@ package opset13 import ( - "fmt" - "github.com/advancedclimatesystems/gonnx/onnx" "github.com/advancedclimatesystems/gonnx/ops" "gorgonia.org/tensor" @@ -41,21 +39,25 @@ func (op *PRelu) Apply(inputs []tensor.Tensor) ([]tensor.Tensor, error) { switch x.Dtype() { case tensor.Float32: - calcPRelu(y.Data().([]float32), x.Data().([]float32), slope.Data().([]float32)) + err = calcPRelu[float32](y.Data(), x.Data(), slope.Data()) case tensor.Float64: - calcPRelu(y.Data().([]float64), x.Data().([]float64), slope.Data().([]float64)) + err = calcPRelu[float64](y.Data(), x.Data(), slope.Data()) case tensor.Uint32: - calcPRelu(y.Data().([]uint32), x.Data().([]uint32), slope.Data().([]uint32)) + err = calcPRelu[uint32](y.Data(), x.Data(), slope.Data()) case tensor.Uint64: - calcPRelu(y.Data().([]uint64), x.Data().([]uint64), slope.Data().([]uint64)) + err = calcPRelu[uint64](y.Data(), x.Data(), slope.Data()) case tensor.Int32: - calcPRelu(y.Data().([]int32), x.Data().([]int32), slope.Data().([]int32)) + err = calcPRelu[int32](y.Data(), x.Data(), slope.Data()) case tensor.Int64: - calcPRelu(y.Data().([]int64), x.Data().([]int64), slope.Data().([]int64)) + err = calcPRelu[int64](y.Data(), x.Data(), slope.Data()) default: return nil, ops.ErrInvalidInputType(0, x.Dtype().String(), op) } + if err != nil { + return nil, err + } + return []tensor.Tensor{y}, nil } @@ -68,7 +70,7 @@ func (op *PRelu) ValidateInputs(inputs []tensor.Tensor) ([]tensor.Tensor, error) x, slope := inputs[0], inputs[1] if x.Dtype() != slope.Dtype() { - return nil, fmt.Errorf("%v: type of slope (%s) does not match type of X (%s)", op, slope.Dtype(), x.Dtype()) + return nil, ops.ErrInvalidTensor("DType of 'slope' does not match DType of 'x'", op) } return inputs, nil @@ -98,12 +100,35 @@ func (op *PRelu) String() string { return "prelu operator" } -func calcPRelu[T float32 | float64 | uint32 | uint64 | int32 | int64](result []T, input []T, slope []T) { - for i, v := range input { +func calcPRelu[T float32 | float64 | uint32 | uint64 | int32 | int64](result any, input any, slope any) error { + var convertedResult []T + + var convertedInput []T + + var convertedSlope []T + + convertedResult, ok := result.([]T) + if !ok { + return ops.ErrTypeAssert("numeric list", result) + } + + convertedInput, ok = input.([]T) + if !ok { + return ops.ErrTypeAssert("numeric list", result) + } + + convertedSlope, ok = slope.([]T) + if !ok { + return ops.ErrTypeAssert("numeric list", result) + } + + for i, v := range convertedInput { if v < 0 { - v = slope[i] * v + v = convertedSlope[i] * v } - result[i] = v + convertedResult[i] = v } + + return nil } diff --git a/ops/opset13/prelu_test.go b/ops/opset13/prelu_test.go index 38d2de9..763cbfb 100644 --- a/ops/opset13/prelu_test.go +++ b/ops/opset13/prelu_test.go @@ -1,7 +1,6 @@ package opset13 import ( - "fmt" "testing" "github.com/advancedclimatesystems/gonnx/ops" @@ -67,14 +66,14 @@ func TestInputValidationPRelu(t *testing.T) { }, { []tensor.Tensor{}, - fmt.Errorf("prelu operator: expected 2 input tensors, got 0"), + ops.ErrInvalidInputCount(0, &PRelu{}), }, { []tensor.Tensor{ ops.TensorWithBackingFixture([]int{1, 2}, 2), ops.TensorWithBackingFixture([]float32{1, 2}, 2), }, - fmt.Errorf("prelu operator: input 0 does not allow type int"), + ops.ErrInvalidInputType(0, "int", &PRelu{}), }, } diff --git a/ops/opset13/relu_test.go b/ops/opset13/relu_test.go index ba417a6..b2d5fa0 100644 --- a/ops/opset13/relu_test.go +++ b/ops/opset13/relu_test.go @@ -1,7 +1,6 @@ package opset13 import ( - "fmt" "testing" "github.com/advancedclimatesystems/gonnx/ops" @@ -68,11 +67,11 @@ func TestInputValidationRelu(t *testing.T) { }, { []tensor.Tensor{}, - fmt.Errorf("relu operator: expected 1 input tensors, got 0"), + ops.ErrInvalidInputCount(0, &Relu{}), }, { []tensor.Tensor{ops.TensorWithBackingFixture([]int{1, 2}, 2)}, - fmt.Errorf("relu operator: input 0 does not allow type int"), + ops.ErrInvalidInputType(0, "int", &Relu{}), }, } diff --git a/ops/opset13/reshape.go b/ops/opset13/reshape.go index a255cc5..140d24d 100644 --- a/ops/opset13/reshape.go +++ b/ops/opset13/reshape.go @@ -1,8 +1,6 @@ package opset13 import ( - "fmt" - "github.com/advancedclimatesystems/gonnx/onnx" "github.com/advancedclimatesystems/gonnx/ops" "gorgonia.org/tensor" @@ -42,7 +40,7 @@ func (r *Reshape) Apply(inputs []tensor.Tensor) ([]tensor.Tensor, error) { out, ok := t.Clone().(tensor.Tensor) if !ok { - return nil, fmt.Errorf("could not cast to tensor") + return nil, ops.ErrTypeAssert("tensor.Tensor", t.Clone()) } err = out.Reshape(newShape...) @@ -80,7 +78,7 @@ func processShape(newShape, currentShape []int) error { for i := 0; i < len(newShape); i++ { if newShape[i] == 0 { if i >= len(currentShape) { - return fmt.Errorf("could not infer dim size") + return ops.ErrDimension("could not infer dim size") } newShape[i] = currentShape[i] @@ -101,7 +99,7 @@ func processShape(newShape, currentShape []int) error { } if newShape[j] == -1 { - return fmt.Errorf("at most one -1 dim size is allowed") + return ops.ErrDimension("at most one -1 dim size is allowed") } remainingSize /= newShape[j] diff --git a/ops/opset13/reshape_test.go b/ops/opset13/reshape_test.go index de642b7..8651d32 100644 --- a/ops/opset13/reshape_test.go +++ b/ops/opset13/reshape_test.go @@ -1,7 +1,6 @@ package opset13 import ( - "fmt" "testing" "github.com/advancedclimatesystems/gonnx/ops" @@ -84,14 +83,14 @@ func TestInputValidationReshape(t *testing.T) { }, { []tensor.Tensor{ops.TensorWithBackingFixture([]int{1, 2}, 2)}, - fmt.Errorf("reshape operator: expected 2 input tensors, got 1"), + ops.ErrInvalidInputCount(1, &Reshape{}), }, { []tensor.Tensor{ ops.TensorWithBackingFixture([]float64{1, 2}, 2), ops.TensorWithBackingFixture([]int{3, 4}, 2), }, - fmt.Errorf("reshape operator: input 1 does not allow type int"), + ops.ErrInvalidInputType(1, "int", &Reshape{}), }, } diff --git a/ops/opset13/scaler.go b/ops/opset13/scaler.go index f642979..811ce80 100644 --- a/ops/opset13/scaler.go +++ b/ops/opset13/scaler.go @@ -1,8 +1,6 @@ package opset13 import ( - "fmt" - "github.com/advancedclimatesystems/gonnx/onnx" "github.com/advancedclimatesystems/gonnx/ops" "gorgonia.org/tensor" @@ -27,7 +25,7 @@ func newScaler() ops.Operator { // Init initializes the scaler operator. func (s *Scaler) Init(attributes []*onnx.AttributeProto) error { if len(attributes) != ScalerExpectedAttributes { - return fmt.Errorf(ops.InvalidAttrCountErrTemplate, s, 2, len(attributes)) + return ops.ErrInvalidAttributeCount(ScalerExpectedAttributes, len(attributes), s) } for _, attr := range attributes { @@ -39,7 +37,7 @@ func (s *Scaler) Init(attributes []*onnx.AttributeProto) error { floats := attr.GetFloats() s.scale = tensor.New(tensor.WithShape(len(floats)), tensor.WithBacking(floats)) default: - return fmt.Errorf(ops.UnknownAttributeErrTemplate, s, attr.GetName()) + return ops.ErrInvalidAttribute(attr.GetName(), s) } } diff --git a/ops/opset13/scaler_test.go b/ops/opset13/scaler_test.go index 1e86462..6dec7a7 100644 --- a/ops/opset13/scaler_test.go +++ b/ops/opset13/scaler_test.go @@ -1,7 +1,6 @@ package opset13 import ( - "fmt" "testing" "github.com/advancedclimatesystems/gonnx/onnx" @@ -23,7 +22,7 @@ func TestScalerInitFailWrongAttribute(t *testing.T) { scaler := &Scaler{} err := scaler.Init([]*onnx.AttributeProto{{Name: "unknownAttribute"}, {Name: "Another"}}) - expected := fmt.Errorf(ops.UnknownAttributeErrTemplate, scaler, "unknownAttribute") + expected := ops.ErrInvalidAttribute("Another", scaler) assert.Equal(t, expected, err) } @@ -31,7 +30,7 @@ func TestScalerInitFailAttrCount(t *testing.T) { scaler := &Scaler{} err := scaler.Init([]*onnx.AttributeProto{}) - expected := fmt.Errorf(ops.InvalidAttrCountErrTemplate, scaler, 2, 0) + expected := ops.ErrInvalidAttributeCount(0, 2, scaler) assert.Equal(t, expected, err) } @@ -109,11 +108,11 @@ func TestInputValidationScaler(t *testing.T) { }, { []tensor.Tensor{}, - fmt.Errorf("scaler operator: expected 1 input tensors, got 0"), + ops.ErrInvalidInputCount(0, &Scaler{}), }, { []tensor.Tensor{ops.TensorWithBackingFixture([]int{1, 2}, 2)}, - fmt.Errorf("scaler operator: input 0 does not allow type int"), + ops.ErrInvalidInputType(0, "int", &Scaler{}), }, } diff --git a/ops/opset13/shape_test.go b/ops/opset13/shape_test.go index 7cc964b..1ab9382 100644 --- a/ops/opset13/shape_test.go +++ b/ops/opset13/shape_test.go @@ -1,7 +1,6 @@ package opset13 import ( - "fmt" "testing" "github.com/advancedclimatesystems/gonnx/ops" @@ -60,11 +59,11 @@ func TestInputValidationShape(t *testing.T) { }, { []tensor.Tensor{}, - fmt.Errorf("shape operator: expected 1 input tensors, got 0"), + ops.ErrInvalidInputCount(0, &Shape{}), }, { []tensor.Tensor{ops.TensorWithBackingFixture([]int{1, 2}, 2)}, - fmt.Errorf("shape operator: input 0 does not allow type int"), + ops.ErrInvalidInputType(0, "int", &Shape{}), }, } diff --git a/ops/opset13/sigmoid_test.go b/ops/opset13/sigmoid_test.go index 3e64bd3..3277a6f 100644 --- a/ops/opset13/sigmoid_test.go +++ b/ops/opset13/sigmoid_test.go @@ -1,7 +1,6 @@ package opset13 import ( - "fmt" "testing" "github.com/advancedclimatesystems/gonnx/ops" @@ -79,11 +78,11 @@ func TestInputValidationSigmoid(t *testing.T) { }, { []tensor.Tensor{}, - fmt.Errorf("sigmoid operator: expected 1 input tensors, got 0"), + ops.ErrInvalidInputCount(0, &Sigmoid{}), }, { []tensor.Tensor{ops.TensorWithBackingFixture([]int{1, 2}, 2)}, - fmt.Errorf("sigmoid operator: input 0 does not allow type int"), + ops.ErrInvalidInputType(0, "int", &Sigmoid{}), }, } diff --git a/ops/opset13/slice_test.go b/ops/opset13/slice_test.go index 7165c96..652608f 100644 --- a/ops/opset13/slice_test.go +++ b/ops/opset13/slice_test.go @@ -1,7 +1,6 @@ package opset13 import ( - "fmt" "testing" "github.com/advancedclimatesystems/gonnx/ops" @@ -200,7 +199,7 @@ func TestInputValidationSlice(t *testing.T) { { []tensor.Tensor{ops.TensorWithBackingFixture([]int{1, 2}, 2)}, nil, - fmt.Errorf("slice operator: expected 3-5 input tensors, got 1"), + ops.ErrInvalidOptionalInputCount(1, &Slice{}), }, { []tensor.Tensor{ @@ -209,7 +208,7 @@ func TestInputValidationSlice(t *testing.T) { ops.TensorWithBackingFixture([]int{3, 4}, 2), }, nil, - fmt.Errorf("slice operator: input 1 does not allow type int"), + ops.ErrInvalidInputType(1, "int", &Slice{}), }, } diff --git a/ops/opset13/squeeze.go b/ops/opset13/squeeze.go index 9ee81cc..b5cbb73 100644 --- a/ops/opset13/squeeze.go +++ b/ops/opset13/squeeze.go @@ -33,7 +33,7 @@ func (s *Squeeze) Apply(inputs []tensor.Tensor) ([]tensor.Tensor, error) { dimsToSqueeze := getDimsToSqueezeFromShape(currentShape) if !ops.AllInRange(dimsToSqueeze, -nDims, nDims-1) { - return nil, ops.ErrNotAllAxisInRange(nDims, nDims) + return nil, ops.ErrNotAllAxesInRange(nDims, nDims) } // negative entries should be offset by the rank of the output tensor diff --git a/ops/opset13/squeeze_test.go b/ops/opset13/squeeze_test.go index 38279af..bb160b5 100644 --- a/ops/opset13/squeeze_test.go +++ b/ops/opset13/squeeze_test.go @@ -1,7 +1,6 @@ package opset13 import ( - "fmt" "testing" "github.com/advancedclimatesystems/gonnx/ops" @@ -149,7 +148,7 @@ func TestInputValidationSqueeze(t *testing.T) { { []tensor.Tensor{}, nil, - fmt.Errorf("squeeze operator: expected 1-2 input tensors, got 0"), + ops.ErrInvalidOptionalInputCount(0, &Squeeze{}), }, { []tensor.Tensor{ @@ -158,7 +157,7 @@ func TestInputValidationSqueeze(t *testing.T) { ops.TensorWithBackingFixture([]int{3, 4}, 2), }, nil, - fmt.Errorf("squeeze operator: expected 1-2 input tensors, got 3"), + ops.ErrInvalidOptionalInputCount(3, &Squeeze{}), }, { []tensor.Tensor{ @@ -166,7 +165,7 @@ func TestInputValidationSqueeze(t *testing.T) { ops.TensorWithBackingFixture([]int{3, 4}, 2), }, nil, - fmt.Errorf("squeeze operator: input 1 does not allow type int"), + ops.ErrInvalidInputType(1, "int", &Squeeze{}), }, } diff --git a/ops/opset13/sub_test.go b/ops/opset13/sub_test.go index d9830af..6812be0 100644 --- a/ops/opset13/sub_test.go +++ b/ops/opset13/sub_test.go @@ -1,7 +1,6 @@ package opset13 import ( - "fmt" "testing" "github.com/advancedclimatesystems/gonnx/ops" @@ -104,14 +103,14 @@ func TestInputValidationSub(t *testing.T) { []tensor.Tensor{ ops.TensorWithBackingFixture([]int{1, 2}, 2), }, - fmt.Errorf("sub operator: expected 2 input tensors, got 1"), + ops.ErrInvalidInputCount(1, &Sub{}), }, { []tensor.Tensor{ ops.TensorWithBackingFixture([]int{1, 2}, 2), ops.TensorWithBackingFixture([]int{3, 4}, 2), }, - fmt.Errorf("sub operator: input 0 does not allow type int"), + ops.ErrInvalidInputType(0, "int", &Sub{}), }, } diff --git a/ops/opset13/tanh_test.go b/ops/opset13/tanh_test.go index 86a7110..44b5409 100644 --- a/ops/opset13/tanh_test.go +++ b/ops/opset13/tanh_test.go @@ -1,7 +1,6 @@ package opset13 import ( - "fmt" "testing" "github.com/advancedclimatesystems/gonnx/ops" @@ -71,11 +70,11 @@ func TestInputValidationTanh(t *testing.T) { }, { []tensor.Tensor{}, - fmt.Errorf("tanh operator: expected 1 input tensors, got 0"), + ops.ErrInvalidInputCount(0, &Tanh{}), }, { []tensor.Tensor{ops.TensorWithBackingFixture([]int{1, 2}, 2)}, - fmt.Errorf("tanh operator: input 0 does not allow type int"), + ops.ErrInvalidInputType(0, "int", &Tanh{}), }, } diff --git a/ops/opset13/transpose.go b/ops/opset13/transpose.go index 4f7f4f2..d1e1780 100644 --- a/ops/opset13/transpose.go +++ b/ops/opset13/transpose.go @@ -1,8 +1,6 @@ package opset13 import ( - "fmt" - "github.com/advancedclimatesystems/gonnx/onnx" "github.com/advancedclimatesystems/gonnx/ops" "gorgonia.org/tensor" @@ -23,13 +21,13 @@ func newTranspose() ops.Operator { // Init initializes the transpose operator. func (t *Transpose) Init(attributes []*onnx.AttributeProto) error { if len(attributes) != 1 { - return fmt.Errorf(ops.InvalidAttrCountErrTemplate, t, 1, len(attributes)) + return ops.ErrInvalidAttributeCount(1, len(attributes), t) } attr := attributes[0] if attr.GetName() != "perm" { - return fmt.Errorf(ops.UnknownAttributeErrTemplate, t, attr.GetName()) + return ops.ErrInvalidAttribute(attr.GetName(), t) } attrPerm := attr.GetInts() diff --git a/ops/opset13/transpose_test.go b/ops/opset13/transpose_test.go index 37d00fe..005b2b6 100644 --- a/ops/opset13/transpose_test.go +++ b/ops/opset13/transpose_test.go @@ -1,7 +1,6 @@ package opset13 import ( - "fmt" "testing" "github.com/advancedclimatesystems/gonnx/onnx" @@ -22,7 +21,7 @@ func TestTransposeInitFailWrongAttribute(t *testing.T) { trans := &Transpose{} err := trans.Init([]*onnx.AttributeProto{{Name: "unknownAttribute"}}) - expected := fmt.Errorf(ops.UnknownAttributeErrTemplate, trans, "unknownAttribute") + expected := ops.ErrInvalidAttribute("unknownAttribute", trans) assert.Equal(t, expected, err) } @@ -30,7 +29,7 @@ func TestTransposeInitFailAttrCount(t *testing.T) { trans := &Transpose{} err := trans.Init([]*onnx.AttributeProto{}) - expected := fmt.Errorf(ops.InvalidAttrCountErrTemplate, trans, 1, 0) + expected := ops.ErrInvalidAttributeCount(1, 0, trans) assert.Equal(t, expected, err) } @@ -85,11 +84,11 @@ func TestInputValidationTranspose(t *testing.T) { }, { []tensor.Tensor{}, - fmt.Errorf("transpose operator: expected 1 input tensors, got 0"), + ops.ErrInvalidInputCount(0, &Transpose{}), }, { []tensor.Tensor{ops.TensorWithBackingFixture([]int{1, 2}, 2)}, - fmt.Errorf("transpose operator: input 0 does not allow type int"), + ops.ErrInvalidInputType(0, "int", &Transpose{}), }, } diff --git a/ops/opset13/unsqueeze.go b/ops/opset13/unsqueeze.go index cd7d0ec..8ba0f4e 100644 --- a/ops/opset13/unsqueeze.go +++ b/ops/opset13/unsqueeze.go @@ -1,7 +1,6 @@ package opset13 import ( - "fmt" "sort" "github.com/advancedclimatesystems/gonnx/onnx" @@ -37,7 +36,7 @@ func (u *Unsqueeze) Apply(inputs []tensor.Tensor) ([]tensor.Tensor, error) { outputRank := len(dataShape) + len(axes) if !ops.AllInRange(axes, -outputRank, outputRank-1) { - return nil, fmt.Errorf(ops.AxesNotAllInRangeErrTemplate, outputRank, outputRank) + return nil, ops.ErrNotAllAxesInRange(outputRank, outputRank) } // negative entries should be offset by the rank of the output tensor @@ -47,14 +46,14 @@ func (u *Unsqueeze) Apply(inputs []tensor.Tensor) ([]tensor.Tensor, error) { sort.Ints(axes) if ops.HasDuplicates(axes) { - return nil, fmt.Errorf("Axes cannot have duplicate entries after offset, axes: %v", axes) + return nil, ops.ErrInvalidInput("axes cannot have duplicate entries after offset", u) } newShape := insertOnes(dataShape, axes) out, ok := inputs[0].Clone().(tensor.Tensor) if !ok { - return nil, fmt.Errorf("could not copy the input tensor") + return nil, ops.ErrTypeAssert("tensor.Tensor", inputs[0].Clone()) } err = out.Reshape(newShape...) diff --git a/ops/opset13/unsqueeze_test.go b/ops/opset13/unsqueeze_test.go index c4140c0..215007f 100644 --- a/ops/opset13/unsqueeze_test.go +++ b/ops/opset13/unsqueeze_test.go @@ -1,7 +1,6 @@ package opset13 import ( - "fmt" "testing" "github.com/advancedclimatesystems/gonnx/ops" @@ -29,7 +28,7 @@ func TestAxesOutRangeError(t *testing.T) { dataIn := ops.TensorWithBackingFixture(data, 3, 3) axesIn := ops.TensorWithBackingFixture(axes, len(axes)) _, err = op.Apply([]tensor.Tensor{dataIn, axesIn}) - expected := fmt.Errorf(ops.AxesNotAllInRangeErrTemplate, 3, 3) + expected := ops.ErrNotAllAxesInRange(3, 3) assert.Equal(t, err, expected) } @@ -149,21 +148,21 @@ func TestInputValidationUnsqueeze(t *testing.T) { }, { []tensor.Tensor{ops.TensorWithBackingFixture([]int{1, 2}, 2)}, - fmt.Errorf("unsqueeze operator: expected 2 input tensors, got 1"), + ops.ErrInvalidInputCount(2, &Unsqueeze{}), }, { []tensor.Tensor{ ops.TensorWithBackingFixture([]int{1, 2}, 2), ops.TensorWithBackingFixture([]int64{3, 4}, 2), }, - fmt.Errorf("unsqueeze operator: input 0 does not allow type int"), + ops.ErrInvalidInputType(0, "int", &Unsqueeze{}), }, { []tensor.Tensor{ ops.TensorWithBackingFixture([]float32{1, 2}, 2), ops.TensorWithBackingFixture([]int32{3, 4}, 2), }, - fmt.Errorf("unsqueeze operator: input 1 does not allow type int32"), + ops.ErrInvalidInputType(1, "int32", &Unsqueeze{}), }, } diff --git a/opset.go b/opset.go index 2ee297f..2b83900 100644 --- a/opset.go +++ b/opset.go @@ -1,14 +1,10 @@ package gonnx import ( - "errors" - "github.com/advancedclimatesystems/gonnx/ops" "github.com/advancedclimatesystems/gonnx/ops/opset13" ) -var ErrInvalidOperator = errors.New("invalid operator getter") - // OpGetter is a function that gets an operator based on a string. type OpGetter func(string) (ops.Operator, error) @@ -22,5 +18,5 @@ func ResolveOperatorGetter(opsetID int64) (OpGetter, error) { return getOperator, nil } - return nil, ErrInvalidOperator + return nil, ops.ErrUnsupportedOpsetVersion } diff --git a/opset_test.go b/opset_test.go index e6bb90b..bd986c9 100644 --- a/opset_test.go +++ b/opset_test.go @@ -1,14 +1,14 @@ package gonnx import ( - "fmt" "testing" + "github.com/advancedclimatesystems/gonnx/ops" "github.com/stretchr/testify/assert" ) func TestResolveOperatorGetterFail(t *testing.T) { opGetter, err := ResolveOperatorGetter(12) assert.Nil(t, opGetter) - assert.Equal(t, fmt.Errorf("expected opset to be in [13], got operator set 12"), err) + assert.Equal(t, ops.ErrUnsupportedOpsetVersion, err) } From 012a1aea94212fff578bf5c9fa7760ba39703015 Mon Sep 17 00:00:00 2001 From: Swopper050 Date: Mon, 6 Nov 2023 21:36:26 +0100 Subject: [PATCH 11/15] Fixed part of tests --- ops/opset13/add.go | 4 ++-- ops/opset13/cast_test.go | 2 +- ops/opset13/constant.go | 4 ++-- ops/opset13/constant_of_shape.go | 2 +- ops/opset13/constant_of_shape_test.go | 8 ++++---- ops/opset13/gather_test.go | 8 ++++---- ops/opset13/gemm_test.go | 4 ++-- ops/opset13/gru_test.go | 4 ++-- ops/utils.go | 2 +- ops/validate_inputs_test.go | 2 +- 10 files changed, 20 insertions(+), 20 deletions(-) diff --git a/ops/opset13/add.go b/ops/opset13/add.go index 32a5ded..aa67cf5 100644 --- a/ops/opset13/add.go +++ b/ops/opset13/add.go @@ -8,10 +8,10 @@ import ( const ( // MinAddInput is the minimimum amount of inputs the add operator expects. - MinAddInput = 1 + MinAddInput = 2 // MaxAddInput is the maximum amount of inputs the add operator accepts. - MaxAddInput = 1 + MaxAddInput = 2 ) // Add represents the ONNX add operator. diff --git a/ops/opset13/cast_test.go b/ops/opset13/cast_test.go index 5ab663b..8c32b14 100644 --- a/ops/opset13/cast_test.go +++ b/ops/opset13/cast_test.go @@ -98,7 +98,7 @@ func TestInputValidationCast(t *testing.T) { []tensor.Tensor{ ops.TensorWithBackingFixture([]bool{true, false}, 2), }, - ops.ErrInvalidInputType(1, "bool", &Cast{}), + ops.ErrInvalidInputType(0, "bool", &Cast{}), }, } diff --git a/ops/opset13/constant.go b/ops/opset13/constant.go index 43f2988..51f8aad 100644 --- a/ops/opset13/constant.go +++ b/ops/opset13/constant.go @@ -35,7 +35,7 @@ func (c *Constant) Init(attributes []*onnx.AttributeProto) error { switch attr.GetName() { case "sparse_value", "value_string", "value_strings": - return ops.ErrInvalidAttribute(attr.GetName(), c) + return ops.ErrUnsupportedAttribute(attr.GetName(), c) case "value": t, err := onnx.TensorFromProto(attr.GetT()) if err != nil { @@ -54,7 +54,7 @@ func (c *Constant) Init(attributes []*onnx.AttributeProto) error { ints := attr.GetInts() c.value = tensor.New(tensor.WithShape(len(ints)), tensor.WithBacking(ints)) default: - return ops.ErrInvalidAttribute(attr.GetName(), c) + return ops.ErrUnsupportedAttribute(attr.GetName(), c) } return nil diff --git a/ops/opset13/constant_of_shape.go b/ops/opset13/constant_of_shape.go index e861cac..d7dbc23 100644 --- a/ops/opset13/constant_of_shape.go +++ b/ops/opset13/constant_of_shape.go @@ -64,7 +64,7 @@ func (c *ConstantOfShape) Apply(inputs []tensor.Tensor) ([]tensor.Tensor, error) // Empty dimensions in a tensor are not supported for i := range shape { if shape[i] <= 0 { - return nil, ops.ErrInvalidTensor("no empty dimensions are allowed", c) + return nil, ops.ErrInvalidTensor("empty dimensions are not allowed", c) } } diff --git a/ops/opset13/constant_of_shape_test.go b/ops/opset13/constant_of_shape_test.go index 17e6053..f41030d 100644 --- a/ops/opset13/constant_of_shape_test.go +++ b/ops/opset13/constant_of_shape_test.go @@ -137,7 +137,7 @@ func TestIncorrectInput(t *testing.T) { assert.NotNil(t, err) assert.Equal( t, - "Value input tensor should be a single element tensor, but was [1 2 3]", + "constant of shape operator invalid tensor found, reason: expected tensor to have one element", err.Error(), ) } @@ -154,7 +154,7 @@ func TestNegativeShapeNotAllowed(t *testing.T) { assert.Equal( t, - "Non positive dimensions are not allowed (must be > 0). Given: [1 -1]", + "constant of shape operator invalid tensor found, reason: empty dimensions are not allowed", err.Error()) } @@ -170,7 +170,7 @@ func TestEmptyTensorNotAllowed(t *testing.T) { assert.Equal( t, - "Non positive dimensions are not allowed (must be > 0). Given: [0]", + "constant of shape operator invalid tensor found, reason: empty dimensions are not allowed", err.Error()) } @@ -200,7 +200,7 @@ func TestInputValidationConstantOfShape(t *testing.T) { }, { []tensor.Tensor{}, - ops.ErrInvalidInputCount(1, &ConstantOfShape{}), + ops.ErrInvalidInputCount(0, &ConstantOfShape{}), }, { []tensor.Tensor{ops.TensorWithBackingFixture([]int{1, 2}, 2)}, diff --git a/ops/opset13/gather_test.go b/ops/opset13/gather_test.go index a9cd9ae..4ab8fca 100644 --- a/ops/opset13/gather_test.go +++ b/ops/opset13/gather_test.go @@ -31,13 +31,13 @@ func TestGatherInitDefault(t *testing.T) { func TestGatherInitTooManyAttrs(t *testing.T) { op := Gather{} err := op.Init([]*onnx.AttributeProto{{Name: "axis"}, {Name: "default"}}) - assert.EqualError(t, err, "gather operator: expected 0 or 1 attributes, got 2") + assert.EqualError(t, err, "operator gather operator attribute error: invalid count 2 expected 1") } func TestGatherInitInvalidAttrName(t *testing.T) { op := Gather{} err := op.Init([]*onnx.AttributeProto{{Name: "axes"}}) // should be axis - assert.EqualError(t, err, "gather operator: unknown attribute: axes") + assert.EqualError(t, err, "operator gather operator attribute error: invalid attribute axes") } func TestGather(t *testing.T) { @@ -241,7 +241,7 @@ func TestGatherAxesIndexOutOfRange(t *testing.T) { _, err = op.Apply([]tensor.Tensor{dataIn, indicesIn}) assert.Error(t, err) - assert.EqualError(t, err, "axis argument must be in the range -1 <= x < 1, was 1") + assert.EqualError(t, err, "axis out of range: axis argument must be in the range -1 <= x < 1, was 1") } func TestGatherIndexOutOfRange(t *testing.T) { @@ -252,7 +252,7 @@ func TestGatherIndexOutOfRange(t *testing.T) { _, err := op.Apply([]tensor.Tensor{dataIn, indicesIn}) assert.Error(t, err) - assert.EqualError(t, err, "all indices entries must be in the range -1 <= x < 1") + assert.EqualError(t, err, "axis out of range: all indices entries must be in the range -1 <= x < 1") } func TestInputValidationGather(t *testing.T) { diff --git a/ops/opset13/gemm_test.go b/ops/opset13/gemm_test.go index d882b16..e0067cd 100644 --- a/ops/opset13/gemm_test.go +++ b/ops/opset13/gemm_test.go @@ -134,7 +134,7 @@ func TestInputValidationGemm(t *testing.T) { { []tensor.Tensor{ops.TensorWithBackingFixture([]int{1, 2}, 2)}, nil, - ops.ErrInvalidInputCount(1, &Gemm{}), + ops.ErrInvalidOptionalInputCount(1, &Gemm{}), }, { []tensor.Tensor{ @@ -144,7 +144,7 @@ func TestInputValidationGemm(t *testing.T) { ops.TensorWithBackingFixture([]uint32{1, 2}, 2), }, nil, - ops.ErrInvalidInputCount(4, &Gemm{}), + ops.ErrInvalidOptionalInputCount(4, &Gemm{}), }, { []tensor.Tensor{ diff --git a/ops/opset13/gru_test.go b/ops/opset13/gru_test.go index cc278d3..39d419f 100644 --- a/ops/opset13/gru_test.go +++ b/ops/opset13/gru_test.go @@ -137,7 +137,7 @@ func TestInputValidationGRU(t *testing.T) { { []tensor.Tensor{ops.TensorWithBackingFixture([]float32{1, 2}, 2)}, nil, - ops.ErrInvalidInputCount(1, &GRU{}), + ops.ErrInvalidOptionalInputCount(1, &GRU{}), }, { []tensor.Tensor{ @@ -206,7 +206,7 @@ func TestInputValidationGRU(t *testing.T) { ops.TensorWithBackingFixture([]int{1, 2}, 2), }, nil, - ops.ErrInvalidInputType(4, "int", &GRU{}), + ops.ErrInvalidInputType(5, "int", &GRU{}), }, } diff --git a/ops/utils.go b/ops/utils.go index 2a81881..63296c6 100644 --- a/ops/utils.go +++ b/ops/utils.go @@ -222,7 +222,7 @@ func NElements(shp ...int) int { // PairwiseAssign essentially does pairwise t1 = t2 in place!. func PairwiseAssign(t1, t2 tensor.Tensor) (err error) { if !t1.Shape().Eq(t2.Shape()) { - return + return ErrInvalidShape } it := t1.Iterator() diff --git a/ops/validate_inputs_test.go b/ops/validate_inputs_test.go index deb1409..40bb8fa 100644 --- a/ops/validate_inputs_test.go +++ b/ops/validate_inputs_test.go @@ -75,7 +75,7 @@ func TestValidateInputs(t *testing.T) { }, PaddedInputsFixture(1, 0), 0, - ErrInvalidInputCount(2, &MockOp{}), + ErrInvalidInputCount(1, &MockOp{}), }, { &MockOp{ From 88e58276c601b6e19cc1c44914c317353afa0bff Mon Sep 17 00:00:00 2001 From: wisse Date: Tue, 7 Nov 2023 08:03:17 +0100 Subject: [PATCH 12/15] Fix validate input tests. --- Makefile | 2 +- ops/validate_inputs_test.go | 8 +++++--- 2 files changed, 6 insertions(+), 4 deletions(-) diff --git a/Makefile b/Makefile index 213b144..1d05943 100644 --- a/Makefile +++ b/Makefile @@ -5,7 +5,7 @@ LDFLAGS=-ldflags "-s -w -X main.Version=${VERSION}" TEST=$(shell go list ./... | grep -v /onnx/) BUILD_PARAMS=CGO_ENABLED=0 -GO1.19=ASSUME_NO_MOVING_GC_UNSAFE_RISK_IT_WITH=go1.19 +GO1.19=ASSUME_NO_MOVING_GC_UNSAFE_RISK_IT_WITH=go1.21 define echotask diff --git a/ops/validate_inputs_test.go b/ops/validate_inputs_test.go index 40bb8fa..ea91fbc 100644 --- a/ops/validate_inputs_test.go +++ b/ops/validate_inputs_test.go @@ -75,7 +75,7 @@ func TestValidateInputs(t *testing.T) { }, PaddedInputsFixture(1, 0), 0, - ErrInvalidInputCount(1, &MockOp{}), + ErrInvalidInputCount(1, &MockOp{minInputs: 2, maxInputs: 2}), }, { &MockOp{ @@ -91,7 +91,7 @@ func TestValidateInputs(t *testing.T) { }, PaddedInputsFixture(7, 0), 0, - ErrInvalidOptionalInputCount(7, &MockOp{}), + ErrInvalidOptionalInputCount(7, &MockOp{minInputs: 3, maxInputs: 5}), }, { &MockOp{ @@ -107,9 +107,11 @@ func TestValidateInputs(t *testing.T) { for _, test := range tests { inputs, err := ValidateInputs(test.op, test.inputs) + if test.err != nil { + assert.EqualError(t, err, test.err.Error()) + } expectedLength := len(test.inputs) + test.expectedNil - assert.Equal(t, test.err, err) assert.Equal(t, expectedLength, len(inputs)) // Check if the added nodes are all nil. From ed49008d7d0c163aed0a51f816ee6240b4f3fe73 Mon Sep 17 00:00:00 2001 From: wisse Date: Tue, 7 Nov 2023 08:11:05 +0100 Subject: [PATCH 13/15] Worked on fixing tests --- ops/opset13/matmul_test.go | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/ops/opset13/matmul_test.go b/ops/opset13/matmul_test.go index 71a7ac0..fa8dcc2 100644 --- a/ops/opset13/matmul_test.go +++ b/ops/opset13/matmul_test.go @@ -83,7 +83,7 @@ func TestMatMul(t *testing.T) { }, } - for _, test := range tests { + for i, test := range tests { matmul := &MatMul{} inputs := []tensor.Tensor{ ops.TensorWithBackingFixture(test.backings[0], test.shapes[0]...), @@ -92,7 +92,7 @@ func TestMatMul(t *testing.T) { res, err := matmul.Apply(inputs) assert.Nil(t, err) - assert.Equal(t, test.expected, res[0].Data()) + assert.Equal(t, test.expected, res[0].Data(), "test number %d", i) assert.Equal(t, test.expectedShape, res[0].Shape()) } } From f9a2ac4e2e06ffbc10a7b6ff1c1dbf6717ab59f0 Mon Sep 17 00:00:00 2001 From: Swopper050 Date: Tue, 7 Nov 2023 08:43:50 +0100 Subject: [PATCH 14/15] Fixed rest of tests --- model_test.go | 4 ++-- ops/opset13/cast.go | 4 ++-- ops/opset13/concat.go | 2 +- ops/opset13/constant.go | 4 ++-- ops/opset13/constant_of_shape.go | 4 ++-- ops/opset13/div.go | 4 ++-- ops/opset13/gather.go | 4 ++-- ops/opset13/gemm.go | 4 ++-- ops/opset13/gru.go | 4 ++-- ops/opset13/matmul.go | 10 ++++++---- ops/opset13/scaler_test.go | 4 ++-- ops/opset13/squeeze.go | 2 +- ops/opset13/unsqueeze_test.go | 6 +++--- 13 files changed, 29 insertions(+), 27 deletions(-) diff --git a/model_test.go b/model_test.go index 1f86c41..ed1c4b9 100644 --- a/model_test.go +++ b/model_test.go @@ -37,7 +37,7 @@ func TestModel(t *testing.T) { [][]float32{rangeFloat(16)}, ), nil, - ErrModel("input %v only has %d dimensions, but index %d was required", "data_input", 3, 0), + ErrInvalidShape([]onnx.Dim{{IsDynamic: true, Name: "batch_size", Size: 0}, {IsDynamic: false, Name: "", Size: 3}}, []int{2, 4, 2}), }, { "./sample_models/onnx_models/mlp.onnx", @@ -47,7 +47,7 @@ func TestModel(t *testing.T) { [][]float32{rangeFloat(6)}, ), nil, - ErrModel("input %v does not exist", "tensor_data"), + ErrModel("tensor: %v not found", "data_input"), }, { "./sample_models/onnx_models/gru.onnx", diff --git a/ops/opset13/cast.go b/ops/opset13/cast.go index d465fe7..e5f1e5d 100644 --- a/ops/opset13/cast.go +++ b/ops/opset13/cast.go @@ -7,10 +7,10 @@ import ( ) const ( - // MinCastInput is the minimimum amount of inputs the add operator expects. + // MinCastInput is the minimimum amount of inputs the cast operator expects. MinCastInput = 1 - // MaxCastInput is the maximum amount of inputs the add operator accepts. + // MaxCastInput is the maximum amount of inputs the cast operator accepts. MaxCastInput = 1 ) diff --git a/ops/opset13/concat.go b/ops/opset13/concat.go index 1327885..3d347eb 100644 --- a/ops/opset13/concat.go +++ b/ops/opset13/concat.go @@ -7,7 +7,7 @@ import ( ) const ( - // MinConcatInput is the minimimum amount of inputs the add operator expects. + // MinConcatInput is the minimimum amount of inputs the concat operator expects. MinConcatInput = 1 ) diff --git a/ops/opset13/constant.go b/ops/opset13/constant.go index 51f8aad..b667447 100644 --- a/ops/opset13/constant.go +++ b/ops/opset13/constant.go @@ -7,10 +7,10 @@ import ( ) const ( - // MinConstInput is the minimimum amount of inputs the add operator expects. + // MinConstInput is the minimimum amount of inputs the constant operator expects. MinConstInput = 1 - // MaxConstInput is the maximum amount of inputs the add operator accepts. + // MaxConstInput is the maximum amount of inputs the constant operator accepts. MaxConstInput = 1 ) diff --git a/ops/opset13/constant_of_shape.go b/ops/opset13/constant_of_shape.go index d7dbc23..2e40b73 100644 --- a/ops/opset13/constant_of_shape.go +++ b/ops/opset13/constant_of_shape.go @@ -7,10 +7,10 @@ import ( ) const ( - // MinConstanShapeOfInput is the minimimum amount of inputs the add operator expects. + // MinConstanShapeOfInput is the minimimum amount of inputs the constant of shape operator expects. MinConstanShapeOfInput = 1 - // MaxConstanShapeOfInput is the maximum amount of inputs the add operator accepts. + // MaxConstanShapeOfInput is the maximum amount of inputs the constant of shape operator accepts. MaxConstanShapeOfInput = 1 ) diff --git a/ops/opset13/div.go b/ops/opset13/div.go index 530aa89..72fa9a4 100644 --- a/ops/opset13/div.go +++ b/ops/opset13/div.go @@ -7,10 +7,10 @@ import ( ) const ( - // MinDivInput is the minimimum amount of inputs the add operator expects. + // MinDivInput is the minimimum amount of inputs the div operator expects. MinDivInput = 2 - // MaxDivInput is the maximum amount of inputs the add operator accepts. + // MaxDivInput is the maximum amount of inputs the div operator accepts. MaxDivInput = 2 ) diff --git a/ops/opset13/gather.go b/ops/opset13/gather.go index 07d9c45..59abbe7 100644 --- a/ops/opset13/gather.go +++ b/ops/opset13/gather.go @@ -7,10 +7,10 @@ import ( ) const ( - // MinGatherInput is the minimimum amount of inputs the add operator expects. + // MinGatherInput is the minimimum amount of inputs the gather operator expects. MinGatherInput = 2 - // MaxGatherInput is the maximum amount of inputs the add operator accepts. + // MaxGatherInput is the maximum amount of inputs the gather operator accepts. MaxGatherInput = 2 ) diff --git a/ops/opset13/gemm.go b/ops/opset13/gemm.go index c0cbec9..0fc05f9 100644 --- a/ops/opset13/gemm.go +++ b/ops/opset13/gemm.go @@ -7,10 +7,10 @@ import ( ) const ( - // MinGemmInput is the minimimum amount of inputs the add operator expects. + // MinGemmInput is the minimimum amount of inputs the gemm operator expects. MinGemmInput = 2 - // MaxGemmInput is the maximum amount of inputs the add operator accepts. + // MaxGemmInput is the maximum amount of inputs the gemm operator accepts. MaxGemmInput = 3 ) diff --git a/ops/opset13/gru.go b/ops/opset13/gru.go index 1cb1258..8f7777e 100644 --- a/ops/opset13/gru.go +++ b/ops/opset13/gru.go @@ -7,10 +7,10 @@ import ( ) const ( - // MinGRUInput is the minimimum amount of inputs the add operator expects. + // MinGRUInput is the minimimum amount of inputs the gru operator expects. MinGRUInput = 3 - // MaxGRUInput is the maximum amount of inputs the add operator accepts. + // MaxGRUInput is the maximum amount of inputs the gru operator accepts. MaxGRUInput = 6 ) diff --git a/ops/opset13/matmul.go b/ops/opset13/matmul.go index fc6365e..c49c505 100644 --- a/ops/opset13/matmul.go +++ b/ops/opset13/matmul.go @@ -7,10 +7,10 @@ import ( ) const ( - // MinMatMulInput is the minimimum amount of inputs the add operator expects. + // MinMatMulInput is the minimimum amount of inputs the matmul operator expects. MinMatMulInput = 2 - // MaxMatMulInput is the maximum amount of inputs the add operator accepts. + // MaxMatMulInput is the maximum amount of inputs the matmul operator accepts. MaxMatMulInput = 2 ) @@ -29,6 +29,8 @@ func (m *MatMul) Init(_ []*onnx.AttributeProto) error { // Apply applies the matmul operator. func (m *MatMul) Apply(inputs []tensor.Tensor) ([]tensor.Tensor, error) { + var ok bool + A := inputs[0] B := inputs[1] @@ -47,7 +49,7 @@ func (m *MatMul) Apply(inputs []tensor.Tensor) ([]tensor.Tensor, error) { if len(A.Shape()) == 1 { prependedDimension = true - A, ok := A.Clone().(tensor.Tensor) + A, ok = A.Clone().(tensor.Tensor) if !ok { return nil, ops.ErrTypeAssert("tensor.Tensor", A.Clone()) } @@ -62,7 +64,7 @@ func (m *MatMul) Apply(inputs []tensor.Tensor) ([]tensor.Tensor, error) { if len(B.Shape()) == 1 { appendedDimension = true - B, ok := B.Clone().(tensor.Tensor) + B, ok = B.Clone().(tensor.Tensor) if !ok { return nil, ops.ErrTypeAssert("tensor.Tensor", B.Clone()) } diff --git a/ops/opset13/scaler_test.go b/ops/opset13/scaler_test.go index 6dec7a7..b49d5aa 100644 --- a/ops/opset13/scaler_test.go +++ b/ops/opset13/scaler_test.go @@ -22,7 +22,7 @@ func TestScalerInitFailWrongAttribute(t *testing.T) { scaler := &Scaler{} err := scaler.Init([]*onnx.AttributeProto{{Name: "unknownAttribute"}, {Name: "Another"}}) - expected := ops.ErrInvalidAttribute("Another", scaler) + expected := ops.ErrInvalidAttribute("unknownAttribute", scaler) assert.Equal(t, expected, err) } @@ -30,7 +30,7 @@ func TestScalerInitFailAttrCount(t *testing.T) { scaler := &Scaler{} err := scaler.Init([]*onnx.AttributeProto{}) - expected := ops.ErrInvalidAttributeCount(0, 2, scaler) + expected := ops.ErrInvalidAttributeCount(2, 0, scaler) assert.Equal(t, expected, err) } diff --git a/ops/opset13/squeeze.go b/ops/opset13/squeeze.go index b5cbb73..88ac56f 100644 --- a/ops/opset13/squeeze.go +++ b/ops/opset13/squeeze.go @@ -7,7 +7,7 @@ import ( ) const ( - SqueezeMinInput = 2 + SqueezeMinInput = 1 SqueezeMaxInput = 2 ) diff --git a/ops/opset13/unsqueeze_test.go b/ops/opset13/unsqueeze_test.go index 215007f..445d0c5 100644 --- a/ops/opset13/unsqueeze_test.go +++ b/ops/opset13/unsqueeze_test.go @@ -44,7 +44,7 @@ func TestDuplicateEntriesAfterOffsetNotAllowed(t *testing.T) { dataIn := ops.TensorWithBackingFixture(data, 3, 3) axesIn := ops.TensorWithBackingFixture(axes, len(axes)) _, err = op.Apply([]tensor.Tensor{dataIn, axesIn}) - assert.EqualError(t, err, "Axes cannot have duplicate entries after offset, axes: [3 3]") + assert.EqualError(t, err, "invalid input tensor for unsqueeze operator: axes cannot have duplicate entries after offset") } func TestDuplicateEntriesNotAllowed(t *testing.T) { @@ -58,7 +58,7 @@ func TestDuplicateEntriesNotAllowed(t *testing.T) { dataIn := ops.TensorWithBackingFixture(data, 3, 3) axesIn := ops.TensorWithBackingFixture(axes, len(axes)) _, err = op.Apply([]tensor.Tensor{dataIn, axesIn}) - assert.EqualError(t, err, "Axes cannot have duplicate entries after offset, axes: [0 0]") + assert.EqualError(t, err, "invalid input tensor for unsqueeze operator: axes cannot have duplicate entries after offset") } func TestUnsqueeze(t *testing.T) { @@ -148,7 +148,7 @@ func TestInputValidationUnsqueeze(t *testing.T) { }, { []tensor.Tensor{ops.TensorWithBackingFixture([]int{1, 2}, 2)}, - ops.ErrInvalidInputCount(2, &Unsqueeze{}), + ops.ErrInvalidInputCount(1, &Unsqueeze{}), }, { []tensor.Tensor{ From ce1ef8ab46d07183f98b26be4f5861d709a2222a Mon Sep 17 00:00:00 2001 From: Swopper050 Date: Wed, 8 Nov 2023 14:40:10 +0100 Subject: [PATCH 15/15] Resolve all MR comments --- Makefile | 2 +- ops/convert.go | 13 ------------- ops/errors.go | 25 ++++++++++++++++++++----- ops/multidir_broadcast.go | 6 +++--- ops/opset13/abs.go | 11 ++++------- ops/opset13/add.go | 11 ++++------- ops/opset13/cast.go | 11 ++++------- ops/opset13/concat.go | 5 ++--- ops/opset13/constant.go | 8 -------- ops/opset13/constant_of_shape.go | 11 ++++------- ops/opset13/div.go | 11 ++++------- ops/opset13/div_test.go | 2 +- ops/opset13/gather.go | 11 ++++------- ops/opset13/gather_test.go | 4 ++-- ops/opset13/gemm.go | 11 ++++------- ops/opset13/gru.go | 11 ++++------- ops/opset13/matmul.go | 25 ++++++++++++------------- ops/opset13/mul.go | 11 ++++------- ops/opset13/prelu.go | 4 ++-- ops/opset13/scaler.go | 7 ++++--- ops/opset13/shape.go | 9 ++++++--- ops/opset13/slice.go | 11 ++++------- ops/opset13/squeeze.go | 10 +++++----- ops/opset13/sub.go | 11 ++++------- ops/opset13/transpose.go | 9 ++++++--- ops/opset13/unsqueeze.go | 10 ++++++---- ops/utils.go | 7 ------- 27 files changed, 114 insertions(+), 153 deletions(-) diff --git a/Makefile b/Makefile index 1d05943..213b144 100644 --- a/Makefile +++ b/Makefile @@ -5,7 +5,7 @@ LDFLAGS=-ldflags "-s -w -X main.Version=${VERSION}" TEST=$(shell go list ./... | grep -v /onnx/) BUILD_PARAMS=CGO_ENABLED=0 -GO1.19=ASSUME_NO_MOVING_GC_UNSAFE_RISK_IT_WITH=go1.21 +GO1.19=ASSUME_NO_MOVING_GC_UNSAFE_RISK_IT_WITH=go1.19 define echotask diff --git a/ops/convert.go b/ops/convert.go index 591d445..0637f49 100644 --- a/ops/convert.go +++ b/ops/convert.go @@ -1,23 +1,10 @@ package ops import ( - "errors" - "fmt" - "github.com/advancedclimatesystems/gonnx/onnx" "gorgonia.org/tensor" ) -var ErrConversion = errors.New("unable to convert") - -func ErrConversionInvalidType(dType tensor.Dtype, newType int32) error { - return fmt.Errorf("%w: type %v, to %v is invalid", ErrConversion, dType, newType) -} - -func ErrConversionNotSupported(dType int32) error { - return fmt.Errorf("%w: to %v is not supported yet", ErrConversion, dType) -} - // Number is a type which represents a number. type Number interface { float32 | float64 | int8 | int16 | int32 | int64 | uint8 | uint16 | uint32 | uint64 diff --git a/ops/errors.go b/ops/errors.go index d2d069c..ec5a9f3 100644 --- a/ops/errors.go +++ b/ops/errors.go @@ -27,13 +27,13 @@ type AttributeError struct { func (t *AttributeError) Error() string { switch t.kind { case AttributeErrorCount: - return fmt.Sprintf("operator %s attribute error: invalid count %d expected %d", t.operator.String(), t.attributeCount, t.expectedCount) + return fmt.Sprintf("%s attribute error: invalid count %d expected %d", t.operator.String(), t.attributeCount, t.expectedCount) case AttributeErrorInvalid: - return fmt.Sprintf("operator %s attribute error: invalid attribute %s", t.operator.String(), t.attributeName) + return fmt.Sprintf("%s attribute error: invalid attribute %s", t.operator.String(), t.attributeName) case AttributeErrorUnsupported: - return fmt.Sprintf("operator %s attribute error: unsupported attribute %s", t.operator.String(), t.attributeName) + return fmt.Sprintf("%s attribute error: unsupported attribute %s", t.operator.String(), t.attributeName) default: - return fmt.Sprintf("operator %s unknown error attribute error kind %s", t.operator.String(), t.kind) + return fmt.Sprintf("%s unknown error attribute error kind %s", t.operator.String(), t.kind) } } @@ -119,7 +119,7 @@ func (i *InputError) Error() string { case InputErrorInvalid: return fmt.Sprintf(InvalidInputErrTemplate, i.operator, i.reason) default: - return fmt.Sprintf("operator %s unknown error input error kind %s", i.operator.String(), i.kind) + return fmt.Sprintf("%s unknown error input error kind %s", i.operator.String(), i.kind) } } @@ -251,3 +251,18 @@ func ErrIncompatibleDimensions() error { func ErrDimension(reason string) error { return &DimensionError{reason: reason} } + +var ( + ErrCast = errors.New("cast error") + ErrInvalidShape = errors.New("invalid shape error") +) + +var ErrConversion = errors.New("unable to convert") + +func ErrConversionInvalidType(dType tensor.Dtype, newType int32) error { + return fmt.Errorf("%w: type %v, to %v is invalid", ErrConversion, dType, newType) +} + +func ErrConversionNotSupported(dType int32) error { + return fmt.Errorf("%w: to %v is not supported yet", ErrConversion, dType) +} diff --git a/ops/multidir_broadcast.go b/ops/multidir_broadcast.go index 3b016dd..5565bd9 100644 --- a/ops/multidir_broadcast.go +++ b/ops/multidir_broadcast.go @@ -95,10 +95,10 @@ func repeatTensorsForMutltidirBroadcast(A, B tensor.Tensor) (tensor.Tensor, tens // All extra dimensions are given size one (otherwise the tensor cannot be reshaped). // The given tensor is cloned such that the tensor is not modified in place. // Example: if we add 2 extra dimensions to shape (2, 3) we get shape (1, 1, 2, 3). -func addExtraDimsToTensor(t tensor.Tensor, nExtraDims int) (tensor.Tensor, error) { - t, ok := t.Clone().(tensor.Tensor) +func addExtraDimsToTensor(originalT tensor.Tensor, nExtraDims int) (tensor.Tensor, error) { + t, ok := originalT.Clone().(tensor.Tensor) if !ok { - return nil, ErrTypeAssert("tensor.Tensor", t.Clone()) + return nil, ErrTypeAssert("tensor.Tensor", originalT.Clone()) } newShape := []int{} diff --git a/ops/opset13/abs.go b/ops/opset13/abs.go index 15f8afc..6a0572b 100644 --- a/ops/opset13/abs.go +++ b/ops/opset13/abs.go @@ -7,11 +7,8 @@ import ( ) const ( - // MinAbsInput is the minimimum amount of inputs the abs operator expects. - MinAbsInput = 1 - - // MaxAbsInput is the maximum amount of inputs the abs operator accepts. - MaxAbsInput = 1 + MinAbsInputs = 1 + MaxAbsInputs = 1 ) // Abs represents the ONNX abs operator. @@ -44,12 +41,12 @@ func (a *Abs) ValidateInputs(inputs []tensor.Tensor) ([]tensor.Tensor, error) { // GetMinInputs returns the minimum number of input tensors this operator expects. func (a *Abs) GetMinInputs() int { - return MinAbsInput + return MinAbsInputs } // GetMaxInputs returns the maximum number of input tensors this operator expects. func (a *Abs) GetMaxInputs() int { - return MaxAbsInput + return MaxAbsInputs } // GetInputTypeConstraints returns a list. Every element represents a set of allowed tensor dtypes diff --git a/ops/opset13/add.go b/ops/opset13/add.go index aa67cf5..9163434 100644 --- a/ops/opset13/add.go +++ b/ops/opset13/add.go @@ -7,11 +7,8 @@ import ( ) const ( - // MinAddInput is the minimimum amount of inputs the add operator expects. - MinAddInput = 2 - - // MaxAddInput is the maximum amount of inputs the add operator accepts. - MaxAddInput = 2 + MinAddInputs = 2 + MaxAddInputs = 2 ) // Add represents the ONNX add operator. @@ -49,12 +46,12 @@ func (a *Add) ValidateInputs(inputs []tensor.Tensor) ([]tensor.Tensor, error) { // GetMinInputs returns the minimum number of input tensors this operator expects. func (a *Add) GetMinInputs() int { - return MinAddInput + return MinAddInputs } // GetMaxInputs returns the maximum number of input tensors this operator expects. func (a *Add) GetMaxInputs() int { - return MaxAddInput + return MaxAddInputs } // GetInputTypeConstraints returns a list. Every element represents a set of allowed tensor dtypes diff --git a/ops/opset13/cast.go b/ops/opset13/cast.go index e5f1e5d..ac79c6e 100644 --- a/ops/opset13/cast.go +++ b/ops/opset13/cast.go @@ -7,11 +7,8 @@ import ( ) const ( - // MinCastInput is the minimimum amount of inputs the cast operator expects. - MinCastInput = 1 - - // MaxCastInput is the maximum amount of inputs the cast operator accepts. - MaxCastInput = 1 + MinCastInputs = 1 + MaxCastInputs = 1 ) // Cast represents the ONNX cast operator. @@ -57,12 +54,12 @@ func (c *Cast) ValidateInputs(inputs []tensor.Tensor) ([]tensor.Tensor, error) { // GetMinInputs returns the minimum number of input tensors this operator expects. func (c *Cast) GetMinInputs() int { - return MinCastInput + return MinCastInputs } // GetMaxInputs returns the maximum number of input tensors this operator expects. func (c *Cast) GetMaxInputs() int { - return MaxCastInput + return MaxCastInputs } // GetInputTypeConstraints returns a list. Every element represents a set of allowed tensor dtypes diff --git a/ops/opset13/concat.go b/ops/opset13/concat.go index 3d347eb..154839d 100644 --- a/ops/opset13/concat.go +++ b/ops/opset13/concat.go @@ -7,8 +7,7 @@ import ( ) const ( - // MinConcatInput is the minimimum amount of inputs the concat operator expects. - MinConcatInput = 1 + MinConcatInputs = 1 ) // Concat represents the ONNX concat operator. @@ -70,7 +69,7 @@ func (c *Concat) ValidateInputs(inputs []tensor.Tensor) ([]tensor.Tensor, error) // GetMinInputs returns the minimum number of input tensors this operator expects. func (c *Concat) GetMinInputs() int { - return MinConcatInput + return MinConcatInputs } // GetMaxInputs returns the maximum number of input tensors this operator expects. diff --git a/ops/opset13/constant.go b/ops/opset13/constant.go index b667447..758dd21 100644 --- a/ops/opset13/constant.go +++ b/ops/opset13/constant.go @@ -6,14 +6,6 @@ import ( "gorgonia.org/tensor" ) -const ( - // MinConstInput is the minimimum amount of inputs the constant operator expects. - MinConstInput = 1 - - // MaxConstInput is the maximum amount of inputs the constant operator accepts. - MaxConstInput = 1 -) - // Constant represents the ONNX constant operator. type Constant struct { value tensor.Tensor diff --git a/ops/opset13/constant_of_shape.go b/ops/opset13/constant_of_shape.go index 2e40b73..a33d864 100644 --- a/ops/opset13/constant_of_shape.go +++ b/ops/opset13/constant_of_shape.go @@ -7,11 +7,8 @@ import ( ) const ( - // MinConstanShapeOfInput is the minimimum amount of inputs the constant of shape operator expects. - MinConstanShapeOfInput = 1 - - // MaxConstanShapeOfInput is the maximum amount of inputs the constant of shape operator accepts. - MaxConstanShapeOfInput = 1 + MinConstantOfShapeInputs = 1 + MaxConstantOfShapeInputs = 1 ) // ConstantOfShape represents the ONNX constant of shape operator. @@ -85,12 +82,12 @@ func (c *ConstantOfShape) ValidateInputs(inputs []tensor.Tensor) ([]tensor.Tenso // GetMinInputs returns the minimum number of input tensors this operator expects. func (c *ConstantOfShape) GetMinInputs() int { - return MinConstanShapeOfInput + return MinConstantOfShapeInputs } // GetMaxInputs returns the maximum number of input tensors this operator expects. func (c *ConstantOfShape) GetMaxInputs() int { - return MaxConstanShapeOfInput + return MaxConstantOfShapeInputs } // GetInputTypeConstraints returns a list. Every element represents a set of allowed tensor dtypes diff --git a/ops/opset13/div.go b/ops/opset13/div.go index 72fa9a4..59bd531 100644 --- a/ops/opset13/div.go +++ b/ops/opset13/div.go @@ -7,11 +7,8 @@ import ( ) const ( - // MinDivInput is the minimimum amount of inputs the div operator expects. - MinDivInput = 2 - - // MaxDivInput is the maximum amount of inputs the div operator accepts. - MaxDivInput = 2 + MinDivInputs = 2 + MaxDivInputs = 2 ) // Div represents the ONNX div operator. @@ -49,12 +46,12 @@ func (d *Div) ValidateInputs(inputs []tensor.Tensor) ([]tensor.Tensor, error) { // GetMinInputs returns the minimum number of input tensors this operator expects. func (d *Div) GetMinInputs() int { - return MinDivInput + return MinDivInputs } // GetMaxInputs returns the maximum number of input tensors this operator expects. func (d *Div) GetMaxInputs() int { - return MaxDivInput + return MaxDivInputs } // GetInputTypeConstraints returns a list. Every element represents a set of allowed tensor dtypes diff --git a/ops/opset13/div_test.go b/ops/opset13/div_test.go index 6af8f60..06a4f45 100644 --- a/ops/opset13/div_test.go +++ b/ops/opset13/div_test.go @@ -107,7 +107,7 @@ func TestInputValidationDiv(t *testing.T) { []tensor.Tensor{ ops.TensorWithBackingFixture([]int{1, 2}, 2), }, - ops.ErrInvalidInputCount(1.0, &Div{}), + ops.ErrInvalidInputCount(1, &Div{}), }, { []tensor.Tensor{ diff --git a/ops/opset13/gather.go b/ops/opset13/gather.go index 59abbe7..5668779 100644 --- a/ops/opset13/gather.go +++ b/ops/opset13/gather.go @@ -7,11 +7,8 @@ import ( ) const ( - // MinGatherInput is the minimimum amount of inputs the gather operator expects. - MinGatherInput = 2 - - // MaxGatherInput is the maximum amount of inputs the gather operator accepts. - MaxGatherInput = 2 + MinGatherInputs = 2 + MaxGatherInputs = 2 ) // Gather represents the ONNX gather operator. @@ -99,12 +96,12 @@ func (g *Gather) ValidateInputs(inputs []tensor.Tensor) ([]tensor.Tensor, error) // GetMinInputs returns the minimum number of input tensors this operator expects. func (g *Gather) GetMinInputs() int { - return MinGatherInput + return MinGatherInputs } // GetMaxInputs returns the maximum number of input tensors this operator expects. func (g *Gather) GetMaxInputs() int { - return MaxGatherInput + return MaxGatherInputs } // GetInputTypeConstraints returns a list. Every element represents a set of allowed tensor dtypes diff --git a/ops/opset13/gather_test.go b/ops/opset13/gather_test.go index 4ab8fca..fb82347 100644 --- a/ops/opset13/gather_test.go +++ b/ops/opset13/gather_test.go @@ -31,13 +31,13 @@ func TestGatherInitDefault(t *testing.T) { func TestGatherInitTooManyAttrs(t *testing.T) { op := Gather{} err := op.Init([]*onnx.AttributeProto{{Name: "axis"}, {Name: "default"}}) - assert.EqualError(t, err, "operator gather operator attribute error: invalid count 2 expected 1") + assert.EqualError(t, err, "gather operator attribute error: invalid count 2 expected 1") } func TestGatherInitInvalidAttrName(t *testing.T) { op := Gather{} err := op.Init([]*onnx.AttributeProto{{Name: "axes"}}) // should be axis - assert.EqualError(t, err, "operator gather operator attribute error: invalid attribute axes") + assert.EqualError(t, err, "gather operator attribute error: invalid attribute axes") } func TestGather(t *testing.T) { diff --git a/ops/opset13/gemm.go b/ops/opset13/gemm.go index 0fc05f9..9268043 100644 --- a/ops/opset13/gemm.go +++ b/ops/opset13/gemm.go @@ -7,11 +7,8 @@ import ( ) const ( - // MinGemmInput is the minimimum amount of inputs the gemm operator expects. - MinGemmInput = 2 - - // MaxGemmInput is the maximum amount of inputs the gemm operator accepts. - MaxGemmInput = 3 + MinGemmInputs = 2 + MaxGemmInputs = 3 ) // Gemm represents the ONNX gemm operator. @@ -114,12 +111,12 @@ func (g *Gemm) ValidateInputs(inputs []tensor.Tensor) ([]tensor.Tensor, error) { // GetMinInputs returns the minimum number of input tensors this operator expects. func (g *Gemm) GetMinInputs() int { - return MinGemmInput + return MinGemmInputs } // GetMaxInputs returns the maximum number of input tensors this operator expects. func (g *Gemm) GetMaxInputs() int { - return MaxGemmInput + return MaxGemmInputs } // GetInputTypeConstraints returns a list. Every element represents a set of allowed tensor dtypes diff --git a/ops/opset13/gru.go b/ops/opset13/gru.go index 8f7777e..80c8802 100644 --- a/ops/opset13/gru.go +++ b/ops/opset13/gru.go @@ -7,11 +7,8 @@ import ( ) const ( - // MinGRUInput is the minimimum amount of inputs the gru operator expects. - MinGRUInput = 3 - - // MaxGRUInput is the maximum amount of inputs the gru operator accepts. - MaxGRUInput = 6 + MinGRUInputs = 3 + MaxGRUInputs = 6 ) // GRU represents the ONNX gru operator. It only supports a simple forward gru @@ -170,12 +167,12 @@ func (g *GRU) ValidateInputs(inputs []tensor.Tensor) ([]tensor.Tensor, error) { // GetMinInputs returns the minimum number of input tensors this operator expects. func (g *GRU) GetMinInputs() int { - return MinGRUInput + return MinGRUInputs } // GetMaxInputs returns the maximum number of input tensors this operator expects. func (g *GRU) GetMaxInputs() int { - return MaxGRUInput + return MaxGRUInputs } // GetInputTypeConstraints returns a list. Every element represents a set of allowed tensor dtypes diff --git a/ops/opset13/matmul.go b/ops/opset13/matmul.go index c49c505..2b5969a 100644 --- a/ops/opset13/matmul.go +++ b/ops/opset13/matmul.go @@ -7,11 +7,8 @@ import ( ) const ( - // MinMatMulInput is the minimimum amount of inputs the matmul operator expects. - MinMatMulInput = 2 - - // MaxMatMulInput is the maximum amount of inputs the matmul operator accepts. - MaxMatMulInput = 2 + MinMatMulInputs = 2 + MaxMatMulInputs = 2 ) // MatMul represents the ONNX matmul operator. @@ -29,8 +26,6 @@ func (m *MatMul) Init(_ []*onnx.AttributeProto) error { // Apply applies the matmul operator. func (m *MatMul) Apply(inputs []tensor.Tensor) ([]tensor.Tensor, error) { - var ok bool - A := inputs[0] B := inputs[1] @@ -49,14 +44,16 @@ func (m *MatMul) Apply(inputs []tensor.Tensor) ([]tensor.Tensor, error) { if len(A.Shape()) == 1 { prependedDimension = true - A, ok = A.Clone().(tensor.Tensor) + reshapedA, ok := A.Clone().(tensor.Tensor) if !ok { return nil, ops.ErrTypeAssert("tensor.Tensor", A.Clone()) } - if err := A.Reshape(1, A.Shape()[0]); err != nil { + if err := reshapedA.Reshape(1, reshapedA.Shape()[0]); err != nil { return nil, err } + + A = reshapedA } // If B is a vector, promote to a matrix for the calculation. @@ -64,14 +61,16 @@ func (m *MatMul) Apply(inputs []tensor.Tensor) ([]tensor.Tensor, error) { if len(B.Shape()) == 1 { appendedDimension = true - B, ok = B.Clone().(tensor.Tensor) + reshapedB, ok := B.Clone().(tensor.Tensor) if !ok { return nil, ops.ErrTypeAssert("tensor.Tensor", B.Clone()) } - if err := B.Reshape(B.Shape()[0], 1); err != nil { + if err := reshapedB.Reshape(reshapedB.Shape()[0], 1); err != nil { return nil, err } + + B = reshapedB } // Now we have to perform batch matrix multiplication. First we need to broadcast @@ -116,12 +115,12 @@ func (m *MatMul) ValidateInputs(inputs []tensor.Tensor) ([]tensor.Tensor, error) // GetMinInputs returns the minimum number of input tensors this operator expects. func (m *MatMul) GetMinInputs() int { - return MinMatMulInput + return MinMatMulInputs } // GetMaxInputs returns the maximum number of input tensors this operator expects. func (m *MatMul) GetMaxInputs() int { - return MaxMatMulInput + return MaxMatMulInputs } // GetInputTypeConstraints returns a list. Every element represents a set of allowed tensor dtypes diff --git a/ops/opset13/mul.go b/ops/opset13/mul.go index 6395572..963872d 100644 --- a/ops/opset13/mul.go +++ b/ops/opset13/mul.go @@ -7,11 +7,8 @@ import ( ) const ( - // MinMulInput is the minimimum amount of inputs the mul operator expects. - MinMulInput = 2 - - // MaxMulInput is the maximum amount of inputs the mul operator accepts. - MaxMulInput = 2 + MinMulInputs = 2 + MaxMulInputs = 2 ) // Mul represents the ONNX mul operator. @@ -49,12 +46,12 @@ func (m *Mul) ValidateInputs(inputs []tensor.Tensor) ([]tensor.Tensor, error) { // GetMinInputs returns the minimum number of input tensors this operator expects. func (m *Mul) GetMinInputs() int { - return MinMulInput + return MinMulInputs } // GetMaxInputs returns the maximum number of input tensors this operator expects. func (m *Mul) GetMaxInputs() int { - return MaxMulInput + return MaxMulInputs } // GetInputTypeConstraints returns a list. Every element represents a set of allowed tensor dtypes diff --git a/ops/opset13/prelu.go b/ops/opset13/prelu.go index 5131dc8..95fa94b 100644 --- a/ops/opset13/prelu.go +++ b/ops/opset13/prelu.go @@ -114,12 +114,12 @@ func calcPRelu[T float32 | float64 | uint32 | uint64 | int32 | int64](result any convertedInput, ok = input.([]T) if !ok { - return ops.ErrTypeAssert("numeric list", result) + return ops.ErrTypeAssert("numeric list", input) } convertedSlope, ok = slope.([]T) if !ok { - return ops.ErrTypeAssert("numeric list", result) + return ops.ErrTypeAssert("numeric list", slope) } for i, v := range convertedInput { diff --git a/ops/opset13/scaler.go b/ops/opset13/scaler.go index 811ce80..b53ec35 100644 --- a/ops/opset13/scaler.go +++ b/ops/opset13/scaler.go @@ -8,7 +8,8 @@ import ( const ( ScalerExpectedAttributes = 2 - ScalerInputs = 1 + MinScalerInputs = 1 + MaxScalerInputs = 1 ) // Scaler represents the ONNX-ml scaler operator. @@ -76,12 +77,12 @@ func (s *Scaler) ValidateInputs(inputs []tensor.Tensor) ([]tensor.Tensor, error) // GetMinInputs returns the minimum number of input tensors this operator expects. func (s *Scaler) GetMinInputs() int { - return ScalerInputs + return MinScalerInputs } // GetMaxInputs returns the maximum number of input tensors this operator expects. func (s *Scaler) GetMaxInputs() int { - return ScalerInputs + return MaxScalerInputs } // GetInputTypeConstraints returns a list. Every element represents a set of allowed tensor dtypes diff --git a/ops/opset13/shape.go b/ops/opset13/shape.go index 16b0978..82a434c 100644 --- a/ops/opset13/shape.go +++ b/ops/opset13/shape.go @@ -6,7 +6,10 @@ import ( "gorgonia.org/tensor" ) -const NShapeInputs = 1 +const ( + MinShapeInputs = 1 + MaxShapeInputs = 1 +) // Shape represents the ONNX shape operator. type Shape struct{} @@ -43,12 +46,12 @@ func (s *Shape) ValidateInputs(inputs []tensor.Tensor) ([]tensor.Tensor, error) // GetMinInputs returns the minimum number of input tensors this operator expects. func (s *Shape) GetMinInputs() int { - return NShapeInputs + return MinShapeInputs } // GetMaxInputs returns the maximum number of input tensors this operator expects. func (s *Shape) GetMaxInputs() int { - return NShapeInputs + return MaxShapeInputs } // GetInputTypeConstraints returns a list. Every element represents a set of allowed tensor dtypes diff --git a/ops/opset13/slice.go b/ops/opset13/slice.go index c31715f..5aeeb97 100644 --- a/ops/opset13/slice.go +++ b/ops/opset13/slice.go @@ -7,11 +7,8 @@ import ( ) const ( - // MinSliceInput is the minimimum amount of inputs the slice operator expects. - MinSliceInput = 3 - - // MaxSliceInput is the maximum amount of inputs the slice operator accepts. - MaxSliceInput = 5 + MinSliceInputs = 3 + MaxSliceInputs = 5 ) // Slice represents the ONNX slice operator. @@ -74,12 +71,12 @@ func (s *Slice) ValidateInputs(inputs []tensor.Tensor) ([]tensor.Tensor, error) // GetMinInputs returns the minimum number of input tensors this operator expects. func (s *Slice) GetMinInputs() int { - return MinSliceInput + return MinSliceInputs } // GetMaxInputs returns the maximum number of input tensors this operator expects. func (s *Slice) GetMaxInputs() int { - return MaxSliceInput + return MaxSliceInputs } // GetInputTypeConstraints returns a list. Every element represents a set of allowed tensor dtypes diff --git a/ops/opset13/squeeze.go b/ops/opset13/squeeze.go index 88ac56f..47b5e1f 100644 --- a/ops/opset13/squeeze.go +++ b/ops/opset13/squeeze.go @@ -7,8 +7,8 @@ import ( ) const ( - SqueezeMinInput = 1 - SqueezeMaxInput = 2 + MinSqueezeInputs = 1 + MaxSqueezeInputs = 2 ) // Squeeze represents the ONNX squeeze operator. @@ -51,7 +51,7 @@ func (s *Squeeze) Apply(inputs []tensor.Tensor) ([]tensor.Tensor, error) { out, ok := inputs[0].Clone().(tensor.Tensor) if !ok { - return nil, nil + return nil, ops.ErrTypeAssert("tensor.Tensor", inputs[0].Clone()) } err = out.Reshape(newShape...) @@ -66,12 +66,12 @@ func (s *Squeeze) ValidateInputs(inputs []tensor.Tensor) ([]tensor.Tensor, error // GetMinInputs returns the minimum number of input tensors this operator expects. func (s *Squeeze) GetMinInputs() int { - return SqueezeMinInput + return MinSqueezeInputs } // GetMaxInputs returns the maximum number of input tensors this operator expects. func (s *Squeeze) GetMaxInputs() int { - return SqueezeMaxInput + return MaxSqueezeInputs } // GetInputTypeConstraints returns a list. Every element represents a set of allowed tensor dtypes diff --git a/ops/opset13/sub.go b/ops/opset13/sub.go index 8e49337..cb6061d 100644 --- a/ops/opset13/sub.go +++ b/ops/opset13/sub.go @@ -7,11 +7,8 @@ import ( ) const ( - // MinSubInput is the minimimum amount of inputs the sub operator expects. - MinSubInput = 2 - - // MaxSubInput is the maximum amount of inputs the sub operator accepts. - MaxSubInput = 2 + MinSubInputs = 2 + MaxSubInputs = 2 ) // Sub represents the ONNX sub operator. @@ -49,12 +46,12 @@ func (s *Sub) ValidateInputs(inputs []tensor.Tensor) ([]tensor.Tensor, error) { // GetMinInputs returns the minimum number of input tensors this operator expects. func (s *Sub) GetMinInputs() int { - return MinSubInput + return MinSubInputs } // GetMaxInputs returns the maximum number of input tensors this operator expects. func (s *Sub) GetMaxInputs() int { - return MaxSubInput + return MaxSubInputs } // GetInputTypeConstraints returns a list. Every element represents a set of allowed tensor dtypes diff --git a/ops/opset13/transpose.go b/ops/opset13/transpose.go index d1e1780..08090f4 100644 --- a/ops/opset13/transpose.go +++ b/ops/opset13/transpose.go @@ -6,7 +6,10 @@ import ( "gorgonia.org/tensor" ) -const TransposeInputs = 1 +const ( + MinTransposeInputs = 1 + MaxTransposeInputs = 1 +) // Transpose represents the ONNX transpose operator. type Transpose struct { @@ -55,12 +58,12 @@ func (t *Transpose) ValidateInputs(inputs []tensor.Tensor) ([]tensor.Tensor, err // GetMinInputs returns the minimum number of input tensors this operator expects. func (t *Transpose) GetMinInputs() int { - return TransposeInputs + return MinTransposeInputs } // GetMaxInputs returns the maximum number of input tensors this operator expects. func (t *Transpose) GetMaxInputs() int { - return TransposeInputs + return MaxTransposeInputs } // GetInputTypeConstraints returns a list. Every element represents a set of allowed tensor dtypes diff --git a/ops/opset13/unsqueeze.go b/ops/opset13/unsqueeze.go index 8ba0f4e..6547541 100644 --- a/ops/opset13/unsqueeze.go +++ b/ops/opset13/unsqueeze.go @@ -8,8 +8,10 @@ import ( "gorgonia.org/tensor" ) -// NUnsqueezeInputs is the exact number of inputs the unsqueeze operator expects. -const NUnsqueezeInputs = 2 +const ( + MinUnsqueezeInputs = 2 + MaxUnsqueezeInputs = 2 +) // Unsqueeze represents the ONNX unsqueeze operator. type Unsqueeze struct{} @@ -68,12 +70,12 @@ func (u *Unsqueeze) ValidateInputs(inputs []tensor.Tensor) ([]tensor.Tensor, err // GetMinInputs returns the minimum number of input tensors this operator expects. func (u *Unsqueeze) GetMinInputs() int { - return NUnsqueezeInputs + return MinUnsqueezeInputs } // GetMaxInputs returns the maximum number of input tensors this operator expects. func (u *Unsqueeze) GetMaxInputs() int { - return NUnsqueezeInputs + return MaxUnsqueezeInputs } // GetInputTypeConstraints returns a list. Every element represents a set of allowed tensor dtypes diff --git a/ops/utils.go b/ops/utils.go index 63296c6..98d205f 100644 --- a/ops/utils.go +++ b/ops/utils.go @@ -1,8 +1,6 @@ package ops import ( - "errors" - "gorgonia.org/tensor" ) @@ -80,11 +78,6 @@ func OffsetTensorIfNegative(t tensor.Tensor, offset int) error { return nil } -var ( - ErrCast = errors.New("cast error") - ErrInvalidShape = errors.New("invalid shape error") -) - // AnyToIntSlice casts the data of a node to an int list. This will only // be done if the data is of some sort of int type. func AnyToIntSlice(value interface{}) ([]int, error) {