diff --git a/ops/constant/constant_1.go b/ops/constant/constant_1.go new file mode 100644 index 0000000..e8f4cea --- /dev/null +++ b/ops/constant/constant_1.go @@ -0,0 +1,72 @@ +package constant + +import ( + "github.com/advancedclimatesystems/gonnx/onnx" + "github.com/advancedclimatesystems/gonnx/ops" + "gorgonia.org/tensor" +) + +// Constant1 represents the ONNX constant operator. +type Constant1 struct { + value tensor.Tensor +} + +// newConstant1 creates a new constant operator. +func NewConstant1() ops.Operator { + return &Constant1{} +} + +// Init initializes the constant operator. +func (c *Constant1) Init(n *onnx.NodeProto) error { + attributes := n.GetAttribute() + if len(attributes) != 1 { + return ops.ErrInvalidAttributeCount(1, len(attributes), c) + } + + attr := attributes[0] + + switch attr.GetName() { + case "value": + t, err := onnx.TensorFromProto(attr.GetT()) + if err != nil { + return err + } + + c.value = t + default: + return ops.ErrUnsupportedAttribute(attr.GetName(), c) + } + + return nil +} + +// Apply applies the constant operator. +func (c *Constant1) Apply(_ []tensor.Tensor) ([]tensor.Tensor, error) { + return []tensor.Tensor{c.value}, nil +} + +// ValidateInputs validates the inputs that will be given to Apply for this operator. +func (c *Constant1) ValidateInputs(inputs []tensor.Tensor) ([]tensor.Tensor, error) { + return ops.ValidateInputs(c, inputs) +} + +// GetMinInputs returns the minimum number of input tensors this operator expects. +func (c *Constant1) GetMinInputs() int { + return 0 +} + +// GetMaxInputs returns the maximum number of input tensors this operator expects. +func (c *Constant1) GetMaxInputs() int { + return 0 +} + +// GetInputTypeConstraints returns a list. Every element represents a set of allowed tensor dtypes +// for the corresponding input tensor. +func (c *Constant1) GetInputTypeConstraints() [][]tensor.Dtype { + return [][]tensor.Dtype{} +} + +// String implements the stringer interface, and can be used to format errors or messages. +func (c *Constant1) String() string { + return "constant1 operator" +} diff --git a/ops/constant/constant_11.go b/ops/constant/constant_11.go new file mode 100644 index 0000000..557df5e --- /dev/null +++ b/ops/constant/constant_11.go @@ -0,0 +1,75 @@ +package constant + +import ( + "github.com/advancedclimatesystems/gonnx/onnx" + "github.com/advancedclimatesystems/gonnx/ops" + "gorgonia.org/tensor" +) + +// Constant11 represents the ONNX constant operator. +type Constant11 struct { + value tensor.Tensor +} + +// newConstant11 creates a new constant operator. +func NewConstant11() ops.Operator { + return &Constant11{} +} + +// Init initializes the constant operator. It supports all constant types except +// `sparse_value`. +func (c *Constant11) Init(n *onnx.NodeProto) error { + attributes := n.GetAttribute() + if len(attributes) != 1 { + return ops.ErrInvalidAttributeCount(1, len(attributes), c) + } + + attr := attributes[0] + + switch attr.GetName() { + case "sparse_value": + return ops.ErrUnsupportedAttribute(attr.GetName(), c) + case "value": + t, err := onnx.TensorFromProto(attr.GetT()) + if err != nil { + return err + } + + c.value = t + default: + return ops.ErrUnsupportedAttribute(attr.GetName(), c) + } + + return nil +} + +// Apply applies the constant operator. +func (c *Constant11) Apply(_ []tensor.Tensor) ([]tensor.Tensor, error) { + return []tensor.Tensor{c.value}, nil +} + +// ValidateInputs validates the inputs that will be given to Apply for this operator. +func (c *Constant11) ValidateInputs(inputs []tensor.Tensor) ([]tensor.Tensor, error) { + return ops.ValidateInputs(c, inputs) +} + +// GetMinInputs returns the minimum number of input tensors this operator expects. +func (c *Constant11) GetMinInputs() int { + return 0 +} + +// GetMaxInputs returns the maximum number of input tensors this operator expects. +func (c *Constant11) GetMaxInputs() int { + return 0 +} + +// GetInputTypeConstraints returns a list. Every element represents a set of allowed tensor dtypes +// for the corresponding input tensor. +func (c *Constant11) GetInputTypeConstraints() [][]tensor.Dtype { + return [][]tensor.Dtype{} +} + +// String implements the stringer interface, and can be used to format errors or messages. +func (c *Constant11) String() string { + return "constant11 operator" +} diff --git a/ops/opset13/constant.go b/ops/constant/constant_12.go similarity index 74% rename from ops/opset13/constant.go rename to ops/constant/constant_12.go index d0c1261..442a4f4 100644 --- a/ops/opset13/constant.go +++ b/ops/constant/constant_12.go @@ -1,4 +1,4 @@ -package opset13 +package constant import ( "github.com/advancedclimatesystems/gonnx/onnx" @@ -6,19 +6,19 @@ import ( "gorgonia.org/tensor" ) -// Constant represents the ONNX constant operator. -type Constant struct { +// Constant12 represents the ONNX constant operator. +type Constant12 struct { value tensor.Tensor } -// newConstant creates a new constant operator. -func newConstant() ops.Operator { - return &Constant{} +// newConstant12 creates a new constant operator. +func NewConstant12() ops.Operator { + return &Constant12{} } // Init initializes the constant operator. It supports all constant types except // `sparse_value`, `value_string`, and `value_strings`. -func (c *Constant) Init(n *onnx.NodeProto) error { +func (c *Constant12) Init(n *onnx.NodeProto) error { attributes := n.GetAttribute() if len(attributes) != 1 { return ops.ErrInvalidAttributeCount(1, len(attributes), c) @@ -54,32 +54,32 @@ func (c *Constant) Init(n *onnx.NodeProto) error { } // Apply applies the constant operator. -func (c *Constant) Apply(_ []tensor.Tensor) ([]tensor.Tensor, error) { +func (c *Constant12) Apply(_ []tensor.Tensor) ([]tensor.Tensor, error) { return []tensor.Tensor{c.value}, nil } // ValidateInputs validates the inputs that will be given to Apply for this operator. -func (c *Constant) ValidateInputs(inputs []tensor.Tensor) ([]tensor.Tensor, error) { +func (c *Constant12) ValidateInputs(inputs []tensor.Tensor) ([]tensor.Tensor, error) { return ops.ValidateInputs(c, inputs) } // GetMinInputs returns the minimum number of input tensors this operator expects. -func (c *Constant) GetMinInputs() int { +func (c *Constant12) GetMinInputs() int { return 0 } // GetMaxInputs returns the maximum number of input tensors this operator expects. -func (c *Constant) GetMaxInputs() int { +func (c *Constant12) GetMaxInputs() int { return 0 } // GetInputTypeConstraints returns a list. Every element represents a set of allowed tensor dtypes // for the corresponding input tensor. -func (c *Constant) GetInputTypeConstraints() [][]tensor.Dtype { +func (c *Constant12) GetInputTypeConstraints() [][]tensor.Dtype { return [][]tensor.Dtype{} } // String implements the stringer interface, and can be used to format errors or messages. -func (c *Constant) String() string { - return "constant operator" +func (c *Constant12) String() string { + return "constant12 operator" } diff --git a/ops/constant/constant_13.go b/ops/constant/constant_13.go new file mode 100644 index 0000000..994d2ee --- /dev/null +++ b/ops/constant/constant_13.go @@ -0,0 +1,85 @@ +package constant + +import ( + "github.com/advancedclimatesystems/gonnx/onnx" + "github.com/advancedclimatesystems/gonnx/ops" + "gorgonia.org/tensor" +) + +// Constant13 represents the ONNX constant operator. +type Constant13 struct { + value tensor.Tensor +} + +// newConstant13 creates a new constant operator. +func NewConstant13() ops.Operator { + return &Constant13{} +} + +// Init initializes the constant operator. It supports all constant types except +// `sparse_value`, `value_string`, and `value_strings`. +func (c *Constant13) Init(n *onnx.NodeProto) error { + attributes := n.GetAttribute() + if len(attributes) != 1 { + return ops.ErrInvalidAttributeCount(1, len(attributes), c) + } + + attr := attributes[0] + + switch attr.GetName() { + case "sparse_value", "value_string", "value_strings": + return ops.ErrUnsupportedAttribute(attr.GetName(), c) + case "value": + t, err := onnx.TensorFromProto(attr.GetT()) + if err != nil { + return err + } + + c.value = t + case "value_float": + c.value = tensor.New(tensor.FromScalar(attr.GetF())) + case "value_floats": + floats := attr.GetFloats() + c.value = tensor.New(tensor.WithShape(len(floats)), tensor.WithBacking(floats)) + case "value_int": + c.value = tensor.New(tensor.FromScalar(attr.GetI())) + case "value_ints": + ints := attr.GetInts() + c.value = tensor.New(tensor.WithShape(len(ints)), tensor.WithBacking(ints)) + default: + return ops.ErrUnsupportedAttribute(attr.GetName(), c) + } + + return nil +} + +// Apply applies the constant operator. +func (c *Constant13) Apply(_ []tensor.Tensor) ([]tensor.Tensor, error) { + return []tensor.Tensor{c.value}, nil +} + +// ValidateInputs validates the inputs that will be given to Apply for this operator. +func (c *Constant13) ValidateInputs(inputs []tensor.Tensor) ([]tensor.Tensor, error) { + return ops.ValidateInputs(c, inputs) +} + +// GetMinInputs returns the minimum number of input tensors this operator expects. +func (c *Constant13) GetMinInputs() int { + return 0 +} + +// GetMaxInputs returns the maximum number of input tensors this operator expects. +func (c *Constant13) GetMaxInputs() int { + return 0 +} + +// GetInputTypeConstraints returns a list. Every element represents a set of allowed tensor dtypes +// for the corresponding input tensor. +func (c *Constant13) GetInputTypeConstraints() [][]tensor.Dtype { + return [][]tensor.Dtype{} +} + +// String implements the stringer interface, and can be used to format errors or messages. +func (c *Constant13) String() string { + return "constant13 operator" +} diff --git a/ops/opset13/constant_test.go b/ops/constant/constant_13_test.go similarity index 66% rename from ops/opset13/constant_test.go rename to ops/constant/constant_13_test.go index ffebccf..b56c769 100644 --- a/ops/opset13/constant_test.go +++ b/ops/constant/constant_13_test.go @@ -1,4 +1,4 @@ -package opset13 +package constant import ( "encoding/binary" @@ -10,56 +10,56 @@ import ( "gorgonia.org/tensor" ) -func TestConstantInit(t *testing.T) { +func TestConstant13Init(t *testing.T) { tests := []struct { initAttr []*onnx.AttributeProto expected interface{} err error }{ { - ConstantValueAttrProtoFixture(), + Constant13ValueAttrProtoFixture(), tensor.New(tensor.WithBacking([]int64{1, 1, 1})), nil, }, { - ConstantValueFloatAttrProtoFixture(), + Constant13ValueFloatAttrProtoFixture(), tensor.New(tensor.FromScalar(float32(0.2))), nil, }, { - ConstantValueFloatsAttrProtoFixture(), + Constant13ValueFloatsAttrProtoFixture(), tensor.New(tensor.WithBacking([]float32{0.1, 0.2})), nil, }, { - ConstantValueIntAttrProtoFixture(), + Constant13ValueIntAttrProtoFixture(), tensor.New(tensor.FromScalar(int64(1))), nil, }, { - ConstantValueIntsAttrProtoFixture(), + Constant13ValueIntsAttrProtoFixture(), tensor.New(tensor.WithBacking([]int64{1, 2, 3})), nil, }, { []*onnx.AttributeProto{{Name: "sparse_value"}}, nil, - ops.ErrUnsupportedAttribute("sparse_value", &Constant{}), + ops.ErrUnsupportedAttribute("sparse_value", &Constant13{}), }, { []*onnx.AttributeProto{{Name: "unknownAttribute"}}, nil, - ops.ErrUnsupportedAttribute("unknownAttribute", &Constant{}), + ops.ErrUnsupportedAttribute("unknownAttribute", &Constant13{}), }, { []*onnx.AttributeProto{}, nil, - ops.ErrInvalidAttributeCount(1, 0, &Constant{}), + ops.ErrInvalidAttributeCount(1, 0, &Constant13{}), }, } for _, test := range tests { - constant := &Constant{} + constant := &Constant13{} err := constant.Init(&onnx.NodeProto{Attribute: test.initAttr}) assert.Equal(t, test.err, err) @@ -70,35 +70,35 @@ func TestConstantInit(t *testing.T) { } } -func TestConstant(t *testing.T) { +func TestConstant13(t *testing.T) { tests := []struct { - constant *Constant + constant *Constant13 initAttr []*onnx.AttributeProto expected interface{} }{ { - &Constant{}, - ConstantValueAttrProtoFixture(), + &Constant13{}, + Constant13ValueAttrProtoFixture(), []int64{1, 1, 1}, }, { - &Constant{}, - ConstantValueFloatAttrProtoFixture(), + &Constant13{}, + Constant13ValueFloatAttrProtoFixture(), float32(0.2), }, { - &Constant{}, - ConstantValueFloatsAttrProtoFixture(), + &Constant13{}, + Constant13ValueFloatsAttrProtoFixture(), []float32{0.1, 0.2}, }, { - &Constant{}, - ConstantValueIntAttrProtoFixture(), + &Constant13{}, + Constant13ValueIntAttrProtoFixture(), int64(1), }, { - &Constant{}, - ConstantValueIntsAttrProtoFixture(), + &Constant13{}, + Constant13ValueIntsAttrProtoFixture(), []int64{1, 2, 3}, }, } @@ -112,15 +112,15 @@ func TestConstant(t *testing.T) { } } -func TestConstantSingleIntShapeTensor(t *testing.T) { - constant := &Constant{} +func TestConstant13SingleIntShapeTensor(t *testing.T) { + constant := &Constant13{} err := constant.Init(&onnx.NodeProto{Attribute: []*onnx.AttributeProto{{Name: "value_ints", Ints: []int64{2}}}}) assert.Nil(t, err) assert.False(t, constant.value.IsScalar()) } -func TestInputValidationConstant(t *testing.T) { +func TestInputValidationConstant13(t *testing.T) { tests := []struct { inputs []tensor.Tensor err error @@ -133,12 +133,12 @@ func TestInputValidationConstant(t *testing.T) { []tensor.Tensor{ ops.TensorWithBackingFixture([]int{1, 2}, 2), }, - ops.ErrInvalidInputCount(1, &Constant{}), + ops.ErrInvalidInputCount(1, &Constant13{}), }, } for _, test := range tests { - constant := &Constant{} + constant := &Constant13{} validated, err := constant.ValidateInputs(test.inputs) assert.Equal(t, test.err, err) @@ -149,7 +149,7 @@ func TestInputValidationConstant(t *testing.T) { } } -func ConstantValueAttrProtoFixture() []*onnx.AttributeProto { +func Constant13ValueAttrProtoFixture() []*onnx.AttributeProto { values := []int64{1, 1, 1} bValues := make([]byte, 24) @@ -162,18 +162,18 @@ func ConstantValueAttrProtoFixture() []*onnx.AttributeProto { return []*onnx.AttributeProto{{Name: "value", T: tp}} } -func ConstantValueFloatAttrProtoFixture() []*onnx.AttributeProto { +func Constant13ValueFloatAttrProtoFixture() []*onnx.AttributeProto { return []*onnx.AttributeProto{{Name: "value_float", F: float32(0.2)}} } -func ConstantValueFloatsAttrProtoFixture() []*onnx.AttributeProto { +func Constant13ValueFloatsAttrProtoFixture() []*onnx.AttributeProto { return []*onnx.AttributeProto{{Name: "value_floats", Floats: []float32{0.1, 0.2}}} } -func ConstantValueIntAttrProtoFixture() []*onnx.AttributeProto { +func Constant13ValueIntAttrProtoFixture() []*onnx.AttributeProto { return []*onnx.AttributeProto{{Name: "value_int", I: int64(1)}} } -func ConstantValueIntsAttrProtoFixture() []*onnx.AttributeProto { +func Constant13ValueIntsAttrProtoFixture() []*onnx.AttributeProto { return []*onnx.AttributeProto{{Name: "value_ints", Ints: []int64{1, 2, 3}}} } diff --git a/ops/constant/constant_9.go b/ops/constant/constant_9.go new file mode 100644 index 0000000..43be99f --- /dev/null +++ b/ops/constant/constant_9.go @@ -0,0 +1,72 @@ +package constant + +import ( + "github.com/advancedclimatesystems/gonnx/onnx" + "github.com/advancedclimatesystems/gonnx/ops" + "gorgonia.org/tensor" +) + +// Constant9 represents the ONNX constant operator. +type Constant9 struct { + value tensor.Tensor +} + +// newConstant9 creates a new constant operator. +func NewConstant9() ops.Operator { + return &Constant9{} +} + +// Init initializes the constant operator. +func (c *Constant9) Init(n *onnx.NodeProto) error { + attributes := n.GetAttribute() + if len(attributes) != 1 { + return ops.ErrInvalidAttributeCount(1, len(attributes), c) + } + + attr := attributes[0] + + switch attr.GetName() { + case "value": + t, err := onnx.TensorFromProto(attr.GetT()) + if err != nil { + return err + } + + c.value = t + default: + return ops.ErrUnsupportedAttribute(attr.GetName(), c) + } + + return nil +} + +// Apply applies the constant operator. +func (c *Constant9) Apply(_ []tensor.Tensor) ([]tensor.Tensor, error) { + return []tensor.Tensor{c.value}, nil +} + +// ValidateInputs validates the inputs that will be given to Apply for this operator. +func (c *Constant9) ValidateInputs(inputs []tensor.Tensor) ([]tensor.Tensor, error) { + return ops.ValidateInputs(c, inputs) +} + +// GetMinInputs returns the minimum number of input tensors this operator expects. +func (c *Constant9) GetMinInputs() int { + return 0 +} + +// GetMaxInputs returns the maximum number of input tensors this operator expects. +func (c *Constant9) GetMaxInputs() int { + return 0 +} + +// GetInputTypeConstraints returns a list. Every element represents a set of allowed tensor dtypes +// for the corresponding input tensor. +func (c *Constant9) GetInputTypeConstraints() [][]tensor.Dtype { + return [][]tensor.Dtype{} +} + +// String implements the stringer interface, and can be used to format errors or messages. +func (c *Constant9) String() string { + return "constant9 operator" +} diff --git a/ops/opset13/constant_of_shape.go b/ops/constantofshape/constant_of_shape_9.go similarity index 70% rename from ops/opset13/constant_of_shape.go rename to ops/constantofshape/constant_of_shape_9.go index 9511108..6aee059 100644 --- a/ops/opset13/constant_of_shape.go +++ b/ops/constantofshape/constant_of_shape_9.go @@ -1,4 +1,4 @@ -package opset13 +package constantofshape import ( "github.com/advancedclimatesystems/gonnx/onnx" @@ -7,24 +7,24 @@ import ( ) const ( - MinConstantOfShapeInputs = 1 - MaxConstantOfShapeInputs = 1 + MinConstantOfShape9Inputs = 1 + MaxConstantOfShape9Inputs = 1 ) -// ConstantOfShape represents the ONNX constant of shape operator. -type ConstantOfShape struct { +// ConstantOfShape9 represents the ONNX constant of shape operator. +type ConstantOfShape9 struct { // One element tensor, giving the value and type of the output tensor // defaults to value 0 and type float32. value *tensor.Dense } -// newConstantOfShape creates a new constant of shape operator. -func newConstantOfShape() ops.Operator { - return &ConstantOfShape{} +// newConstantOfShape9 creates a new constant of shape operator. +func NewConstantOfShape9() ops.Operator { + return &ConstantOfShape9{} } // Init initializes the constant of shape operator. -func (c *ConstantOfShape) Init(n *onnx.NodeProto) error { +func (c *ConstantOfShape9) Init(n *onnx.NodeProto) error { attributes := n.GetAttribute() if len(attributes) > 1 { @@ -54,7 +54,7 @@ func (c *ConstantOfShape) Init(n *onnx.NodeProto) error { } // Apply applies the constant of shape operator. -func (c *ConstantOfShape) Apply(inputs []tensor.Tensor) ([]tensor.Tensor, error) { +func (c *ConstantOfShape9) Apply(inputs []tensor.Tensor) ([]tensor.Tensor, error) { shape, err := ops.AnyToIntSlice(ops.IfScalarToSlice(inputs[0].Data())) if err != nil { return nil, err @@ -78,29 +78,29 @@ func (c *ConstantOfShape) Apply(inputs []tensor.Tensor) ([]tensor.Tensor, error) } // ValidateInputs validates the inputs that will be given to Apply for this operator. -func (c *ConstantOfShape) ValidateInputs(inputs []tensor.Tensor) ([]tensor.Tensor, error) { +func (c *ConstantOfShape9) ValidateInputs(inputs []tensor.Tensor) ([]tensor.Tensor, error) { return ops.ValidateInputs(c, inputs) } // GetMinInputs returns the minimum number of input tensors this operator expects. -func (c *ConstantOfShape) GetMinInputs() int { - return MinConstantOfShapeInputs +func (c *ConstantOfShape9) GetMinInputs() int { + return MinConstantOfShape9Inputs } // GetMaxInputs returns the maximum number of input tensors this operator expects. -func (c *ConstantOfShape) GetMaxInputs() int { - return MaxConstantOfShapeInputs +func (c *ConstantOfShape9) GetMaxInputs() int { + return MaxConstantOfShape9Inputs } // GetInputTypeConstraints returns a list. Every element represents a set of allowed tensor dtypes // for the corresponding input tensor. -func (c *ConstantOfShape) GetInputTypeConstraints() [][]tensor.Dtype { +func (c *ConstantOfShape9) GetInputTypeConstraints() [][]tensor.Dtype { return [][]tensor.Dtype{ {tensor.Int64}, } } // String implements the stringer interface, and can be used to format errors or messages. -func (c *ConstantOfShape) String() string { - return "constant of shape operator" +func (c *ConstantOfShape9) String() string { + return "constantofshape9 operator" } diff --git a/ops/opset13/constant_of_shape_test.go b/ops/constantofshape/constant_of_shape_9_test.go similarity index 90% rename from ops/opset13/constant_of_shape_test.go rename to ops/constantofshape/constant_of_shape_9_test.go index e294c25..31318c4 100644 --- a/ops/opset13/constant_of_shape_test.go +++ b/ops/constantofshape/constant_of_shape_9_test.go @@ -1,4 +1,4 @@ -package opset13 +package constantofshape import ( "encoding/binary" @@ -64,7 +64,7 @@ func TensorProtoFromNumber(n interface{}) *onnx.TensorProto { } } -func TestConstantOfShape(t *testing.T) { +func TestConstantOfShape9(t *testing.T) { // Test cases, verifying that all these types work. // Unfortunately uint* and bool are not supported. tests := []struct { @@ -90,7 +90,7 @@ func TestConstantOfShape(t *testing.T) { node := &onnx.NodeProto{Attribute: []*onnx.AttributeProto{{Name: "value", T: tp}}} // Create operator - op := ConstantOfShape{} + op := ConstantOfShape9{} err := op.Init(node) assert.NoError(t, err) assert.Equal(t, test.input, op.value.Data()) @@ -106,8 +106,8 @@ func TestConstantOfShape(t *testing.T) { } } -func TestConstantOfShapeEmptyInit(t *testing.T) { - op := &ConstantOfShape{} +func TestConstantOfShape9EmptyInit(t *testing.T) { + op := &ConstantOfShape9{} // No init value given err := op.Init(ops.EmptyNodeProto()) @@ -132,7 +132,7 @@ func TestIncorrectInput(t *testing.T) { } node := &onnx.NodeProto{Attribute: []*onnx.AttributeProto{{Name: "value", T: tp}}} - op := &ConstantOfShape{} + op := &ConstantOfShape9{} err := op.Init(node) assert.NotNil(t, err) assert.Equal( @@ -143,7 +143,7 @@ func TestIncorrectInput(t *testing.T) { } func TestNegativeShapeNotAllowed(t *testing.T) { - op := &ConstantOfShape{} + op := &ConstantOfShape9{} _ = op.Init(ops.EmptyNodeProto()) shape := []int64{1, -1} @@ -159,7 +159,7 @@ func TestNegativeShapeNotAllowed(t *testing.T) { } func TestEmptyTensorNotAllowed(t *testing.T) { - op := &ConstantOfShape{} + op := &ConstantOfShape9{} _ = op.Init(ops.EmptyNodeProto()) shape := []int64{0} @@ -175,7 +175,7 @@ func TestEmptyTensorNotAllowed(t *testing.T) { } func TestScalarShapeInput(t *testing.T) { - op := &ConstantOfShape{} + op := &ConstantOfShape9{} _ = op.Init(ops.EmptyNodeProto()) shape := []int64{6} @@ -187,7 +187,7 @@ func TestScalarShapeInput(t *testing.T) { assert.Equal(t, []float32{0, 0, 0, 0, 0, 0}, res[0].Data()) } -func TestInputValidationConstantOfShape(t *testing.T) { +func TestInputValidationConstantOfShape9(t *testing.T) { tests := []struct { inputs []tensor.Tensor err error @@ -200,16 +200,16 @@ func TestInputValidationConstantOfShape(t *testing.T) { }, { []tensor.Tensor{}, - ops.ErrInvalidInputCount(0, &ConstantOfShape{}), + ops.ErrInvalidInputCount(0, &ConstantOfShape9{}), }, { []tensor.Tensor{ops.TensorWithBackingFixture([]int{1, 2}, 2)}, - ops.ErrInvalidInputType(0, "int", &ConstantOfShape{}), + ops.ErrInvalidInputType(0, "int", &ConstantOfShape9{}), }, } for _, test := range tests { - constantOfShape := &ConstantOfShape{} + constantOfShape := &ConstantOfShape9{} validated, err := constantOfShape.ValidateInputs(test.inputs) assert.Equal(t, test.err, err) diff --git a/ops/conv/conv_1.go b/ops/conv/conv_1.go new file mode 100644 index 0000000..fe7bef4 --- /dev/null +++ b/ops/conv/conv_1.go @@ -0,0 +1,575 @@ +package conv + +import ( + "github.com/advancedclimatesystems/gonnx/onnx" + "github.com/advancedclimatesystems/gonnx/ops" + "gorgonia.org/tensor" +) + +var ( + MinConv1Inputs = 2 + MaxConv1Inputs = 3 + NDims1DConv1olution = 3 + NDims2DConv1olution = 4 +) + +// Conv1 represents the ONNX conv operator. +type Conv1 struct { + autoPad AutoPadSetting + dilations []int + group int + kernelShape []int + pads []int + strides []int +} + +// newConv1 creates a new conv operator. +func NewConv1() ops.Operator { + return &Conv1{ + autoPad: NotSet, + } +} + +// Init initializes the conv operator. +func (c *Conv1) Init(n *onnx.NodeProto) error { + var err error + + for _, attr := range n.GetAttribute() { + switch attr.GetName() { + case "auto_pad": + c.autoPad = AutoPadSetting(attr.GetS()) + case "dilations": + c.dilations, err = ops.AnyToIntSlice(attr.GetInts()) + if err != nil { + return ops.ErrInvalidAttribute(attr.GetName(), c) + } + case "group": + c.group = int(attr.GetI()) + if c.group != 1 { + return ops.ErrUnsupportedAttribute(attr.GetName(), c) + } + case "kernel_shape": + c.kernelShape, err = ops.AnyToIntSlice(attr.GetInts()) + if err != nil { + return ops.ErrInvalidAttribute(attr.GetName(), c) + } + case "pads": + c.pads, err = ops.AnyToIntSlice(attr.GetInts()) + if err != nil { + return ops.ErrInvalidAttribute(attr.GetName(), c) + } + case "strides": + c.strides, err = ops.AnyToIntSlice(attr.GetInts()) + if err != nil { + return ops.ErrInvalidAttribute(attr.GetName(), c) + } + default: + return ops.ErrUnsupportedAttribute(attr.GetName(), c) + } + } + + return nil +} + +// Apply applies the conv operator. +func (c *Conv1) Apply(inputs []tensor.Tensor) ([]tensor.Tensor, error) { + x := inputs[0] + kernel := inputs[1] + bias := inputs[2] + + if len(c.dilations) == 0 { + c.setDefaultDilations(x) + } + + if len(c.kernelShape) == 0 { + c.setKernelShape(kernel) + } + + if len(c.pads) == 0 { + c.setDefaultPaddings(x) + } + + if len(c.strides) == 0 { + c.setDefaultStrides(x) + } + + kernel, err := c.getDilatedKernel(kernel) + if err != nil { + return nil, err + } + + if c.autoPad != NotSet { + c.setPaddingWithAutoPad(x) + } + + var out tensor.Tensor + + switch len(x.Shape()) { + case NDims1DConv1olution: + out, err = c.applyConv11D(x, kernel) + case NDims2DConv1olution: + out, err = c.applyConv12D(x, kernel) + default: + return nil, ops.ErrInvalidInput("the convolution operator currently only supports 1D or 2D convolution, i.e. shape [N x C x H (x W)]", c) + } + + if err != nil { + return nil, err + } + + if bias != nil { + out, err = c.addBias(out, bias) + if err != nil { + return nil, err + } + } + + return []tensor.Tensor{out}, nil +} + +// ValidateInputs validates the inputs that will be given to Apply for this operator. +func (c *Conv1) ValidateInputs(inputs []tensor.Tensor) ([]tensor.Tensor, error) { + return ops.ValidateInputs(c, inputs) +} + +// GetMinInputs returns the minimum number of input tensors this operator expects. +func (c *Conv1) GetMinInputs() int { + return MinConv1Inputs +} + +// GetMaxInputs returns the maximum number of input tensors this operator expects. +func (c *Conv1) GetMaxInputs() int { + return MaxConv1Inputs +} + +// GetInputTypeConstraints returns a list. Every element represents a set of allowed tensor dtypes +// for the corresponding input tensor. +func (c *Conv1) GetInputTypeConstraints() [][]tensor.Dtype { + return [][]tensor.Dtype{ + {tensor.Float32, tensor.Float64}, + {tensor.Float32, tensor.Float64}, + {tensor.Float32, tensor.Float64}, + } +} + +// String implements the stringer interface, and can be used to format errors or messages. +func (c *Conv1) String() string { + return "conv1 operator" +} + +// setDefaultDilations sets the dilations attribute to the default. Can be called when no +// dilations were set when initializing. +func (c *Conv1) setDefaultDilations(x tensor.Tensor) { + nDims := len(x.Shape()[2:]) + + dilations := make([]int, nDims) + for i := 0; i < nDims; i++ { + dilations[i] = 1 + } + + c.dilations = dilations +} + +// setKernelShape infers the shape of the kernel when it was not given in the attributes. +func (c *Conv1) setKernelShape(kernel tensor.Tensor) { + c.kernelShape = kernel.Shape()[2:] +} + +// setDefaultPaddings sets default paddings as attribute. Can be called when no paddings +// were set during initialization. +func (c *Conv1) setDefaultPaddings(x tensor.Tensor) { + NPadsPerDim := 2 + paddingLength := len(x.Shape()[2:]) * NPadsPerDim + + pads := make([]int, paddingLength) + for i := 0; i < paddingLength; i++ { + pads[i] = 0 + } + + c.pads = pads +} + +// setDefaultStrides sets default strides as attribute. Can be called when no strides +// were set during initialization. +func (c *Conv1) setDefaultStrides(x tensor.Tensor) { + nDims := len(x.Shape()[2:]) + + strides := make([]int, nDims) + for i := 0; i < nDims; i++ { + strides[i] = 1 + } + + c.strides = strides +} + +// setPaddingWithAutoPad sets the padding attribute of the operator based on +// the input tensor `x`, the shape of the kernel and the strides. +func (c *Conv1) setPaddingWithAutoPad(x tensor.Tensor) { + if c.autoPad == NotSet { + return + } + + NPadsPerDim := 2 + inputShape := x.Shape() + nDims := len(inputShape) + nSpatialDims := nDims - nNonSpatialDims + + c.pads = make([]int, nSpatialDims*NPadsPerDim) + + for i := 0; i < nSpatialDims; i++ { + dim := inputShape[i] + targetSize := (dim + c.strides[i] - 1) / c.strides[i] + padNeeded := (targetSize-1)*c.strides[i] + c.kernelShape[i] - dim + + var padHead int + if c.autoPad == SameLower { + // nolint as the division by zero is literally division by two + padHead = (padNeeded + 1) / 2 + } else { + // nolint as the division by two is literally division by two + padHead = padNeeded / 2 + } + + padTail := padNeeded - padHead + c.pads[i] = padHead + c.pads[i+nSpatialDims] = padTail + } +} + +// getDilatedKernel creates a new kernel given the `dilations` attribute of this +// conv operator. A dilated kernel basically means inserting zeros in between +// the kernels, i.e. a 2D kernel like: +// +// 1 2 +// 3 4 +// +// Dilated by one in both dimensions yields a new kernel of: +// +// 1 0 2 +// 0 0 0 +// 3 0 4 +// +// This function updates the given kernel and dilates it by the given amount +// for each dimensions separately. It returns a new tensor with the new kernel. +func (c *Conv1) getDilatedKernel(kernel tensor.Tensor) (tensor.Tensor, error) { + oldKernelShape := kernel.Shape() + newKernelShape := make([]int, len(oldKernelShape)) + + // Add the non spatial dimensions of the kernel, i.e. the number of + // kernels (index 0) and the number of channels (index 1). These + // dimensions do not have to be dilated. + for i := 0; i < nNonSpatialDims; i++ { + newKernelShape[i] = oldKernelShape[i] + } + + // Add the dilated spatial dimensions of the kernel, i.e. in the case + // of 2D images these are the width and height dimensions. + for i, dilation := range c.dilations { + oldKernelDim := oldKernelShape[nNonSpatialDims+i] + newKernelShape[nNonSpatialDims+i] = oldKernelDim + (oldKernelDim-1)*(dilation-1) + } + + newKernel := tensor.NewDense(kernel.Dtype(), newKernelShape) + newKernel.Zero() + + // Now we fill the empty kernel with the original kernel values at the + // right positions. + iterator := kernel.Iterator() + iterator.Reset() + + for !iterator.Done() { + oldCoords := iterator.Coord() + + value, err := kernel.At(oldCoords...) + if err != nil { + return nil, err + } + + newCoords := c.getNewCoordsAfterDilation(oldCoords) + + err = newKernel.SetAt(value, newCoords...) + if err != nil { + return nil, err + } + + _, err = iterator.Next() + if err != nil { + return nil, err + } + } + + c.setKernelShape(newKernel) + + return newKernel, nil +} + +// getNewCoordsAfterDilation returns the new coordinates of a value given the old coordinates of that +// value in the old kernel and its shape. The new coordinates can be used to store the value/weight +// in the dilated kernel. +func (c *Conv1) getNewCoordsAfterDilation(oldCoords []int) []int { + newCoords := make([]int, len(oldCoords)) + + for i := 0; i < nNonSpatialDims; i++ { + newCoords[i] = oldCoords[i] + } + + for i, dilation := range c.dilations { + newCoords[nNonSpatialDims+i] = oldCoords[nNonSpatialDims+i] * dilation + } + + return newCoords +} + +// Applies 1D convolution to tensor X with the 'kernel' tensor. +// X will have 3 dimensions: [N, C, H] where N is the batch size, C is the number +// of channels and H is the number of dimensions on which to apply the convolutions. +// The kernel will have shape [kernelDim], where 'kernelDim' is the size of the kernel +// size of the kernel. +func (c *Conv1) applyConv11D(x, kernel tensor.Tensor) (tensor.Tensor, error) { + outputShape := c.getOutputShape(x, kernel) + out := tensor.Tensor(tensor.NewDense(x.Dtype(), outputShape)) + out.Zero() + + paddedX, err := c.padInput(x) + if err != nil { + return nil, err + } + + nBatches := x.Shape()[0] + nKernels := kernel.Shape()[0] + strideSize := c.strides[0] + outputHDim := outputShape[nNonSpatialDims] + + for batchIdx := 0; batchIdx < nBatches; batchIdx++ { + for kernelIdx := 0; kernelIdx < nKernels; kernelIdx++ { + subKernelView, err := kernel.Slice(ops.NewSlicer(kernelIdx, kernelIdx+1)) + if err != nil { + return nil, err + } + + subKernel := subKernelView.Materialize() + + for h := 0; h < paddedX.Shape()[2]; h += strideSize { + dimHOutputIdx := h / strideSize + if dimHOutputIdx >= outputHDim { + continue + } + + subImage, err := c.getSubImage(paddedX, batchIdx, h) + if err != nil { + return nil, err + } + + subImage, subKernel, err = ops.UnidirectionalBroadcast(subImage, subKernel) + if err != nil { + return nil, err + } + + convResult, err := tensor.Mul(subImage, subKernel) + if err != nil { + return nil, err + } + + convValue, err := tensor.Sum(convResult) + if err != nil { + return nil, err + } + + err = out.SetAt(convValue.ScalarValue(), batchIdx, kernelIdx, dimHOutputIdx) + if err != nil { + return nil, err + } + } + } + } + + return out, nil +} + +// Applies 2D convolution to tensor X with the 'kernel' tensor. +// X will have 4 dimensions: [N, C, H, W] where N is the batch size, C is the number +// of channels, H and W are the height and width dimensions on which to apply the convolutions. +// The kernel will have shape [M, C, H, W]. +func (c *Conv1) applyConv12D(x, kernel tensor.Tensor) (tensor.Tensor, error) { + outputShape := c.getOutputShape(x, kernel) + out := tensor.Tensor(tensor.NewDense(x.Dtype(), outputShape)) + out.Zero() + + outputHDim := outputShape[nNonSpatialDims] + outputWDim := outputShape[nNonSpatialDims+1] + + paddedX, err := c.padInput(x) + if err != nil { + return nil, err + } + + nBatches := x.Shape()[0] + nKernels := kernel.Shape()[0] + + for batchIdx := 0; batchIdx < nBatches; batchIdx++ { + for kernelIdx := 0; kernelIdx < nKernels; kernelIdx++ { + subKernelView, err := kernel.Slice(ops.NewSlicer(kernelIdx, kernelIdx+1)) + if err != nil { + return nil, err + } + + subKernel := subKernelView.Materialize() + + // Loop over all 2D subImages of the input image and compute the convolution + // for that subImage. Store the result at the right place in the output tensor. + for h := 0; h < paddedX.Shape()[2]; h += c.strides[0] { + dimHOutputIdx := h / c.strides[0] + if dimHOutputIdx >= outputHDim { + continue + } + + for w := 0; w < paddedX.Shape()[2]; w += c.strides[1] { + dimWOutputIdx := w / c.strides[1] + if dimWOutputIdx >= outputWDim { + continue + } + + subImage, err := c.getSubImage(paddedX, batchIdx, h, w) + if err != nil { + return nil, err + } + + subImage, subKernel, err = ops.UnidirectionalBroadcast(subImage, subKernel) + if err != nil { + return nil, err + } + + convResult, err := tensor.Mul(subImage, subKernel) + if err != nil { + return nil, err + } + + convValue, err := tensor.Sum(convResult) + if err != nil { + return nil, err + } + + err = out.SetAt(convValue.ScalarValue(), batchIdx, kernelIdx, dimHOutputIdx, dimWOutputIdx) + if err != nil { + return nil, err + } + } + } + } + } + + return out, nil +} + +// getOutputShape calculates the shape of the output tensor resulting from +// the convolution operation between `x` and `kernel`. +// `x` has shape [N, C, H, W, ...] and `kernel` has shape [M, C, H, W, ...]. +// The output shape will be [N, M, newH, newW, ...], where values like `newH` +// are calculated based on the input shape, kernel size, padding and strides. +func (c *Conv1) getOutputShape(x, kernel tensor.Tensor) tensor.Shape { + outputShape := make([]int, len(x.Shape())) + + outputShape[0] = x.Shape()[0] + outputShape[1] = kernel.Shape()[0] + + nSpatialDims := len(x.Shape()) - nNonSpatialDims + for i := 0; i < nSpatialDims; i++ { + inputDim := x.Shape()[nNonSpatialDims+i] + kernelDim := c.kernelShape[i] + outputShape[nNonSpatialDims+i] = ((inputDim - kernelDim + c.pads[i] + c.pads[i+nSpatialDims]) / c.strides[i]) + 1 + } + + return outputShape +} + +// padInput pads the input with zeros according to the `pads` attribute. +// The pad attribute specifies how many zeros should be added before and +// after the values in that specific dimension. +// Please note that according to ONNX specs, the `pads` attributes is an +// array with pads as [x1_begin, x2_begin, ..., x1_after, x2_after]. +// This method achieves padding by concatting tensors with zero values +// before and after each spatial dimension of the input tensor `x`. +func (c *Conv1) padInput(x tensor.Tensor) (tensor.Tensor, error) { + var err error + + nSpatialDims := len(x.Shape()[nNonSpatialDims:]) + + for i := 0; i < nSpatialDims; i++ { + if c.pads[i] != 0 { + padsBeforeShape := x.Shape().Clone() + padsBeforeShape[nNonSpatialDims+i] = c.pads[i] + zerosBefore := tensor.Tensor(tensor.NewDense(x.Dtype(), padsBeforeShape)) + zerosBefore.Zero() + + x, err = tensor.Concat(nNonSpatialDims+i, zerosBefore, x) + if err != nil { + return nil, err + } + } + + if c.pads[i+nSpatialDims] != 0 { + padsAfterShape := x.Shape().Clone() + padsAfterShape[nNonSpatialDims+i] = c.pads[i+nSpatialDims] + zerosAfter := tensor.Tensor(tensor.NewDense(x.Dtype(), padsAfterShape)) + zerosAfter.Zero() + + x, err = tensor.Concat(nNonSpatialDims+i, x, zerosAfter) + if err != nil { + return nil, err + } + } + } + + return x, nil +} + +// getSubImage returns a the subimage for a specific example in the batch, based on the +// kernel shape and the given start coordinates. The resulting sub image will be of +// shape [C, kernelShape[0], kernelShape[1], ...]. +func (c *Conv1) getSubImage(x tensor.Tensor, batchIdx int, startSpatialCoords ...int) (tensor.Tensor, error) { + if len(startSpatialCoords) != len(c.kernelShape) { + return nil, ops.ErrDimension("expected the coordinates to have the same number of dimensions as the kernel") + } + + slices := []tensor.Slice{ + ops.NewSlicer(batchIdx, batchIdx+1), + nil, // Take all channels at once. + } + + for i := 0; i < len(c.kernelShape); i++ { + dimStartIdx := startSpatialCoords[i] + dimKernelSize := c.kernelShape[i] + slices = append(slices, ops.NewSlicer(dimStartIdx, dimStartIdx+dimKernelSize)) + } + + subImage, err := x.Slice(slices...) + if err != nil { + return nil, err + } + + return subImage.Materialize(), nil +} + +// addBias adds a bias to the output of the convolution. It reshapes the +// bias such that it can be broadcasted, and then is added to the output +// tensor. +func (c *Conv1) addBias(out, bias tensor.Tensor) (tensor.Tensor, error) { + biasShape := make([]int, len(out.Shape())) + for i := 0; i < len(out.Shape()); i++ { + biasShape[i] = 1 + } + + biasShape[1] = bias.Shape()[0] + + err := bias.Reshape(biasShape...) + if err != nil { + return nil, err + } + + out, bias, err = ops.UnidirectionalBroadcast(out, bias) + if err != nil { + return nil, err + } + + return tensor.Add(out, bias) +} diff --git a/ops/opset13/conv.go b/ops/conv/conv_11.go similarity index 89% rename from ops/opset13/conv.go rename to ops/conv/conv_11.go index 801a5e9..bf740fc 100644 --- a/ops/opset13/conv.go +++ b/ops/conv/conv_11.go @@ -1,4 +1,4 @@ -package opset13 +package conv import ( "github.com/advancedclimatesystems/gonnx/onnx" @@ -7,10 +7,10 @@ import ( ) var ( - MinConvInputs = 2 - MaxConvInputs = 3 - NDims1DConvolution = 3 - NDims2DConvolution = 4 + MinConv11Inputs = 2 + MaxConv11Inputs = 3 + NDims1DConv11olution = 3 + NDims2DConv11olution = 4 ) type AutoPadSetting string @@ -28,8 +28,8 @@ const ( // For all tensors, the second dimension will be the number of channels. const nNonSpatialDims = 2 -// Conv represents the ONNX conv operator. -type Conv struct { +// Conv11 represents the ONNX conv operator. +type Conv11 struct { autoPad AutoPadSetting dilations []int group int @@ -38,15 +38,15 @@ type Conv struct { strides []int } -// newConv creates a new conv operator. -func newConv() ops.Operator { - return &Conv{ +// newConv11 creates a new conv operator. +func NewConv11() ops.Operator { + return &Conv11{ autoPad: NotSet, } } // Init initializes the conv operator. -func (c *Conv) Init(n *onnx.NodeProto) error { +func (c *Conv11) Init(n *onnx.NodeProto) error { var err error for _, attr := range n.GetAttribute() { @@ -87,7 +87,7 @@ func (c *Conv) Init(n *onnx.NodeProto) error { } // Apply applies the conv operator. -func (c *Conv) Apply(inputs []tensor.Tensor) ([]tensor.Tensor, error) { +func (c *Conv11) Apply(inputs []tensor.Tensor) ([]tensor.Tensor, error) { x := inputs[0] kernel := inputs[1] bias := inputs[2] @@ -120,10 +120,10 @@ func (c *Conv) Apply(inputs []tensor.Tensor) ([]tensor.Tensor, error) { var out tensor.Tensor switch len(x.Shape()) { - case NDims1DConvolution: - out, err = c.applyConv1D(x, kernel) - case NDims2DConvolution: - out, err = c.applyConv2D(x, kernel) + case NDims1DConv11olution: + out, err = c.applyConv111D(x, kernel) + case NDims2DConv11olution: + out, err = c.applyConv112D(x, kernel) default: return nil, ops.ErrInvalidInput("the convolution operator currently only supports 1D or 2D convolution, i.e. shape [N x C x H (x W)]", c) } @@ -143,23 +143,23 @@ func (c *Conv) Apply(inputs []tensor.Tensor) ([]tensor.Tensor, error) { } // ValidateInputs validates the inputs that will be given to Apply for this operator. -func (c *Conv) ValidateInputs(inputs []tensor.Tensor) ([]tensor.Tensor, error) { +func (c *Conv11) ValidateInputs(inputs []tensor.Tensor) ([]tensor.Tensor, error) { return ops.ValidateInputs(c, inputs) } // GetMinInputs returns the minimum number of input tensors this operator expects. -func (c *Conv) GetMinInputs() int { - return MinConvInputs +func (c *Conv11) GetMinInputs() int { + return MinConv11Inputs } // GetMaxInputs returns the maximum number of input tensors this operator expects. -func (c *Conv) GetMaxInputs() int { - return MaxConvInputs +func (c *Conv11) GetMaxInputs() int { + return MaxConv11Inputs } // GetInputTypeConstraints returns a list. Every element represents a set of allowed tensor dtypes // for the corresponding input tensor. -func (c *Conv) GetInputTypeConstraints() [][]tensor.Dtype { +func (c *Conv11) GetInputTypeConstraints() [][]tensor.Dtype { return [][]tensor.Dtype{ {tensor.Float32, tensor.Float64}, {tensor.Float32, tensor.Float64}, @@ -168,13 +168,13 @@ func (c *Conv) GetInputTypeConstraints() [][]tensor.Dtype { } // String implements the stringer interface, and can be used to format errors or messages. -func (c *Conv) String() string { - return "conv operator" +func (c *Conv11) String() string { + return "conv11 operator" } // setDefaultDilations sets the dilations attribute to the default. Can be called when no // dilations were set when initializing. -func (c *Conv) setDefaultDilations(x tensor.Tensor) { +func (c *Conv11) setDefaultDilations(x tensor.Tensor) { nDims := len(x.Shape()[2:]) dilations := make([]int, nDims) @@ -186,13 +186,13 @@ func (c *Conv) setDefaultDilations(x tensor.Tensor) { } // setKernelShape infers the shape of the kernel when it was not given in the attributes. -func (c *Conv) setKernelShape(kernel tensor.Tensor) { +func (c *Conv11) setKernelShape(kernel tensor.Tensor) { c.kernelShape = kernel.Shape()[2:] } // setDefaultPaddings sets default paddings as attribute. Can be called when no paddings // were set during initialization. -func (c *Conv) setDefaultPaddings(x tensor.Tensor) { +func (c *Conv11) setDefaultPaddings(x tensor.Tensor) { NPadsPerDim := 2 paddingLength := len(x.Shape()[2:]) * NPadsPerDim @@ -206,7 +206,7 @@ func (c *Conv) setDefaultPaddings(x tensor.Tensor) { // setDefaultStrides sets default strides as attribute. Can be called when no strides // were set during initialization. -func (c *Conv) setDefaultStrides(x tensor.Tensor) { +func (c *Conv11) setDefaultStrides(x tensor.Tensor) { nDims := len(x.Shape()[2:]) strides := make([]int, nDims) @@ -219,7 +219,7 @@ func (c *Conv) setDefaultStrides(x tensor.Tensor) { // setPaddingWithAutoPad sets the padding attribute of the operator based on // the input tensor `x`, the shape of the kernel and the strides. -func (c *Conv) setPaddingWithAutoPad(x tensor.Tensor) { +func (c *Conv11) setPaddingWithAutoPad(x tensor.Tensor) { if c.autoPad == NotSet { return } @@ -266,7 +266,7 @@ func (c *Conv) setPaddingWithAutoPad(x tensor.Tensor) { // // This function updates the given kernel and dilates it by the given amount // for each dimensions separately. It returns a new tensor with the new kernel. -func (c *Conv) getDilatedKernel(kernel tensor.Tensor) (tensor.Tensor, error) { +func (c *Conv11) getDilatedKernel(kernel tensor.Tensor) (tensor.Tensor, error) { oldKernelShape := kernel.Shape() newKernelShape := make([]int, len(oldKernelShape)) @@ -321,7 +321,7 @@ func (c *Conv) getDilatedKernel(kernel tensor.Tensor) (tensor.Tensor, error) { // getNewCoordsAfterDilation returns the new coordinates of a value given the old coordinates of that // value in the old kernel and its shape. The new coordinates can be used to store the value/weight // in the dilated kernel. -func (c *Conv) getNewCoordsAfterDilation(oldCoords []int) []int { +func (c *Conv11) getNewCoordsAfterDilation(oldCoords []int) []int { newCoords := make([]int, len(oldCoords)) for i := 0; i < nNonSpatialDims; i++ { @@ -340,7 +340,7 @@ func (c *Conv) getNewCoordsAfterDilation(oldCoords []int) []int { // of channels and H is the number of dimensions on which to apply the convolutions. // The kernel will have shape [kernelDim], where 'kernelDim' is the size of the kernel // size of the kernel. -func (c *Conv) applyConv1D(x, kernel tensor.Tensor) (tensor.Tensor, error) { +func (c *Conv11) applyConv111D(x, kernel tensor.Tensor) (tensor.Tensor, error) { outputShape := c.getOutputShape(x, kernel) out := tensor.Tensor(tensor.NewDense(x.Dtype(), outputShape)) out.Zero() @@ -405,7 +405,7 @@ func (c *Conv) applyConv1D(x, kernel tensor.Tensor) (tensor.Tensor, error) { // X will have 4 dimensions: [N, C, H, W] where N is the batch size, C is the number // of channels, H and W are the height and width dimensions on which to apply the convolutions. // The kernel will have shape [M, C, H, W]. -func (c *Conv) applyConv2D(x, kernel tensor.Tensor) (tensor.Tensor, error) { +func (c *Conv11) applyConv112D(x, kernel tensor.Tensor) (tensor.Tensor, error) { outputShape := c.getOutputShape(x, kernel) out := tensor.Tensor(tensor.NewDense(x.Dtype(), outputShape)) out.Zero() @@ -481,7 +481,7 @@ func (c *Conv) applyConv2D(x, kernel tensor.Tensor) (tensor.Tensor, error) { // `x` has shape [N, C, H, W, ...] and `kernel` has shape [M, C, H, W, ...]. // The output shape will be [N, M, newH, newW, ...], where values like `newH` // are calculated based on the input shape, kernel size, padding and strides. -func (c *Conv) getOutputShape(x, kernel tensor.Tensor) tensor.Shape { +func (c *Conv11) getOutputShape(x, kernel tensor.Tensor) tensor.Shape { outputShape := make([]int, len(x.Shape())) outputShape[0] = x.Shape()[0] @@ -504,7 +504,7 @@ func (c *Conv) getOutputShape(x, kernel tensor.Tensor) tensor.Shape { // array with pads as [x1_begin, x2_begin, ..., x1_after, x2_after]. // This method achieves padding by concatting tensors with zero values // before and after each spatial dimension of the input tensor `x`. -func (c *Conv) padInput(x tensor.Tensor) (tensor.Tensor, error) { +func (c *Conv11) padInput(x tensor.Tensor) (tensor.Tensor, error) { var err error nSpatialDims := len(x.Shape()[nNonSpatialDims:]) @@ -541,7 +541,7 @@ func (c *Conv) padInput(x tensor.Tensor) (tensor.Tensor, error) { // getSubImage returns a the subimage for a specific example in the batch, based on the // kernel shape and the given start coordinates. The resulting sub image will be of // shape [C, kernelShape[0], kernelShape[1], ...]. -func (c *Conv) getSubImage(x tensor.Tensor, batchIdx int, startSpatialCoords ...int) (tensor.Tensor, error) { +func (c *Conv11) getSubImage(x tensor.Tensor, batchIdx int, startSpatialCoords ...int) (tensor.Tensor, error) { if len(startSpatialCoords) != len(c.kernelShape) { return nil, ops.ErrDimension("expected the coordinates to have the same number of dimensions as the kernel") } @@ -568,7 +568,7 @@ func (c *Conv) getSubImage(x tensor.Tensor, batchIdx int, startSpatialCoords ... // addBias adds a bias to the output of the convolution. It reshapes the // bias such that it can be broadcasted, and then is added to the output // tensor. -func (c *Conv) addBias(out, bias tensor.Tensor) (tensor.Tensor, error) { +func (c *Conv11) addBias(out, bias tensor.Tensor) (tensor.Tensor, error) { biasShape := make([]int, len(out.Shape())) for i := 0; i < len(out.Shape()); i++ { biasShape[i] = 1 diff --git a/ops/opset13/conv_test.go b/ops/conv/conv_11_test.go similarity index 91% rename from ops/opset13/conv_test.go rename to ops/conv/conv_11_test.go index 8da4b87..fe5eae4 100644 --- a/ops/opset13/conv_test.go +++ b/ops/conv/conv_11_test.go @@ -1,4 +1,4 @@ -package opset13 +package conv import ( "testing" @@ -9,9 +9,9 @@ import ( "gorgonia.org/tensor" ) -func TestConvInit(t *testing.T) { - c := &Conv{} - err := c.Init(Conv2DOnnxNodeProtoFixture()) +func TestConv11Init(t *testing.T) { + c := &Conv11{} + err := c.Init(Conv112DOnnxNodeProtoFixture()) assert.Nil(t, err) @@ -24,9 +24,9 @@ func TestConvInit(t *testing.T) { assert.Equal(t, []int{1, 1}, c.strides) } -func TestConvInitUnsupported(t *testing.T) { - c := &Conv{} - err := c.Init(ConvUnsupportedOnnxNodeProtoFixture()) +func TestConv11InitUnsupported(t *testing.T) { + c := &Conv11{} + err := c.Init(Conv11UnsupportedOnnxNodeProtoFixture()) assert.Equal( t, @@ -35,17 +35,17 @@ func TestConvInitUnsupported(t *testing.T) { ) } -func TestConv(t *testing.T) { +func TestConv11(t *testing.T) { tests := []struct { - conv *Conv + conv *Conv11 shapes [][]int backings [][]float32 expectedShape tensor.Shape expected []float32 }{ - // Test 1D Convolution. + // Test 1D Conv11olution. { - &Conv{ + &Conv11{ autoPad: "NOTSET", dilations: []int{}, group: 1, @@ -58,9 +58,9 @@ func TestConv(t *testing.T) { []int{1, 1, 4}, []float32{3, 6, 9, 12}, }, - // Test 2D Convolution. + // Test 2D Conv11olution. { - &Conv{ + &Conv11{ autoPad: "NOTSET", dilations: []int{}, group: 1, @@ -75,7 +75,7 @@ func TestConv(t *testing.T) { }, // Test SAME_LOWER autopad setting. { - &Conv{ + &Conv11{ autoPad: "SAME_LOWER", dilations: []int{}, group: 1, @@ -90,7 +90,7 @@ func TestConv(t *testing.T) { }, // Test SAME_UPPER autopad setting. { - &Conv{ + &Conv11{ autoPad: "SAME_UPPER", dilations: []int{}, group: 1, @@ -105,7 +105,7 @@ func TestConv(t *testing.T) { }, // Test VALID autopad setting. { - &Conv{ + &Conv11{ autoPad: "VALID", dilations: []int{}, group: 1, @@ -120,7 +120,7 @@ func TestConv(t *testing.T) { }, // Test dilation attribute. { - &Conv{ + &Conv11{ autoPad: "NOTSET", dilations: []int{2, 2}, group: 1, @@ -135,7 +135,7 @@ func TestConv(t *testing.T) { }, // Test pads attribute. { - &Conv{ + &Conv11{ autoPad: "NOTSET", dilations: []int{1, 1}, group: 1, @@ -150,7 +150,7 @@ func TestConv(t *testing.T) { }, // Test strides attribute. { - &Conv{ + &Conv11{ autoPad: "NOTSET", dilations: []int{}, group: 1, @@ -165,7 +165,7 @@ func TestConv(t *testing.T) { }, // Test batch dimension. { - &Conv{ + &Conv11{ autoPad: "NOTSET", dilations: []int{}, group: 1, @@ -180,7 +180,7 @@ func TestConv(t *testing.T) { }, // Test 2D convolution with multiple channels. { - &Conv{ + &Conv11{ autoPad: "NOTSET", dilations: []int{}, group: 1, @@ -195,7 +195,7 @@ func TestConv(t *testing.T) { }, // Test multiple kernels. { - &Conv{ + &Conv11{ autoPad: "NOTSET", dilations: []int{}, group: 1, @@ -210,7 +210,7 @@ func TestConv(t *testing.T) { }, // Test bias. { - &Conv{ + &Conv11{ autoPad: "NOTSET", dilations: []int{}, group: 1, @@ -244,7 +244,7 @@ func TestConv(t *testing.T) { } } -func TestInputValidationConv(t *testing.T) { +func TestInputValidationConv11(t *testing.T) { tests := []struct { inputs []tensor.Tensor err error @@ -269,19 +269,19 @@ func TestInputValidationConv(t *testing.T) { []tensor.Tensor{ ops.TensorWithBackingFixture([]int{1, 2}, 2), }, - ops.ErrInvalidOptionalInputCount(1, &Conv{}), + ops.ErrInvalidOptionalInputCount(1, &Conv11{}), }, { []tensor.Tensor{ ops.TensorWithBackingFixture([]int{1, 2}, 2), ops.TensorWithBackingFixture([]int{3, 4}, 2), }, - ops.ErrInvalidInputType(0, "int", &Conv{}), + ops.ErrInvalidInputType(0, "int", &Conv11{}), }, } for _, test := range tests { - conv := &Conv{} + conv := &Conv11{} validated, err := conv.ValidateInputs(test.inputs) assert.Equal(t, test.err, err) @@ -293,7 +293,7 @@ func TestInputValidationConv(t *testing.T) { } func TestSetDefaultDilations(t *testing.T) { - c := &Conv{} + c := &Conv11{} x := ops.TensorWithBackingFixture([]float32{0, 1, 2, 3, 4, 5, 6, 7, 8}, 1, 1, 3, 3) c.setDefaultDilations(x) @@ -302,7 +302,7 @@ func TestSetDefaultDilations(t *testing.T) { } func TestSetKernelShape(t *testing.T) { - c := &Conv{} + c := &Conv11{} kernel := ops.TensorWithBackingFixture([]float32{0, 1, 2, 3}, 1, 1, 2, 2) c.setKernelShape(kernel) @@ -311,7 +311,7 @@ func TestSetKernelShape(t *testing.T) { } func TestSetDefaultPaddings(t *testing.T) { - c := &Conv{} + c := &Conv11{} x := ops.TensorWithBackingFixture([]float32{0, 1, 2, 3, 4, 5, 6, 7, 8}, 1, 1, 3, 3) c.setDefaultPaddings(x) @@ -320,7 +320,7 @@ func TestSetDefaultPaddings(t *testing.T) { } func TestSetDefaultStrides(t *testing.T) { - c := &Conv{} + c := &Conv11{} x := ops.TensorWithBackingFixture([]float32{0, 1, 2, 3, 4, 5, 6, 7, 8}, 1, 1, 3, 3) c.setDefaultStrides(x) @@ -342,7 +342,7 @@ func TestSetPaddingWithAutoPad(t *testing.T) { } for _, test := range tests { - conv := &Conv{ + conv := &Conv11{ autoPad: test.setting, pads: []int{0, 0, 0, 0}, kernelShape: []int{2, 2}, @@ -407,7 +407,7 @@ func TestGetDilatedKernel(t *testing.T) { } for _, test := range tests { - conv := &Conv{ + conv := &Conv11{ dilations: test.dilations, kernelShape: []int{2, 2}, } @@ -423,7 +423,7 @@ func TestGetDilatedKernel(t *testing.T) { func TestGetOutputShape(t *testing.T) { tests := []struct { - conv *Conv + conv *Conv11 xShape []int xBacking []float32 kernelShape []int @@ -431,7 +431,7 @@ func TestGetOutputShape(t *testing.T) { expected tensor.Shape }{ { - &Conv{ + &Conv11{ kernelShape: []int{3}, pads: []int{0, 0}, strides: []int{1}, @@ -443,7 +443,7 @@ func TestGetOutputShape(t *testing.T) { []int{1, 1, 4}, }, { - &Conv{ + &Conv11{ kernelShape: []int{3}, pads: []int{1, 2}, strides: []int{2}, @@ -455,7 +455,7 @@ func TestGetOutputShape(t *testing.T) { []int{1, 1, 4}, }, { - &Conv{ + &Conv11{ kernelShape: []int{2, 2}, pads: []int{1, 2, 1, 2}, strides: []int{2, 1}, @@ -467,7 +467,7 @@ func TestGetOutputShape(t *testing.T) { []int{1, 1, 3, 7}, }, { - &Conv{ + &Conv11{ kernelShape: []int{2, 2}, pads: []int{0, 0, 0, 0}, strides: []int{1, 1}, @@ -492,14 +492,14 @@ func TestGetOutputShape(t *testing.T) { func TestPadInput(t *testing.T) { tests := []struct { - conv *Conv + conv *Conv11 xShape []int xBacking []float32 expectedShape tensor.Shape expectedBacking []float32 }{ { - &Conv{ + &Conv11{ pads: []int{0, 0}, }, []int{1, 1, 6}, @@ -508,7 +508,7 @@ func TestPadInput(t *testing.T) { []float32{0, 1, 2, 3, 4, 5}, }, { - &Conv{ + &Conv11{ pads: []int{1, 2}, }, []int{1, 1, 6}, @@ -517,7 +517,7 @@ func TestPadInput(t *testing.T) { []float32{0, 0, 1, 2, 3, 4, 5, 0, 0}, }, { - &Conv{ + &Conv11{ pads: []int{1, 1, 1, 1}, }, []int{1, 1, 2, 2}, @@ -526,7 +526,7 @@ func TestPadInput(t *testing.T) { []float32{0, 0, 0, 0, 0, 1, 2, 0, 0, 3, 4, 0, 0, 0, 0, 0}, }, { - &Conv{ + &Conv11{ pads: []int{1, 0, 2, 0}, }, []int{1, 1, 2, 2}, @@ -549,7 +549,7 @@ func TestPadInput(t *testing.T) { func TestGetSubImage(t *testing.T) { tests := []struct { - conv *Conv + conv *Conv11 xShape []int xBacking []float32 batchIdx int @@ -558,7 +558,7 @@ func TestGetSubImage(t *testing.T) { expectedBacking []float32 }{ { - &Conv{kernelShape: []int{2}}, + &Conv11{kernelShape: []int{2}}, []int{1, 1, 3}, []float32{0, 1, 2}, 0, @@ -567,7 +567,7 @@ func TestGetSubImage(t *testing.T) { []float32{0, 1}, }, { - &Conv{kernelShape: []int{2}}, + &Conv11{kernelShape: []int{2}}, []int{1, 2, 3}, []float32{0, 1, 2, 3, 4, 5}, 0, @@ -576,7 +576,7 @@ func TestGetSubImage(t *testing.T) { []float32{0, 1, 3, 4}, }, { - &Conv{kernelShape: []int{2, 2}}, + &Conv11{kernelShape: []int{2, 2}}, []int{1, 1, 3, 3}, []float32{0, 1, 2, 3, 4, 5, 6, 7, 8}, 0, @@ -585,7 +585,7 @@ func TestGetSubImage(t *testing.T) { []float32{0, 1, 3, 4}, }, { - &Conv{kernelShape: []int{2, 2}}, + &Conv11{kernelShape: []int{2, 2}}, []int{1, 1, 3, 3}, []float32{0, 1, 2, 3, 4, 5, 6, 7, 8}, 0, @@ -594,7 +594,7 @@ func TestGetSubImage(t *testing.T) { []float32{4, 5, 7, 8}, }, { - &Conv{kernelShape: []int{2}}, + &Conv11{kernelShape: []int{2}}, []int{2, 1, 3}, []float32{0, 1, 2, 3, 4, 5}, 1, @@ -619,7 +619,7 @@ func TestGetSubImage(t *testing.T) { func TestAddBias(t *testing.T) { tests := []struct { - conv *Conv + conv *Conv11 outShape []int outBacking []float32 biasShape []int @@ -627,7 +627,7 @@ func TestAddBias(t *testing.T) { expected []float32 }{ { - &Conv{}, + &Conv11{}, []int{1, 1, 3}, []float32{0, 1, 2}, []int{1}, @@ -635,7 +635,7 @@ func TestAddBias(t *testing.T) { []float32{0.5, 1.5, 2.5}, }, { - &Conv{}, + &Conv11{}, []int{1, 1, 3, 3}, []float32{0, 1, 2, 3, 4, 5, 6, 7, 8}, []int{1}, @@ -643,7 +643,7 @@ func TestAddBias(t *testing.T) { []float32{0.5, 1.5, 2.5, 3.5, 4.5, 5.5, 6.5, 7.5, 8.5}, }, { - &Conv{}, + &Conv11{}, []int{1, 2, 2, 2}, []float32{0, 1, 2, 3, 4, 5, 6, 7}, []int{2}, @@ -651,7 +651,7 @@ func TestAddBias(t *testing.T) { []float32{-1, 0, 1, 2, 5, 6, 7, 8}, }, { - &Conv{}, + &Conv11{}, []int{2, 2, 2, 2}, []float32{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15}, []int{2}, @@ -671,7 +671,7 @@ func TestAddBias(t *testing.T) { } } -func Conv2DOnnxNodeProtoFixture() *onnx.NodeProto { +func Conv112DOnnxNodeProtoFixture() *onnx.NodeProto { return &onnx.NodeProto{ Attribute: []*onnx.AttributeProto{ {Name: "auto_pad", S: []byte("VALID")}, @@ -683,7 +683,7 @@ func Conv2DOnnxNodeProtoFixture() *onnx.NodeProto { } } -func ConvUnsupportedOnnxNodeProtoFixture() *onnx.NodeProto { +func Conv11UnsupportedOnnxNodeProtoFixture() *onnx.NodeProto { return &onnx.NodeProto{ Attribute: []*onnx.AttributeProto{ {Name: "group", I: 2}, diff --git a/ops/opset13/cos.go b/ops/cos/cos_7.go similarity index 70% rename from ops/opset13/cos.go rename to ops/cos/cos_7.go index ad01f82..e086ad7 100644 --- a/ops/opset13/cos.go +++ b/ops/cos/cos_7.go @@ -1,4 +1,4 @@ -package opset13 +package cos import ( "math" @@ -8,21 +8,21 @@ import ( "gorgonia.org/tensor" ) -// Cos represents the ONNX cos operator. -type Cos struct{} +// Cos7 represents the ONNX cos operator. +type Cos7 struct{} -// newCos creates a new cos operator. -func newCos() ops.Operator { - return &Cos{} +// newCos7 creates a new cos operator. +func NewCos7() ops.Operator { + return &Cos7{} } // Init initializes the cos operator. -func (c *Cos) Init(*onnx.NodeProto) error { +func (c *Cos7) Init(*onnx.NodeProto) error { return nil } // Apply applies the cos operator. -func (c *Cos) Apply(inputs []tensor.Tensor) ([]tensor.Tensor, error) { +func (c *Cos7) Apply(inputs []tensor.Tensor) ([]tensor.Tensor, error) { var ( out tensor.Tensor err error @@ -45,29 +45,29 @@ func (c *Cos) Apply(inputs []tensor.Tensor) ([]tensor.Tensor, error) { } // ValidateInputs validates the inputs that will be given to Apply for this operator. -func (c *Cos) ValidateInputs(inputs []tensor.Tensor) ([]tensor.Tensor, error) { +func (c *Cos7) ValidateInputs(inputs []tensor.Tensor) ([]tensor.Tensor, error) { return ops.ValidateInputs(c, inputs) } // GetMinInputs returns the minimum number of input tensors this operator expects. -func (c *Cos) GetMinInputs() int { +func (c *Cos7) GetMinInputs() int { return 1 } // GetMaxInputs returns the maximum number of input tensors this operator expects. -func (c *Cos) GetMaxInputs() int { +func (c *Cos7) GetMaxInputs() int { return 1 } // GetInputTypeConstraints returns a list. Every element represents a set of allowed tensor dtypes // for the corresponding input tensor. -func (c *Cos) GetInputTypeConstraints() [][]tensor.Dtype { +func (c *Cos7) GetInputTypeConstraints() [][]tensor.Dtype { return [][]tensor.Dtype{{tensor.Float32, tensor.Float64}} } // String implements the stringer interface, and can be used to format errors or messages. -func (c *Cos) String() string { - return "cos operator" +func (c *Cos7) String() string { + return "cos7 operator" } func cos[T ops.FloatType](x T) T { diff --git a/ops/opset13/cos_test.go b/ops/cos/cos_7_test.go similarity index 83% rename from ops/opset13/cos_test.go rename to ops/cos/cos_7_test.go index b1087c4..d8af997 100644 --- a/ops/opset13/cos_test.go +++ b/ops/cos/cos_7_test.go @@ -1,4 +1,4 @@ -package opset13 +package cos import ( "testing" @@ -8,8 +8,8 @@ import ( "gorgonia.org/tensor" ) -func TestCosInit(t *testing.T) { - c := &Cos{} +func TestCos7Init(t *testing.T) { + c := &Cos7{} // since 'cos' does not have any attributes we pass in nil. This should not // fail initializing the cos. @@ -17,27 +17,27 @@ func TestCosInit(t *testing.T) { assert.Nil(t, err) } -func TestCos(t *testing.T) { +func TestCos7(t *testing.T) { tests := []struct { - cos *Cos + cos *Cos7 backing []float32 shape []int expected []float32 }{ { - &Cos{}, + &Cos7{}, []float32{-2, -1, 0, 1}, []int{2, 2}, []float32{-0.41614684, 0.5403023, 1, 0.5403023}, }, { - &Cos{}, + &Cos7{}, []float32{1, 3, 4, 5}, []int{1, 4}, []float32{0.5403023, -0.9899925, -0.6536436, 0.2836622}, }, { - &Cos{}, + &Cos7{}, []float32{-1, -1, -1, -1}, []int{1, 4}, []float32{0.5403023, 0.5403023, 0.5403023, 0.5403023}, @@ -57,7 +57,7 @@ func TestCos(t *testing.T) { } } -func TestInputValidationCos(t *testing.T) { +func TestInputValidationCos7(t *testing.T) { tests := []struct { inputs []tensor.Tensor err error @@ -76,18 +76,18 @@ func TestInputValidationCos(t *testing.T) { }, { []tensor.Tensor{}, - ops.ErrInvalidInputCount(0, &Cos{}), + ops.ErrInvalidInputCount(0, &Cos7{}), }, { []tensor.Tensor{ ops.TensorWithBackingFixture([]int{1, 2}, 2), }, - ops.ErrInvalidInputType(0, "int", &Cos{}), + ops.ErrInvalidInputType(0, "int", &Cos7{}), }, } for _, test := range tests { - cos := &Cos{} + cos := &Cos7{} validated, err := cos.ValidateInputs(test.inputs) assert.Equal(t, test.err, err) diff --git a/ops/opset13/cosh.go b/ops/cosh/cosh_9.go similarity index 69% rename from ops/opset13/cosh.go rename to ops/cosh/cosh_9.go index cddb129..0e8e586 100644 --- a/ops/opset13/cosh.go +++ b/ops/cosh/cosh_9.go @@ -1,4 +1,4 @@ -package opset13 +package cosh import ( "math" @@ -8,21 +8,21 @@ import ( "gorgonia.org/tensor" ) -// Cosh represents the ONNX cosh operator. -type Cosh struct{} +// Cosh9 represents the ONNX cosh operator. +type Cosh9 struct{} -// newCosh creates a new cosh operator. -func newCosh() ops.Operator { - return &Cosh{} +// newCosh9 creates a new cosh operator. +func NewCosh9() ops.Operator { + return &Cosh9{} } // Init initializes the cosh operator. -func (c *Cosh) Init(*onnx.NodeProto) error { +func (c *Cosh9) Init(*onnx.NodeProto) error { return nil } // Apply applies the cosh operator. -func (c *Cosh) Apply(inputs []tensor.Tensor) ([]tensor.Tensor, error) { +func (c *Cosh9) Apply(inputs []tensor.Tensor) ([]tensor.Tensor, error) { var ( out tensor.Tensor err error @@ -45,29 +45,29 @@ func (c *Cosh) Apply(inputs []tensor.Tensor) ([]tensor.Tensor, error) { } // ValidateInputs validates the inputs that will be given to Apply for this operator. -func (c *Cosh) ValidateInputs(inputs []tensor.Tensor) ([]tensor.Tensor, error) { +func (c *Cosh9) ValidateInputs(inputs []tensor.Tensor) ([]tensor.Tensor, error) { return ops.ValidateInputs(c, inputs) } // GetMinInputs returns the minimum number of input tensors this operator expects. -func (c *Cosh) GetMinInputs() int { +func (c *Cosh9) GetMinInputs() int { return 1 } // GetMaxInputs returns the maximum number of input tensors this operator expects. -func (c *Cosh) GetMaxInputs() int { +func (c *Cosh9) GetMaxInputs() int { return 1 } // GetInputTypeConstraints returns a list. Every element represents a set of allowed tensor dtypes // for the corresponding input tensor. -func (c *Cosh) GetInputTypeConstraints() [][]tensor.Dtype { +func (c *Cosh9) GetInputTypeConstraints() [][]tensor.Dtype { return [][]tensor.Dtype{{tensor.Float32, tensor.Float64}} } // String implements the stringer interface, and can be used to format errors or messages. -func (c *Cosh) String() string { - return "cosh operator" +func (c *Cosh9) String() string { + return "cosh9 operator" } func cosh[T ops.FloatType](x T) T { diff --git a/ops/opset13/cosh_test.go b/ops/cosh/cosh_9_test.go similarity index 83% rename from ops/opset13/cosh_test.go rename to ops/cosh/cosh_9_test.go index 3359ada..5717a7d 100644 --- a/ops/opset13/cosh_test.go +++ b/ops/cosh/cosh_9_test.go @@ -1,4 +1,4 @@ -package opset13 +package cosh import ( "testing" @@ -8,8 +8,8 @@ import ( "gorgonia.org/tensor" ) -func TestCoshInit(t *testing.T) { - c := &Cosh{} +func TestCosh9Init(t *testing.T) { + c := &Cosh9{} // since 'cosh' does not have any attributes we pass in nil. This should not // fail initializing the cosh. @@ -17,27 +17,27 @@ func TestCoshInit(t *testing.T) { assert.Nil(t, err) } -func TestCosh(t *testing.T) { +func TestCosh9(t *testing.T) { tests := []struct { - cosh *Cosh + cosh *Cosh9 backing []float32 shape []int expected []float32 }{ { - &Cosh{}, + &Cosh9{}, []float32{-2, -1, 0, 1}, []int{2, 2}, []float32{3.7621956, 1.5430807, 1, 1.5430807}, }, { - &Cosh{}, + &Cosh9{}, []float32{1, 3, 4, 5}, []int{1, 4}, []float32{1.5430807, 10.067662, 27.308233, 74.209946}, }, { - &Cosh{}, + &Cosh9{}, []float32{-1, -1, -1, -1}, []int{1, 4}, []float32{1.5430807, 1.5430807, 1.5430807, 1.5430807}, @@ -57,7 +57,7 @@ func TestCosh(t *testing.T) { } } -func TestInputValidationCosh(t *testing.T) { +func TestInputValidationCosh9(t *testing.T) { tests := []struct { inputs []tensor.Tensor err error @@ -76,18 +76,18 @@ func TestInputValidationCosh(t *testing.T) { }, { []tensor.Tensor{}, - ops.ErrInvalidInputCount(0, &Cosh{}), + ops.ErrInvalidInputCount(0, &Cosh9{}), }, { []tensor.Tensor{ ops.TensorWithBackingFixture([]int{1, 2}, 2), }, - ops.ErrInvalidInputType(0, "int", &Cosh{}), + ops.ErrInvalidInputType(0, "int", &Cosh9{}), }, } for _, test := range tests { - cosh := &Cosh{} + cosh := &Cosh9{} validated, err := cosh.ValidateInputs(test.inputs) assert.Equal(t, test.err, err) diff --git a/ops/div/div_13.go b/ops/div/div_13.go new file mode 100644 index 0000000..882fd24 --- /dev/null +++ b/ops/div/div_13.go @@ -0,0 +1,64 @@ +package div + +import ( + "github.com/advancedclimatesystems/gonnx/onnx" + "github.com/advancedclimatesystems/gonnx/ops" + "gorgonia.org/tensor" +) + +const ( + MinDiv13Inputs = 2 + MaxDiv13Inputs = 2 +) + +// Div13 represents the ONNX div operator. +type Div13 struct{} + +// newDiv13 creates a new div operator. +func NewDiv13() ops.Operator { + return &Div13{} +} + +// Init initializes the div operator. +func (d *Div13) Init(*onnx.NodeProto) error { + return nil +} + +// Apply applies the div operator. +func (d *Div13) Apply(inputs []tensor.Tensor) ([]tensor.Tensor, error) { + return ops.ApplyBinaryOperation( + inputs[0], + inputs[1], + ops.Div, + ops.MultidirectionalBroadcasting, + ) +} + +// ValidateInputs validates the inputs that will be given to Apply for this operator. +func (d *Div13) ValidateInputs(inputs []tensor.Tensor) ([]tensor.Tensor, error) { + return ops.ValidateInputs(d, inputs) +} + +// GetMinInputs returns the minimum number of input tensors this operator expects. +func (d *Div13) GetMinInputs() int { + return MinDiv13Inputs +} + +// GetMaxInputs returns the maximum number of input tensors this operator expects. +func (d *Div13) GetMaxInputs() int { + return MaxDiv13Inputs +} + +// GetInputTypeConstraints returns a list. Every element represents a set of allowed tensor dtypes +// for the corresponding input tensor. +func (d *Div13) GetInputTypeConstraints() [][]tensor.Dtype { + return [][]tensor.Dtype{ + {tensor.Uint32, tensor.Uint64, tensor.Int32, tensor.Int64, tensor.Float32, tensor.Float64}, + {tensor.Uint32, tensor.Uint64, tensor.Int32, tensor.Int64, tensor.Float32, tensor.Float64}, + } +} + +// String implements the stringer interface, and can be used to format errors or messages. +func (d *Div13) String() string { + return "div13 operator" +} diff --git a/ops/opset13/div_test.go b/ops/div/div_13_test.go similarity index 88% rename from ops/opset13/div_test.go rename to ops/div/div_13_test.go index 06a4f45..f28c8ba 100644 --- a/ops/opset13/div_test.go +++ b/ops/div/div_13_test.go @@ -1,4 +1,4 @@ -package opset13 +package div import ( "testing" @@ -8,8 +8,8 @@ import ( "gorgonia.org/tensor" ) -func TestDivInit(t *testing.T) { - div := &Div{} +func TestDiv13Init(t *testing.T) { + div := &Div13{} // since the div does not have any attributes we pass in nil. This should not // fail initializing the div. @@ -17,27 +17,27 @@ func TestDivInit(t *testing.T) { assert.Nil(t, err) } -func TestDiv(t *testing.T) { +func TestDiv13(t *testing.T) { tests := []struct { - div *Div + div *Div13 shapes [][]int backings [][]float32 expected []float32 }{ { - &Div{}, + &Div13{}, [][]int{{2, 2}, {2, 2}}, [][]float32{{10, 10, 10, 10}, {2, 5, 2.5, 1.0}}, []float32{5, 2, 4, 10}, }, { - &Div{}, + &Div13{}, [][]int{{2, 2}, {2}}, [][]float32{{1, 1, 1, 1}, {1, 2}}, []float32{1, 0.5, 1, 0.5}, }, { - &Div{}, + &Div13{}, [][]int{{2, 2}, {1}}, [][]float32{{1, 1, 1, 1}, {2}}, []float32{0.5, 0.5, 0.5, 0.5}, @@ -56,7 +56,7 @@ func TestDiv(t *testing.T) { } } -func TestInputValidationDiv(t *testing.T) { +func TestInputValidationDiv13(t *testing.T) { tests := []struct { inputs []tensor.Tensor err error @@ -107,19 +107,19 @@ func TestInputValidationDiv(t *testing.T) { []tensor.Tensor{ ops.TensorWithBackingFixture([]int{1, 2}, 2), }, - ops.ErrInvalidInputCount(1, &Div{}), + ops.ErrInvalidInputCount(1, &Div13{}), }, { []tensor.Tensor{ ops.TensorWithBackingFixture([]int{1, 2}, 2), ops.TensorWithBackingFixture([]int{3, 4}, 2), }, - ops.ErrInvalidInputType(0, "int", &Div{}), + ops.ErrInvalidInputType(0, "int", &Div13{}), }, } for _, test := range tests { - div := &Div{} + div := &Div13{} validated, err := div.ValidateInputs(test.inputs) assert.Equal(t, test.err, err) diff --git a/ops/opset13/div.go b/ops/div/div_7.go similarity index 63% rename from ops/opset13/div.go rename to ops/div/div_7.go index e918e7f..a8cbdff 100644 --- a/ops/opset13/div.go +++ b/ops/div/div_7.go @@ -1,4 +1,4 @@ -package opset13 +package div import ( "github.com/advancedclimatesystems/gonnx/onnx" @@ -7,25 +7,25 @@ import ( ) const ( - MinDivInputs = 2 - MaxDivInputs = 2 + MinDiv7Inputs = 2 + MaxDiv7Inputs = 2 ) -// Div represents the ONNX div operator. -type Div struct{} +// Div7 represents the ONNX div operator. +type Div7 struct{} -// newDiv creates a new div operator. -func newDiv() ops.Operator { - return &Div{} +// newDiv7 creates a new div operator. +func NewDiv7() ops.Operator { + return &Div7{} } // Init initializes the div operator. -func (d *Div) Init(*onnx.NodeProto) error { +func (d *Div7) Init(*onnx.NodeProto) error { return nil } // Apply applies the div operator. -func (d *Div) Apply(inputs []tensor.Tensor) ([]tensor.Tensor, error) { +func (d *Div7) Apply(inputs []tensor.Tensor) ([]tensor.Tensor, error) { return ops.ApplyBinaryOperation( inputs[0], inputs[1], @@ -35,23 +35,23 @@ func (d *Div) Apply(inputs []tensor.Tensor) ([]tensor.Tensor, error) { } // ValidateInputs validates the inputs that will be given to Apply for this operator. -func (d *Div) ValidateInputs(inputs []tensor.Tensor) ([]tensor.Tensor, error) { +func (d *Div7) ValidateInputs(inputs []tensor.Tensor) ([]tensor.Tensor, error) { return ops.ValidateInputs(d, inputs) } // GetMinInputs returns the minimum number of input tensors this operator expects. -func (d *Div) GetMinInputs() int { - return MinDivInputs +func (d *Div7) GetMinInputs() int { + return MinDiv7Inputs } // GetMaxInputs returns the maximum number of input tensors this operator expects. -func (d *Div) GetMaxInputs() int { - return MaxDivInputs +func (d *Div7) GetMaxInputs() int { + return MaxDiv7Inputs } // GetInputTypeConstraints returns a list. Every element represents a set of allowed tensor dtypes // for the corresponding input tensor. -func (d *Div) GetInputTypeConstraints() [][]tensor.Dtype { +func (d *Div7) GetInputTypeConstraints() [][]tensor.Dtype { return [][]tensor.Dtype{ {tensor.Uint32, tensor.Uint64, tensor.Int32, tensor.Int64, tensor.Float32, tensor.Float64}, {tensor.Uint32, tensor.Uint64, tensor.Int32, tensor.Int64, tensor.Float32, tensor.Float64}, @@ -59,6 +59,6 @@ func (d *Div) GetInputTypeConstraints() [][]tensor.Dtype { } // String implements the stringer interface, and can be used to format errors or messages. -func (d *Div) String() string { - return "div operator" +func (d *Div7) String() string { + return "div7 operator" } diff --git a/ops/opset13/equal.go b/ops/equal/equal_11.go similarity index 57% rename from ops/opset13/equal.go rename to ops/equal/equal_11.go index db888b8..bbe468a 100644 --- a/ops/opset13/equal.go +++ b/ops/equal/equal_11.go @@ -1,4 +1,4 @@ -package opset13 +package equal import ( "github.com/advancedclimatesystems/gonnx/onnx" @@ -7,25 +7,25 @@ import ( ) var ( - MinEqualInputs = 2 - MaxEqualInputs = 2 + MinEqual11Inputs = 2 + MaxEqual11Inputs = 2 ) -// Equal represents the ONNX equal operator. -type Equal struct{} +// Equal11 represents the ONNX equal operator. +type Equal11 struct{} -// newEqual creates a new equal operator. -func newEqual() ops.Operator { - return &Equal{} +// newEqual11 creates a new equal operator. +func NewEqual11() ops.Operator { + return &Equal11{} } // Init initializes the equal operator. -func (e *Equal) Init(*onnx.NodeProto) error { +func (e *Equal11) Init(*onnx.NodeProto) error { return nil } // Apply applies the equal operator. -func (e *Equal) Apply(inputs []tensor.Tensor) ([]tensor.Tensor, error) { +func (e *Equal11) Apply(inputs []tensor.Tensor) ([]tensor.Tensor, error) { return ops.ApplyBinaryOperation( inputs[0], inputs[1], @@ -35,27 +35,27 @@ func (e *Equal) Apply(inputs []tensor.Tensor) ([]tensor.Tensor, error) { } // ValidateInputs validates the inputs that will be given to Apply for this operator. -func (e *Equal) ValidateInputs(inputs []tensor.Tensor) ([]tensor.Tensor, error) { +func (e *Equal11) ValidateInputs(inputs []tensor.Tensor) ([]tensor.Tensor, error) { return ops.ValidateInputs(e, inputs) } // GetMinInputs returns the minimum number of input tensors this operator expects. -func (e *Equal) GetMinInputs() int { - return MinEqualInputs +func (e *Equal11) GetMinInputs() int { + return MinEqual11Inputs } // GetMaxInputs returns the maximum number of input tensors this operator expects. -func (e *Equal) GetMaxInputs() int { - return MaxEqualInputs +func (e *Equal11) GetMaxInputs() int { + return MaxEqual11Inputs } // GetInputTypeConstraints returns a list. Every element represents a set of allowed tensor dtypes // for the corresponding input tensor. -func (e *Equal) GetInputTypeConstraints() [][]tensor.Dtype { +func (e *Equal11) GetInputTypeConstraints() [][]tensor.Dtype { return [][]tensor.Dtype{ops.AllTypes, ops.AllTypes} } // String implements the stringer interface, and can be used to format errors or messages. -func (e *Equal) String() string { - return "equal operator" +func (e *Equal11) String() string { + return "equal11 operator" } diff --git a/ops/equal/equal_13.go b/ops/equal/equal_13.go new file mode 100644 index 0000000..55a52cd --- /dev/null +++ b/ops/equal/equal_13.go @@ -0,0 +1,61 @@ +package equal + +import ( + "github.com/advancedclimatesystems/gonnx/onnx" + "github.com/advancedclimatesystems/gonnx/ops" + "gorgonia.org/tensor" +) + +var ( + MinEqual13Inputs = 2 + MaxEqual13Inputs = 2 +) + +// Equal13 represents the ONNX equal operator. +type Equal13 struct{} + +// newEqual13 creates a new equal operator. +func NewEqual13() ops.Operator { + return &Equal13{} +} + +// Init initializes the equal operator. +func (e *Equal13) Init(*onnx.NodeProto) error { + return nil +} + +// Apply applies the equal operator. +func (e *Equal13) Apply(inputs []tensor.Tensor) ([]tensor.Tensor, error) { + return ops.ApplyBinaryOperation( + inputs[0], + inputs[1], + ops.Equal, + ops.MultidirectionalBroadcasting, + ) +} + +// ValidateInputs validates the inputs that will be given to Apply for this operator. +func (e *Equal13) ValidateInputs(inputs []tensor.Tensor) ([]tensor.Tensor, error) { + return ops.ValidateInputs(e, inputs) +} + +// GetMinInputs returns the minimum number of input tensors this operator expects. +func (e *Equal13) GetMinInputs() int { + return MinEqual13Inputs +} + +// GetMaxInputs returns the maximum number of input tensors this operator expects. +func (e *Equal13) GetMaxInputs() int { + return MaxEqual13Inputs +} + +// GetInputTypeConstraints returns a list. Every element represents a set of allowed tensor dtypes +// for the corresponding input tensor. +func (e *Equal13) GetInputTypeConstraints() [][]tensor.Dtype { + return [][]tensor.Dtype{ops.AllTypes, ops.AllTypes} +} + +// String implements the stringer interface, and can be used to format errors or messages. +func (e *Equal13) String() string { + return "equal13 operator" +} diff --git a/ops/opset13/equal_test.go b/ops/equal/equal_13_test.go similarity index 88% rename from ops/opset13/equal_test.go rename to ops/equal/equal_13_test.go index 9014e78..27a09fa 100644 --- a/ops/opset13/equal_test.go +++ b/ops/equal/equal_13_test.go @@ -1,4 +1,4 @@ -package opset13 +package equal import ( "testing" @@ -8,8 +8,8 @@ import ( "gorgonia.org/tensor" ) -func TestEqualInit(t *testing.T) { - e := &Equal{} +func TestEqual13Init(t *testing.T) { + e := &Equal13{} // since 'equal' does not have any attributes we pass in nil. This should not // fail initializing the equal. @@ -17,27 +17,27 @@ func TestEqualInit(t *testing.T) { assert.Nil(t, err) } -func TestEqual(t *testing.T) { +func TestEqual13(t *testing.T) { tests := []struct { - equal *Equal + equal *Equal13 backings [][]float32 shapes [][]int expected []bool }{ { - &Equal{}, + &Equal13{}, [][]float32{{0, 1, 2, 3}, {1, 1, 1, 1}}, [][]int{{2, 2}, {2, 2}}, []bool{false, true, false, false}, }, { - &Equal{}, + &Equal13{}, [][]float32{{0, 1, 2, 2, 4, 5}, {2, 2, 2, 2, 2, 2}}, [][]int{{3, 2}, {3, 2}}, []bool{false, false, true, true, false, false}, }, { - &Equal{}, + &Equal13{}, [][]float32{{0, 1}, {0, 1, 0, 1}}, [][]int{{2}, {2, 2}}, []bool{true, true, true, true}, @@ -58,7 +58,7 @@ func TestEqual(t *testing.T) { } } -func TestInputValidationEqual(t *testing.T) { +func TestInputValidationEqual13(t *testing.T) { tests := []struct { inputs []tensor.Tensor err error @@ -109,19 +109,19 @@ func TestInputValidationEqual(t *testing.T) { []tensor.Tensor{ ops.TensorWithBackingFixture([]int{1, 2}, 2), }, - ops.ErrInvalidInputCount(1, &Equal{}), + ops.ErrInvalidInputCount(1, &Equal13{}), }, { []tensor.Tensor{ ops.TensorWithBackingFixture([]int{1, 2}, 2), ops.TensorWithBackingFixture([]int{3, 4}, 2), }, - ops.ErrInvalidInputType(0, "int", &Equal{}), + ops.ErrInvalidInputType(0, "int", &Equal13{}), }, } for _, test := range tests { - equal := &Equal{} + equal := &Equal13{} validated, err := equal.ValidateInputs(test.inputs) assert.Equal(t, test.err, err) diff --git a/ops/equal/equal_7.go b/ops/equal/equal_7.go new file mode 100644 index 0000000..79583b2 --- /dev/null +++ b/ops/equal/equal_7.go @@ -0,0 +1,61 @@ +package equal + +import ( + "github.com/advancedclimatesystems/gonnx/onnx" + "github.com/advancedclimatesystems/gonnx/ops" + "gorgonia.org/tensor" +) + +var ( + MinEqual7Inputs = 2 + MaxEqual7Inputs = 2 +) + +// Equal7 represents the ONNX equal operator. +type Equal7 struct{} + +// newEqual7 creates a new equal operator. +func NewEqual7() ops.Operator { + return &Equal7{} +} + +// Init initializes the equal operator. +func (e *Equal7) Init(*onnx.NodeProto) error { + return nil +} + +// Apply applies the equal operator. +func (e *Equal7) Apply(inputs []tensor.Tensor) ([]tensor.Tensor, error) { + return ops.ApplyBinaryOperation( + inputs[0], + inputs[1], + ops.Equal, + ops.MultidirectionalBroadcasting, + ) +} + +// ValidateInputs validates the inputs that will be given to Apply for this operator. +func (e *Equal7) ValidateInputs(inputs []tensor.Tensor) ([]tensor.Tensor, error) { + return ops.ValidateInputs(e, inputs) +} + +// GetMinInputs returns the minimum number of input tensors this operator expects. +func (e *Equal7) GetMinInputs() int { + return MinEqual7Inputs +} + +// GetMaxInputs returns the maximum number of input tensors this operator expects. +func (e *Equal7) GetMaxInputs() int { + return MaxEqual7Inputs +} + +// GetInputTypeConstraints returns a list. Every element represents a set of allowed tensor dtypes +// for the corresponding input tensor. +func (e *Equal7) GetInputTypeConstraints() [][]tensor.Dtype { + return [][]tensor.Dtype{{tensor.Bool, tensor.Int32, tensor.Int64}, {tensor.Bool, tensor.Int32, tensor.Int64}} +} + +// String implements the stringer interface, and can be used to format errors or messages. +func (e *Equal7) String() string { + return "equal7 operator" +} diff --git a/ops/expand/expand_13.go b/ops/expand/expand_13.go new file mode 100644 index 0000000..ebedb1e --- /dev/null +++ b/ops/expand/expand_13.go @@ -0,0 +1,81 @@ +package expand + +import ( + "github.com/advancedclimatesystems/gonnx/onnx" + "github.com/advancedclimatesystems/gonnx/ops" + "gorgonia.org/tensor" +) + +const ( + MinExpand13Inputs = 2 + MaxExpand13Inputs = 2 +) + +// Expand13 represents the ONNX expand operator. +type Expand13 struct{} + +// newExpand13 creates a new expand operator. +func NewExpand13() ops.Operator { + return &Expand13{} +} + +// Init initializes the expand operator. +func (f *Expand13) Init(*onnx.NodeProto) error { + return nil +} + +// Apply applies the expand operator. +func (f *Expand13) Apply(inputs []tensor.Tensor) ([]tensor.Tensor, error) { + input := inputs[0] + + shape, err := ops.AnyToIntSlice(inputs[1].Data()) + if err != nil { + return nil, err + } + + // If the new shape has more dimensions than the input tensor, we + // need to prepend some dimensions to the input tensor shape. + if len(shape) > len(input.Shape()) { + input, err = ops.AddExtraDimsToTensor(input, len(shape)-len(input.Shape())) + if err != nil { + return nil, err + } + } + + for axis := len(shape) - 1; axis >= 0; axis-- { + if input.Shape()[axis] != shape[axis] { + input, err = tensor.Repeat(input, axis, shape[axis]) + if err != nil { + return nil, err + } + } + } + + return []tensor.Tensor{input}, nil +} + +// ValidateInputs validates the inputs that will be given to Apply for this operator. +func (f *Expand13) ValidateInputs(inputs []tensor.Tensor) ([]tensor.Tensor, error) { + return ops.ValidateInputs(f, inputs) +} + +// GetMinInputs returns the minimum number of input tensors this operator expects. +func (f *Expand13) GetMinInputs() int { + return MinExpand13Inputs +} + +// GetMaxInputs returns the maximum number of input tensors this operator expects. +func (f *Expand13) GetMaxInputs() int { + return MaxExpand13Inputs +} + +// GetInputTypeConstraints returns a list. Every element represents a set of allowed tensor dtypes +// for the corresponding input tensor. +func (f *Expand13) GetInputTypeConstraints() [][]tensor.Dtype { + return [][]tensor.Dtype{ops.AllTypes, {tensor.Int64}} +} + +// String implements the stringer interface, and can be used to format errors or messages. +func (f *Expand13) String() string { + return "expand13 operator" +} diff --git a/ops/opset13/expand_test.go b/ops/expand/expand_13_test.go similarity index 88% rename from ops/opset13/expand_test.go rename to ops/expand/expand_13_test.go index 325d200..5ae7f3c 100644 --- a/ops/opset13/expand_test.go +++ b/ops/expand/expand_13_test.go @@ -1,4 +1,4 @@ -package opset13 +package expand import ( "testing" @@ -8,16 +8,16 @@ import ( "gorgonia.org/tensor" ) -func TestExpandInit(t *testing.T) { - e := &Expand{} +func TestExpand13Init(t *testing.T) { + e := &Expand13{} err := e.Init(nil) assert.Nil(t, err) } -func TestExpand(t *testing.T) { +func TestExpand13(t *testing.T) { tests := []struct { - expand *Expand + expand *Expand13 backing []float32 shape []int newShapeBacking []int64 @@ -25,7 +25,7 @@ func TestExpand(t *testing.T) { expectedData []float32 }{ { - &Expand{}, + &Expand13{}, []float32{0, 1, 2, 3}, []int{2, 2}, []int64{1, 1, 1}, @@ -33,7 +33,7 @@ func TestExpand(t *testing.T) { []float32{0, 1, 2, 3}, }, { - &Expand{}, + &Expand13{}, []float32{0, 1, 2, 3}, []int{2, 2}, []int64{1, 3, 1, 1}, @@ -56,7 +56,7 @@ func TestExpand(t *testing.T) { } } -func TestInputValidationExpand(t *testing.T) { +func TestInputValidationExpand13(t *testing.T) { tests := []struct { inputs []tensor.Tensor err error @@ -109,19 +109,19 @@ func TestInputValidationExpand(t *testing.T) { ops.TensorWithBackingFixture([]int64{1, 1, 1}, 3), ops.TensorWithBackingFixture([]float32{1, 2}, 2), }, - ops.ErrInvalidInputCount(3, &Expand{}), + ops.ErrInvalidInputCount(3, &Expand13{}), }, { []tensor.Tensor{ ops.TensorWithBackingFixture([]int{1, 2}, 2), ops.TensorWithBackingFixture([]int64{1, 1, 1}, 3), }, - ops.ErrInvalidInputType(0, "int", &Expand{}), + ops.ErrInvalidInputType(0, "int", &Expand13{}), }, } for _, test := range tests { - expand := &Expand{} + expand := &Expand13{} validated, err := expand.ValidateInputs(test.inputs) assert.Equal(t, test.err, err) diff --git a/ops/opset13/expand.go b/ops/expand/expand_8.go similarity index 68% rename from ops/opset13/expand.go rename to ops/expand/expand_8.go index f84fb3a..cf399ab 100644 --- a/ops/opset13/expand.go +++ b/ops/expand/expand_8.go @@ -1,4 +1,4 @@ -package opset13 +package expand import ( "github.com/advancedclimatesystems/gonnx/onnx" @@ -7,25 +7,25 @@ import ( ) const ( - MinExpandInputs = 2 - MaxExpandInputs = 2 + MinExpand8Inputs = 2 + MaxExpand8Inputs = 2 ) -// Expand represents the ONNX expand operator. -type Expand struct{} +// Expand8 represents the ONNX expand operator. +type Expand8 struct{} -// newExpand creates a new expand operator. -func newExpand() ops.Operator { - return &Expand{} +// newExpand8 creates a new expand operator. +func NewExpand8() ops.Operator { + return &Expand8{} } // Init initializes the expand operator. -func (f *Expand) Init(*onnx.NodeProto) error { +func (f *Expand8) Init(*onnx.NodeProto) error { return nil } // Apply applies the expand operator. -func (f *Expand) Apply(inputs []tensor.Tensor) ([]tensor.Tensor, error) { +func (f *Expand8) Apply(inputs []tensor.Tensor) ([]tensor.Tensor, error) { input := inputs[0] shape, err := ops.AnyToIntSlice(inputs[1].Data()) @@ -55,27 +55,27 @@ func (f *Expand) Apply(inputs []tensor.Tensor) ([]tensor.Tensor, error) { } // ValidateInputs validates the inputs that will be given to Apply for this operator. -func (f *Expand) ValidateInputs(inputs []tensor.Tensor) ([]tensor.Tensor, error) { +func (f *Expand8) ValidateInputs(inputs []tensor.Tensor) ([]tensor.Tensor, error) { return ops.ValidateInputs(f, inputs) } // GetMinInputs returns the minimum number of input tensors this operator expects. -func (f *Expand) GetMinInputs() int { - return MinExpandInputs +func (f *Expand8) GetMinInputs() int { + return MinExpand8Inputs } // GetMaxInputs returns the maximum number of input tensors this operator expects. -func (f *Expand) GetMaxInputs() int { - return MaxExpandInputs +func (f *Expand8) GetMaxInputs() int { + return MaxExpand8Inputs } // GetInputTypeConstraints returns a list. Every element represents a set of allowed tensor dtypes // for the corresponding input tensor. -func (f *Expand) GetInputTypeConstraints() [][]tensor.Dtype { +func (f *Expand8) GetInputTypeConstraints() [][]tensor.Dtype { return [][]tensor.Dtype{ops.AllTypes, {tensor.Int64}} } // String implements the stringer interface, and can be used to format errors or messages. -func (f *Expand) String() string { - return "expand operator" +func (f *Expand8) String() string { + return "expand13 operator" } diff --git a/ops/flatten/flatten_1.go b/ops/flatten/flatten_1.go new file mode 100644 index 0000000..dda0910 --- /dev/null +++ b/ops/flatten/flatten_1.go @@ -0,0 +1,88 @@ +package flatten + +import ( + "github.com/advancedclimatesystems/gonnx/onnx" + "github.com/advancedclimatesystems/gonnx/ops" + "gorgonia.org/tensor" +) + +const ( + MinFlatten1Inputs = 1 + MaxFlatten1Inputs = 1 +) + +// Flatten1 represents the ONNX flatten operator. +type Flatten1 struct { + axis int +} + +// newFlatten1 creates a new flatten operator. +func NewFlatten1() ops.Operator { + return &Flatten1{ + axis: 1, + } +} + +// Init initializes the flatten operator. +func (f *Flatten1) Init(n *onnx.NodeProto) error { + for _, attr := range n.GetAttribute() { + switch attr.GetName() { + case "axis": + f.axis = int(attr.GetI()) + default: + return ops.ErrInvalidAttribute(attr.GetName(), f) + } + } + + return nil +} + +// Apply applies the flatten operator. +func (f *Flatten1) Apply(inputs []tensor.Tensor) ([]tensor.Tensor, error) { + inputShape := inputs[0].Shape() + out, ok := inputs[0].Clone().(tensor.Tensor) + if !ok { + return nil, ops.ErrTypeAssert("tensor.Tensor", inputs[0].Clone()) + } + + var err error + // In the special case where axis is 0, we reshape the tensor to shape + // (1, ). This is ONNX defined behaviour. + if f.axis == 0 { + err = out.Reshape(1, ops.NElements(inputShape...)) + } else { + err = out.Reshape(ops.NElements(inputShape[:f.axis]...), ops.NElements(inputShape[f.axis:]...)) + } + + if err != nil { + return nil, err + } + + return []tensor.Tensor{out}, nil +} + +// ValidateInputs validates the inputs that will be given to Apply for this operator. +func (f *Flatten1) ValidateInputs(inputs []tensor.Tensor) ([]tensor.Tensor, error) { + return ops.ValidateInputs(f, inputs) +} + +// GetMinInputs returns the minimum number of input tensors this operator expects. +func (f *Flatten1) GetMinInputs() int { + return MinFlatten1Inputs +} + +// GetMaxInputs returns the maximum number of input tensors this operator expects. +func (f *Flatten1) GetMaxInputs() int { + return MaxFlatten1Inputs +} + +// GetInputTypeConstraints returns a list. Every element represents a set of allowed tensor dtypes +// for the corresponding input tensor. +func (f *Flatten1) GetInputTypeConstraints() [][]tensor.Dtype { + return [][]tensor.Dtype{{tensor.Float32, tensor.Float64}} +} + +// String implements the stringer interface, and can be used to format errors or messages. +func (f *Flatten1) String() string { + return "flatten1 operator" +} diff --git a/ops/opset13/flatten.go b/ops/flatten/flatten_11.go similarity index 69% rename from ops/opset13/flatten.go rename to ops/flatten/flatten_11.go index 50e9039..a2cc90b 100644 --- a/ops/opset13/flatten.go +++ b/ops/flatten/flatten_11.go @@ -1,4 +1,4 @@ -package opset13 +package flatten import ( "github.com/advancedclimatesystems/gonnx/onnx" @@ -7,24 +7,24 @@ import ( ) const ( - MinFlattenInputs = 1 - MaxFlattenInputs = 1 + MinFlatten11Inputs = 1 + MaxFlatten11Inputs = 1 ) -// Flatten represents the ONNX flatten operator. -type Flatten struct { +// Flatten11 represents the ONNX flatten operator. +type Flatten11 struct { axis int } -// newFlatten creates a new flatten operator. -func newFlatten() ops.Operator { - return &Flatten{ +// newFlatten11 creates a new flatten operator. +func NewFlatten11() ops.Operator { + return &Flatten11{ axis: 1, } } // Init initializes the flatten operator. -func (f *Flatten) Init(n *onnx.NodeProto) error { +func (f *Flatten11) Init(n *onnx.NodeProto) error { for _, attr := range n.GetAttribute() { switch attr.GetName() { case "axis": @@ -38,7 +38,7 @@ func (f *Flatten) Init(n *onnx.NodeProto) error { } // Apply applies the flatten operator. -func (f *Flatten) Apply(inputs []tensor.Tensor) ([]tensor.Tensor, error) { +func (f *Flatten11) Apply(inputs []tensor.Tensor) ([]tensor.Tensor, error) { inputShape := inputs[0].Shape() rank := len(inputShape) @@ -69,27 +69,27 @@ func (f *Flatten) Apply(inputs []tensor.Tensor) ([]tensor.Tensor, error) { } // ValidateInputs validates the inputs that will be given to Apply for this operator. -func (f *Flatten) ValidateInputs(inputs []tensor.Tensor) ([]tensor.Tensor, error) { +func (f *Flatten11) ValidateInputs(inputs []tensor.Tensor) ([]tensor.Tensor, error) { return ops.ValidateInputs(f, inputs) } // GetMinInputs returns the minimum number of input tensors this operator expects. -func (f *Flatten) GetMinInputs() int { - return MinFlattenInputs +func (f *Flatten11) GetMinInputs() int { + return MinFlatten11Inputs } // GetMaxInputs returns the maximum number of input tensors this operator expects. -func (f *Flatten) GetMaxInputs() int { - return MaxFlattenInputs +func (f *Flatten11) GetMaxInputs() int { + return MaxFlatten11Inputs } // GetInputTypeConstraints returns a list. Every element represents a set of allowed tensor dtypes // for the corresponding input tensor. -func (f *Flatten) GetInputTypeConstraints() [][]tensor.Dtype { +func (f *Flatten11) GetInputTypeConstraints() [][]tensor.Dtype { return [][]tensor.Dtype{ops.AllTypes} } // String implements the stringer interface, and can be used to format errors or messages. -func (f *Flatten) String() string { - return "flatten operator" +func (f *Flatten11) String() string { + return "flatten11 operator" } diff --git a/ops/flatten/flatten_13.go b/ops/flatten/flatten_13.go new file mode 100644 index 0000000..fba34b9 --- /dev/null +++ b/ops/flatten/flatten_13.go @@ -0,0 +1,95 @@ +package flatten + +import ( + "github.com/advancedclimatesystems/gonnx/onnx" + "github.com/advancedclimatesystems/gonnx/ops" + "gorgonia.org/tensor" +) + +const ( + MinFlatten13Inputs = 1 + MaxFlatten13Inputs = 1 +) + +// Flatten13 represents the ONNX flatten operator. +type Flatten13 struct { + axis int +} + +// newFlatten13 creates a new flatten operator. +func NewFlatten13() ops.Operator { + return &Flatten13{ + axis: 1, + } +} + +// Init initializes the flatten operator. +func (f *Flatten13) Init(n *onnx.NodeProto) error { + for _, attr := range n.GetAttribute() { + switch attr.GetName() { + case "axis": + f.axis = int(attr.GetI()) + default: + return ops.ErrInvalidAttribute(attr.GetName(), f) + } + } + + return nil +} + +// Apply applies the flatten operator. +func (f *Flatten13) Apply(inputs []tensor.Tensor) ([]tensor.Tensor, error) { + inputShape := inputs[0].Shape() + rank := len(inputShape) + + axis := f.axis + if axis < 0 { + axis = rank + axis + } + + out, ok := inputs[0].Clone().(tensor.Tensor) + if !ok { + return nil, ops.ErrTypeAssert("tensor.Tensor", inputs[0].Clone()) + } + + var err error + // In the special case where axis is 0, we reshape the tensor to shape + // (1, ). This is ONNX defined behaviour. + if axis == 0 { + err = out.Reshape(1, ops.NElements(inputShape...)) + } else { + err = out.Reshape(ops.NElements(inputShape[:axis]...), ops.NElements(inputShape[axis:]...)) + } + + if err != nil { + return nil, err + } + + return []tensor.Tensor{out}, nil +} + +// ValidateInputs validates the inputs that will be given to Apply for this operator. +func (f *Flatten13) ValidateInputs(inputs []tensor.Tensor) ([]tensor.Tensor, error) { + return ops.ValidateInputs(f, inputs) +} + +// GetMinInputs returns the minimum number of input tensors this operator expects. +func (f *Flatten13) GetMinInputs() int { + return MinFlatten13Inputs +} + +// GetMaxInputs returns the maximum number of input tensors this operator expects. +func (f *Flatten13) GetMaxInputs() int { + return MaxFlatten13Inputs +} + +// GetInputTypeConstraints returns a list. Every element represents a set of allowed tensor dtypes +// for the corresponding input tensor. +func (f *Flatten13) GetInputTypeConstraints() [][]tensor.Dtype { + return [][]tensor.Dtype{ops.AllTypes} +} + +// String implements the stringer interface, and can be used to format errors or messages. +func (f *Flatten13) String() string { + return "flatten13 operator" +} diff --git a/ops/opset13/flatten_test.go b/ops/flatten/flatten_13_test.go similarity index 83% rename from ops/opset13/flatten_test.go rename to ops/flatten/flatten_13_test.go index 4a750e1..6d50a43 100644 --- a/ops/opset13/flatten_test.go +++ b/ops/flatten/flatten_13_test.go @@ -1,4 +1,4 @@ -package opset13 +package flatten import ( "testing" @@ -9,8 +9,8 @@ import ( "gorgonia.org/tensor" ) -func TestFlattenInit(t *testing.T) { - f := &Flatten{} +func TestFlatten13Init(t *testing.T) { + f := &Flatten13{} err := f.Init(&onnx.NodeProto{Attribute: []*onnx.AttributeProto{{Name: "axis", I: 2}}}) assert.Nil(t, err) @@ -18,63 +18,63 @@ func TestFlattenInit(t *testing.T) { assert.Equal(t, 2, f.axis) } -func TestFlatten(t *testing.T) { +func TestFlatten13(t *testing.T) { tests := []struct { - flatten *Flatten + flatten *Flatten13 backing []float32 shape []int expectedShape tensor.Shape }{ { - &Flatten{}, + &Flatten13{}, []float32{0, 1, 2, 3}, []int{2, 2}, []int{1, 4}, }, { - &Flatten{}, + &Flatten13{}, []float32{0, 1, 2, 3, 4, 5}, []int{2, 3}, []int{1, 6}, }, { - &Flatten{axis: 1}, + &Flatten13{axis: 1}, []float32{0, 1, 2, 3, 4, 5, 6, 7}, []int{2, 2, 2}, []int{2, 4}, }, { - &Flatten{axis: 2}, + &Flatten13{axis: 2}, []float32{0, 1, 2, 3, 4, 5, 6, 7}, []int{2, 2, 2}, []int{4, 2}, }, { - &Flatten{axis: -1}, + &Flatten13{axis: -1}, []float32{0, 1, 2, 3, 4, 5, 6, 7}, []int{2, 2, 2}, []int{4, 2}, }, { - &Flatten{axis: -2}, + &Flatten13{axis: -2}, []float32{0, 1, 2, 3, 4, 5, 6, 7}, []int{2, 2, 2}, []int{2, 4}, }, { - &Flatten{axis: -3}, + &Flatten13{axis: -3}, []float32{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17}, []int{3, 2, 3}, []int{1, 18}, }, { - &Flatten{axis: 2}, + &Flatten13{axis: 2}, []float32{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17}, []int{3, 2, 3}, []int{6, 3}, }, { - &Flatten{axis: 1}, + &Flatten13{axis: 1}, []float32{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17}, []int{3, 2, 3}, []int{3, 6}, @@ -93,7 +93,7 @@ func TestFlatten(t *testing.T) { } } -func TestInputValidationFlatten(t *testing.T) { +func TestInputValidationFlatten13(t *testing.T) { tests := []struct { inputs []tensor.Tensor err error @@ -139,18 +139,18 @@ func TestInputValidationFlatten(t *testing.T) { ops.TensorWithBackingFixture([]float32{1, 2}, 2), ops.TensorWithBackingFixture([]float32{1, 2}, 2), }, - ops.ErrInvalidInputCount(2, &Flatten{}), + ops.ErrInvalidInputCount(2, &Flatten13{}), }, { []tensor.Tensor{ ops.TensorWithBackingFixture([]int{1, 2}, 2), }, - ops.ErrInvalidInputType(0, "int", &Flatten{}), + ops.ErrInvalidInputType(0, "int", &Flatten13{}), }, } for _, test := range tests { - flatten := &Flatten{} + flatten := &Flatten13{} validated, err := flatten.ValidateInputs(test.inputs) assert.Equal(t, test.err, err) diff --git a/ops/flatten/flatten_9.go b/ops/flatten/flatten_9.go new file mode 100644 index 0000000..1e5ef04 --- /dev/null +++ b/ops/flatten/flatten_9.go @@ -0,0 +1,88 @@ +package flatten + +import ( + "github.com/advancedclimatesystems/gonnx/onnx" + "github.com/advancedclimatesystems/gonnx/ops" + "gorgonia.org/tensor" +) + +const ( + MinFlatten9Inputs = 1 + MaxFlatten9Inputs = 1 +) + +// Flatten9 represents the ONNX flatten operator. +type Flatten9 struct { + axis int +} + +// newFlatten9 creates a new flatten operator. +func NewFlatten9() ops.Operator { + return &Flatten9{ + axis: 1, + } +} + +// Init initializes the flatten operator. +func (f *Flatten9) Init(n *onnx.NodeProto) error { + for _, attr := range n.GetAttribute() { + switch attr.GetName() { + case "axis": + f.axis = int(attr.GetI()) + default: + return ops.ErrInvalidAttribute(attr.GetName(), f) + } + } + + return nil +} + +// Apply applies the flatten operator. +func (f *Flatten9) Apply(inputs []tensor.Tensor) ([]tensor.Tensor, error) { + inputShape := inputs[0].Shape() + out, ok := inputs[0].Clone().(tensor.Tensor) + if !ok { + return nil, ops.ErrTypeAssert("tensor.Tensor", inputs[0].Clone()) + } + + var err error + // In the special case where axis is 0, we reshape the tensor to shape + // (1, ). This is ONNX defined behaviour. + if f.axis == 0 { + err = out.Reshape(1, ops.NElements(inputShape...)) + } else { + err = out.Reshape(ops.NElements(inputShape[:f.axis]...), ops.NElements(inputShape[f.axis:]...)) + } + + if err != nil { + return nil, err + } + + return []tensor.Tensor{out}, nil +} + +// ValidateInputs validates the inputs that will be given to Apply for this operator. +func (f *Flatten9) ValidateInputs(inputs []tensor.Tensor) ([]tensor.Tensor, error) { + return ops.ValidateInputs(f, inputs) +} + +// GetMinInputs returns the minimum number of input tensors this operator expects. +func (f *Flatten9) GetMinInputs() int { + return MinFlatten9Inputs +} + +// GetMaxInputs returns the maximum number of input tensors this operator expects. +func (f *Flatten9) GetMaxInputs() int { + return MaxFlatten9Inputs +} + +// GetInputTypeConstraints returns a list. Every element represents a set of allowed tensor dtypes +// for the corresponding input tensor. +func (f *Flatten9) GetInputTypeConstraints() [][]tensor.Dtype { + return [][]tensor.Dtype{ops.AllTypes} +} + +// String implements the stringer interface, and can be used to format errors or messages. +func (f *Flatten9) String() string { + return "flatten9 operator" +} diff --git a/opset.go b/opset.go index 342738c..00bcf5f 100644 --- a/opset.go +++ b/opset.go @@ -14,6 +14,15 @@ import ( "github.com/advancedclimatesystems/gonnx/ops/atanh" "github.com/advancedclimatesystems/gonnx/ops/cast" "github.com/advancedclimatesystems/gonnx/ops/concat" + "github.com/advancedclimatesystems/gonnx/ops/constant" + "github.com/advancedclimatesystems/gonnx/ops/constantofshape" + "github.com/advancedclimatesystems/gonnx/ops/conv" + "github.com/advancedclimatesystems/gonnx/ops/cos" + "github.com/advancedclimatesystems/gonnx/ops/cosh" + "github.com/advancedclimatesystems/gonnx/ops/div" + "github.com/advancedclimatesystems/gonnx/ops/equal" + "github.com/advancedclimatesystems/gonnx/ops/expand" + "github.com/advancedclimatesystems/gonnx/ops/flatten" ) const ( @@ -28,7 +37,7 @@ type OperatorVersions map[int64]func() ops.Operator var operators = map[string]OperatorVersions{ "Abs": { - 6: abs.NewAbs6, + 6: abs.NewAbs6, // Same, but bfloat16 type is added 13: abs.NewAbs13, }, "Acos": { @@ -38,15 +47,15 @@ var operators = map[string]OperatorVersions{ 9: acosh.NewAcosh9, }, "Add": { - 7: add.NewAdd7, + 7: add.NewAdd7, // Same, but bfloat16 type is added 13: add.NewAdd13, }, "And": { 7: and.NewAnd7, }, "ArgMax": { - 11: argmax.NewArgMax11, - 12: argmax.NewArgMax12, + 11: argmax.NewArgMax11, // Same, but one attribute is added (which we don't support it anyway) + 12: argmax.NewArgMax12, // Same, but bfloat16 type differs 13: argmax.NewArgMax13, }, "Asin": { @@ -62,24 +71,54 @@ var operators = map[string]OperatorVersions{ 9: atanh.NewAtanh9, }, "Cast": { - 6: cast.NewCast6, - 9: cast.NewCast9, + 6: cast.NewCast6, // Same, but string type is added + 9: cast.NewCast9, // Same, but bfloat16 type differs 13: cast.NewCast13, }, "Concat": { 4: concat.NewConcat4, - 11: concat.NewConcat11, + 11: concat.NewConcat11, // Same, but bfloat16 type differs 13: concat.NewConcat13, }, - "Constant": {}, - "ConstantOfShape": {}, - "Conv": {}, - "Cos": {}, - "Cosh": {}, - "Div": {}, - "Equal": {}, - "Expand": {}, - "Flatten": {}, + "Constant": { + 1: constant.NewConstant1, + 9: constant.NewConstant9, + 11: constant.NewConstant11, + 12: constant.NewConstant12, // Same, but bfloat16 type differs + 13: constant.NewConstant13, + }, + "ConstantOfShape": { + 9: constantofshape.NewConstantOfShape9, + }, + "Conv": { + 1: conv.NewConv1, // Same, but only float16 type differs + 11: conv.NewConv11, + }, + "Cos": { + 7: cos.NewCos7, + }, + "Cosh": { + 9: cosh.NewCosh9, + }, + "Div": { + 7: div.NewDiv7, // Same, but float16 type differs + 13: div.NewDiv13, + }, + "Equal": { + 7: equal.NewEqual7, + 11: equal.NewEqual11, // Same, but float16 type differs + 13: equal.NewEqual13, + }, + "Expand": { + 8: expand.NewExpand8, // Same, but float16 type differs + 13: expand.NewExpand13, + }, + "Flatten": { + 1: flatten.NewFlatten1, // Same, but only float types + 9: flatten.NewFlatten9, // Same, but negative axis added + 11: flatten.NewFlatten11, // Same, but float16 type differs + 13: flatten.NewFlatten13, + }, "Gather": {}, "Gemm": {}, "Greater": {},