diff --git a/Makefile b/Makefile index ec0e94b..f05e1b4 100644 --- a/Makefile +++ b/Makefile @@ -44,7 +44,7 @@ test_ci: ## Run tests using normal test runner for ci output. test_data: ## Creates test data from the ONNX test module. rm -R ./test_data; mkdir ./test_data; touch ./test_data/ - git clone --depth 1 --branch v1.15.0 https://github.com/onnx/onnx.git temp_onnx + git clone --depth 1 --branch v1.17.0 https://github.com/onnx/onnx.git temp_onnx cp -r temp_onnx/onnx/backend/test/data/node/* ./test_data rm -Rf temp_onnx diff --git a/ops/constant/constant_1.go b/ops/constant/constant_1.go index 9fd18c0..964f6b6 100644 --- a/ops/constant/constant_1.go +++ b/ops/constant/constant_1.go @@ -26,7 +26,7 @@ func (c *Constant1) Init(n *onnx.NodeProto) error { attr := attributes[0] switch attr.GetName() { - case "value": + case value: t, err := onnx.TensorFromProto(attr.GetT()) if err != nil { return err diff --git a/ops/constant/constant_11.go b/ops/constant/constant_11.go index b10b7ea..b5de777 100644 --- a/ops/constant/constant_11.go +++ b/ops/constant/constant_11.go @@ -27,9 +27,9 @@ func (c *Constant11) Init(n *onnx.NodeProto) error { attr := attributes[0] switch attr.GetName() { - case "sparse_value": + case sparseValue: return ops.ErrUnsupportedAttribute(attr.GetName(), c) - case "value": + case value: t, err := onnx.TensorFromProto(attr.GetT()) if err != nil { return err diff --git a/ops/constant/constant_12.go b/ops/constant/constant_12.go index 0174bea..3390b49 100644 --- a/ops/constant/constant_12.go +++ b/ops/constant/constant_12.go @@ -27,23 +27,23 @@ func (c *Constant12) Init(n *onnx.NodeProto) error { attr := attributes[0] switch attr.GetName() { - case "sparse_value", "value_string", "value_strings": + case sparseValue, valueString, valueStrings: return ops.ErrUnsupportedAttribute(attr.GetName(), c) - case "value": + case value: t, err := onnx.TensorFromProto(attr.GetT()) if err != nil { return err } c.value = t - case "value_float": + case valueFloat: c.value = tensor.New(tensor.FromScalar(attr.GetF())) - case "value_floats": + case valueFloats: floats := attr.GetFloats() c.value = tensor.New(tensor.WithShape(len(floats)), tensor.WithBacking(floats)) - case "value_int": + case valueInt: c.value = tensor.New(tensor.FromScalar(attr.GetI())) - case "value_ints": + case valueInts: ints := attr.GetInts() c.value = tensor.New(tensor.WithShape(len(ints)), tensor.WithBacking(ints)) default: diff --git a/ops/constant/constant_13.go b/ops/constant/constant_13.go index 85261bc..c75f629 100644 --- a/ops/constant/constant_13.go +++ b/ops/constant/constant_13.go @@ -27,23 +27,23 @@ func (c *Constant13) Init(n *onnx.NodeProto) error { attr := attributes[0] switch attr.GetName() { - case "sparse_value", "value_string", "value_strings": + case sparseValue, valueString, valueStrings: return ops.ErrUnsupportedAttribute(attr.GetName(), c) - case "value": + case value: t, err := onnx.TensorFromProto(attr.GetT()) if err != nil { return err } c.value = t - case "value_float": + case valueFloat: c.value = tensor.New(tensor.FromScalar(attr.GetF())) - case "value_floats": + case valueFloats: floats := attr.GetFloats() c.value = tensor.New(tensor.WithShape(len(floats)), tensor.WithBacking(floats)) - case "value_int": + case valueInt: c.value = tensor.New(tensor.FromScalar(attr.GetI())) - case "value_ints": + case valueInts: ints := attr.GetInts() c.value = tensor.New(tensor.WithShape(len(ints)), tensor.WithBacking(ints)) default: diff --git a/ops/constant/constant_9.go b/ops/constant/constant_9.go index 08b6830..fe054dd 100644 --- a/ops/constant/constant_9.go +++ b/ops/constant/constant_9.go @@ -26,7 +26,7 @@ func (c *Constant9) Init(n *onnx.NodeProto) error { attr := attributes[0] switch attr.GetName() { - case "value": + case value: t, err := onnx.TensorFromProto(attr.GetT()) if err != nil { return err diff --git a/ops/flatten/flatten_1.go b/ops/flatten/flatten_1.go index 42f20fa..c2c6b66 100644 --- a/ops/flatten/flatten_1.go +++ b/ops/flatten/flatten_1.go @@ -27,7 +27,7 @@ func newFlatten1() ops.Operator { func (f *Flatten1) Init(n *onnx.NodeProto) error { for _, attr := range n.GetAttribute() { switch attr.GetName() { - case "axis": + case axis: f.axis = int(attr.GetI()) default: return ops.ErrInvalidAttribute(attr.GetName(), f) @@ -40,6 +40,7 @@ func (f *Flatten1) Init(n *onnx.NodeProto) error { // Apply applies the flatten operator. func (f *Flatten1) Apply(inputs []tensor.Tensor) ([]tensor.Tensor, error) { inputShape := inputs[0].Shape() + out, ok := inputs[0].Clone().(tensor.Tensor) if !ok { return nil, ops.ErrTypeAssert("tensor.Tensor", inputs[0].Clone()) diff --git a/ops/flatten/flatten_11.go b/ops/flatten/flatten_11.go index 55685e4..d483365 100644 --- a/ops/flatten/flatten_11.go +++ b/ops/flatten/flatten_11.go @@ -27,7 +27,7 @@ func newFlatten11() ops.Operator { func (f *Flatten11) Init(n *onnx.NodeProto) error { for _, attr := range n.GetAttribute() { switch attr.GetName() { - case "axis": + case axis: f.axis = int(attr.GetI()) default: return ops.ErrInvalidAttribute(attr.GetName(), f) diff --git a/ops/flatten/flatten_13.go b/ops/flatten/flatten_13.go index 621f74c..190831d 100644 --- a/ops/flatten/flatten_13.go +++ b/ops/flatten/flatten_13.go @@ -27,7 +27,7 @@ func newFlatten13() ops.Operator { func (f *Flatten13) Init(n *onnx.NodeProto) error { for _, attr := range n.GetAttribute() { switch attr.GetName() { - case "axis": + case axis: f.axis = int(attr.GetI()) default: return ops.ErrInvalidAttribute(attr.GetName(), f) diff --git a/ops/flatten/flatten_9.go b/ops/flatten/flatten_9.go index b7f5c45..1932954 100644 --- a/ops/flatten/flatten_9.go +++ b/ops/flatten/flatten_9.go @@ -27,7 +27,7 @@ func newFlatten9() ops.Operator { func (f *Flatten9) Init(n *onnx.NodeProto) error { for _, attr := range n.GetAttribute() { switch attr.GetName() { - case "axis": + case axis: f.axis = int(attr.GetI()) default: return ops.ErrInvalidAttribute(attr.GetName(), f) @@ -40,6 +40,7 @@ func (f *Flatten9) Init(n *onnx.NodeProto) error { // Apply applies the flatten operator. func (f *Flatten9) Apply(inputs []tensor.Tensor) ([]tensor.Tensor, error) { inputShape := inputs[0].Shape() + out, ok := inputs[0].Clone().(tensor.Tensor) if !ok { return nil, ops.ErrTypeAssert("tensor.Tensor", inputs[0].Clone()) diff --git a/ops/gather/gather_1.go b/ops/gather/gather_1.go index 01cdce4..0198f7d 100644 --- a/ops/gather/gather_1.go +++ b/ops/gather/gather_1.go @@ -30,7 +30,7 @@ func (g *Gather1) Init(n *onnx.NodeProto) error { if len(attributes) == 1 { attr := attributes[0] - if attr.GetName() == "axis" { + if attr.GetName() == axis { g.axis = int(attr.GetI()) } else { return ops.ErrInvalidAttribute(attr.GetName(), g) diff --git a/ops/gather/gather_11.go b/ops/gather/gather_11.go index 3677d39..a3f8358 100644 --- a/ops/gather/gather_11.go +++ b/ops/gather/gather_11.go @@ -30,7 +30,7 @@ func (g *Gather11) Init(n *onnx.NodeProto) error { if len(attributes) == 1 { attr := attributes[0] - if attr.GetName() == "axis" { + if attr.GetName() == axis { g.axis = int(attr.GetI()) } else { return ops.ErrInvalidAttribute(attr.GetName(), g) diff --git a/ops/gather/gather_13.go b/ops/gather/gather_13.go index f4bd7fc..2c72ea0 100644 --- a/ops/gather/gather_13.go +++ b/ops/gather/gather_13.go @@ -30,7 +30,7 @@ func (g *Gather13) Init(n *onnx.NodeProto) error { if len(attributes) == 1 { attr := attributes[0] - if attr.GetName() == "axis" { + if attr.GetName() == axis { g.axis = int(attr.GetI()) } else { return ops.ErrInvalidAttribute(attr.GetName(), g) diff --git a/ops/gemm/gemm_11.go b/ops/gemm/gemm_11.go index 44c7638..ae9d71f 100644 --- a/ops/gemm/gemm_11.go +++ b/ops/gemm/gemm_11.go @@ -33,13 +33,13 @@ func newGemm11() ops.Operator { func (g *Gemm11) Init(n *onnx.NodeProto) error { for _, attr := range n.GetAttribute() { switch attr.GetName() { - case "alpha": + case alpha: g.alpha = attr.GetF() - case "beta": + case beta: g.beta = attr.GetF() - case "transA": + case transA: g.transA = ops.Int64ToBool(attr.GetI()) - case "transB": + case transB: g.transB = ops.Int64ToBool(attr.GetI()) default: return ops.ErrInvalidAttribute(attr.GetName(), g) diff --git a/ops/gemm/gemm_13.go b/ops/gemm/gemm_13.go index b3c6de9..b67516b 100644 --- a/ops/gemm/gemm_13.go +++ b/ops/gemm/gemm_13.go @@ -33,13 +33,13 @@ func newGemm13() ops.Operator { func (g *Gemm13) Init(n *onnx.NodeProto) error { for _, attr := range n.GetAttribute() { switch attr.GetName() { - case "alpha": + case alpha: g.alpha = attr.GetF() - case "beta": + case beta: g.beta = attr.GetF() - case "transA": + case transA: g.transA = ops.Int64ToBool(attr.GetI()) - case "transB": + case transB: g.transB = ops.Int64ToBool(attr.GetI()) default: return ops.ErrInvalidAttribute(attr.GetName(), g) diff --git a/ops/gemm/gemm_7.go b/ops/gemm/gemm_7.go index b7192f8..19a7cd2 100644 --- a/ops/gemm/gemm_7.go +++ b/ops/gemm/gemm_7.go @@ -33,13 +33,13 @@ func newGemm7() ops.Operator { func (g *Gemm7) Init(n *onnx.NodeProto) error { for _, attr := range n.GetAttribute() { switch attr.GetName() { - case "alpha": + case alpha: g.alpha = attr.GetF() - case "beta": + case beta: g.beta = attr.GetF() - case "transA": + case transA: g.transA = ops.Int64ToBool(attr.GetI()) - case "transB": + case transB: g.transB = ops.Int64ToBool(attr.GetI()) default: return ops.ErrInvalidAttribute(attr.GetName(), g) diff --git a/ops/gemm/gemm_9.go b/ops/gemm/gemm_9.go index b06fb1c..f7d38db 100644 --- a/ops/gemm/gemm_9.go +++ b/ops/gemm/gemm_9.go @@ -33,13 +33,13 @@ func newGemm9() ops.Operator { func (g *Gemm9) Init(n *onnx.NodeProto) error { for _, attr := range n.GetAttribute() { switch attr.GetName() { - case "alpha": + case alpha: g.alpha = attr.GetF() - case "beta": + case beta: g.beta = attr.GetF() - case "transA": + case transA: g.transA = ops.Int64ToBool(attr.GetI()) - case "transB": + case transB: g.transB = ops.Int64ToBool(attr.GetI()) default: return ops.ErrInvalidAttribute(attr.GetName(), g) diff --git a/ops/matmul/matmul_1.go b/ops/matmul/matmul_1.go index d106b3b..420b99c 100644 --- a/ops/matmul/matmul_1.go +++ b/ops/matmul/matmul_1.go @@ -134,7 +134,7 @@ func (m *MatMul1) GetInputTypeConstraints() [][]tensor.Dtype { // String implements the stringer interface, and can be used to format errors or messages. func (m *MatMul1) String() string { - return "matmul13 operator" + return "matmul1 operator" } // broadcastTensors broadcasts both tensors for the matmul operator. It is almost identical diff --git a/ops/matmul/matmul_9.go b/ops/matmul/matmul_9.go index 7beac84..3673553 100644 --- a/ops/matmul/matmul_9.go +++ b/ops/matmul/matmul_9.go @@ -134,7 +134,7 @@ func (m *MatMul9) GetInputTypeConstraints() [][]tensor.Dtype { // String implements the stringer interface, and can be used to format errors or messages. func (m *MatMul9) String() string { - return "matmul13 operator" + return "matmul9 operator" } // broadcastTensors broadcasts both tensors for the matmul operator. It is almost identical diff --git a/ops/reducemax/reduce_max_1.go b/ops/reducemax/reduce_max_1.go index 8b5f901..2cd8863 100644 --- a/ops/reducemax/reduce_max_1.go +++ b/ops/reducemax/reduce_max_1.go @@ -34,14 +34,14 @@ func (r *ReduceMax1) Init(n *onnx.NodeProto) error { for _, attr := range attributes { switch attr.GetName() { - case "axes": + case axes: axes, err := ops.AnyToIntSlice(attr.GetInts()) if err != nil { return err } r.axes = axes - case "keepdims": + case keepDims: r.keepDims = attr.GetI() == 1 default: return ops.ErrInvalidAttribute(attr.GetName(), r) diff --git a/ops/reducemax/reduce_max_11.go b/ops/reducemax/reduce_max_11.go index b40be9a..24c5300 100644 --- a/ops/reducemax/reduce_max_11.go +++ b/ops/reducemax/reduce_max_11.go @@ -34,14 +34,14 @@ func (r *ReduceMax11) Init(n *onnx.NodeProto) error { for _, attr := range attributes { switch attr.GetName() { - case "axes": + case axes: axes, err := ops.AnyToIntSlice(attr.GetInts()) if err != nil { return err } r.axes = axes - case "keepdims": + case keepDims: r.keepDims = attr.GetI() == 1 default: return ops.ErrInvalidAttribute(attr.GetName(), r) diff --git a/ops/reducemax/reduce_max_12.go b/ops/reducemax/reduce_max_12.go index 8139cfa..1bbcd6a 100644 --- a/ops/reducemax/reduce_max_12.go +++ b/ops/reducemax/reduce_max_12.go @@ -34,14 +34,14 @@ func (r *ReduceMax12) Init(n *onnx.NodeProto) error { for _, attr := range attributes { switch attr.GetName() { - case "axes": + case axes: axes, err := ops.AnyToIntSlice(attr.GetInts()) if err != nil { return err } r.axes = axes - case "keepdims": + case keepDims: r.keepDims = attr.GetI() == 1 default: return ops.ErrInvalidAttribute(attr.GetName(), r) diff --git a/ops/reducemax/reduce_max_13.go b/ops/reducemax/reduce_max_13.go index f474ba2..f0801fd 100644 --- a/ops/reducemax/reduce_max_13.go +++ b/ops/reducemax/reduce_max_13.go @@ -34,14 +34,14 @@ func (r *ReduceMax13) Init(n *onnx.NodeProto) error { for _, attr := range attributes { switch attr.GetName() { - case "axes": + case axes: axes, err := ops.AnyToIntSlice(attr.GetInts()) if err != nil { return err } r.axes = axes - case "keepdims": + case keepDims: r.keepDims = attr.GetI() == 1 default: return ops.ErrInvalidAttribute(attr.GetName(), r) diff --git a/ops/reducemin/reduce_min_1.go b/ops/reducemin/reduce_min_1.go index 5704ae0..91cdb3f 100644 --- a/ops/reducemin/reduce_min_1.go +++ b/ops/reducemin/reduce_min_1.go @@ -34,14 +34,14 @@ func (r *ReduceMin1) Init(n *onnx.NodeProto) error { for _, attr := range attributes { switch attr.GetName() { - case "axes": + case axes: axes, err := ops.AnyToIntSlice(attr.GetInts()) if err != nil { return err } r.axes = axes - case "keepdims": + case keepDims: r.keepDims = attr.GetI() == 1 default: return ops.ErrInvalidAttribute(attr.GetName(), r) diff --git a/ops/reducemin/reduce_min_11.go b/ops/reducemin/reduce_min_11.go index 8c58d5c..4ebddb2 100644 --- a/ops/reducemin/reduce_min_11.go +++ b/ops/reducemin/reduce_min_11.go @@ -34,14 +34,14 @@ func (r *ReduceMin11) Init(n *onnx.NodeProto) error { for _, attr := range attributes { switch attr.GetName() { - case "axes": + case axes: axes, err := ops.AnyToIntSlice(attr.GetInts()) if err != nil { return err } r.axes = axes - case "keepdims": + case keepDims: r.keepDims = attr.GetI() == 1 default: return ops.ErrInvalidAttribute(attr.GetName(), r) diff --git a/ops/reducemin/reduce_min_12.go b/ops/reducemin/reduce_min_12.go index 93a6f21..dfb41db 100644 --- a/ops/reducemin/reduce_min_12.go +++ b/ops/reducemin/reduce_min_12.go @@ -34,14 +34,14 @@ func (r *ReduceMin12) Init(n *onnx.NodeProto) error { for _, attr := range attributes { switch attr.GetName() { - case "axes": + case axes: axes, err := ops.AnyToIntSlice(attr.GetInts()) if err != nil { return err } r.axes = axes - case "keepdims": + case keepDims: r.keepDims = attr.GetI() == 1 default: return ops.ErrInvalidAttribute(attr.GetName(), r) diff --git a/ops/reducemin/reduce_min_13.go b/ops/reducemin/reduce_min_13.go index 6ab63cd..be24e55 100644 --- a/ops/reducemin/reduce_min_13.go +++ b/ops/reducemin/reduce_min_13.go @@ -34,14 +34,14 @@ func (r *ReduceMin13) Init(n *onnx.NodeProto) error { for _, attr := range attributes { switch attr.GetName() { - case "axes": + case axes: axes, err := ops.AnyToIntSlice(attr.GetInts()) if err != nil { return err } r.axes = axes - case "keepdims": + case keepDims: r.keepDims = attr.GetI() == 1 default: return ops.ErrInvalidAttribute(attr.GetName(), r) diff --git a/ops/slice/slice_1.go b/ops/slice/slice_1.go index 13720c6..45e8df2 100644 --- a/ops/slice/slice_1.go +++ b/ops/slice/slice_1.go @@ -7,8 +7,10 @@ import ( ) const ( - MinSlice1Inputs = 3 - MaxSlice1Inputs = 5 + MinSliceAttributes = 2 + MaxSliceAttributes = 3 + MinSlice1Inputs = 3 + MaxSlice1Inputs = 5 ) // Slice1 represents the ONNX slice operator. @@ -27,7 +29,7 @@ func newSlice1() ops.Operator { func (s *Slice1) Init(n *onnx.NodeProto) error { nAttrs := len(n.GetAttribute()) if nAttrs < 2 || nAttrs > 3 { - return ops.ErrInvalidOptionalAttributeCount(2, 3, nAttrs, s) + return ops.ErrInvalidOptionalAttributeCount(MinSliceAttributes, MaxSliceAttributes, nAttrs, s) } for _, attr := range n.GetAttribute() { @@ -109,7 +111,7 @@ func (s *Slice1) GetInputTypeConstraints() [][]tensor.Dtype { // String implements the stringer interface, and can be used to format errors or messages. func (s *Slice1) String() string { - return "slice11 operator" + return "slice1 operator" } // constructSlice constructs a list with tensor.Slice objects. The list is initializes with nils. diff --git a/ops/slice/slice_10.go b/ops/slice/slice_10.go index 30d41cf..48687cf 100644 --- a/ops/slice/slice_10.go +++ b/ops/slice/slice_10.go @@ -93,7 +93,7 @@ func (s *Slice10) GetInputTypeConstraints() [][]tensor.Dtype { // String implements the stringer interface, and can be used to format errors or messages. func (s *Slice10) String() string { - return "slice11 operator" + return "slice10 operator" } // constructSlice constructs a list with tensor.Slice objects. The list is initializes with nils. diff --git a/ops/squeeze/squeeze_1.go b/ops/squeeze/squeeze_1.go index 548594e..13faf2b 100644 --- a/ops/squeeze/squeeze_1.go +++ b/ops/squeeze/squeeze_1.go @@ -57,10 +57,7 @@ func (s *Squeeze1) Apply(inputs []tensor.Tensor) ([]tensor.Tensor, error) { ops.OffsetArrayIfNegative(dimsToSqueeze, nDims) if len(s.axes) > 0 { - dimsToSqueeze, err = getDimsToSqueezeFromList(s.axes, nDims) - if err != nil { - return nil, err - } + dimsToSqueeze = getDimsToSqueezeFromList(s.axes, nDims) } newShape := getNewShape(currentShape, dimsToSqueeze) diff --git a/ops/squeeze/squeeze_11.go b/ops/squeeze/squeeze_11.go index 35ffc81..cda72bb 100644 --- a/ops/squeeze/squeeze_11.go +++ b/ops/squeeze/squeeze_11.go @@ -57,10 +57,7 @@ func (s *Squeeze11) Apply(inputs []tensor.Tensor) ([]tensor.Tensor, error) { ops.OffsetArrayIfNegative(dimsToSqueeze, nDims) if len(s.axes) > 0 { - dimsToSqueeze, err = getDimsToSqueezeFromList(s.axes, nDims) - if err != nil { - return nil, err - } + dimsToSqueeze = getDimsToSqueezeFromList(s.axes, nDims) } newShape := getNewShape(currentShape, dimsToSqueeze) @@ -105,7 +102,7 @@ func (s *Squeeze11) String() string { // based on a list of ints. The list should contain dimensions/axes to squeeze. Negative dimensions // represent dimensions counting from the end of the shape, i.e. -2 repesents the second // last dimension. -func getDimsToSqueezeFromList(axes []int, nDims int) ([]int, error) { +func getDimsToSqueezeFromList(axes []int, nDims int) []int { dimsToSqueeze := make([]int, len(axes)) copy(dimsToSqueeze, axes) @@ -115,5 +112,5 @@ func getDimsToSqueezeFromList(axes []int, nDims int) ([]int, error) { } } - return dimsToSqueeze, nil + return dimsToSqueeze } diff --git a/ops/transpose/transpose_13.go b/ops/transpose/transpose_13.go index 1f76c7d..bda0183 100644 --- a/ops/transpose/transpose_13.go +++ b/ops/transpose/transpose_13.go @@ -25,19 +25,17 @@ func newTranspose13() ops.Operator { func (t *Transpose13) Init(n *onnx.NodeProto) error { attributes := n.GetAttribute() - if len(attributes) != 1 { - return ops.ErrInvalidAttributeCount(1, len(attributes), t) - } - - attr := attributes[0] + if len(attributes) == 1 { + attr := attributes[0] - if attr.GetName() != "perm" { - return ops.ErrInvalidAttribute(attr.GetName(), t) - } + if attr.GetName() != "perm" { + return ops.ErrInvalidAttribute(attr.GetName(), t) + } - attrPerm := attr.GetInts() - for _, val := range attrPerm { - t.perm = append(t.perm, int(val)) + attrPerm := attr.GetInts() + for _, val := range attrPerm { + t.perm = append(t.perm, int(val)) + } } return nil diff --git a/ops/transpose/transpose_13_test.go b/ops/transpose/transpose_13_test.go index c61398e..7163178 100644 --- a/ops/transpose/transpose_13_test.go +++ b/ops/transpose/transpose_13_test.go @@ -25,14 +25,6 @@ func TestTranspose13InitFailWrongAttribute(t *testing.T) { assert.Equal(t, expected, err) } -func TestTranspose13InitFailAttrCount(t *testing.T) { - trans := &Transpose13{} - err := trans.Init(ops.EmptyNodeProto()) - - expected := ops.ErrInvalidAttributeCount(1, 0, trans) - assert.Equal(t, expected, err) -} - func TestTranspose13(t *testing.T) { tests := []struct { trans *Transpose13 diff --git a/ops_test.go b/ops_test.go index 7d017f5..4ea05dd 100644 --- a/ops_test.go +++ b/ops_test.go @@ -35,6 +35,7 @@ var ignoredTests = []string{ "test_logsoftmax_axis_2_expanded_ver18", // Opset18 "test_lstm_batchwise", // Opset14 "test_mul_uint8", // Opset14 + "test_reduce_max_empty_set", // Opset20 "test_reduce_max_do_not_keepdims_random", // Opset18 "test_reduce_max_keepdims_random", // Opset18 "test_reduce_max_default_axes_keepdims_random", // Opset18 @@ -75,9 +76,6 @@ var ignoredTests = []string{ "test_constant_pad", // Pad is not implemented yet. "test_constant_pad_axes", // Pad is not implemented yet. - "test_gemm_alpha", // For gemm in opset 11. - "test_gemm_default_no_bias", // For gemm in opset 11. - "test_gemm_default_scalar_bias", // For gemm in opset 11. "test_logsoftmax_large_number_expanded", // Requires 'Exp' operator. "test_logsoftmax_axis_0_expanded", // Requires 'Exp' operator. "test_logsoftmax_axis_1_expanded", // Requires 'Exp' operator. @@ -98,7 +96,6 @@ var ignoredTests = []string{ "test_slice_end_out_of_bounds", // ONNX expects nil output, but we throw an error. "test_slice_neg_steps", // ONNX expects nil output, but we throw an error. "test_slice_neg", // ONNX expects nil output, but we throw an error. - "test_transpose_default", // For transpose in opset 9. "test_equal_string", // Unsupported datatype String. "test_equal_string_broadcast", // Unsupported datatype String. @@ -145,7 +142,6 @@ var ignoredTests = []string{ "test_cast_no_saturate_FLOAT16_to_FLOAT8E4M3FN", // Unsupported datatype. "test_cast_no_saturate_FLOAT16_to_FLOAT8E5M2", // Unsupported datatype. - "test_unsqueeze_axis_3", // Tests an old version of Unsqueeze (<= 11) "test_constantofshape_int_shape_zero", // Empty tensors are not supported in gorgonia "test_gather_elements_0", // Operator GatherElements is not implemented "test_gather_elements_1", // Operator GatherElements is not implemented @@ -303,9 +299,14 @@ func readTestModel(folder string) (*Model, error) { return nil, err } - // Currently we only implemented Opset13, hence we enforce this in our tests. All + // Currently we support Opset 7-13, hence we enforce this in our tests. All // tests that fail because of this are ignored. - mp.OpsetImport[0].Version = 13 + fmt.Println(folder, mp.OpsetImport[0].Version) + if mp.OpsetImport[0].Version < MinSupportedOpset { + mp.OpsetImport[0].Version = MinSupportedOpset + } else if mp.OpsetImport[0].Version > MaxSupportedOpset { + mp.OpsetImport[0].Version = MaxSupportedOpset + } model, err := NewModel(mp) if err != nil { @@ -424,7 +425,10 @@ var expectedTests = []string{ "test_gather_negative_indices", "test_gemm_default_single_elem_vector_bias", "test_gemm_all_attributes", + "test_gemm_alpha", "test_gemm_default_matrix_bias", + "test_gemm_default_no_bias", + "test_gemm_default_scalar_bias", "test_gemm_default_vector_bias", "test_gemm_transposeA", "test_gemm_default_zero_bias", @@ -514,6 +518,7 @@ var expectedTests = []string{ "test_transpose_all_permutations_3", "test_transpose_all_permutations_4", "test_transpose_all_permutations_5", + "test_transpose_default", "test_unsqueeze_axis_0", "test_unsqueeze_axis_1", "test_unsqueeze_axis_2", diff --git a/opset.go b/opset.go index 156117e..b307965 100644 --- a/opset.go +++ b/opset.go @@ -133,7 +133,7 @@ var operators = map[string]ops.OperatorVersions{ // one is used. If the opset version is 13, and an operator has versions 7 and 14, version 7 is used, as // it is the closest opset version going downwards. func GetClosestOperatorVersion(opsetID int64, versions ops.OperatorVersions) func() ops.Operator { - for closestOpset := opsetID; opsetID >= MinSupportedOpset; closestOpset-- { + for closestOpset := opsetID; closestOpset >= 1; closestOpset-- { if operator, ok := versions[closestOpset]; ok { return operator } @@ -149,6 +149,7 @@ func ResolveOpset(opsetID int64) (Opset, error) { } opset := map[string]func() ops.Operator{} + for operatorName, operatorVersions := range operators { operator := GetClosestOperatorVersion(opsetID, operatorVersions) if operator == nil {