Skip to content

Commit

Permalink
Merged develop
Browse files Browse the repository at this point in the history
  • Loading branch information
Swopper050 committed Nov 11, 2023
2 parents efeaa1a + 75cee12 commit 7968eef
Show file tree
Hide file tree
Showing 73 changed files with 1,084 additions and 542 deletions.
15 changes: 11 additions & 4 deletions .golangci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -43,10 +43,17 @@ linters:
linters-settings:
gomnd:
ignored-functions:
- 'strconv.ParseInt'
- 'strconv.ParseFloat'
- 'strconv.FormatInt'
- 'strconv.FormatFloat'
- "strconv.ParseInt"
- "strconv.ParseFloat"
- "strconv.FormatInt"
- "strconv.FormatFloat"
gocritic:
disabled-checks:
# In the world of AI tensor's are often denoted with a capital letter.
# We want to adopt the go style guide as much as possible but we also want
# to be able to easily show when a variable is a Tensor. So we chose to
# disable captLocal. Note that any other parameter should use a lower case letters.
- "captLocal"
issues:
max-issues-per-linter: 0
max-same-issues: 0
33 changes: 29 additions & 4 deletions errors.go
Original file line number Diff line number Diff line change
@@ -1,7 +1,32 @@
package gonnx

// InvalidShapeError is used when the shape of an input tensor does not match the expectation.
const InvalidShapeError = "input shape does not match for %v: expected %v but got %v"
import (
"errors"
"fmt"

// SetOutputTensorsError is used when the output of an operation could not be set.
const SetOutputTensorsError = "could not set output tensors, expected %v tensors but got %v"
"github.com/advancedclimatesystems/gonnx/onnx"
)

var errModel = errors.New("gonnx model error")

type InvalidShapeError struct {
expected onnx.Shape
actual []int
}

func (i InvalidShapeError) Error() string {
return fmt.Sprintf("invalid shape error expected: %v actual %v", i.expected, i.actual)
}

func ErrInvalidShape(expected onnx.Shape, actual []int) error {
return InvalidShapeError{
expected: expected,
actual: actual,
}
}

// ErrModel is used for when an error ocured during setup of running onnx models.
// The user can specify a formatted message using the standard formatting rules.
func ErrModel(format string, a ...any) error {
return fmt.Errorf("%w: %s", errModel, fmt.Sprintf(format, a...))
}
35 changes: 17 additions & 18 deletions model.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,16 +2,16 @@ package gonnx

import (
"archive/zip"
"fmt"
"io/ioutil"
"io"
"os"

"github.com/advancedclimatesystems/gonnx/onnx"
"github.com/advancedclimatesystems/gonnx/ops"
"google.golang.org/protobuf/proto"
"gorgonia.org/tensor"
)

// Tensors is a map with tensors
// Tensors is a map with tensors.
type Tensors map[string]tensor.Tensor

