diff --git a/.github/workflows/go.yml b/.github/workflows/go.yml index d53567a..0b6497d 100644 --- a/.github/workflows/go.yml +++ b/.github/workflows/go.yml @@ -18,7 +18,7 @@ jobs: - name: Set up Go uses: actions/setup-go@v3 with: - go-version: 1.19 + go-version: 1.21 - name: Install linter run: make install_lint @@ -34,7 +34,7 @@ jobs: - name: Set up Go uses: actions/setup-go@v3 with: - go-version: 1.19 + go-version: 1.21 - name: Install dependencies run: make install @@ -56,7 +56,7 @@ jobs: - name: Set up Go uses: actions/setup-go@v3 with: - go-version: 1.19 + go-version: 1.21 - name: Build amd64 run: make build_amd64 @@ -69,7 +69,7 @@ jobs: - name: Set up Go uses: actions/setup-go@v3 with: - go-version: 1.19 + go-version: 1.21 - name: Build arm64 run: make build_arm64 diff --git a/onnx/graph_proto.go b/onnx/graph_proto.go index 046b7e0..88e5d54 100644 --- a/onnx/graph_proto.go +++ b/onnx/graph_proto.go @@ -556,7 +556,7 @@ func ReadInt32ArrayFromBytes(data []byte) ([]int32, error) { // ReadUint64ArrayFromBytes reads data and parses it to an array of uint64. func ReadUint64ArrayFromBytes(data []byte) ([]uint64, error) { buffer := bytes.NewReader(data) - element := make([]byte, int32Size) + element := make([]byte, int64Size) var ( err error diff --git a/ops/binary_op.go b/ops/binary_op.go index 3df36e8..ba2aa19 100644 --- a/ops/binary_op.go +++ b/ops/binary_op.go @@ -1,6 +1,8 @@ package ops import ( + "slices" + "gorgonia.org/tensor" ) @@ -48,6 +50,48 @@ func Mul(A, B tensor.Tensor) (tensor.Tensor, error) { return tensor.Mul(A, B) } +// Pow raises the first tensor to the power of the second tensor. +// Because the gorgonia.Tensor 'Pow' operation only supports float32 and float64, +// we need to convert the tensors to float64 if they are of a different type. +// After the operation is done, we convert the result back to the original type. +func Pow(A, B tensor.Tensor) (tensor.Tensor, error) { + needsConversion := false + if slices.Contains(IntTypes, A.Dtype()) { + needsConversion = true + } + + if !needsConversion { + return tensor.Pow(A, B) + } + + oldType, err := DTypeToONNXType(A.Dtype()) + if err != nil { + return nil, err + } + + newType, err := DTypeToONNXType(tensor.Float64) + if err != nil { + return nil, err + } + + A, err = ConvertTensorDtype(A, newType) + if err != nil { + return nil, err + } + + B, err = ConvertTensorDtype(B, newType) + if err != nil { + return nil, err + } + + out, err := tensor.Pow(A, B) + if err != nil { + return nil, err + } + + return ConvertTensorDtype(out, oldType) +} + // Sub subtracts 1 tensor from the other. func Sub(A, B tensor.Tensor) (tensor.Tensor, error) { return tensor.Sub(A, B) diff --git a/ops/convert.go b/ops/convert.go index 0637f49..dd313cc 100644 --- a/ops/convert.go +++ b/ops/convert.go @@ -51,6 +51,33 @@ func ConvertTensorDtype(t tensor.Tensor, newType int32) (tensor.Tensor, error) { return tensor.New(tensor.WithShape(t.Shape()...), tensor.WithBacking(newBacking)), nil } +func DTypeToONNXType(t tensor.Dtype) (int32, error) { + switch t { + case tensor.Float32: + return int32(onnx.TensorProto_FLOAT), nil + case tensor.Float64: + return int32(onnx.TensorProto_DOUBLE), nil + case tensor.Int8: + return int32(onnx.TensorProto_INT8), nil + case tensor.Int16: + return int32(onnx.TensorProto_INT16), nil + case tensor.Int32: + return int32(onnx.TensorProto_INT32), nil + case tensor.Int64: + return int32(onnx.TensorProto_INT64), nil + case tensor.Uint8: + return int32(onnx.TensorProto_UINT8), nil + case tensor.Uint16: + return int32(onnx.TensorProto_UINT16), nil + case tensor.Uint32: + return int32(onnx.TensorProto_UINT32), nil + case tensor.Uint64: + return int32(onnx.TensorProto_UINT64), nil + default: + return 0, ErrUnknownTensorONNXDtype(t) + } +} + func convertBacking[B Number](backing []B, dataType int32) (any, error) { switch onnx.TensorProto_DataType(dataType) { case onnx.TensorProto_FLOAT: diff --git a/ops/errors.go b/ops/errors.go index 0518d6f..bf6fe87 100644 --- a/ops/errors.go +++ b/ops/errors.go @@ -282,6 +282,10 @@ func ErrConversionNotSupported(dType int32) error { return fmt.Errorf("%w: to %v is not supported yet", ErrConversion, dType) } +func ErrUnknownTensorONNXDtype(dType tensor.Dtype) error { + return fmt.Errorf("%w: tensor with dtype %v does not have a corresponding onnx type", ErrCast, dType) +} + var ErrActivationNotImplementedBase = errors.New("the given activation function is not implemented") func ErrActivationNotImplemented(activation string) error { diff --git a/ops/pow/pow.go b/ops/pow/pow.go new file mode 100644 index 0000000..1e08a9e --- /dev/null +++ b/ops/pow/pow.go @@ -0,0 +1,63 @@ +package pow + +import ( + "github.com/advancedclimatesystems/gonnx/onnx" + "github.com/advancedclimatesystems/gonnx/ops" + "gorgonia.org/tensor" +) + +var pow7TypeConstraints = [][]tensor.Dtype{ + {tensor.Float32, tensor.Float64}, + {tensor.Float32, tensor.Float64}, +} + +var powTypeConstraints = [][]tensor.Dtype{ + {tensor.Int32, tensor.Int64, tensor.Float32, tensor.Float64}, + {tensor.Uint8, tensor.Uint16, tensor.Uint32, tensor.Uint64, tensor.Int8, tensor.Int16, tensor.Int32, tensor.Int64, tensor.Float32, tensor.Float64}, +} + +// Pow represents the ONNX pow operator. +type Pow struct { + ops.BaseOperator +} + +// newPow creates a new pow operator. +func newPow(version int, typeConstraints [][]tensor.Dtype) ops.Operator { + return &Pow{ + BaseOperator: ops.NewBaseOperator( + version, + 2, + 2, + typeConstraints, + "pow", + ), + } +} + +// Init initializes the pow operator. +func (a *Pow) Init(*onnx.NodeProto) error { + return nil +} + +// Apply applies the pow operator. +func (a *Pow) Apply(inputs []tensor.Tensor) ([]tensor.Tensor, error) { + powTensor := inputs[1] + if inputs[0].Dtype() != powTensor.Dtype() { + to, err := ops.DTypeToONNXType(inputs[0].Dtype()) + if err != nil { + return nil, err + } + + powTensor, err = ops.ConvertTensorDtype(powTensor, to) + if err != nil { + return nil, err + } + } + + return ops.ApplyBinaryOperation( + inputs[0], + powTensor, + ops.Pow, + ops.MultidirectionalBroadcasting, + ) +} diff --git a/ops/pow/pow_test.go b/ops/pow/pow_test.go new file mode 100644 index 0000000..b85a0ec --- /dev/null +++ b/ops/pow/pow_test.go @@ -0,0 +1,68 @@ +package pow + +import ( + "testing" + + "github.com/advancedclimatesystems/gonnx/ops" + "github.com/stretchr/testify/assert" + "gorgonia.org/tensor" +) + +func TestPowInit(t *testing.T) { + p := &Pow{} + err := p.Init(nil) + assert.Nil(t, err) +} + +func TestPow(t *testing.T) { + tests := []struct { + version int64 + backing0 any + backing1 any + shapes [][]int + expected any + }{ + { + 13, + []float32{0, 1, 2, 3}, + []float32{1, 1, 1, 1}, + [][]int{{2, 2}, {2, 2}}, + []float32{0, 1, 2, 3}, + }, + { + 13, + []float32{0, 1, 2, 3, 4, 5}, + []float32{2, 2, 2, 2, 2, 2}, + [][]int{{3, 2}, {3, 2}}, + []float32{0, 1, 4, 9, 16, 25}, + }, + { + 13, + []float32{0, 1}, + []float32{0, 1, 2, 3}, + [][]int{{2}, {2, 2}}, + []float32{1, 1, 0, 1}, + }, + { + 13, + []int32{1, 2, 3}, + []int32{4, 5, 6}, + [][]int{{3}, {3}}, + []int32{1, 32, 729}, + }, + } + + for _, test := range tests { + inputs := []tensor.Tensor{ + ops.TensorWithBackingFixture(test.backing0, test.shapes[0]...), + ops.TensorWithBackingFixture(test.backing1, test.shapes[1]...), + } + + pow := powVersions[test.version]() + + res, err := pow.Apply(inputs) + assert.Nil(t, err) + + assert.Equal(t, test.expected, res[0].Data()) + } +} diff --git a/ops/pow/versions.go b/ops/pow/versions.go new file mode 100644 index 0000000..eb77068 --- /dev/null +++ b/ops/pow/versions.go @@ -0,0 +1,15 @@ +package pow + +import ( + "github.com/advancedclimatesystems/gonnx/ops" +) + +var powVersions = ops.OperatorVersions{ + 7: ops.NewOperatorConstructor(newPow, 7, pow7TypeConstraints), + 12: ops.NewOperatorConstructor(newPow, 12, powTypeConstraints), + 13: ops.NewOperatorConstructor(newPow, 13, powTypeConstraints), +} + +func GetVersions() ops.OperatorVersions { + return powVersions +} diff --git a/ops/types.go b/ops/types.go index fdc0f81..385ba0c 100644 --- a/ops/types.go +++ b/ops/types.go @@ -21,6 +21,12 @@ var AllTypes = []tensor.Dtype{ tensor.Bool, } +// IntTypes is a list with all integer types. +var IntTypes = []tensor.Dtype{ + tensor.Int8, tensor.Int16, tensor.Int32, tensor.Int64, + tensor.Uint8, tensor.Uint16, tensor.Uint32, tensor.Uint64, +} + // NumericTypes is a list with all numeric types. var NumericTypes = []tensor.Dtype{ tensor.Uint8, tensor.Uint16, tensor.Uint32, tensor.Uint64, diff --git a/ops_test.go b/ops_test.go index 8965ff2..fc18995 100644 --- a/ops_test.go +++ b/ops_test.go @@ -187,7 +187,7 @@ func TestOps(t *testing.T) { if expectedTensor.Dtype() == tensor.Bool { assert.ElementsMatch(t, expectedTensor.Data(), actualTensor.Data()) } else { - assert.InDeltaSlice(t, expectedTensor.Data(), actualTensor.Data(), 0.00001) + assert.InDeltaSlice(t, expectedTensor.Data(), actualTensor.Data(), 0.001) } } }) @@ -472,6 +472,18 @@ var expectedTests = []string{ "test_or_bcast4v2d", "test_or_bcast4v3d", "test_or_bcast4v4d", + "test_pow", + "test_pow_bcast_array", + "test_pow_bcast_scalar", + "test_pow_example", + "test_pow_types_float32_int32", + "test_pow_types_float32_int64", + "test_pow_types_float32_uint32", + "test_pow_types_float32_uint64", + "test_pow_types_int32_float32", + "test_pow_types_int32_int32", + "test_pow_types_int64_float32", + "test_pow_types_int64_int64", "test_prelu_broadcast", "test_prelu_example", "test_relu", diff --git a/opset.go b/opset.go index 79437ad..ec6768a 100644 --- a/opset.go +++ b/opset.go @@ -38,6 +38,7 @@ import ( "github.com/advancedclimatesystems/gonnx/ops/mul" "github.com/advancedclimatesystems/gonnx/ops/not" "github.com/advancedclimatesystems/gonnx/ops/or" + "github.com/advancedclimatesystems/gonnx/ops/pow" "github.com/advancedclimatesystems/gonnx/ops/prelu" "github.com/advancedclimatesystems/gonnx/ops/reducemax" "github.com/advancedclimatesystems/gonnx/ops/reducemin" @@ -106,6 +107,7 @@ var operators = map[string]ops.OperatorVersions{ "Mul": mul.GetMulVersions(), "Not": not.GetNotVersions(), "Or": or.GetOrVersions(), + "Pow": pow.GetVersions(), "PRelu": prelu.GetPReluVersions(), "ReduceMax": reducemax.GetReduceMaxVersions(), "ReduceMin": reducemin.GetReduceMinVersions(),