Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Added Pow operator #222

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 4 additions & 4 deletions .github/workflows/go.yml
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ jobs:
- name: Set up Go
uses: actions/setup-go@v3
with:
go-version: 1.19
go-version: 1.21
Swopper050 marked this conversation as resolved.
Show resolved Hide resolved

- name: Install linter
run: make install_lint
Expand All @@ -34,7 +34,7 @@ jobs:
- name: Set up Go
uses: actions/setup-go@v3
with:
go-version: 1.19
go-version: 1.21

- name: Install dependencies
run: make install
Expand All @@ -56,7 +56,7 @@ jobs:
- name: Set up Go
uses: actions/setup-go@v3
with:
go-version: 1.19
go-version: 1.21

- name: Build amd64
run: make build_amd64
Expand All @@ -69,7 +69,7 @@ jobs:
- name: Set up Go
uses: actions/setup-go@v3
with:
go-version: 1.19
go-version: 1.21

- name: Build arm64
run: make build_arm64
2 changes: 1 addition & 1 deletion onnx/graph_proto.go
Original file line number Diff line number Diff line change
Expand Up @@ -556,7 +556,7 @@ func ReadInt32ArrayFromBytes(data []byte) ([]int32, error) {
// ReadUint64ArrayFromBytes reads data and parses it to an array of uint64.
func ReadUint64ArrayFromBytes(data []byte) ([]uint64, error) {
buffer := bytes.NewReader(data)
element := make([]byte, int32Size)
element := make([]byte, int64Size)

var (
err error
Expand Down
44 changes: 44 additions & 0 deletions ops/binary_op.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
package ops

import (
"slices"

"gorgonia.org/tensor"
)

Expand Down Expand Up @@ -48,6 +50,48 @@ func Mul(A, B tensor.Tensor) (tensor.Tensor, error) {
return tensor.Mul(A, B)
}

// Pow raises the first tensor to the power of the second tensor.
// Because the gorgonia.Tensor 'Pow' operation only supports float32 and float64,
// we need to convert the tensors to float64 if they are of a different type.
// After the operation is done, we convert the result back to the original type.
func Pow(A, B tensor.Tensor) (tensor.Tensor, error) {
needsConversion := false
if slices.Contains(IntTypes, A.Dtype()) {
needsConversion = true
}

if !needsConversion {
return tensor.Pow(A, B)
}

oldType, err := DTypeToONNXType(A.Dtype())
if err != nil {
return nil, err
}

newType, err := DTypeToONNXType(tensor.Float64)
if err != nil {
return nil, err
}

A, err = ConvertTensorDtype(A, newType)
if err != nil {
return nil, err
}

B, err = ConvertTensorDtype(B, newType)
if err != nil {
return nil, err
}

out, err := tensor.Pow(A, B)
if err != nil {
return nil, err
}

return ConvertTensorDtype(out, oldType)
}

// Sub subtracts 1 tensor from the other.
func Sub(A, B tensor.Tensor) (tensor.Tensor, error) {
return tensor.Sub(A, B)
Expand Down
27 changes: 27 additions & 0 deletions ops/convert.go
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,33 @@ func ConvertTensorDtype(t tensor.Tensor, newType int32) (tensor.Tensor, error) {
return tensor.New(tensor.WithShape(t.Shape()...), tensor.WithBacking(newBacking)), nil
}

func DTypeToONNXType(t tensor.Dtype) (int32, error) {
switch t {
case tensor.Float32:
return int32(onnx.TensorProto_FLOAT), nil
case tensor.Float64:
return int32(onnx.TensorProto_DOUBLE), nil
case tensor.Int8:
return int32(onnx.TensorProto_INT8), nil
case tensor.Int16:
return int32(onnx.TensorProto_INT16), nil
case tensor.Int32:
return int32(onnx.TensorProto_INT32), nil
case tensor.Int64:
return int32(onnx.TensorProto_INT64), nil
case tensor.Uint8:
return int32(onnx.TensorProto_UINT8), nil
case tensor.Uint16:
return int32(onnx.TensorProto_UINT16), nil
case tensor.Uint32:
return int32(onnx.TensorProto_UINT32), nil
case tensor.Uint64:
return int32(onnx.TensorProto_UINT64), nil
default:
return 0, ErrUnknownTensorONNXDtype(t)
}
}

func convertBacking[B Number](backing []B, dataType int32) (any, error) {
switch onnx.TensorProto_DataType(dataType) {
case onnx.TensorProto_FLOAT:
Expand Down
4 changes: 4 additions & 0 deletions ops/errors.go
Original file line number Diff line number Diff line change
Expand Up @@ -282,6 +282,10 @@ func ErrConversionNotSupported(dType int32) error {
return fmt.Errorf("%w: to %v is not supported yet", ErrConversion, dType)
}

func ErrUnknownTensorONNXDtype(dType tensor.Dtype) error {
return fmt.Errorf("%w: tensor with dtype %v does not have a corresponding onnx type", ErrCast, dType)
}

var ErrActivationNotImplementedBase = errors.New("the given activation function is not implemented")

func ErrActivationNotImplemented(activation string) error {
Expand Down
63 changes: 63 additions & 0 deletions ops/pow/pow.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,63 @@
package pow

import (
"github.com/advancedclimatesystems/gonnx/onnx"
"github.com/advancedclimatesystems/gonnx/ops"
"gorgonia.org/tensor"
)

var pow7TypeConstraints = [][]tensor.Dtype{
{tensor.Float32, tensor.Float64},
{tensor.Float32, tensor.Float64},
}

var powTypeConstraints = [][]tensor.Dtype{
{tensor.Int32, tensor.Int64, tensor.Float32, tensor.Float64},
{tensor.Uint8, tensor.Uint16, tensor.Uint32, tensor.Uint64, tensor.Int8, tensor.Int16, tensor.Int32, tensor.Int64, tensor.Float32, tensor.Float64},
}

// Pow represents the ONNX pow operator.
type Pow struct {
ops.BaseOperator
}

// newPow creates a new pow operator.
func newPow(version int, typeConstraints [][]tensor.Dtype) ops.Operator {
return &Pow{
BaseOperator: ops.NewBaseOperator(
version,
2,
2,
typeConstraints,
"pow",
),
}
}

// Init initializes the pow operator.
func (a *Pow) Init(*onnx.NodeProto) error {
return nil
}

// Apply applies the pow operator.
func (a *Pow) Apply(inputs []tensor.Tensor) ([]tensor.Tensor, error) {
powTensor := inputs[1]
if inputs[0].Dtype() != powTensor.Dtype() {
to, err := ops.DTypeToONNXType(inputs[0].Dtype())
if err != nil {
return nil, err
}

powTensor, err = ops.ConvertTensorDtype(powTensor, to)
if err != nil {
return nil, err
}
}

return ops.ApplyBinaryOperation(
inputs[0],
powTensor,
ops.Pow,
ops.MultidirectionalBroadcasting,
)
}
68 changes: 68 additions & 0 deletions ops/pow/pow_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,68 @@
package pow

import (
"testing"

"github.com/advancedclimatesystems/gonnx/ops"
"github.com/stretchr/testify/assert"
"gorgonia.org/tensor"
)

func TestPowInit(t *testing.T) {
p := &Pow{}
err := p.Init(nil)
assert.Nil(t, err)
}

func TestPow(t *testing.T) {
tests := []struct {
version int64
backing0 any
backing1 any
shapes [][]int
expected any
}{
{
13,
[]float32{0, 1, 2, 3},
[]float32{1, 1, 1, 1},
[][]int{{2, 2}, {2, 2}},
[]float32{0, 1, 2, 3},
},
{
13,
[]float32{0, 1, 2, 3, 4, 5},
[]float32{2, 2, 2, 2, 2, 2},
[][]int{{3, 2}, {3, 2}},
[]float32{0, 1, 4, 9, 16, 25},
},
{
13,
[]float32{0, 1},
[]float32{0, 1, 2, 3},
[][]int{{2}, {2, 2}},
[]float32{1, 1, 0, 1},
},
{
13,
[]int32{1, 2, 3},
[]int32{4, 5, 6},
[][]int{{3}, {3}},
[]int32{1, 32, 729},
},
}

for _, test := range tests {
inputs := []tensor.Tensor{
ops.TensorWithBackingFixture(test.backing0, test.shapes[0]...),
ops.TensorWithBackingFixture(test.backing1, test.shapes[1]...),
}

pow := powVersions[test.version]()

res, err := pow.Apply(inputs)
assert.Nil(t, err)

assert.Equal(t, test.expected, res[0].Data())
}
}
15 changes: 15 additions & 0 deletions ops/pow/versions.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
package pow

import (
"github.com/advancedclimatesystems/gonnx/ops"
)

var powVersions = ops.OperatorVersions{
7: ops.NewOperatorConstructor(newPow, 7, pow7TypeConstraints),
12: ops.NewOperatorConstructor(newPow, 12, powTypeConstraints),
13: ops.NewOperatorConstructor(newPow, 13, powTypeConstraints),
}

func GetVersions() ops.OperatorVersions {
return powVersions
}
6 changes: 6 additions & 0 deletions ops/types.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,12 @@ var AllTypes = []tensor.Dtype{
tensor.Bool,
}

// IntTypes is a list with all integer types.
var IntTypes = []tensor.Dtype{
tensor.Int8, tensor.Int16, tensor.Int32, tensor.Int64,
tensor.Uint8, tensor.Uint16, tensor.Uint32, tensor.Uint64,
}

// NumericTypes is a list with all numeric types.
var NumericTypes = []tensor.Dtype{
tensor.Uint8, tensor.Uint16, tensor.Uint32, tensor.Uint64,
Expand Down
14 changes: 13 additions & 1 deletion ops_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -187,7 +187,7 @@ func TestOps(t *testing.T) {
if expectedTensor.Dtype() == tensor.Bool {
assert.ElementsMatch(t, expectedTensor.Data(), actualTensor.Data())
} else {
assert.InDeltaSlice(t, expectedTensor.Data(), actualTensor.Data(), 0.00001)
assert.InDeltaSlice(t, expectedTensor.Data(), actualTensor.Data(), 0.001)
}
}
})
Expand Down Expand Up @@ -472,6 +472,18 @@ var expectedTests = []string{
"test_or_bcast4v2d",
"test_or_bcast4v3d",
"test_or_bcast4v4d",
"test_pow",
"test_pow_bcast_array",
"test_pow_bcast_scalar",
"test_pow_example",
"test_pow_types_float32_int32",
"test_pow_types_float32_int64",
"test_pow_types_float32_uint32",
"test_pow_types_float32_uint64",
"test_pow_types_int32_float32",
"test_pow_types_int32_int32",
"test_pow_types_int64_float32",
"test_pow_types_int64_int64",
"test_prelu_broadcast",
"test_prelu_example",
"test_relu",
Expand Down
2 changes: 2 additions & 0 deletions opset.go
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@ import (
"github.com/advancedclimatesystems/gonnx/ops/mul"
"github.com/advancedclimatesystems/gonnx/ops/not"
"github.com/advancedclimatesystems/gonnx/ops/or"
"github.com/advancedclimatesystems/gonnx/ops/pow"
"github.com/advancedclimatesystems/gonnx/ops/prelu"
"github.com/advancedclimatesystems/gonnx/ops/reducemax"
"github.com/advancedclimatesystems/gonnx/ops/reducemin"
Expand Down Expand Up @@ -106,6 +107,7 @@ var operators = map[string]ops.OperatorVersions{
"Mul": mul.GetMulVersions(),
"Not": not.GetNotVersions(),
"Or": or.GetOrVersions(),
"Pow": pow.GetVersions(),
"PRelu": prelu.GetPReluVersions(),
"ReduceMax": reducemax.GetReduceMaxVersions(),
"ReduceMin": reducemin.GetReduceMinVersions(),
Expand Down
Loading