// Model defines a model that can be used for inference.
Expand All @@ -23,7 +23,7 @@ type Model struct {

// NewModelFromFile creates a new model from a path to a file.
func NewModelFromFile(path string) (*Model, error) {
bytesModel, err := ioutil.ReadFile(path)
bytesModel, err := os.ReadFile(path)
if err != nil {
return nil, err
}
Expand All @@ -38,7 +38,7 @@ func NewModelFromZipFile(file *zip.File) (*Model, error) {
return nil, err
}

bytesModel, err := ioutil.ReadAll(fc)
bytesModel, err := io.ReadAll(fc)
if err != nil {
return nil, err
}
Expand Down Expand Up @@ -66,6 +66,7 @@ func NewModel(mp *onnx.ModelProto) (*Model, error) {
opsetImports := mp.GetOpsetImport()

var opsetID int64

for i := 0; i < len(opsetImports); i++ {
version := opsetImports[i].GetVersion()
if version > opsetID {
Expand All @@ -78,12 +79,11 @@ func NewModel(mp *onnx.ModelProto) (*Model, error) {
return nil, err
}

model := &Model{
return &Model{
mp: mp,
parameters: params,
GetOperator: GetOperator,
}
return model, nil
}, nil
}

// ModelProtoFromBytes creates an onnx.ModelProto based on a list of bytes.
Expand All @@ -92,6 +92,7 @@ func ModelProtoFromBytes(bytesModel []byte) (*onnx.ModelProto, error) {
if err := proto.Unmarshal(bytesModel, mp); err != nil {
return nil, err
}

return mp, nil
}

Expand All @@ -108,16 +109,13 @@ func (m *Model) InputShapes() onnx.Shapes {
// InputDimSize returns the size of the input dimension given an input tensor.
func (m *Model) InputDimSize(input string, i int) (int, error) {
if !m.hasInput(input) {
return 0, fmt.Errorf("input %v does not exist", input)
return 0, ErrModel("input %v does not exist", input)
}

inputShape := m.mp.Graph.InputShapes()[input]

if i >= len(inputShape) {
err := fmt.Errorf(
"input %v only has %d dimensions, but index %d was required", input, len(inputShape), i,
)
return 0, err
return 0, ErrModel("input %v only has %d dimensions, but index %d was required", input, len(inputShape), i)
}

return int(inputShape[i].Size), nil
Expand Down Expand Up @@ -222,13 +220,13 @@ func (m *Model) validateShapes(inputTensors Tensors) error {

tensor, ok := inputTensors[name]
if !ok {
return fmt.Errorf("tensor: %v not found", name)
return ErrModel("tensor: %v not found", name)
}

shapeReceived := tensor.Shape()

if len(shapeReceived) != len(shapeExpected) {
return fmt.Errorf(InvalidShapeError, name, shapeExpected, shapeReceived)
return ErrInvalidShape(shapeExpected, shapeReceived)
}

for i, dim := range shapeExpected {
Expand All @@ -239,7 +237,7 @@ func (m *Model) validateShapes(inputTensors Tensors) error {
}

if dim.Size != int64(shapeReceived[i]) {
return fmt.Errorf(InvalidShapeError, name, shapeExpected, shapeReceived)
return ErrInvalidShape(shapeExpected, shapeReceived)
}
}
}
Expand All @@ -249,6 +247,7 @@ func (m *Model) validateShapes(inputTensors Tensors) error {

func getInputTensorsForNode(names []string, tensors Tensors) ([]tensor.Tensor, error) {
var inputTensors []tensor.Tensor

for _, tensorName := range names {
// An empty name can happen in between optional inputs, like:
// [<required_input>, <optional_input>, nil, <optional_input>]
Expand All @@ -259,7 +258,7 @@ func getInputTensorsForNode(names []string, tensors Tensors) ([]tensor.Tensor, e
} else if tensor, ok := tensors[tensorName]; ok {
inputTensors = append(inputTensors, tensor)
} else {
return nil, fmt.Errorf("no tensor yet for name %v", tensorName)
return nil, ErrModel("no tensor yet for name %v", tensorName)
}
}

Expand All @@ -270,7 +269,7 @@ func setOutputTensorsOfNode(
names []string, outputTensors []tensor.Tensor, tensors Tensors,
) error {
if len(names) != len(outputTensors) {
return fmt.Errorf(SetOutputTensorsError, len(names), len(outputTensors))
return ErrModel("could not set output tensor")
}

for i, tensor := range outputTensors {
Expand Down
19 changes: 11 additions & 8 deletions model_test.go
Original file line number Diff line number Diff line change
@@ -1,8 +1,6 @@
package gonnx

import (
"errors"
"fmt"
"testing"

"github.com/advancedclimatesystems/gonnx/onnx"
Expand Down Expand Up @@ -39,9 +37,7 @@ func TestModel(t *testing.T) {
[][]float32{rangeFloat(16)},
),
nil,
errors.New(
"input shape does not match for data_input: expected [0 3] but got (2, 4, 2)",
),
ErrInvalidShape([]onnx.Dim{{IsDynamic: true, Name: "batch_size", Size: 0}, {IsDynamic: false, Name: "", Size: 3}}, []int{2, 4, 2}),
},
{
"./sample_models/onnx_models/mlp.onnx",
Expand All @@ -51,7 +47,7 @@ func TestModel(t *testing.T) {
[][]float32{rangeFloat(6)},
),
nil,
errors.New("tensor: data_input not found"),
ErrModel("tensor: %v not found", "data_input"),
},
{
"./sample_models/onnx_models/gru.onnx",
Expand Down Expand Up @@ -106,6 +102,7 @@ func TestModel(t *testing.T) {
outputs, err := model.Run(test.input)

assert.Equal(t, test.err, err)

if test.expected == nil {
assert.Nil(t, outputs)
} else {
Expand All @@ -128,6 +125,7 @@ func TestModelIOUtil(t *testing.T) {
{IsDynamic: false, Name: "", Size: 3},
},
}

assert.Equal(t, []string{"data_input"}, model.InputNames())
assert.Equal(t, expectedInputShapes, model.InputShapes())

Expand All @@ -137,6 +135,7 @@ func TestModelIOUtil(t *testing.T) {
{IsDynamic: false, Name: "", Size: 2},
},
}

assert.Equal(t, []string{"preds"}, model.OutputNames())
assert.Equal(t, expectedOutputShapes, model.OutputShapes())
assert.Equal(t, expectedOutputShapes["preds"], model.OutputShape("preds"))
Expand Down Expand Up @@ -165,11 +164,12 @@ func TestInputDimSizeInvalidInput(t *testing.T) {
assert.Nil(t, err)

_, err = model.InputDimSize("swagger", 0)
assert.Equal(t, fmt.Errorf("input swagger does not exist"), err)

assert.Equal(t, ErrModel("input %v does not exist", "swagger"), err)
}

// tensorsFixture creates Tensors with the given names shapes and backings. This is useful for
// providing a model with inputs and checking it's outputs
// providing a model with inputs and checking it's outputs.
func tensorsFixture(names []string, shapes [][]int, backing [][]float32) Tensors {
res := make(Tensors, len(names))
for i, name := range names {
Expand All @@ -178,6 +178,7 @@ func tensorsFixture(names []string, shapes [][]int, backing [][]float32) Tensors
tensor.WithBacking(backing[i]),
)
}

return res
}

Expand All @@ -186,6 +187,7 @@ func rangeFloat(size int) []float32 {
for i := 0; i < size; i++ {
res[i] = float32(i)
}

return res
}

Expand All @@ -194,6 +196,7 @@ func rangeZeros(size int) []float32 {
for i := range res {
res[i] = 0.0
}

return res
}

Expand Down
Loading

0 comments on commit 7968eef

Please sign in to comment.