diff --git a/ops/abs/abs_13.go b/ops/abs/abs_13.go index 69a768b..0ad2ad5 100644 --- a/ops/abs/abs_13.go +++ b/ops/abs/abs_13.go @@ -15,7 +15,7 @@ const ( type Abs13 struct{} // newAbs13 creates a new abs operator. -func NewAbs13() ops.Operator { +func newAbs13() ops.Operator { return &Abs13{} } diff --git a/ops/abs/abs_6.go b/ops/abs/abs_6.go index 704192b..e4c4f6e 100644 --- a/ops/abs/abs_6.go +++ b/ops/abs/abs_6.go @@ -15,7 +15,7 @@ const ( type Abs6 struct{} // newAbs6 creates a new abs operator. -func NewAbs6() ops.Operator { +func newAbs6() ops.Operator { return &Abs6{} } diff --git a/ops/abs/versions.go b/ops/abs/versions.go new file mode 100644 index 0000000..975523e --- /dev/null +++ b/ops/abs/versions.go @@ -0,0 +1,10 @@ +package abs + +import ( + "github.com/advancedclimatesystems/gonnx/ops" +) + +var AbsVersions = ops.OperatorVersions{ + 6: newAbs6, // Same, but bfloat16 type is added + 13: newAbs13, +} diff --git a/ops/acos/acos_7.go b/ops/acos/acos_7.go index 9ced81e..6e80244 100644 --- a/ops/acos/acos_7.go +++ b/ops/acos/acos_7.go @@ -12,7 +12,7 @@ import ( type Acos7 struct{} // newAcos7 creates a new acos operator. -func NewAcos7() ops.Operator { +func newAcos7() ops.Operator { return &Acos7{} } diff --git a/ops/acos/versions.go b/ops/acos/versions.go new file mode 100644 index 0000000..7ea32ad --- /dev/null +++ b/ops/acos/versions.go @@ -0,0 +1,9 @@ +package acos + +import ( + "github.com/advancedclimatesystems/gonnx/ops" +) + +var AcosVersions = ops.OperatorVersions{ + 7: newAcos7, +} diff --git a/ops/acosh/acosh_9.go b/ops/acosh/acosh_9.go index d37f7fb..f376795 100644 --- a/ops/acosh/acosh_9.go +++ b/ops/acosh/acosh_9.go @@ -12,7 +12,7 @@ import ( type Acosh9 struct{} // newAcosh9 creates a new acosh operator. -func NewAcosh9() ops.Operator { +func newAcosh9() ops.Operator { return &Acosh9{} } diff --git a/ops/acosh/versions.go b/ops/acosh/versions.go new file mode 100644 index 0000000..3fff494 --- /dev/null +++ b/ops/acosh/versions.go @@ -0,0 +1,9 @@ +package acosh + +import ( + "github.com/advancedclimatesystems/gonnx/ops" +) + +var AcoshVersions = ops.OperatorVersions{ + 9: newAcosh9, +} diff --git a/ops/add/add_13.go b/ops/add/add_13.go index 8fd6a52..a366e5d 100644 --- a/ops/add/add_13.go +++ b/ops/add/add_13.go @@ -15,7 +15,7 @@ const ( type Add13 struct{} // newAdd13 creates a new add operator. -func NewAdd13() ops.Operator { +func newAdd13() ops.Operator { return &Add13{} } diff --git a/ops/add/add_7.go b/ops/add/add_7.go index a5bc9ca..f83cd37 100644 --- a/ops/add/add_7.go +++ b/ops/add/add_7.go @@ -15,7 +15,7 @@ const ( type Add7 struct{} // newAdd7 creates a new add operator. -func NewAdd7() ops.Operator { +func newAdd7() ops.Operator { return &Add7{} } diff --git a/ops/add/versions.go b/ops/add/versions.go new file mode 100644 index 0000000..4af7e71 --- /dev/null +++ b/ops/add/versions.go @@ -0,0 +1,10 @@ +package add + +import ( + "github.com/advancedclimatesystems/gonnx/ops" +) + +var AddVersions = ops.OperatorVersions{ + 7: newAdd7, // Same, but bfloat16 type is added + 13: newAdd13, +} diff --git a/ops/and/and_7.go b/ops/and/and_7.go index 44d6707..2f768bc 100644 --- a/ops/and/and_7.go +++ b/ops/and/and_7.go @@ -15,7 +15,7 @@ var ( type And7 struct{} // newAnd7 creates a new and operator. -func NewAnd7() ops.Operator { +func newAnd7() ops.Operator { return &And7{} } diff --git a/ops/and/versions.go b/ops/and/versions.go new file mode 100644 index 0000000..1b26c23 --- /dev/null +++ b/ops/and/versions.go @@ -0,0 +1,9 @@ +package and + +import ( + "github.com/advancedclimatesystems/gonnx/ops" +) + +var AndVersions = ops.OperatorVersions{ + 7: newAnd7, +} diff --git a/ops/argmax/argmax_11.go b/ops/argmax/argmax_11.go index 21f23e8..ce8463a 100644 --- a/ops/argmax/argmax_11.go +++ b/ops/argmax/argmax_11.go @@ -19,7 +19,7 @@ type ArgMax11 struct { } // newArgMax11 creates a new argmax operator. -func NewArgMax11() ops.Operator { +func newArgMax11() ops.Operator { return &ArgMax11{ keepDims: true, selectLastIndex: false, diff --git a/ops/argmax/argmax_12.go b/ops/argmax/argmax_12.go index 9d0f9a7..da552d7 100644 --- a/ops/argmax/argmax_12.go +++ b/ops/argmax/argmax_12.go @@ -19,7 +19,7 @@ type ArgMax12 struct { } // newArgMax12 creates a new argmax operator. -func NewArgMax12() ops.Operator { +func newArgMax12() ops.Operator { return &ArgMax12{ keepDims: true, selectLastIndex: false, diff --git a/ops/argmax/argmax_13.go b/ops/argmax/argmax_13.go index b32d1e1..d8ae77f 100644 --- a/ops/argmax/argmax_13.go +++ b/ops/argmax/argmax_13.go @@ -19,7 +19,7 @@ type ArgMax13 struct { } // newArgMax13 creates a new argmax operator. -func NewArgMax13() ops.Operator { +func newArgMax13() ops.Operator { return &ArgMax13{ keepDims: true, selectLastIndex: false, diff --git a/ops/argmax/versions.go b/ops/argmax/versions.go new file mode 100644 index 0000000..e9c44aa --- /dev/null +++ b/ops/argmax/versions.go @@ -0,0 +1,11 @@ +package argmax + +import ( + "github.com/advancedclimatesystems/gonnx/ops" +) + +var ArgMaxVersions = ops.OperatorVersions{ + 11: newArgMax11, // Same, but one attribute is added (which we don't support it anyway) + 12: newArgMax12, // Same, but bfloat16 type differs + 13: newArgMax13, +} diff --git a/ops/asin/asin_7.go b/ops/asin/asin_7.go index 35848a1..c6466fb 100644 --- a/ops/asin/asin_7.go +++ b/ops/asin/asin_7.go @@ -12,7 +12,7 @@ import ( type Asin7 struct{} // newSin creates a new asin operator. -func NewAsin7() ops.Operator { +func newAsin7() ops.Operator { return &Asin7{} } diff --git a/ops/asin/versions.go b/ops/asin/versions.go new file mode 100644 index 0000000..415dab1 --- /dev/null +++ b/ops/asin/versions.go @@ -0,0 +1,9 @@ +package asin + +import ( + "github.com/advancedclimatesystems/gonnx/ops" +) + +var AsinVersions = ops.OperatorVersions{ + 7: newAsin7, +} diff --git a/ops/asinh/asinh_9.go b/ops/asinh/asinh_9.go index ec0eab8..eab0151 100644 --- a/ops/asinh/asinh_9.go +++ b/ops/asinh/asinh_9.go @@ -12,7 +12,7 @@ import ( type Asinh9 struct{} // newAsinh9 creates a new asinh operator. -func NewAsinh9() ops.Operator { +func newAsinh9() ops.Operator { return &Asinh9{} } diff --git a/ops/asinh/versions.go b/ops/asinh/versions.go new file mode 100644 index 0000000..ffeabb2 --- /dev/null +++ b/ops/asinh/versions.go @@ -0,0 +1,9 @@ +package asinh + +import ( + "github.com/advancedclimatesystems/gonnx/ops" +) + +var AsinhVersions = ops.OperatorVersions{ + 9: newAsinh9, +} diff --git a/ops/atan/atan_7.go b/ops/atan/atan_7.go index aeb093a..6e29017 100644 --- a/ops/atan/atan_7.go +++ b/ops/atan/atan_7.go @@ -12,7 +12,7 @@ import ( type Atan7 struct{} // newAtan7 creates a new atan operator. -func NewAtan7() ops.Operator { +func newAtan7() ops.Operator { return &Atan7{} } diff --git a/ops/atan/versions.go b/ops/atan/versions.go new file mode 100644 index 0000000..f1ac3c5 --- /dev/null +++ b/ops/atan/versions.go @@ -0,0 +1,9 @@ +package atan + +import ( + "github.com/advancedclimatesystems/gonnx/ops" +) + +var AtanVersions = ops.OperatorVersions{ + 7: newAtan7, +} diff --git a/ops/atanh/atanh_9.go b/ops/atanh/atanh_9.go index 5f1f484..28e6b88 100644 --- a/ops/atanh/atanh_9.go +++ b/ops/atanh/atanh_9.go @@ -12,7 +12,7 @@ import ( type Atanh9 struct{} // newAtanh9 creates a new atanh operator. -func NewAtanh9() ops.Operator { +func newAtanh9() ops.Operator { return &Atanh9{} } diff --git a/ops/atanh/versions.go b/ops/atanh/versions.go new file mode 100644 index 0000000..eb0be68 --- /dev/null +++ b/ops/atanh/versions.go @@ -0,0 +1,9 @@ +package atanh + +import ( + "github.com/advancedclimatesystems/gonnx/ops" +) + +var AtanhVersions = ops.OperatorVersions{ + 9: newAtanh9, +} diff --git a/ops/cast/cast_13.go b/ops/cast/cast_13.go index f45aaf1..f3bfb41 100644 --- a/ops/cast/cast_13.go +++ b/ops/cast/cast_13.go @@ -17,7 +17,7 @@ type Cast13 struct { } // newCast13 creates a new cast operator. -func NewCast13() ops.Operator { +func newCast13() ops.Operator { return &Cast13{} } diff --git a/ops/cast/cast_6.go b/ops/cast/cast_6.go index 07b9254..71628f1 100644 --- a/ops/cast/cast_6.go +++ b/ops/cast/cast_6.go @@ -17,7 +17,7 @@ type Cast6 struct { } // newCast6 creates a new cast operator. -func NewCast6() ops.Operator { +func newCast6() ops.Operator { return &Cast6{} } diff --git a/ops/cast/cast_9.go b/ops/cast/cast_9.go index 9805ec1..2751d84 100644 --- a/ops/cast/cast_9.go +++ b/ops/cast/cast_9.go @@ -17,7 +17,7 @@ type Cast9 struct { } // newCast9 creates a new cast operator. -func NewCast9() ops.Operator { +func newCast9() ops.Operator { return &Cast9{} } diff --git a/ops/cast/versions.go b/ops/cast/versions.go new file mode 100644 index 0000000..d47a24e --- /dev/null +++ b/ops/cast/versions.go @@ -0,0 +1,11 @@ +package cast + +import ( + "github.com/advancedclimatesystems/gonnx/ops" +) + +var CastVersions = ops.OperatorVersions{ + 6: newCast6, // Same, but string type is added + 9: newCast9, // Same, but bfloat16 type differs + 13: newCast13, +} diff --git a/ops/concat/concat_11.go b/ops/concat/concat_11.go index 28840f8..98188e0 100644 --- a/ops/concat/concat_11.go +++ b/ops/concat/concat_11.go @@ -18,7 +18,7 @@ type Concat11 struct { } // newConcat11 creates a new concat operator. -func NewConcat11() ops.Operator { +func newConcat11() ops.Operator { return &Concat11{} } diff --git a/ops/concat/concat_13.go b/ops/concat/concat_13.go index 088d244..10f84c4 100644 --- a/ops/concat/concat_13.go +++ b/ops/concat/concat_13.go @@ -18,7 +18,7 @@ type Concat13 struct { } // newConcat13 creates a new concat operator. -func NewConcat13() ops.Operator { +func newConcat13() ops.Operator { return &Concat13{} } diff --git a/ops/concat/concat_4.go b/ops/concat/concat_4.go index 538ddf1..7b47410 100644 --- a/ops/concat/concat_4.go +++ b/ops/concat/concat_4.go @@ -18,7 +18,7 @@ type Concat4 struct { } // newConcat4 creates a new concat operator. -func NewConcat4() ops.Operator { +func newConcat4() ops.Operator { return &Concat4{} } diff --git a/ops/concat/versions.go b/ops/concat/versions.go new file mode 100644 index 0000000..8ad09ba --- /dev/null +++ b/ops/concat/versions.go @@ -0,0 +1,11 @@ +package concat + +import ( + "github.com/advancedclimatesystems/gonnx/ops" +) + +var ConcatVersions = ops.OperatorVersions{ + 4: newConcat4, + 11: newConcat11, // Same, but bfloat16 type differs + 13: newConcat13, +} diff --git a/ops/constant/constant_1.go b/ops/constant/constant_1.go index e8f4cea..9fd18c0 100644 --- a/ops/constant/constant_1.go +++ b/ops/constant/constant_1.go @@ -12,7 +12,7 @@ type Constant1 struct { } // newConstant1 creates a new constant operator. -func NewConstant1() ops.Operator { +func newConstant1() ops.Operator { return &Constant1{} } diff --git a/ops/constant/constant_11.go b/ops/constant/constant_11.go index 557df5e..b10b7ea 100644 --- a/ops/constant/constant_11.go +++ b/ops/constant/constant_11.go @@ -12,7 +12,7 @@ type Constant11 struct { } // newConstant11 creates a new constant operator. -func NewConstant11() ops.Operator { +func newConstant11() ops.Operator { return &Constant11{} } diff --git a/ops/constant/constant_12.go b/ops/constant/constant_12.go index 442a4f4..0174bea 100644 --- a/ops/constant/constant_12.go +++ b/ops/constant/constant_12.go @@ -12,7 +12,7 @@ type Constant12 struct { } // newConstant12 creates a new constant operator. -func NewConstant12() ops.Operator { +func newConstant12() ops.Operator { return &Constant12{} } diff --git a/ops/constant/constant_13.go b/ops/constant/constant_13.go index 994d2ee..85261bc 100644 --- a/ops/constant/constant_13.go +++ b/ops/constant/constant_13.go @@ -12,7 +12,7 @@ type Constant13 struct { } // newConstant13 creates a new constant operator. -func NewConstant13() ops.Operator { +func newConstant13() ops.Operator { return &Constant13{} } diff --git a/ops/constant/constant_9.go b/ops/constant/constant_9.go index 43be99f..08b6830 100644 --- a/ops/constant/constant_9.go +++ b/ops/constant/constant_9.go @@ -12,7 +12,7 @@ type Constant9 struct { } // newConstant9 creates a new constant operator. -func NewConstant9() ops.Operator { +func newConstant9() ops.Operator { return &Constant9{} } diff --git a/ops/constant/versions.go b/ops/constant/versions.go new file mode 100644 index 0000000..607e613 --- /dev/null +++ b/ops/constant/versions.go @@ -0,0 +1,13 @@ +package constant + +import ( + "github.com/advancedclimatesystems/gonnx/ops" +) + +var ConstantVersions = ops.OperatorVersions{ + 1: newConstant1, + 9: newConstant9, + 11: newConstant11, + 12: newConstant12, // Same, but bfloat16 type differs + 13: newConstant13, +} diff --git a/ops/constantofshape/constant_of_shape_9.go b/ops/constantofshape/constant_of_shape_9.go index 6aee059..ac7714d 100644 --- a/ops/constantofshape/constant_of_shape_9.go +++ b/ops/constantofshape/constant_of_shape_9.go @@ -19,7 +19,7 @@ type ConstantOfShape9 struct { } // newConstantOfShape9 creates a new constant of shape operator. -func NewConstantOfShape9() ops.Operator { +func newConstantOfShape9() ops.Operator { return &ConstantOfShape9{} } diff --git a/ops/constantofshape/versions.go b/ops/constantofshape/versions.go new file mode 100644 index 0000000..6900ac8 --- /dev/null +++ b/ops/constantofshape/versions.go @@ -0,0 +1,9 @@ +package constantofshape + +import ( + "github.com/advancedclimatesystems/gonnx/ops" +) + +var ConstantOfShapeVersions = ops.OperatorVersions{ + 9: newConstantOfShape9, +} diff --git a/ops/conv/conv_1.go b/ops/conv/conv_1.go index fe7bef4..5f5c8e6 100644 --- a/ops/conv/conv_1.go +++ b/ops/conv/conv_1.go @@ -24,7 +24,7 @@ type Conv1 struct { } // newConv1 creates a new conv operator. -func NewConv1() ops.Operator { +func newConv1() ops.Operator { return &Conv1{ autoPad: NotSet, } diff --git a/ops/conv/conv_11.go b/ops/conv/conv_11.go index bf740fc..f0a00c2 100644 --- a/ops/conv/conv_11.go +++ b/ops/conv/conv_11.go @@ -39,7 +39,7 @@ type Conv11 struct { } // newConv11 creates a new conv operator. -func NewConv11() ops.Operator { +func newConv11() ops.Operator { return &Conv11{ autoPad: NotSet, } diff --git a/ops/conv/versions.go b/ops/conv/versions.go new file mode 100644 index 0000000..a3cf8d4 --- /dev/null +++ b/ops/conv/versions.go @@ -0,0 +1,10 @@ +package conv + +import ( + "github.com/advancedclimatesystems/gonnx/ops" +) + +var ConvVersions = ops.OperatorVersions{ + 1: newConv1, // Same, but only float16 type differs + 11: newConv11, +} diff --git a/ops/cos/cos_7.go b/ops/cos/cos_7.go index e086ad7..634f9b3 100644 --- a/ops/cos/cos_7.go +++ b/ops/cos/cos_7.go @@ -12,7 +12,7 @@ import ( type Cos7 struct{} // newCos7 creates a new cos operator. -func NewCos7() ops.Operator { +func newCos7() ops.Operator { return &Cos7{} } diff --git a/ops/cos/versions.go b/ops/cos/versions.go new file mode 100644 index 0000000..c8a19da --- /dev/null +++ b/ops/cos/versions.go @@ -0,0 +1,9 @@ +package cos + +import ( + "github.com/advancedclimatesystems/gonnx/ops" +) + +var CosVersions = ops.OperatorVersions{ + 7: newCos7, +} diff --git a/ops/cosh/cosh_9.go b/ops/cosh/cosh_9.go index 0e8e586..7c8da59 100644 --- a/ops/cosh/cosh_9.go +++ b/ops/cosh/cosh_9.go @@ -12,7 +12,7 @@ import ( type Cosh9 struct{} // newCosh9 creates a new cosh operator. -func NewCosh9() ops.Operator { +func newCosh9() ops.Operator { return &Cosh9{} } diff --git a/ops/cosh/versions.go b/ops/cosh/versions.go new file mode 100644 index 0000000..c8cd83a --- /dev/null +++ b/ops/cosh/versions.go @@ -0,0 +1,9 @@ +package cosh + +import ( + "github.com/advancedclimatesystems/gonnx/ops" +) + +var CoshVersions = ops.OperatorVersions{ + 9: newCosh9, +} diff --git a/ops/div/div_13.go b/ops/div/div_13.go index 882fd24..2e7640a 100644 --- a/ops/div/div_13.go +++ b/ops/div/div_13.go @@ -15,7 +15,7 @@ const ( type Div13 struct{} // newDiv13 creates a new div operator. -func NewDiv13() ops.Operator { +func newDiv13() ops.Operator { return &Div13{} } diff --git a/ops/div/div_7.go b/ops/div/div_7.go index a8cbdff..fdf745b 100644 --- a/ops/div/div_7.go +++ b/ops/div/div_7.go @@ -15,7 +15,7 @@ const ( type Div7 struct{} // newDiv7 creates a new div operator. -func NewDiv7() ops.Operator { +func newDiv7() ops.Operator { return &Div7{} } diff --git a/ops/div/versions.go b/ops/div/versions.go new file mode 100644 index 0000000..5e22173 --- /dev/null +++ b/ops/div/versions.go @@ -0,0 +1,10 @@ +package div + +import ( + "github.com/advancedclimatesystems/gonnx/ops" +) + +var DivVersions = ops.OperatorVersions{ + 7: newDiv7, // Same, but float16 type differs + 13: newDiv13, +} diff --git a/ops/equal/equal_11.go b/ops/equal/equal_11.go index bbe468a..d53590b 100644 --- a/ops/equal/equal_11.go +++ b/ops/equal/equal_11.go @@ -15,7 +15,7 @@ var ( type Equal11 struct{} // newEqual11 creates a new equal operator. -func NewEqual11() ops.Operator { +func newEqual11() ops.Operator { return &Equal11{} } diff --git a/ops/equal/equal_13.go b/ops/equal/equal_13.go index 55a52cd..7e3b0a9 100644 --- a/ops/equal/equal_13.go +++ b/ops/equal/equal_13.go @@ -15,7 +15,7 @@ var ( type Equal13 struct{} // newEqual13 creates a new equal operator. -func NewEqual13() ops.Operator { +func newEqual13() ops.Operator { return &Equal13{} } diff --git a/ops/equal/equal_7.go b/ops/equal/equal_7.go index 79583b2..cbdc1f0 100644 --- a/ops/equal/equal_7.go +++ b/ops/equal/equal_7.go @@ -15,7 +15,7 @@ var ( type Equal7 struct{} // newEqual7 creates a new equal operator. -func NewEqual7() ops.Operator { +func newEqual7() ops.Operator { return &Equal7{} } diff --git a/ops/equal/versions.go b/ops/equal/versions.go new file mode 100644 index 0000000..26d99bc --- /dev/null +++ b/ops/equal/versions.go @@ -0,0 +1,11 @@ +package equal + +import ( + "github.com/advancedclimatesystems/gonnx/ops" +) + +var EqualVersions = ops.OperatorVersions{ + 7: newEqual7, + 11: newEqual11, // Same, but float16 type differs + 13: newEqual13, +} diff --git a/ops/expand/expand_13.go b/ops/expand/expand_13.go index ebedb1e..18051d4 100644 --- a/ops/expand/expand_13.go +++ b/ops/expand/expand_13.go @@ -15,7 +15,7 @@ const ( type Expand13 struct{} // newExpand13 creates a new expand operator. -func NewExpand13() ops.Operator { +func newExpand13() ops.Operator { return &Expand13{} } diff --git a/ops/expand/expand_8.go b/ops/expand/expand_8.go index cf399ab..6bee3dd 100644 --- a/ops/expand/expand_8.go +++ b/ops/expand/expand_8.go @@ -15,7 +15,7 @@ const ( type Expand8 struct{} // newExpand8 creates a new expand operator. -func NewExpand8() ops.Operator { +func newExpand8() ops.Operator { return &Expand8{} } diff --git a/ops/expand/versions.go b/ops/expand/versions.go new file mode 100644 index 0000000..5f05988 --- /dev/null +++ b/ops/expand/versions.go @@ -0,0 +1,10 @@ +package expand + +import ( + "github.com/advancedclimatesystems/gonnx/ops" +) + +var ExpandVersions = ops.OperatorVersions{ + 8: newExpand8, // Same, but float16 type differs + 13: newExpand13, +} diff --git a/ops/flatten/flatten_1.go b/ops/flatten/flatten_1.go index dda0910..42f20fa 100644 --- a/ops/flatten/flatten_1.go +++ b/ops/flatten/flatten_1.go @@ -17,7 +17,7 @@ type Flatten1 struct { } // newFlatten1 creates a new flatten operator. -func NewFlatten1() ops.Operator { +func newFlatten1() ops.Operator { return &Flatten1{ axis: 1, } diff --git a/ops/flatten/flatten_11.go b/ops/flatten/flatten_11.go index a2cc90b..55685e4 100644 --- a/ops/flatten/flatten_11.go +++ b/ops/flatten/flatten_11.go @@ -17,7 +17,7 @@ type Flatten11 struct { } // newFlatten11 creates a new flatten operator. -func NewFlatten11() ops.Operator { +func newFlatten11() ops.Operator { return &Flatten11{ axis: 1, } diff --git a/ops/flatten/flatten_13.go b/ops/flatten/flatten_13.go index fba34b9..621f74c 100644 --- a/ops/flatten/flatten_13.go +++ b/ops/flatten/flatten_13.go @@ -17,7 +17,7 @@ type Flatten13 struct { } // newFlatten13 creates a new flatten operator. -func NewFlatten13() ops.Operator { +func newFlatten13() ops.Operator { return &Flatten13{ axis: 1, } diff --git a/ops/flatten/flatten_9.go b/ops/flatten/flatten_9.go index 1e5ef04..b7f5c45 100644 --- a/ops/flatten/flatten_9.go +++ b/ops/flatten/flatten_9.go @@ -17,7 +17,7 @@ type Flatten9 struct { } // newFlatten9 creates a new flatten operator. -func NewFlatten9() ops.Operator { +func newFlatten9() ops.Operator { return &Flatten9{ axis: 1, } diff --git a/ops/flatten/versions.go b/ops/flatten/versions.go new file mode 100644 index 0000000..f158a08 --- /dev/null +++ b/ops/flatten/versions.go @@ -0,0 +1,12 @@ +package flatten + +import ( + "github.com/advancedclimatesystems/gonnx/ops" +) + +var FlattenVersions = ops.OperatorVersions{ + 1: newFlatten1, // Same, but only float types + 9: newFlatten9, // Same, but negative axis added + 11: newFlatten11, // Same, but float16 type differs + 13: newFlatten13, +} diff --git a/ops/gather/gather_1.go b/ops/gather/gather_1.go new file mode 100644 index 0000000..01cdce4 --- /dev/null +++ b/ops/gather/gather_1.go @@ -0,0 +1,121 @@ +package gather + +import ( + "github.com/advancedclimatesystems/gonnx/onnx" + "github.com/advancedclimatesystems/gonnx/ops" + "gorgonia.org/tensor" +) + +const ( + MinGather1Inputs = 2 + MaxGather1Inputs = 2 +) + +// Gather1 represents the ONNX gather operator. +type Gather1 struct { + axis int // axis to gather on, default is 0 +} + +// newGather1 creates a new gather operator. +func newGather1() ops.Operator { + return &Gather1{ + axis: 0, + } +} + +// Init initializes the gather operator. +func (g *Gather1) Init(n *onnx.NodeProto) error { + attributes := n.GetAttribute() + + if len(attributes) == 1 { + attr := attributes[0] + + if attr.GetName() == "axis" { + g.axis = int(attr.GetI()) + } else { + return ops.ErrInvalidAttribute(attr.GetName(), g) + } + } else if len(attributes) > 1 { + return ops.ErrInvalidAttributeCount(1, len(attributes), g) + } + + return nil +} + +// Apply applies the gather operator. +func (g *Gather1) Apply(inputs []tensor.Tensor) ([]tensor.Tensor, error) { + // Convert the indices (of Dtype Int32 or Int64) to a tensor with Dtype Int + indicesData, err := ops.AnyToIntSlice(ops.IfScalarToSlice(inputs[1].Data())) + if err != nil { + return nil, err + } + + indices := tensor.New(tensor.WithBacking(indicesData), tensor.WithShape(inputs[1].Shape()...)) + + data := inputs[0] + + // Make sure axis is in the correct range (according to the size of the data tensor) + rank := len(data.Shape()) + dataAxis := g.axis + + if dataAxis < -rank || dataAxis > rank-1 { + return nil, ops.ErrAxisOutOfRange(rank, rank, dataAxis) + } + // Offset axis if a negative index is given. + if dataAxis < 0 { + dataAxis += rank + } + + // Make sure the input indices are all in the correct range (according to the size of the + // dimension which is selected by `axis`) + axisDimSize := data.Shape()[dataAxis] + if !ops.AllInRange(indicesData, -axisDimSize, axisDimSize-1) { + return nil, ops.ErrNotAllAxesInRange(axisDimSize, axisDimSize) + } + + err = ops.OffsetTensorIfNegative(indices, axisDimSize) + if err != nil { + return nil, err + } + + // Make the shape of the output tensor + os := insertWithReplace(indices.Shape(), data.Shape(), dataAxis) + output := tensor.New(tensor.WithShape(os...), tensor.Of(data.Dtype())) + + // Perform the actual gather operation + err = gather(output, data, indices, dataAxis) + if err != nil { + return nil, err + } + + return []tensor.Tensor{output}, nil +} + +// ValidateInputs validates the inputs that will be given to Apply for this operator. +func (g *Gather1) ValidateInputs(inputs []tensor.Tensor) ([]tensor.Tensor, error) { + return ops.ValidateInputs(g, inputs) +} + +// GetMinInputs returns the minimum number of input tensors this operator expects. +func (g *Gather1) GetMinInputs() int { + return MinGather1Inputs +} + +// GetMaxInputs returns the maximum number of input tensors this operator expects. +func (g *Gather1) GetMaxInputs() int { + return MaxGather1Inputs +} + +// GetInputTypeConstraints returns a list. Every element represents a set of allowed tensor dtypes +// for the corresponding input tensor. +func (g *Gather1) GetInputTypeConstraints() [][]tensor.Dtype { + return [][]tensor.Dtype{ + ops.AllTypes, + {tensor.Int32, tensor.Int64}, + } +} + +// String implements the stringer interface, and can be used to format errors or messages. +func (g *Gather1) String() string { + return "gather1 operator" +} diff --git a/ops/gather/gather_11.go b/ops/gather/gather_11.go new file mode 100644 index 0000000..3677d39 --- /dev/null +++ b/ops/gather/gather_11.go @@ -0,0 +1,121 @@ +package gather + +import ( + "github.com/advancedclimatesystems/gonnx/onnx" + "github.com/advancedclimatesystems/gonnx/ops" + "gorgonia.org/tensor" +) + +const ( + MinGather11Inputs = 2 + MaxGather11Inputs = 2 +) + +// Gather11 represents the ONNX gather operator. +type Gather11 struct { + axis int // axis to gather on, default is 0 +} + +// newGather11 creates a new gather operator. +func newGather11() ops.Operator { + return &Gather11{ + axis: 0, + } +} + +// Init initializes the gather operator. +func (g *Gather11) Init(n *onnx.NodeProto) error { + attributes := n.GetAttribute() + + if len(attributes) == 1 { + attr := attributes[0] + + if attr.GetName() == "axis" { + g.axis = int(attr.GetI()) + } else { + return ops.ErrInvalidAttribute(attr.GetName(), g) + } + } else if len(attributes) > 1 { + return ops.ErrInvalidAttributeCount(1, len(attributes), g) + } + + return nil +} + +// Apply applies the gather operator. +func (g *Gather11) Apply(inputs []tensor.Tensor) ([]tensor.Tensor, error) { + // Convert the indices (of Dtype Int32 or Int64) to a tensor with Dtype Int + indicesData, err := ops.AnyToIntSlice(ops.IfScalarToSlice(inputs[1].Data())) + if err != nil { + return nil, err + } + + indices := tensor.New(tensor.WithBacking(indicesData), tensor.WithShape(inputs[1].Shape()...)) + + data := inputs[0] + + // Make sure axis is in the correct range (according to the size of the data tensor) + rank := len(data.Shape()) + dataAxis := g.axis + + if dataAxis < -rank || dataAxis > rank-1 { + return nil, ops.ErrAxisOutOfRange(rank, rank, dataAxis) + } + // Offset axis if a negative index is given. + if dataAxis < 0 { + dataAxis += rank + } + + // Make sure the input indices are all in the correct range (according to the size of the + // dimension which is selected by `axis`) + axisDimSize := data.Shape()[dataAxis] + if !ops.AllInRange(indicesData, -axisDimSize, axisDimSize-1) { + return nil, ops.ErrNotAllAxesInRange(axisDimSize, axisDimSize) + } + + err = ops.OffsetTensorIfNegative(indices, axisDimSize) + if err != nil { + return nil, err + } + + // Make the shape of the output tensor + os := insertWithReplace(indices.Shape(), data.Shape(), dataAxis) + output := tensor.New(tensor.WithShape(os...), tensor.Of(data.Dtype())) + + // Perform the actual gather operation + err = gather(output, data, indices, dataAxis) + if err != nil { + return nil, err + } + + return []tensor.Tensor{output}, nil +} + +// ValidateInputs validates the inputs that will be given to Apply for this operator. +func (g *Gather11) ValidateInputs(inputs []tensor.Tensor) ([]tensor.Tensor, error) { + return ops.ValidateInputs(g, inputs) +} + +// GetMinInputs returns the minimum number of input tensors this operator expects. +func (g *Gather11) GetMinInputs() int { + return MinGather11Inputs +} + +// GetMaxInputs returns the maximum number of input tensors this operator expects. +func (g *Gather11) GetMaxInputs() int { + return MaxGather11Inputs +} + +// GetInputTypeConstraints returns a list. Every element represents a set of allowed tensor dtypes +// for the corresponding input tensor. +func (g *Gather11) GetInputTypeConstraints() [][]tensor.Dtype { + return [][]tensor.Dtype{ + ops.AllTypes, + {tensor.Int32, tensor.Int64}, + } +} + +// String implements the stringer interface, and can be used to format errors or messages. +func (g *Gather11) String() string { + return "gather11 operator" +} diff --git a/ops/opset13/gather.go b/ops/gather/gather_13.go similarity index 89% rename from ops/opset13/gather.go rename to ops/gather/gather_13.go index e6e7f3f..f4bd7fc 100644 --- a/ops/opset13/gather.go +++ b/ops/gather/gather_13.go @@ -1,4 +1,4 @@ -package opset13 +package gather import ( "github.com/advancedclimatesystems/gonnx/onnx" @@ -7,24 +7,24 @@ import ( ) const ( - MinGatherInputs = 2 - MaxGatherInputs = 2 + MinGather13Inputs = 2 + MaxGather13Inputs = 2 ) -// Gather represents the ONNX gather operator. -type Gather struct { +// Gather13 represents the ONNX gather operator. +type Gather13 struct { axis int // axis to gather on, default is 0 } -// newGather creates a new gather operator. -func newGather() ops.Operator { - return &Gather{ +// newGather13 creates a new gather operator. +func newGather13() ops.Operator { + return &Gather13{ axis: 0, } } // Init initializes the gather operator. -func (g *Gather) Init(n *onnx.NodeProto) error { +func (g *Gather13) Init(n *onnx.NodeProto) error { attributes := n.GetAttribute() if len(attributes) == 1 { @@ -43,7 +43,7 @@ func (g *Gather) Init(n *onnx.NodeProto) error { } // Apply applies the gather operator. -func (g *Gather) Apply(inputs []tensor.Tensor) ([]tensor.Tensor, error) { +func (g *Gather13) Apply(inputs []tensor.Tensor) ([]tensor.Tensor, error) { // Convert the indices (of Dtype Int32 or Int64) to a tensor with Dtype Int indicesData, err := ops.AnyToIntSlice(ops.IfScalarToSlice(inputs[1].Data())) if err != nil { @@ -92,23 +92,23 @@ func (g *Gather) Apply(inputs []tensor.Tensor) ([]tensor.Tensor, error) { } // ValidateInputs validates the inputs that will be given to Apply for this operator. -func (g *Gather) ValidateInputs(inputs []tensor.Tensor) ([]tensor.Tensor, error) { +func (g *Gather13) ValidateInputs(inputs []tensor.Tensor) ([]tensor.Tensor, error) { return ops.ValidateInputs(g, inputs) } // GetMinInputs returns the minimum number of input tensors this operator expects. -func (g *Gather) GetMinInputs() int { - return MinGatherInputs +func (g *Gather13) GetMinInputs() int { + return MinGather13Inputs } // GetMaxInputs returns the maximum number of input tensors this operator expects. -func (g *Gather) GetMaxInputs() int { - return MaxGatherInputs +func (g *Gather13) GetMaxInputs() int { + return MaxGather13Inputs } // GetInputTypeConstraints returns a list. Every element represents a set of allowed tensor dtypes // for the corresponding input tensor. -func (g *Gather) GetInputTypeConstraints() [][]tensor.Dtype { +func (g *Gather13) GetInputTypeConstraints() [][]tensor.Dtype { return [][]tensor.Dtype{ ops.AllTypes, {tensor.Int32, tensor.Int64}, @@ -116,8 +116,8 @@ func (g *Gather) GetInputTypeConstraints() [][]tensor.Dtype { } // String implements the stringer interface, and can be used to format errors or messages. -func (g *Gather) String() string { - return "gather operator" +func (g *Gather13) String() string { + return "gather13 operator" } // Perform gather according to the definition given by ONNX : diff --git a/ops/opset13/gather_test.go b/ops/gather/gather_13_test.go similarity index 86% rename from ops/opset13/gather_test.go rename to ops/gather/gather_13_test.go index e48925a..d9df16c 100644 --- a/ops/opset13/gather_test.go +++ b/ops/gather/gather_13_test.go @@ -1,10 +1,11 @@ -package opset13 +package gather import ( "testing" "github.com/advancedclimatesystems/gonnx/onnx" "github.com/advancedclimatesystems/gonnx/ops" + "github.com/advancedclimatesystems/gonnx/ops/concat" "github.com/stretchr/testify/assert" "gorgonia.org/tensor" ) @@ -15,34 +16,34 @@ func makeAxisProto(n int) *onnx.NodeProto { } } -func TestGatherInit(t *testing.T) { +func TestGather13Init(t *testing.T) { attrs := makeAxisProto(1) - op := Gather{} + op := Gather13{} err := op.Init(attrs) assert.NoError(t, err) assert.Equal(t, op.axis, 1) } -func TestGatherInitDefault(t *testing.T) { - op := Gather{} +func TestGather13InitDefault(t *testing.T) { + op := Gather13{} err := op.Init(ops.EmptyNodeProto()) assert.NoError(t, err) assert.Equal(t, op.axis, 0) } -func TestGatherInitTooManyAttrs(t *testing.T) { - op := Gather{} +func TestGather13InitTooManyAttrs(t *testing.T) { + op := Gather13{} err := op.Init(&onnx.NodeProto{Attribute: []*onnx.AttributeProto{{Name: "axis"}, {Name: "default"}}}) - assert.EqualError(t, err, "gather operator attribute error: invalid count 2 expected 1") + assert.EqualError(t, err, "gather13 operator attribute error: invalid count 2 expected 1") } -func TestGatherInitInvalidAttrName(t *testing.T) { - op := Gather{} +func TestGather13InitInvalidAttrName(t *testing.T) { + op := Gather13{} err := op.Init(&onnx.NodeProto{Attribute: []*onnx.AttributeProto{{Name: "axes"}}}) // should be axis - assert.EqualError(t, err, "gather operator attribute error: invalid attribute axes") + assert.EqualError(t, err, "gather13 operator attribute error: invalid attribute axes") } -func TestGather(t *testing.T) { +func TestGather13(t *testing.T) { tests := []struct { data interface{} shape []int @@ -186,7 +187,7 @@ func TestGather(t *testing.T) { } for _, test := range tests { - op := &Gather{test.axis} + op := &Gather13{test.axis} indices := test.indices data := test.data @@ -202,7 +203,7 @@ func TestGather(t *testing.T) { } func TestCombinedWithOtherOp(t *testing.T) { - concat := &Concat{} + concat := &concat.Concat13{} err := concat.Init(&onnx.NodeProto{Attribute: []*onnx.AttributeProto{{Name: "axis", I: 0}}}) assert.NoError(t, err) @@ -212,7 +213,7 @@ func TestCombinedWithOtherOp(t *testing.T) { data, err := concat.Apply([]tensor.Tensor{data0, data1}) assert.NoError(t, err) - gather := &Gather{0} + gather := &Gather13{0} indices := tensor.New(tensor.WithBacking([]int64{1}), tensor.WithShape(1)) res, err := gather.Apply([]tensor.Tensor{data[0], indices}) @@ -221,7 +222,7 @@ func TestCombinedWithOtherOp(t *testing.T) { } func TestScalarInput(t *testing.T) { - op := &Gather{0} + op := &Gather13{0} dataIn := tensor.New(tensor.WithBacking([]int64{1}), tensor.WithShape(1)) @@ -233,8 +234,8 @@ func TestScalarInput(t *testing.T) { assert.Equal(t, int64(1), res[0].Data()) } -func TestGatherAxesIndexOutOfRange(t *testing.T) { - op := &Gather{} +func TestGather13AxesIndexOutOfRange(t *testing.T) { + op := &Gather13{} err := op.Init(makeAxisProto(1)) assert.NoError(t, err) @@ -246,8 +247,8 @@ func TestGatherAxesIndexOutOfRange(t *testing.T) { assert.EqualError(t, err, "axis out of range: axis argument must be in the range -1 <= x < 1, was 1") } -func TestGatherIndexOutOfRange(t *testing.T) { - op := &Gather{0} +func TestGather13IndexOutOfRange(t *testing.T) { + op := &Gather13{0} dataIn := tensor.New(tensor.WithBacking([]int64{1}), tensor.WithShape(1)) indicesIn := tensor.New(tensor.WithBacking([]int64{2}), tensor.WithShape(1)) @@ -257,7 +258,7 @@ func TestGatherIndexOutOfRange(t *testing.T) { assert.EqualError(t, err, "axis out of range: all indices entries must be in the range -1 <= x < 1") } -func TestInputValidationGather(t *testing.T) { +func TestInputValidationGather13(t *testing.T) { tests := []struct { inputs []tensor.Tensor err error @@ -278,19 +279,19 @@ func TestInputValidationGather(t *testing.T) { }, { []tensor.Tensor{ops.TensorWithBackingFixture([]int{1, 2}, 2)}, - ops.ErrInvalidInputCount(1, &Gather{}), + ops.ErrInvalidInputCount(1, &Gather13{}), }, { []tensor.Tensor{ ops.TensorWithBackingFixture([]float64{1, 2}, 2), ops.TensorWithBackingFixture([]float32{3, 4}, 2), }, - ops.ErrInvalidInputType(1, "float32", &Gather{}), + ops.ErrInvalidInputType(1, "float32", &Gather13{}), }, } for _, test := range tests { - gather := &Gather{} + gather := &Gather13{} validated, err := gather.ValidateInputs(test.inputs) assert.Equal(t, test.err, err) diff --git a/ops/gather/versions.go b/ops/gather/versions.go new file mode 100644 index 0000000..dda16f1 --- /dev/null +++ b/ops/gather/versions.go @@ -0,0 +1,9 @@ +package gather + +import "github.com/advancedclimatesystems/gonnx/ops" + +var GatherVersions = ops.OperatorVersions{ + 1: newGather1, // Same, but with negative axis + 11: newGather11, // Same, but with bfloat16 + 13: newGather13, +} diff --git a/ops/opset13/gemm.go b/ops/gemm/gemm_11.go similarity index 74% rename from ops/opset13/gemm.go rename to ops/gemm/gemm_11.go index 2db2a44..44c7638 100644 --- a/ops/opset13/gemm.go +++ b/ops/gemm/gemm_11.go @@ -1,4 +1,4 @@ -package opset13 +package gemm import ( "github.com/advancedclimatesystems/gonnx/onnx" @@ -7,21 +7,21 @@ import ( ) const ( - MinGemmInputs = 2 - MaxGemmInputs = 3 + MinGemm11Inputs = 2 + MaxGemm11Inputs = 3 ) -// Gemm represents the ONNX gemm operator. -type Gemm struct { +// Gemm11 represents the ONNX gemm operator. +type Gemm11 struct { alpha float32 beta float32 transA bool transB bool } -// newGemm creates a new gemm operator and initializes it with the default values. -func newGemm() ops.Operator { - return &Gemm{ +// newGemm11 creates a new gemm operator and initializes it with the default values. +func newGemm11() ops.Operator { + return &Gemm11{ alpha: 1.0, beta: 1.0, transA: false, @@ -29,8 +29,8 @@ func newGemm() ops.Operator { } } -// Init initializes the Gemm operator based on the ModelProto attributes. -func (g *Gemm) Init(n *onnx.NodeProto) error { +// Init initializes the Gemm11 operator based on the ModelProto attributes. +func (g *Gemm11) Init(n *onnx.NodeProto) error { for _, attr := range n.GetAttribute() { switch attr.GetName() { case "alpha": @@ -50,7 +50,7 @@ func (g *Gemm) Init(n *onnx.NodeProto) error { } // Apply applies the gemm operator on the given graph. -func (g *Gemm) Apply(inputs []tensor.Tensor) ([]tensor.Tensor, error) { +func (g *Gemm11) Apply(inputs []tensor.Tensor) ([]tensor.Tensor, error) { var err error a := inputs[0] @@ -105,23 +105,23 @@ func (g *Gemm) Apply(inputs []tensor.Tensor) ([]tensor.Tensor, error) { } // ValidateInputs validates the inputs that will be given to Apply for this operator. -func (g *Gemm) ValidateInputs(inputs []tensor.Tensor) ([]tensor.Tensor, error) { +func (g *Gemm11) ValidateInputs(inputs []tensor.Tensor) ([]tensor.Tensor, error) { return ops.ValidateInputs(g, inputs) } // GetMinInputs returns the minimum number of input tensors this operator expects. -func (g *Gemm) GetMinInputs() int { - return MinGemmInputs +func (g *Gemm11) GetMinInputs() int { + return MinGemm11Inputs } // GetMaxInputs returns the maximum number of input tensors this operator expects. -func (g *Gemm) GetMaxInputs() int { - return MaxGemmInputs +func (g *Gemm11) GetMaxInputs() int { + return MaxGemm11Inputs } // GetInputTypeConstraints returns a list. Every element represents a set of allowed tensor dtypes // for the corresponding input tensor. -func (g *Gemm) GetInputTypeConstraints() [][]tensor.Dtype { +func (g *Gemm11) GetInputTypeConstraints() [][]tensor.Dtype { return [][]tensor.Dtype{ {tensor.Uint32, tensor.Uint64, tensor.Int32, tensor.Int64, tensor.Float32, tensor.Float64}, {tensor.Uint32, tensor.Uint64, tensor.Int32, tensor.Int64, tensor.Float32, tensor.Float64}, @@ -130,6 +130,6 @@ func (g *Gemm) GetInputTypeConstraints() [][]tensor.Dtype { } // String implements the stringer interface, and can be used to format errors or messages. -func (g *Gemm) String() string { - return "gemm operator" +func (g *Gemm11) String() string { + return "gemm11 operator" } diff --git a/ops/gemm/gemm_13.go b/ops/gemm/gemm_13.go new file mode 100644 index 0000000..b3c6de9 --- /dev/null +++ b/ops/gemm/gemm_13.go @@ -0,0 +1,135 @@ +package gemm + +import ( + "github.com/advancedclimatesystems/gonnx/onnx" + "github.com/advancedclimatesystems/gonnx/ops" + "gorgonia.org/tensor" +) + +const ( + MinGemm13Inputs = 2 + MaxGemm13Inputs = 3 +) + +// Gemm13 represents the ONNX gemm operator. +type Gemm13 struct { + alpha float32 + beta float32 + transA bool + transB bool +} + +// newGemm13 creates a new gemm operator and initializes it with the default values. +func newGemm13() ops.Operator { + return &Gemm13{ + alpha: 1.0, + beta: 1.0, + transA: false, + transB: false, + } +} + +// Init initializes the Gemm13 operator based on the ModelProto attributes. +func (g *Gemm13) Init(n *onnx.NodeProto) error { + for _, attr := range n.GetAttribute() { + switch attr.GetName() { + case "alpha": + g.alpha = attr.GetF() + case "beta": + g.beta = attr.GetF() + case "transA": + g.transA = ops.Int64ToBool(attr.GetI()) + case "transB": + g.transB = ops.Int64ToBool(attr.GetI()) + default: + return ops.ErrInvalidAttribute(attr.GetName(), g) + } + } + + return nil +} + +// Apply applies the gemm operator on the given graph. +func (g *Gemm13) Apply(inputs []tensor.Tensor) ([]tensor.Tensor, error) { + var err error + + a := inputs[0] + b := inputs[1] + c := inputs[2] + + if g.transA { + a, err = tensor.Transpose(a) + if err != nil { + return nil, err + } + } + + if g.transB { + b, err = tensor.Transpose(b) + if err != nil { + return nil, err + } + } + + x, err := tensor.MatMul(a, b) + if err != nil { + return nil, err + } + + x, err = tensor.Mul(x, g.alpha) + if err != nil { + return nil, err + } + + // If C was not given, it is assumed to be 0, hence we can stop the calculation here. + if c == nil { + return []tensor.Tensor{x}, nil + } + + y, err := tensor.Mul(c, g.beta) + if err != nil { + return nil, err + } + + x, y, err = ops.UnidirectionalBroadcast(x, y) + if err != nil { + return nil, err + } + + output, err := tensor.Add(x, y) + if err != nil { + return nil, err + } + + return []tensor.Tensor{output}, nil +} + +// ValidateInputs validates the inputs that will be given to Apply for this operator. +func (g *Gemm13) ValidateInputs(inputs []tensor.Tensor) ([]tensor.Tensor, error) { + return ops.ValidateInputs(g, inputs) +} + +// GetMinInputs returns the minimum number of input tensors this operator expects. +func (g *Gemm13) GetMinInputs() int { + return MinGemm13Inputs +} + +// GetMaxInputs returns the maximum number of input tensors this operator expects. +func (g *Gemm13) GetMaxInputs() int { + return MaxGemm13Inputs +} + +// GetInputTypeConstraints returns a list. Every element represents a set of allowed tensor dtypes +// for the corresponding input tensor. +func (g *Gemm13) GetInputTypeConstraints() [][]tensor.Dtype { + return [][]tensor.Dtype{ + {tensor.Uint32, tensor.Uint64, tensor.Int32, tensor.Int64, tensor.Float32, tensor.Float64}, + {tensor.Uint32, tensor.Uint64, tensor.Int32, tensor.Int64, tensor.Float32, tensor.Float64}, + {tensor.Uint32, tensor.Uint64, tensor.Int32, tensor.Int64, tensor.Float32, tensor.Float64}, + } +} + +// String implements the stringer interface, and can be used to format errors or messages. +func (g *Gemm13) String() string { + return "gemm13 operator" +} diff --git a/ops/opset13/gemm_test.go b/ops/gemm/gemm_13_test.go similarity index 84% rename from ops/opset13/gemm_test.go rename to ops/gemm/gemm_13_test.go index 37255d4..ae70808 100644 --- a/ops/opset13/gemm_test.go +++ b/ops/gemm/gemm_13_test.go @@ -1,4 +1,4 @@ -package opset13 +package gemm import ( "testing" @@ -9,9 +9,9 @@ import ( "gorgonia.org/tensor" ) -func TestGemmInit(t *testing.T) { - gemm := Gemm{} - err := gemm.Init(GemmOnnxNodeProtoFixture()) +func TestGemm13Init(t *testing.T) { + gemm := Gemm13{} + err := gemm.Init(Gemm13OnnxNodeProtoFixture()) assert.Nil(t, err) assert.Equal(t, float32(10.0), gemm.alpha) @@ -20,52 +20,52 @@ func TestGemmInit(t *testing.T) { assert.Equal(t, true, gemm.transB) } -func TestGemmInitFail(t *testing.T) { - gemm := &Gemm{} +func TestGemm13InitFail(t *testing.T) { + gemm := &Gemm13{} err := gemm.Init(&onnx.NodeProto{Attribute: []*onnx.AttributeProto{{Name: "unknownAttribute"}}}) expected := ops.ErrInvalidAttribute("unknownAttribute", gemm) assert.Equal(t, expected, err) } -func TestGemm(t *testing.T) { +func TestGemm13(t *testing.T) { tests := []struct { - gemm *Gemm + gemm *Gemm13 shapes [][]int expected []float32 }{ { - &Gemm{1, 1, false, false}, + &Gemm13{1, 1, false, false}, [][]int{{3, 2}, {2, 5}, {5}}, []float32{5, 7, 9, 11, 13, 15, 21, 27, 33, 39, 25, 35, 45, 55, 65}, }, { - &Gemm{1, 1, true, false}, + &Gemm13{1, 1, true, false}, [][]int{{2, 3}, {2, 5}, {5}}, []float32{15, 19, 23, 27, 31, 20, 26, 32, 38, 44, 25, 33, 41, 49, 57}, }, { - &Gemm{1, 1, true, true}, + &Gemm13{1, 1, true, true}, [][]int{{2, 3}, {5, 2}, {5}}, []float32{3, 10, 17, 24, 31, 4, 15, 26, 37, 48, 5, 20, 35, 50, 65}, }, { - &Gemm{1, 1, false, true}, + &Gemm13{1, 1, false, true}, [][]int{{3, 2}, {5, 2}, {5}}, []float32{1, 4, 7, 10, 13, 3, 14, 25, 36, 47, 5, 24, 43, 62, 81}, }, { - &Gemm{1, 1, false, false}, + &Gemm13{1, 1, false, false}, [][]int{{1, 2}, {2, 5}, {5}}, []float32{5, 7, 9, 11, 13}, }, { - &Gemm{1, 1, false, false}, + &Gemm13{1, 1, false, false}, [][]int{{1, 2}, {2, 5}}, []float32{5, 6, 7, 8, 9}, }, { - &Gemm{1, 1, false, false}, + &Gemm13{1, 1, false, false}, [][]int{{20, 4}, {4, 6}, {6}}, []float32{ 84, 91, 98, 105, 112, 119, 228, 251, 274, @@ -104,7 +104,7 @@ func TestGemm(t *testing.T) { } } -func TestInputValidationGemm(t *testing.T) { +func TestInputValidationGemm13(t *testing.T) { tests := []struct { inputs []tensor.Tensor expected []tensor.Tensor @@ -134,7 +134,7 @@ func TestInputValidationGemm(t *testing.T) { { []tensor.Tensor{ops.TensorWithBackingFixture([]int{1, 2}, 2)}, nil, - ops.ErrInvalidOptionalInputCount(1, &Gemm{}), + ops.ErrInvalidOptionalInputCount(1, &Gemm13{}), }, { []tensor.Tensor{ @@ -144,7 +144,7 @@ func TestInputValidationGemm(t *testing.T) { ops.TensorWithBackingFixture([]uint32{1, 2}, 2), }, nil, - ops.ErrInvalidOptionalInputCount(4, &Gemm{}), + ops.ErrInvalidOptionalInputCount(4, &Gemm13{}), }, { []tensor.Tensor{ @@ -152,12 +152,12 @@ func TestInputValidationGemm(t *testing.T) { ops.TensorWithBackingFixture([]int{3, 4}, 2), }, nil, - ops.ErrInvalidInputType(0, "int", &Gemm{}), + ops.ErrInvalidInputType(0, "int", &Gemm13{}), }, } for _, test := range tests { - gemm := &Gemm{} + gemm := &Gemm13{} validated, err := gemm.ValidateInputs(test.inputs) assert.Equal(t, test.err, err) @@ -172,7 +172,7 @@ func TestInputValidationGemm(t *testing.T) { } } -func GemmOnnxNodeProtoFixture() *onnx.NodeProto { +func Gemm13OnnxNodeProtoFixture() *onnx.NodeProto { return &onnx.NodeProto{ Attribute: []*onnx.AttributeProto{ {Name: "alpha", F: 10.0}, diff --git a/ops/gemm/gemm_7.go b/ops/gemm/gemm_7.go new file mode 100644 index 0000000..b7192f8 --- /dev/null +++ b/ops/gemm/gemm_7.go @@ -0,0 +1,130 @@ +package gemm + +import ( + "github.com/advancedclimatesystems/gonnx/onnx" + "github.com/advancedclimatesystems/gonnx/ops" + "gorgonia.org/tensor" +) + +const ( + MinGemm7Inputs = 3 + MaxGemm7Inputs = 3 +) + +// Gemm7 represents the ONNX gemm operator. +type Gemm7 struct { + alpha float32 + beta float32 + transA bool + transB bool +} + +// newGemm7 creates a new gemm operator and initializes it with the default values. +func newGemm7() ops.Operator { + return &Gemm7{ + alpha: 1.0, + beta: 1.0, + transA: false, + transB: false, + } +} + +// Init initializes the Gemm7 operator based on the ModelProto attributes. +func (g *Gemm7) Init(n *onnx.NodeProto) error { + for _, attr := range n.GetAttribute() { + switch attr.GetName() { + case "alpha": + g.alpha = attr.GetF() + case "beta": + g.beta = attr.GetF() + case "transA": + g.transA = ops.Int64ToBool(attr.GetI()) + case "transB": + g.transB = ops.Int64ToBool(attr.GetI()) + default: + return ops.ErrInvalidAttribute(attr.GetName(), g) + } + } + + return nil +} + +// Apply applies the gemm operator on the given graph. +func (g *Gemm7) Apply(inputs []tensor.Tensor) ([]tensor.Tensor, error) { + var err error + + a := inputs[0] + b := inputs[1] + c := inputs[2] + + if g.transA { + a, err = tensor.Transpose(a) + if err != nil { + return nil, err + } + } + + if g.transB { + b, err = tensor.Transpose(b) + if err != nil { + return nil, err + } + } + + x, err := tensor.MatMul(a, b) + if err != nil { + return nil, err + } + + x, err = tensor.Mul(x, g.alpha) + if err != nil { + return nil, err + } + + y, err := tensor.Mul(c, g.beta) + if err != nil { + return nil, err + } + + x, y, err = ops.UnidirectionalBroadcast(x, y) + if err != nil { + return nil, err + } + + output, err := tensor.Add(x, y) + if err != nil { + return nil, err + } + + return []tensor.Tensor{output}, nil +} + +// ValidateInputs validates the inputs that will be given to Apply for this operator. +func (g *Gemm7) ValidateInputs(inputs []tensor.Tensor) ([]tensor.Tensor, error) { + return ops.ValidateInputs(g, inputs) +} + +// GetMinInputs returns the minimum number of input tensors this operator expects. +func (g *Gemm7) GetMinInputs() int { + return MinGemm7Inputs +} + +// GetMaxInputs returns the maximum number of input tensors this operator expects. +func (g *Gemm7) GetMaxInputs() int { + return MaxGemm7Inputs +} + +// GetInputTypeConstraints returns a list. Every element represents a set of allowed tensor dtypes +// for the corresponding input tensor. +func (g *Gemm7) GetInputTypeConstraints() [][]tensor.Dtype { + return [][]tensor.Dtype{ + {tensor.Float32, tensor.Float64}, + {tensor.Float32, tensor.Float64}, + {tensor.Float32, tensor.Float64}, + } +} + +// String implements the stringer interface, and can be used to format errors or messages. +func (g *Gemm7) String() string { + return "gemm7 operator" +} diff --git a/ops/gemm/gemm_9.go b/ops/gemm/gemm_9.go new file mode 100644 index 0000000..b06fb1c --- /dev/null +++ b/ops/gemm/gemm_9.go @@ -0,0 +1,130 @@ +package gemm + +import ( + "github.com/advancedclimatesystems/gonnx/onnx" + "github.com/advancedclimatesystems/gonnx/ops" + "gorgonia.org/tensor" +) + +const ( + MinGemm9Inputs = 3 + MaxGemm9Inputs = 3 +) + +// Gemm9 represents the ONNX gemm operator. +type Gemm9 struct { + alpha float32 + beta float32 + transA bool + transB bool +} + +// newGemm9 creates a new gemm operator and initializes it with the default values. +func newGemm9() ops.Operator { + return &Gemm9{ + alpha: 1.0, + beta: 1.0, + transA: false, + transB: false, + } +} + +// Init initializes the Gemm9 operator based on the ModelProto attributes. +func (g *Gemm9) Init(n *onnx.NodeProto) error { + for _, attr := range n.GetAttribute() { + switch attr.GetName() { + case "alpha": + g.alpha = attr.GetF() + case "beta": + g.beta = attr.GetF() + case "transA": + g.transA = ops.Int64ToBool(attr.GetI()) + case "transB": + g.transB = ops.Int64ToBool(attr.GetI()) + default: + return ops.ErrInvalidAttribute(attr.GetName(), g) + } + } + + return nil +} + +// Apply applies the gemm operator on the given graph. +func (g *Gemm9) Apply(inputs []tensor.Tensor) ([]tensor.Tensor, error) { + var err error + + a := inputs[0] + b := inputs[1] + c := inputs[2] + + if g.transA { + a, err = tensor.Transpose(a) + if err != nil { + return nil, err + } + } + + if g.transB { + b, err = tensor.Transpose(b) + if err != nil { + return nil, err + } + } + + x, err := tensor.MatMul(a, b) + if err != nil { + return nil, err + } + + x, err = tensor.Mul(x, g.alpha) + if err != nil { + return nil, err + } + + y, err := tensor.Mul(c, g.beta) + if err != nil { + return nil, err + } + + x, y, err = ops.UnidirectionalBroadcast(x, y) + if err != nil { + return nil, err + } + + output, err := tensor.Add(x, y) + if err != nil { + return nil, err + } + + return []tensor.Tensor{output}, nil +} + +// ValidateInputs validates the inputs that will be given to Apply for this operator. +func (g *Gemm9) ValidateInputs(inputs []tensor.Tensor) ([]tensor.Tensor, error) { + return ops.ValidateInputs(g, inputs) +} + +// GetMinInputs returns the minimum number of input tensors this operator expects. +func (g *Gemm9) GetMinInputs() int { + return MinGemm9Inputs +} + +// GetMaxInputs returns the maximum number of input tensors this operator expects. +func (g *Gemm9) GetMaxInputs() int { + return MaxGemm9Inputs +} + +// GetInputTypeConstraints returns a list. Every element represents a set of allowed tensor dtypes +// for the corresponding input tensor. +func (g *Gemm9) GetInputTypeConstraints() [][]tensor.Dtype { + return [][]tensor.Dtype{ + {tensor.Uint32, tensor.Uint64, tensor.Int32, tensor.Int64, tensor.Float32, tensor.Float64}, + {tensor.Uint32, tensor.Uint64, tensor.Int32, tensor.Int64, tensor.Float32, tensor.Float64}, + {tensor.Uint32, tensor.Uint64, tensor.Int32, tensor.Int64, tensor.Float32, tensor.Float64}, + } +} + +// String implements the stringer interface, and can be used to format errors or messages. +func (g *Gemm9) String() string { + return "gemm9 operator" +} diff --git a/ops/gemm/versions.go b/ops/gemm/versions.go new file mode 100644 index 0000000..464b392 --- /dev/null +++ b/ops/gemm/versions.go @@ -0,0 +1,10 @@ +package gemm + +import "github.com/advancedclimatesystems/gonnx/ops" + +var GemmVersions = ops.OperatorVersions{ + 7: newGemm7, + 9: newGemm9, + 11: newGemm11, + 13: newGemm13, +} diff --git a/ops/greater/greater_13.go b/ops/greater/greater_13.go new file mode 100644 index 0000000..c4371aa --- /dev/null +++ b/ops/greater/greater_13.go @@ -0,0 +1,61 @@ +package greater + +import ( + "github.com/advancedclimatesystems/gonnx/onnx" + "github.com/advancedclimatesystems/gonnx/ops" + "gorgonia.org/tensor" +) + +var ( + MinGreater13Inputs = 2 + MaxGreater13Inputs = 2 +) + +// Greater13 represents the ONNX greater operator. +type Greater13 struct{} + +// newGreater13 creates a new greater operator. +func newGreater13() ops.Operator { + return &Greater13{} +} + +// Init initializes the greater operator. +func (g *Greater13) Init(*onnx.NodeProto) error { + return nil +} + +// Apply applies the greater operator. +func (g *Greater13) Apply(inputs []tensor.Tensor) ([]tensor.Tensor, error) { + return ops.ApplyBinaryOperation( + inputs[0], + inputs[1], + ops.Gt, + ops.MultidirectionalBroadcasting, + ) +} + +// ValidateInputs validates the inputs that will be given to Apply for this operator. +func (g *Greater13) ValidateInputs(inputs []tensor.Tensor) ([]tensor.Tensor, error) { + return ops.ValidateInputs(g, inputs) +} + +// GetMinInputs returns the minimum number of input tensors this operator expects. +func (g *Greater13) GetMinInputs() int { + return MinGreater13Inputs +} + +// GetMaxInputs returns the maximum number of input tensors this operator expects. +func (g *Greater13) GetMaxInputs() int { + return MaxGreater13Inputs +} + +// GetInputTypeConstraints returns a list. Every element represents a set of allowed tensor dtypes +// for the corresponding input tensor. +func (g *Greater13) GetInputTypeConstraints() [][]tensor.Dtype { + return [][]tensor.Dtype{ops.AllTypes, ops.AllTypes} +} + +// String implements the stringer interface, and can be used to format errors or messages. +func (g *Greater13) String() string { + return "greater13 operator" +} diff --git a/ops/opset13/greater_test.go b/ops/greater/greater_13_test.go similarity index 87% rename from ops/opset13/greater_test.go rename to ops/greater/greater_13_test.go index 18bc294..7606829 100644 --- a/ops/opset13/greater_test.go +++ b/ops/greater/greater_13_test.go @@ -1,4 +1,4 @@ -package opset13 +package greater import ( "testing" @@ -8,8 +8,8 @@ import ( "gorgonia.org/tensor" ) -func TestGreaterInit(t *testing.T) { - g := &Greater{} +func TestGreater13Init(t *testing.T) { + g := &Greater13{} // since 'greater' does not have any attributes we pass in nil. This should not // fail initializing the greater. @@ -17,27 +17,27 @@ func TestGreaterInit(t *testing.T) { assert.Nil(t, err) } -func TestGreater(t *testing.T) { +func TestGreater13(t *testing.T) { tests := []struct { - greater *Greater + greater *Greater13 backings [][]float32 shapes [][]int expected []bool }{ { - &Greater{}, + &Greater13{}, [][]float32{{0, 1, 2, 3}, {1, 1, 1, 1}}, [][]int{{2, 2}, {2, 2}}, []bool{false, false, true, true}, }, { - &Greater{}, + &Greater13{}, [][]float32{{0, 1, 2, 3, 4, 5}, {2, 2, 2, 2, 2, 2}}, [][]int{{3, 2}, {3, 2}}, []bool{false, false, false, true, true, true}, }, { - &Greater{}, + &Greater13{}, [][]float32{{0, 1}, {0, 1, 2, 3}}, [][]int{{2}, {2, 2}}, []bool{false, false, false, false}, @@ -58,7 +58,7 @@ func TestGreater(t *testing.T) { } } -func TestInputValidationGreater(t *testing.T) { +func TestInputValidationGreater13(t *testing.T) { tests := []struct { inputs []tensor.Tensor err error @@ -109,19 +109,19 @@ func TestInputValidationGreater(t *testing.T) { []tensor.Tensor{ ops.TensorWithBackingFixture([]int{1, 2}, 2), }, - ops.ErrInvalidInputCount(1, &Greater{}), + ops.ErrInvalidInputCount(1, &Greater13{}), }, { []tensor.Tensor{ ops.TensorWithBackingFixture([]int{1, 2}, 2), ops.TensorWithBackingFixture([]int{3, 4}, 2), }, - ops.ErrInvalidInputType(0, "int", &Greater{}), + ops.ErrInvalidInputType(0, "int", &Greater13{}), }, } for _, test := range tests { - greater := &Greater{} + greater := &Greater13{} validated, err := greater.ValidateInputs(test.inputs) assert.Equal(t, test.err, err) diff --git a/ops/greater/greater_7.go b/ops/greater/greater_7.go new file mode 100644 index 0000000..81723f8 --- /dev/null +++ b/ops/greater/greater_7.go @@ -0,0 +1,61 @@ +package greater + +import ( + "github.com/advancedclimatesystems/gonnx/onnx" + "github.com/advancedclimatesystems/gonnx/ops" + "gorgonia.org/tensor" +) + +var ( + MinGreater7Inputs = 2 + MaxGreater7Inputs = 2 +) + +// Greater7 represents the ONNX greater operator. +type Greater7 struct{} + +// newGreater7 creates a new greater operator. +func newGreater7() ops.Operator { + return &Greater7{} +} + +// Init initializes the greater operator. +func (g *Greater7) Init(*onnx.NodeProto) error { + return nil +} + +// Apply applies the greater operator. +func (g *Greater7) Apply(inputs []tensor.Tensor) ([]tensor.Tensor, error) { + return ops.ApplyBinaryOperation( + inputs[0], + inputs[1], + ops.Gt, + ops.MultidirectionalBroadcasting, + ) +} + +// ValidateInputs validates the inputs that will be given to Apply for this operator. +func (g *Greater7) ValidateInputs(inputs []tensor.Tensor) ([]tensor.Tensor, error) { + return ops.ValidateInputs(g, inputs) +} + +// GetMinInputs returns the minimum number of input tensors this operator expects. +func (g *Greater7) GetMinInputs() int { + return MinGreater7Inputs +} + +// GetMaxInputs returns the maximum number of input tensors this operator expects. +func (g *Greater7) GetMaxInputs() int { + return MaxGreater7Inputs +} + +// GetInputTypeConstraints returns a list. Every element represents a set of allowed tensor dtypes +// for the corresponding input tensor. +func (g *Greater7) GetInputTypeConstraints() [][]tensor.Dtype { + return [][]tensor.Dtype{{tensor.Float32, tensor.Float64}, {tensor.Float32, tensor.Float64}} +} + +// String implements the stringer interface, and can be used to format errors or messages. +func (g *Greater7) String() string { + return "greater7 operator" +} diff --git a/ops/opset13/greater.go b/ops/greater/greater_9.go similarity index 57% rename from ops/opset13/greater.go rename to ops/greater/greater_9.go index 37e5af4..7b19159 100644 --- a/ops/opset13/greater.go +++ b/ops/greater/greater_9.go @@ -1,4 +1,4 @@ -package opset13 +package greater import ( "github.com/advancedclimatesystems/gonnx/onnx" @@ -7,25 +7,25 @@ import ( ) var ( - MinGreaterInputs = 2 - MaxGreaterInputs = 2 + MinGreater9Inputs = 2 + MaxGreater9Inputs = 2 ) -// Greater represents the ONNX greater operator. -type Greater struct{} +// Greater9 represents the ONNX greater operator. +type Greater9 struct{} -// newGreater creates a new greater operator. -func newGreater() ops.Operator { - return &Greater{} +// newGreater9 creates a new greater operator. +func newGreater9() ops.Operator { + return &Greater9{} } // Init initializes the greater operator. -func (g *Greater) Init(*onnx.NodeProto) error { +func (g *Greater9) Init(*onnx.NodeProto) error { return nil } // Apply applies the greater operator. -func (g *Greater) Apply(inputs []tensor.Tensor) ([]tensor.Tensor, error) { +func (g *Greater9) Apply(inputs []tensor.Tensor) ([]tensor.Tensor, error) { return ops.ApplyBinaryOperation( inputs[0], inputs[1], @@ -35,27 +35,27 @@ func (g *Greater) Apply(inputs []tensor.Tensor) ([]tensor.Tensor, error) { } // ValidateInputs validates the inputs that will be given to Apply for this operator. -func (g *Greater) ValidateInputs(inputs []tensor.Tensor) ([]tensor.Tensor, error) { +func (g *Greater9) ValidateInputs(inputs []tensor.Tensor) ([]tensor.Tensor, error) { return ops.ValidateInputs(g, inputs) } // GetMinInputs returns the minimum number of input tensors this operator expects. -func (g *Greater) GetMinInputs() int { - return MinGreaterInputs +func (g *Greater9) GetMinInputs() int { + return MinGreater9Inputs } // GetMaxInputs returns the maximum number of input tensors this operator expects. -func (g *Greater) GetMaxInputs() int { - return MaxGreaterInputs +func (g *Greater9) GetMaxInputs() int { + return MaxGreater9Inputs } // GetInputTypeConstraints returns a list. Every element represents a set of allowed tensor dtypes // for the corresponding input tensor. -func (g *Greater) GetInputTypeConstraints() [][]tensor.Dtype { +func (g *Greater9) GetInputTypeConstraints() [][]tensor.Dtype { return [][]tensor.Dtype{ops.AllTypes, ops.AllTypes} } // String implements the stringer interface, and can be used to format errors or messages. -func (g *Greater) String() string { - return "greater operator" +func (g *Greater9) String() string { + return "greater9 operator" } diff --git a/ops/greater/versions.go b/ops/greater/versions.go new file mode 100644 index 0000000..a443703 --- /dev/null +++ b/ops/greater/versions.go @@ -0,0 +1,9 @@ +package greater + +import "github.com/advancedclimatesystems/gonnx/ops" + +var GreaterVersions = ops.OperatorVersions{ + 7: newGreater7, + 9: newGreater9, + 13: newGreater13, +} diff --git a/ops/opset13/greater_or_equal.go b/ops/greaterorequal/greater_or_equal_12.go similarity index 52% rename from ops/opset13/greater_or_equal.go rename to ops/greaterorequal/greater_or_equal_12.go index 25eb27b..547245a 100644 --- a/ops/opset13/greater_or_equal.go +++ b/ops/greaterorequal/greater_or_equal_12.go @@ -1,4 +1,4 @@ -package opset13 +package greaterorequal import ( "github.com/advancedclimatesystems/gonnx/onnx" @@ -7,25 +7,25 @@ import ( ) var ( - MinGreaterOrEqualInputs = 2 - MaxGreaterOrEqualInputs = 2 + MinGreaterOrEqual12Inputs = 2 + MaxGreaterOrEqual12Inputs = 2 ) -// GreaterOrEqual represents the ONNX greaterOrEqual operator. -type GreaterOrEqual struct{} +// GreaterOrEqual12 represents the ONNX greaterOrEqual operator. +type GreaterOrEqual12 struct{} -// newGreaterOrEqual creates a new greaterOrEqual operator. -func newGreaterOrEqual() ops.Operator { - return &GreaterOrEqual{} +// newGreaterOrEqual12 creates a new greaterOrEqual operator. +func newGreaterOrEqual12() ops.Operator { + return &GreaterOrEqual12{} } // Init initializes the greaterOrEqual operator. -func (g *GreaterOrEqual) Init(*onnx.NodeProto) error { +func (g *GreaterOrEqual12) Init(*onnx.NodeProto) error { return nil } // Apply applies the greaterOrEqual operator. -func (g *GreaterOrEqual) Apply(inputs []tensor.Tensor) ([]tensor.Tensor, error) { +func (g *GreaterOrEqual12) Apply(inputs []tensor.Tensor) ([]tensor.Tensor, error) { return ops.ApplyBinaryOperation( inputs[0], inputs[1], @@ -35,27 +35,27 @@ func (g *GreaterOrEqual) Apply(inputs []tensor.Tensor) ([]tensor.Tensor, error) } // ValidateInputs validates the inputs that will be given to Apply for this operator. -func (g *GreaterOrEqual) ValidateInputs(inputs []tensor.Tensor) ([]tensor.Tensor, error) { +func (g *GreaterOrEqual12) ValidateInputs(inputs []tensor.Tensor) ([]tensor.Tensor, error) { return ops.ValidateInputs(g, inputs) } // GetMinInputs returns the minimum number of input tensors this operator expects. -func (g *GreaterOrEqual) GetMinInputs() int { - return MinGreaterOrEqualInputs +func (g *GreaterOrEqual12) GetMinInputs() int { + return MinGreaterOrEqual12Inputs } // GetMaxInputs returns the maximum number of input tensors this operator expects. -func (g *GreaterOrEqual) GetMaxInputs() int { - return MaxGreaterOrEqualInputs +func (g *GreaterOrEqual12) GetMaxInputs() int { + return MaxGreaterOrEqual12Inputs } // GetInputTypeConstraints returns a list. Every element represents a set of allowed tensor dtypes // for the corresponding input tensor. -func (g *GreaterOrEqual) GetInputTypeConstraints() [][]tensor.Dtype { +func (g *GreaterOrEqual12) GetInputTypeConstraints() [][]tensor.Dtype { return [][]tensor.Dtype{ops.AllTypes, ops.AllTypes} } // String implements the stringer interface, and can be used to format errors or messages. -func (g *GreaterOrEqual) String() string { - return "greaterOrEqual operator" +func (g *GreaterOrEqual12) String() string { + return "greaterOrEqual12 operator" } diff --git a/ops/opset13/greater_or_equal_test.go b/ops/greaterorequal/greater_or_equal_12_test.go similarity index 84% rename from ops/opset13/greater_or_equal_test.go rename to ops/greaterorequal/greater_or_equal_12_test.go index 37f5dec..4f1db87 100644 --- a/ops/opset13/greater_or_equal_test.go +++ b/ops/greaterorequal/greater_or_equal_12_test.go @@ -1,4 +1,4 @@ -package opset13 +package greaterorequal import ( "testing" @@ -8,8 +8,8 @@ import ( "gorgonia.org/tensor" ) -func TestGreaterOrEqualInit(t *testing.T) { - g := &GreaterOrEqual{} +func TestGreaterOrEqual12Init(t *testing.T) { + g := &GreaterOrEqual12{} // since 'greaterOrEqual' does not have any attributes we pass in nil. This should not // fail initializing the greaterOrEqual. @@ -17,27 +17,27 @@ func TestGreaterOrEqualInit(t *testing.T) { assert.Nil(t, err) } -func TestGreaterOrEqual(t *testing.T) { +func TestGreaterOrEqual12(t *testing.T) { tests := []struct { - greaterOrEqual *GreaterOrEqual + greaterOrEqual *GreaterOrEqual12 backings [][]float32 shapes [][]int expected []bool }{ { - &GreaterOrEqual{}, + &GreaterOrEqual12{}, [][]float32{{0, 1, 2, 3}, {1, 1, 1, 1}}, [][]int{{2, 2}, {2, 2}}, []bool{false, true, true, true}, }, { - &GreaterOrEqual{}, + &GreaterOrEqual12{}, [][]float32{{0, 1, 2, 3, 4, 5}, {2, 2, 2, 2, 2, 2}}, [][]int{{3, 2}, {3, 2}}, []bool{false, false, true, true, true, true}, }, { - &GreaterOrEqual{}, + &GreaterOrEqual12{}, [][]float32{{0, 1}, {0, 1, 2, 3}}, [][]int{{2}, {2, 2}}, []bool{true, true, false, false}, @@ -58,7 +58,7 @@ func TestGreaterOrEqual(t *testing.T) { } } -func TestInputValidationGreaterOrEqual(t *testing.T) { +func TestInputValidationGreaterOrEqual12(t *testing.T) { tests := []struct { inputs []tensor.Tensor err error @@ -109,19 +109,19 @@ func TestInputValidationGreaterOrEqual(t *testing.T) { []tensor.Tensor{ ops.TensorWithBackingFixture([]int{1, 2}, 2), }, - ops.ErrInvalidInputCount(1, &GreaterOrEqual{}), + ops.ErrInvalidInputCount(1, &GreaterOrEqual12{}), }, { []tensor.Tensor{ ops.TensorWithBackingFixture([]int{1, 2}, 2), ops.TensorWithBackingFixture([]int{3, 4}, 2), }, - ops.ErrInvalidInputType(0, "int", &GreaterOrEqual{}), + ops.ErrInvalidInputType(0, "int", &GreaterOrEqual12{}), }, } for _, test := range tests { - greaterOrEqual := &GreaterOrEqual{} + greaterOrEqual := &GreaterOrEqual12{} validated, err := greaterOrEqual.ValidateInputs(test.inputs) assert.Equal(t, test.err, err) diff --git a/ops/greaterorequal/versions.go b/ops/greaterorequal/versions.go new file mode 100644 index 0000000..e01e690 --- /dev/null +++ b/ops/greaterorequal/versions.go @@ -0,0 +1,7 @@ +package greaterorequal + +import "github.com/advancedclimatesystems/gonnx/ops" + +var GreaterOrEqualVersions = ops.OperatorVersions{ + 12: newGreaterOrEqual12, +} diff --git a/ops/gru/gru_7.go b/ops/gru/gru_7.go new file mode 100644 index 0000000..f5196e4 --- /dev/null +++ b/ops/gru/gru_7.go @@ -0,0 +1,362 @@ +package gru + +import ( + "github.com/advancedclimatesystems/gonnx/onnx" + "github.com/advancedclimatesystems/gonnx/ops" + "github.com/advancedclimatesystems/gonnx/ops/gemm" + "gorgonia.org/tensor" +) + +const ( + MinGRU7Inputs = 3 + MaxGRU7Inputs = 6 +) + +// GRU7 represents the ONNX gru operator. It only supports a simple forward gru +// operation with default activations. +type GRU7 struct { + activationAlpha []float32 + activationBeta []float32 + activations []string + direction ops.SequenceProcessDirection + hiddenSize int + linearBeforeReset bool +} + +// newGRU7 creates a new gru operator. +func newGRU7() ops.Operator { + return &GRU7{ + activations: []string{"sigmoid", "tanh"}, + direction: ops.Forward, + linearBeforeReset: false, + } +} + +// Init initializes the gru operator. Currently, our GRU7 operator does not support all +// attributes as specified by the ONNX operator. The basic functionality is working and +// the other attributes can be added later on. +func (g *GRU7) Init(n *onnx.NodeProto) error { + attributes := n.GetAttribute() + for _, attr := range attributes { + switch attr.GetName() { + case ops.ActivationAlphaAttr: + g.activationAlpha = attr.GetFloats() + case ops.ActivationBetaAttr: + g.activationBeta = attr.GetFloats() + case ops.ActivationsAttr: + activations := []string{} + for _, activation := range attr.GetStrings() { + activations = append(activations, string(activation)) + } + + g.activations = activations + case ops.ClipAttr: + return ops.ErrUnsupportedAttribute(attr.GetName(), g) + case ops.DirectionAttr: + g.direction = ops.SequenceProcessDirection(attr.GetS()) + if g.direction != ops.Forward { + return ops.ErrUnsupportedAttribute(attr.GetName(), g) + } + case ops.HiddenSizeAttr: + g.hiddenSize = int(attr.GetI()) + case "linear_before_reset": + g.linearBeforeReset = ops.Int64ToBool(attr.GetI()) + default: + return ops.ErrInvalidAttribute(attr.GetName(), g) + } + } + + return nil +} + +// Apply applies the gru operator. +func (g *GRU7) Apply(inputs []tensor.Tensor) ([]tensor.Tensor, error) { + if inputs[4] != nil { + return nil, ops.ErrUnsupportedInput("sequence lens", g) + } + + X := inputs[0] + seqLength := X.Shape()[0] + batchSize := X.Shape()[1] + + Wz, Wr, Wh, err := g.getWeights(inputs[1]) + if err != nil { + return nil, err + } + + Rz, Rr, Rh, err := g.getWeights(inputs[2]) + if err != nil { + return nil, err + } + + B := inputs[3] + if B == nil { + // 6 is the number of bias matrices required by ONNX definition. + nBiasMatrices := 6 + B = ops.ZeroTensor(1, nBiasMatrices*g.hiddenSize) + } + + Wbz, Wbr, Wbh, Rbz, Rbr, Rbh, err := g.getBiases(B) + if err != nil { + return nil, err + } + + prevH := inputs[5] + if prevH == nil { + prevH = ops.ZeroTensor(1, batchSize, g.hiddenSize) + } + + // Extract the shape of the hidden dimensions without the bidirectional dimension, as + // we do not support bidirectional GRU7 yet. + shapeWithoutBidir := prevH.Shape().Clone()[1:] + + err = prevH.Reshape(shapeWithoutBidir...) + if err != nil { + return nil, err + } + + fActivation, err := ops.GetActivation(g.activations[0]) + if err != nil { + return nil, err + } + + gActivation, err := ops.GetActivation(g.activations[1]) + if gActivation == nil { + return nil, err + } + + outputs := []tensor.Tensor{} + + for i := 0; i < seqLength; i++ { + Xt, err := g.extractXt(X, i) + if err != nil { + return nil, err + } + + zt, err := g.gateCalculation(Xt, prevH, Wz, Rz, Wbz, Rbz, fActivation) + if err != nil { + return nil, err + } + + rt, err := g.gateCalculation(Xt, prevH, Wr, Rr, Wbr, Rbr, fActivation) + if err != nil { + return nil, err + } + + ht, err := g.htCalculation(Xt, prevH, rt, Wh, Rh, Wbh, Rbh, gActivation) + if err != nil { + return nil, err + } + + prevH, err = g.hiddenCalculation(zt, ht, prevH) + if err != nil { + return nil, err + } + + outputs = append(outputs, prevH) + } + + var Y tensor.Tensor + if len(outputs) > 1 { + Y, err = tensor.Concat(0, outputs[0], outputs[1:]...) + if err != nil { + return nil, err + } + } else { + Y = outputs[0] + } + + // Reshape the output so it adds the num_directions as specified by onnx. + err = Y.Reshape([]int{seqLength, 1, batchSize, g.hiddenSize}...) + if err != nil { + return nil, err + } + + Yh, ok := prevH.Clone().(tensor.Tensor) + if !ok { + return nil, ops.ErrTypeAssert("tensor.Tensor", prevH.Clone()) + } + + // Reshape the output so it adds the num_directions as specified by onnx. + err = Yh.Reshape([]int{1, batchSize, g.hiddenSize}...) + if err != nil { + return nil, err + } + + return []tensor.Tensor{Y, Yh}, nil +} + +// ValidateInputs validates the inputs that will be given to Apply for this operator. +func (g *GRU7) ValidateInputs(inputs []tensor.Tensor) ([]tensor.Tensor, error) { + return ops.ValidateInputs(g, inputs) +} + +// GetMinInputs returns the minimum number of input tensors this operator expects. +func (g *GRU7) GetMinInputs() int { + return MinGRU7Inputs +} + +// GetMaxInputs returns the maximum number of input tensors this operator expects. +func (g *GRU7) GetMaxInputs() int { + return MaxGRU7Inputs +} + +// GetInputTypeConstraints returns a list. Every element represents a set of allowed tensor dtypes +// for the corresponding input tensor. +func (g *GRU7) GetInputTypeConstraints() [][]tensor.Dtype { + return [][]tensor.Dtype{ + {tensor.Float32, tensor.Float64}, + {tensor.Float32, tensor.Float64}, + {tensor.Float32, tensor.Float64}, + {tensor.Float32, tensor.Float64}, + {tensor.Int32}, + {tensor.Float32, tensor.Float64}, + } +} + +// String implements the stringer interface, and can be used to format errors or messages. +func (g *GRU7) String() string { + return "gru7 operator" +} + +// extractXt extracts the value of x for timestep t. +func (g *GRU7) extractXt(X tensor.Tensor, t int) (tensor.Tensor, error) { + return X.Slice(ops.NewSlicer(t, t+1), nil, nil) +} + +func (g *GRU7) gateCalculation( + Xt, H, W, R, Wb, Rb tensor.Tensor, activation ops.Activation, +) (tensor.Tensor, error) { + gemm := gemm.GemmVersions[13]() + + err := gemm.Init( + &onnx.NodeProto{ + Attribute: []*onnx.AttributeProto{ + {Name: "alpha", F: 1.0}, + {Name: "beta", F: 1.0}, + {Name: "transA", I: 0}, + {Name: "transB", I: 1}, + }, + }, + ) + if err != nil { + return nil, err + } + + inputCalc, err := gemm.Apply([]tensor.Tensor{Xt, W, Wb}) + if err != nil { + return nil, err + } + + hiddenCalc, err := gemm.Apply([]tensor.Tensor{H, R, Rb}) + if err != nil { + return nil, err + } + + gate, err := tensor.Add(inputCalc[0], hiddenCalc[0]) + if err != nil { + return nil, err + } + + return activation(gate) +} + +func (g *GRU7) htCalculation( + Xt, prevH, rt, W, R, Wb, Rb tensor.Tensor, activation ops.Activation, +) (tensor.Tensor, error) { + if !g.linearBeforeReset { + temp1, err := tensor.Mul(rt, prevH) + if err != nil { + return nil, err + } + + return g.gateCalculation(Xt, temp1, W, R, Wb, Rb, activation) + } + + gemm := gemm.GemmVersions[13]() + + err := gemm.Init( + &onnx.NodeProto{ + Attribute: []*onnx.AttributeProto{ + {Name: "alpha", F: 1.0}, + {Name: "beta", F: 1.0}, + {Name: "transA", I: 0}, + {Name: "transB", I: 1}, + }, + }, + ) + if err != nil { + return nil, err + } + + inputCalc, err := gemm.Apply([]tensor.Tensor{Xt, W, Wb}) + if err != nil { + return nil, err + } + + hiddenCalc, err := gemm.Apply([]tensor.Tensor{prevH, R, Rb}) + if err != nil { + return nil, err + } + + temp1, err := tensor.Mul(hiddenCalc[0], rt) + if err != nil { + return nil, err + } + + temp2, err := tensor.Add(temp1, inputCalc[0]) + if err != nil { + return nil, err + } + + return activation(temp2) +} + +func (g *GRU7) hiddenCalculation(zt, ht, prevH tensor.Tensor) (tensor.Tensor, error) { + temp1, err := tensor.Sub(ops.OnesTensor(zt), zt) + if err != nil { + return nil, err + } + + temp2, err := tensor.Mul(temp1, ht) + if err != nil { + return nil, err + } + + temp3, err := tensor.Mul(zt, prevH) + if err != nil { + return nil, err + } + + return tensor.Add(temp2, temp3) +} + +// getWeights splits tensor W into 3 weight matrices. +// The W tensor, by GONNX definition, has 3 dimensions with 3 weight +// tensors in it (6 if bidirectional, but that is not supported). +func (g *GRU7) getWeights(W tensor.Tensor) (Wz, Wr, Wh tensor.Tensor, err error) { + nWeightMatrices := 3 + nWeightDimensions := 3 + + weights, err := ops.ExtractMatrices(W, nWeightMatrices, nWeightDimensions, g.hiddenSize) + if err != nil { + return nil, nil, nil, err + } + + return weights[0], weights[1], weights[2], nil +} + +// getBiases returns the biases from the Bias node as specified by the ONNX standard. +// The B tensor, by GONNX definition, has 2 dimensions with 6 bias +// tensors in it (12 if bidirectional, but that is not supported). +func (g *GRU7) getBiases(B tensor.Tensor) (Wbz, Wbr, Wbh, Rbz, Rbr, Rbh tensor.Tensor, err error) { + nBiasMatrices := 6 + nBiasDimensions := 2 + + biases, err := ops.ExtractMatrices(B, nBiasMatrices, nBiasDimensions, g.hiddenSize) + if err != nil { + return nil, nil, nil, nil, nil, nil, err + } + + return biases[0], biases[1], biases[2], biases[3], biases[4], biases[5], nil +} diff --git a/ops/gru/gru_7_test.go b/ops/gru/gru_7_test.go new file mode 100644 index 0000000..4b69090 --- /dev/null +++ b/ops/gru/gru_7_test.go @@ -0,0 +1,298 @@ +package gru + +import ( + "testing" + + "github.com/advancedclimatesystems/gonnx/onnx" + "github.com/advancedclimatesystems/gonnx/ops" + "github.com/stretchr/testify/assert" + "gorgonia.org/tensor" +) + +func TestGruInit(t *testing.T) { + gru := &GRU7{} + err := gru.Init(GRU7OnnxNodeProtoFixture()) + + assert.Nil(t, err) + assert.Equal(t, []float32{1.0}, gru.activationAlpha) + assert.Equal(t, []float32{2.0}, gru.activationBeta) + assert.Equal(t, []string{"sigmoid", "tanh"}, gru.activations) + assert.Equal(t, gru.direction, ops.Forward) + assert.Equal(t, 5, gru.hiddenSize) + assert.Equal(t, true, gru.linearBeforeReset) +} + +func TestGruInitUnkownAttr(t *testing.T) { + gru := GRU7{} + tests := []struct { + attr []*onnx.AttributeProto + err error + }{ + { + []*onnx.AttributeProto{{Name: "clip"}}, + ops.ErrUnsupportedAttribute("clip", &gru), + }, + { + []*onnx.AttributeProto{{Name: "unknown"}}, + ops.ErrInvalidAttribute("unknown", &gru), + }, + } + + for _, test := range tests { + err := gru.Init(&onnx.NodeProto{Attribute: test.attr}) + assert.Equal(t, test.err, err) + } +} + +func TestGru(t *testing.T) { + tests := []struct { + gru *GRU7 + inputs ops.InputFixture + expected []float32 + err error + }{ + { + &GRU7{ + activationAlpha: []float32{}, + activationBeta: []float32{}, + activations: []string{"sigmoid", "tanh"}, + direction: ops.Forward, + hiddenSize: 4, + linearBeforeReset: true, + }, + gruInput0, + []float32{6.6936556e-03, 8.3446503e-07, 0.0000000e+00, 0.0000000e+00}, + nil, + }, + { + &GRU7{ + activationAlpha: []float32{}, + activationBeta: []float32{}, + activations: []string{"sigmoid", "tanh"}, + direction: ops.Forward, + hiddenSize: 4, + linearBeforeReset: false, + }, + gruInput0, + []float32{6.6936556e-03, 8.3446503e-07, 0.0000000e+00, 0.0000000e+00}, + nil, + }, + { + &GRU7{ + activationAlpha: []float32{}, + activationBeta: []float32{}, + activations: []string{"sigmoid", "tanh"}, + direction: ops.Forward, + hiddenSize: 4, + linearBeforeReset: false, + }, + gruInput1, + []float32{0.44905475, 0.4406946, 0.43368173, 0.42782417}, + nil, + }, + { + &GRU7{ + activationAlpha: []float32{}, + activationBeta: []float32{}, + activations: []string{"sigmoid", "tanh"}, + direction: ops.Forward, + hiddenSize: 4, + linearBeforeReset: false, + }, + gruInputNoBNoH, + []float32{0.24553154, 0.24553154, 0.24553154, 0.24553154}, + nil, + }, + } + + for _, test := range tests { + inputs := test.inputs() + res, err := test.gru.Apply(inputs) + assert.Equal(t, test.err, err) + + if err == nil { + assert.Equal(t, test.expected, res[1].Data()) + } + } +} + +func TestInputValidationGRU7(t *testing.T) { + tests := []struct { + inputs []tensor.Tensor + expected []tensor.Tensor + err error + }{ + { + []tensor.Tensor{ + ops.TensorWithBackingFixture([]float32{1, 2}, 2), + ops.TensorWithBackingFixture([]float32{1, 2}, 2), + ops.TensorWithBackingFixture([]float32{1, 2}, 2), + ops.TensorWithBackingFixture([]float32{1, 2}, 2), + ops.TensorWithBackingFixture([]int32{1, 2}, 2), + ops.TensorWithBackingFixture([]float32{1, 2}, 2), + }, + nil, + nil, + }, + { + []tensor.Tensor{ + ops.TensorWithBackingFixture([]float32{1, 2}, 2), + ops.TensorWithBackingFixture([]float32{1, 2}, 2), + ops.TensorWithBackingFixture([]float32{1, 2}, 2), + }, + []tensor.Tensor{ + ops.TensorWithBackingFixture([]float32{1, 2}, 2), + ops.TensorWithBackingFixture([]float32{1, 2}, 2), + ops.TensorWithBackingFixture([]float32{1, 2}, 2), + nil, + nil, + nil, + }, + nil, + }, + { + []tensor.Tensor{ops.TensorWithBackingFixture([]float32{1, 2}, 2)}, + nil, + ops.ErrInvalidOptionalInputCount(1, &GRU7{}), + }, + { + []tensor.Tensor{ + ops.TensorWithBackingFixture([]float32{1, 2}, 2), + ops.TensorWithBackingFixture([]int{1, 2}, 2), + ops.TensorWithBackingFixture([]float32{1, 2}, 2), + }, + nil, + ops.ErrInvalidInputType(1, "int", &GRU7{}), + }, + { + []tensor.Tensor{ + ops.TensorWithBackingFixture([]int{1, 2}, 2), + ops.TensorWithBackingFixture([]float32{1, 2}, 2), + ops.TensorWithBackingFixture([]float32{1, 2}, 2), + }, + nil, + ops.ErrInvalidInputType(0, "int", &GRU7{}), + }, + { + []tensor.Tensor{ + ops.TensorWithBackingFixture([]float32{1, 2}, 2), + ops.TensorWithBackingFixture([]int{1, 2}, 2), + ops.TensorWithBackingFixture([]float32{1, 2}, 2), + }, + nil, + ops.ErrInvalidInputType(1, "int", &GRU7{}), + }, + { + []tensor.Tensor{ + ops.TensorWithBackingFixture([]float32{1, 2}, 2), + ops.TensorWithBackingFixture([]float32{1, 2}, 2), + ops.TensorWithBackingFixture([]int{1, 2}, 2), + }, + nil, + ops.ErrInvalidInputType(2, "int", &GRU7{}), + }, + { + []tensor.Tensor{ + ops.TensorWithBackingFixture([]float32{1, 2}, 2), + ops.TensorWithBackingFixture([]float32{1, 2}, 2), + ops.TensorWithBackingFixture([]float32{1, 2}, 2), + ops.TensorWithBackingFixture([]int{1, 2}, 2), + }, + nil, + ops.ErrInvalidInputType(3, "int", &GRU7{}), + }, + { + []tensor.Tensor{ + ops.TensorWithBackingFixture([]float32{1, 2}, 2), + ops.TensorWithBackingFixture([]float32{1, 2}, 2), + ops.TensorWithBackingFixture([]float32{1, 2}, 2), + ops.TensorWithBackingFixture([]float32{1, 2}, 2), + ops.TensorWithBackingFixture([]float32{1, 2}, 2), + }, + nil, + ops.ErrInvalidInputType(4, "float32", &GRU7{}), + }, + { + []tensor.Tensor{ + ops.TensorWithBackingFixture([]float32{1, 2}, 2), + ops.TensorWithBackingFixture([]float32{1, 2}, 2), + ops.TensorWithBackingFixture([]float32{1, 2}, 2), + ops.TensorWithBackingFixture([]float32{1, 2}, 2), + ops.TensorWithBackingFixture([]int32{1, 2}, 2), + ops.TensorWithBackingFixture([]int{1, 2}, 2), + }, + nil, + ops.ErrInvalidInputType(5, "int", &GRU7{}), + }, + } + + for _, test := range tests { + gru := &GRU7{} + validated, err := gru.ValidateInputs(test.inputs) + + assert.Equal(t, test.err, err) + + if test.err == nil { + if test.expected != nil { + assert.Equal(t, test.expected, validated) + } else { + assert.Equal(t, test.inputs, validated) + } + } + } +} + +func gruInput0() []tensor.Tensor { + shapes := [][]int{{2, 1, 3}, {1, 12, 3}, {1, 12, 4}, {1, 24}, {1, 1, 4}} + inputs := []tensor.Tensor{ + ops.Float32TensorFixture(shapes[0]...), + ops.Float32TensorFixture(shapes[1]...), + ops.Float32TensorFixture(shapes[2]...), + ops.TensorWithBackingFixture(ops.Zeros(ops.NElements(shapes[3]...)), shapes[3]...), + nil, + ops.TensorWithBackingFixture(ops.Zeros(ops.NElements(shapes[4]...)), shapes[4]...), + } + + return inputs +} + +func gruInput1() []tensor.Tensor { + shps := [][]int{{10, 1, 3}, {1, 12, 3}, {1, 12, 4}, {1, 24}, {1, 1, 4}} + inputs := []tensor.Tensor{ + ops.Float32TensorFixture(shps[0]...), + ops.TensorWithBackingFixture(ops.Full(ops.NElements(shps[1]...), 0.2), shps[1]...), + ops.TensorWithBackingFixture(ops.Full(ops.NElements(shps[2]...), 0.5), shps[2]...), + ops.TensorWithBackingFixture(ops.Arange(ops.NElements(shps[3]...), 0.10), shps[3]...), + nil, + ops.TensorWithBackingFixture(ops.Full(ops.NElements(shps[4]...), 0.4), shps[4]...), + } + + return inputs +} + +func gruInputNoBNoH() []tensor.Tensor { + shps := [][]int{{10, 1, 3}, {1, 12, 3}, {1, 12, 4}, {1, 24}, {1, 1, 4}} + inputs := []tensor.Tensor{ + ops.Float32TensorFixture(shps[0]...), + ops.TensorWithBackingFixture(ops.Full(ops.NElements(shps[1]...), 0.2), shps[1]...), + ops.TensorWithBackingFixture(ops.Full(ops.NElements(shps[2]...), 0.5), shps[2]...), + nil, + nil, + nil, + } + + return inputs +} + +func GRU7OnnxNodeProtoFixture() *onnx.NodeProto { + return &onnx.NodeProto{ + Attribute: []*onnx.AttributeProto{ + {Name: "activation_alpha", Floats: []float32{1.0}}, + {Name: "activation_beta", Floats: []float32{2.0}}, + {Name: "activations", Strings: [][]byte{[]byte("sigmoid"), []byte("tanh")}}, + {Name: "direction", S: []byte("forward")}, + {Name: "hidden_size", I: 5}, + {Name: "linear_before_reset", I: 1}, + }, + } +} diff --git a/ops/gru/versions.go b/ops/gru/versions.go new file mode 100644 index 0000000..6c325b1 --- /dev/null +++ b/ops/gru/versions.go @@ -0,0 +1,7 @@ +package gru + +import "github.com/advancedclimatesystems/gonnx/ops" + +var GRUVersions = ops.OperatorVersions{ + 7: newGRU7, +} diff --git a/ops/operator.go b/ops/operator.go index 7f26e4d..3ee5b9d 100644 --- a/ops/operator.go +++ b/ops/operator.go @@ -5,6 +5,8 @@ import ( "gorgonia.org/tensor" ) +type OperatorVersions map[int64]func() Operator + // Operator is the base interface for all operators. type Operator interface { // String should return a simple string describing the operator diff --git a/opset.go b/opset.go index 00bcf5f..5290968 100644 --- a/opset.go +++ b/opset.go @@ -23,6 +23,11 @@ import ( "github.com/advancedclimatesystems/gonnx/ops/equal" "github.com/advancedclimatesystems/gonnx/ops/expand" "github.com/advancedclimatesystems/gonnx/ops/flatten" + "github.com/advancedclimatesystems/gonnx/ops/gather" + "github.com/advancedclimatesystems/gonnx/ops/gemm" + "github.com/advancedclimatesystems/gonnx/ops/greater" + "github.com/advancedclimatesystems/gonnx/ops/greaterorequal" + "github.com/advancedclimatesystems/gonnx/ops/gru" ) const ( @@ -33,97 +38,33 @@ const ( // OpGetter is a function that gets an operator based on a string. type OpGetter func(string) (ops.Operator, error) -type OperatorVersions map[int64]func() ops.Operator - -var operators = map[string]OperatorVersions{ - "Abs": { - 6: abs.NewAbs6, // Same, but bfloat16 type is added - 13: abs.NewAbs13, - }, - "Acos": { - 7: acos.NewAcos7, - }, - "Acosh": { - 9: acosh.NewAcosh9, - }, - "Add": { - 7: add.NewAdd7, // Same, but bfloat16 type is added - 13: add.NewAdd13, - }, - "And": { - 7: and.NewAnd7, - }, - "ArgMax": { - 11: argmax.NewArgMax11, // Same, but one attribute is added (which we don't support it anyway) - 12: argmax.NewArgMax12, // Same, but bfloat16 type differs - 13: argmax.NewArgMax13, - }, - "Asin": { - 7: asin.NewAsin7, - }, - "Asinh": { - 9: asinh.NewAsinh9, - }, - "Atan": { - 7: atan.NewAtan7, - }, - "Atanh": { - 9: atanh.NewAtanh9, - }, - "Cast": { - 6: cast.NewCast6, // Same, but string type is added - 9: cast.NewCast9, // Same, but bfloat16 type differs - 13: cast.NewCast13, - }, - "Concat": { - 4: concat.NewConcat4, - 11: concat.NewConcat11, // Same, but bfloat16 type differs - 13: concat.NewConcat13, - }, - "Constant": { - 1: constant.NewConstant1, - 9: constant.NewConstant9, - 11: constant.NewConstant11, - 12: constant.NewConstant12, // Same, but bfloat16 type differs - 13: constant.NewConstant13, - }, - "ConstantOfShape": { - 9: constantofshape.NewConstantOfShape9, - }, - "Conv": { - 1: conv.NewConv1, // Same, but only float16 type differs - 11: conv.NewConv11, - }, - "Cos": { - 7: cos.NewCos7, - }, - "Cosh": { - 9: cosh.NewCosh9, - }, - "Div": { - 7: div.NewDiv7, // Same, but float16 type differs - 13: div.NewDiv13, - }, - "Equal": { - 7: equal.NewEqual7, - 11: equal.NewEqual11, // Same, but float16 type differs - 13: equal.NewEqual13, - }, - "Expand": { - 8: expand.NewExpand8, // Same, but float16 type differs - 13: expand.NewExpand13, - }, - "Flatten": { - 1: flatten.NewFlatten1, // Same, but only float types - 9: flatten.NewFlatten9, // Same, but negative axis added - 11: flatten.NewFlatten11, // Same, but float16 type differs - 13: flatten.NewFlatten13, - }, - "Gather": {}, - "Gemm": {}, - "Greater": {}, - "GreaterOrEqual": {}, - "GRU": {}, +var operators = map[string]ops.OperatorVersions{ + "Abs": abs.AbsVersions, + "Acos": acos.AcosVersions, + "Acosh": acosh.AcoshVersions, + "Add": add.AddVersions, + "And": and.AndVersions, + "ArgMax": argmax.ArgMaxVersions, + "Asin": asin.AsinVersions, + "Asinh": asinh.AsinhVersions, + "Atan": atan.AtanVersions, + "Atanh": atanh.AtanhVersions, + "Cast": cast.CastVersions, + "Concat": concat.ConcatVersions, + "Constant": constant.ConstantVersions, + "ConstantOfShape": constantofshape.ConstantOfShapeVersions, + "Conv": conv.ConvVersions, + "Cos": cos.CosVersions, + "Cosh": cosh.CoshVersions, + "Div": div.DivVersions, + "Equal": equal.EqualVersions, + "Expand": expand.ExpandVersions, + "Flatten": flatten.FlattenVersions, + "Gather": gather.GatherVersions, + "Gemm": gemm.GemmVersions, + "Greater": greater.GreaterVersions, + "GreaterOrEqual": greaterorequal.GreaterOrEqualVersions, + "GRU": gru.GRUVersions, "Less": {}, "LessOrEqual": {}, "LinearRegressor": {},