Skip to content

Commit

Permalink
Added Pow operator (#222)
Browse files Browse the repository at this point in the history
* Added Pow operator

* Handle non-float pow csaes

* Fix lint

* Update pipeline go versions

* Use go1.23

* Back to go1.21
  • Loading branch information
Swopper050 authored Dec 22, 2024
1 parent 22331fb commit c97d70c
Show file tree
Hide file tree
Showing 11 changed files with 247 additions and 6 deletions.
8 changes: 4 additions & 4 deletions .github/workflows/go.yml
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ jobs:
- name: Set up Go
uses: actions/setup-go@v3
with:
go-version: 1.19
go-version: 1.21

- name: Install linter
run: make install_lint
Expand All @@ -34,7 +34,7 @@ jobs:
- name: Set up Go
uses: actions/setup-go@v3
with:
go-version: 1.19
go-version: 1.21

- name: Install dependencies
run: make install
Expand All @@ -56,7 +56,7 @@ jobs:
- name: Set up Go
uses: actions/setup-go@v3
with:
go-version: 1.19
go-version: 1.21

- name: Build amd64
run: make build_amd64
Expand All @@ -69,7 +69,7 @@ jobs:
- name: Set up Go
uses: actions/setup-go@v3
with:
go-version: 1.19
go-version: 1.21

- name: Build arm64
run: make build_arm64
2 changes: 1 addition & 1 deletion onnx/graph_proto.go
Original file line number Diff line number Diff line change
Expand Up @@ -556,7 +556,7 @@ func ReadInt32ArrayFromBytes(data []byte) ([]int32, error) {
// ReadUint64ArrayFromBytes reads data and parses it to an array of uint64.
func ReadUint64ArrayFromBytes(data []byte) ([]uint64, error) {
buffer := bytes.NewReader(data)
element := make([]byte, int32Size)
element := make([]byte, int64Size)

var (
err error
Expand Down
44 changes: 44 additions & 0 deletions ops/binary_op.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
package ops

import (
"slices"

"gorgonia.org/tensor"
)

Expand Down Expand Up @@ -48,6 +50,48 @@ func Mul(A, B tensor.Tensor) (tensor.Tensor, error) {
return tensor.Mul(A, B)
}

// Pow raises the first tensor to the power of the second tensor.
// Because the gorgonia.Tensor 'Pow' operation only supports float32 and float64,
// we need to convert the tensors to float64 if they are of a different type.
// After the operation is done, we convert the result back to the original type.
func Pow(A, B tensor.Tensor) (tensor.Tensor, error) {
needsConversion := false
if slices.Contains(IntTypes, A.Dtype()) {
needsConversion = true
}

if !needsConversion {
return tensor.Pow(A, B)
}

oldType, err := DTypeToONNXType(A.Dtype())
if err != nil {
return nil, err
}

newType, err := DTypeToONNXType(tensor.Float64)
if err != nil {
return nil, err
}

A, err = ConvertTensorDtype(A, newType)
if err != nil {
return nil, err
}

B, err = ConvertTensorDtype(B, newType)
if err != nil {
return nil, err
}

out, err := tensor.Pow(A, B)
if err != nil {
return nil, err
}

return ConvertTensorDtype(out, oldType)
}

// Sub subtracts 1 tensor from the other.
func Sub(A, B tensor.Tensor) (tensor.Tensor, error) {
return tensor.Sub(A, B)
Expand Down
27 changes: 27 additions & 0 deletions ops/convert.go
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,33 @@ func ConvertTensorDtype(t tensor.Tensor, newType int32) (tensor.Tensor, error) {
return tensor.New(tensor.WithShape(t.Shape()...), tensor.WithBacking(newBacking)), nil
}

func DTypeToONNXType(t tensor.Dtype) (int32, error) {
switch t {
case tensor.Float32:
return int32(onnx.TensorProto_FLOAT), nil
case tensor.Float64:
return int32(onnx.TensorProto_DOUBLE), nil
case tensor.Int8:
return int32(onnx.TensorProto_INT8), nil
case tensor.Int16:
return int32(onnx.TensorProto_INT16), nil
case tensor.Int32:
return int32(onnx.TensorProto_INT32), nil
case tensor.Int64:
return int32(onnx.TensorProto_INT64), nil
case tensor.Uint8:
return int32(onnx.TensorProto_UINT8), nil
case tensor.Uint16:
return int32(onnx.TensorProto_UINT16), nil
case tensor.Uint32:
return int32(onnx.TensorProto_UINT32), nil
case tensor.Uint64:
return int32(onnx.TensorProto_UINT64), nil
default:
return 0, ErrUnknownTensorONNXDtype(t)
}
}

func convertBacking[B Number](backing []B, dataType int32) (any, error) {
switch onnx.TensorProto_DataType(dataType) {
case onnx.TensorProto_FLOAT:
Expand Down
4 changes: 4 additions & 0 deletions ops/errors.go
Original file line number Diff line number Diff line change
Expand Up @@ -282,6 +282,10 @@ func ErrConversionNotSupported(dType int32) error {
return fmt.Errorf("%w: to %v is not supported yet", ErrConversion, dType)
}

func ErrUnknownTensorONNXDtype(dType tensor.Dtype) error {
return fmt.Errorf("%w: tensor with dtype %v does not have a corresponding onnx type", ErrCast, dType)
}

var ErrActivationNotImplementedBase = errors.New("the given activation function is not implemented")

func ErrActivationNotImplemented(activation string) error {
Expand Down
63 changes: 63 additions & 0 deletions ops/pow/pow.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,63 @@
package pow

import (
"github.com/advancedclimatesystems/gonnx/onnx"
"github.com/advancedclimatesystems/gonnx/ops"
"gorgonia.org/tensor"
)

var pow7TypeConstraints = [][]tensor.Dtype{
{tensor.Float32, tensor.Float64},
{tensor.Float32, tensor.Float64},
}

var powTypeConstraints = [][]tensor.Dtype{
{tensor.Int32, tensor.Int64, tensor.Float32, tensor.Float64},
{tensor.Uint8, tensor.Uint16, tensor.Uint32, tensor.Uint64, tensor.Int8, tensor.Int16, tensor.Int32, tensor.Int64, tensor.Float32, tensor.Float64},
}

// Pow represents the ONNX pow operator.
type Pow struct {
ops.BaseOperator
}

// newPow creates a new pow operator.
func newPow(version int, typeConstraints [][]tensor.Dtype) ops.Operator {
return &Pow{
BaseOperator: ops.NewBaseOperator(
version,
2,
2,
typeConstraints,
"pow",
),
}
}

// Init initializes the pow operator.
func (a *Pow) Init(*onnx.NodeProto) error {
return nil
}

// Apply applies the pow operator.
func (a *Pow) Apply(inputs []tensor.Tensor) ([]tensor.Tensor, error) {
powTensor := inputs[1]
if inputs[0].Dtype() != powTensor.Dtype() {
to, err := ops.DTypeToONNXType(inputs[0].Dtype())
if err != nil {
return nil, err
}

powTensor, err = ops.ConvertTensorDtype(powTensor, to)
if err != nil {
return nil, err
}
}

return ops.ApplyBinaryOperation(
inputs[0],
powTensor,
ops.Pow,
ops.MultidirectionalBroadcasting,
)
}
68 changes: 68 additions & 0 deletions ops/pow/pow_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,68 @@
package pow

import (
"testing"

"github.com/advancedclimatesystems/gonnx/ops"
"github.com/stretchr/testify/assert"
"gorgonia.org/tensor"
)

func TestPowInit(t *testing.T) {
p := &Pow{}
err := p.Init(nil)
assert.Nil(t, err)
}

func TestPow(t *testing.T) {
tests := []struct {
version int64
backing0 any
backing1 any
shapes [][]int
expected any
}{
{
13,
[]float32{0, 1, 2, 3},
[]float32{1, 1, 1, 1},
[][]int{{2, 2}, {2, 2}},
[]float32{0, 1, 2, 3},
},
{
13,
[]float32{0, 1, 2, 3, 4, 5},
[]float32{2, 2, 2, 2, 2, 2},
[][]int{{3, 2}, {3, 2}},
[]float32{0, 1, 4, 9, 16, 25},
},
{
13,
[]float32{0, 1},
[]float32{0, 1, 2, 3},
[][]int{{2}, {2, 2}},
[]float32{1, 1, 0, 1},
},
{
13,
[]int32{1, 2, 3},
[]int32{4, 5, 6},
[][]int{{3}, {3}},
[]int32{1, 32, 729},
},
}

for _, test := range tests {
inputs := []tensor.Tensor{
ops.TensorWithBackingFixture(test.backing0, test.shapes[0]...),
ops.TensorWithBackingFixture(test.backing1, test.shapes[1]...),
}

pow := powVersions[test.version]()

res, err := pow.Apply(inputs)
assert.Nil(t, err)

assert.Equal(t, test.expected, res[0].Data())
}
}
15 changes: 15 additions & 0 deletions ops/pow/versions.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
package pow

import (
"github.com/advancedclimatesystems/gonnx/ops"
)

var powVersions = ops.OperatorVersions{
7: ops.NewOperatorConstructor(newPow, 7, pow7TypeConstraints),
12: ops.NewOperatorConstructor(newPow, 12, powTypeConstraints),
13: ops.NewOperatorConstructor(newPow, 13, powTypeConstraints),
}

func GetVersions() ops.OperatorVersions {
return powVersions
}
6 changes: 6 additions & 0 deletions ops/types.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,12 @@ var AllTypes = []tensor.Dtype{
tensor.Bool,
}

// IntTypes is a list with all integer types.
var IntTypes = []tensor.Dtype{
tensor.Int8, tensor.Int16, tensor.Int32, tensor.Int64,
tensor.Uint8, tensor.Uint16, tensor.Uint32, tensor.Uint64,
}

// NumericTypes is a list with all numeric types.
var NumericTypes = []tensor.Dtype{
tensor.Uint8, tensor.Uint16, tensor.Uint32, tensor.Uint64,
Expand Down
14 changes: 13 additions & 1 deletion ops_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -195,7 +195,7 @@ func TestOps(t *testing.T) {
if expectedTensor.Dtype() == tensor.Bool {
assert.ElementsMatch(t, expectedTensor.Data(), actualTensor.Data())
} else {
assert.InDeltaSlice(t, expectedTensor.Data(), actualTensor.Data(), 0.00001)
assert.InDeltaSlice(t, expectedTensor.Data(), actualTensor.Data(), 0.001)
}
}
})
Expand Down Expand Up @@ -480,6 +480,18 @@ var expectedTests = []string{
"test_or_bcast4v2d",
"test_or_bcast4v3d",
"test_or_bcast4v4d",
"test_pow",
"test_pow_bcast_array",
"test_pow_bcast_scalar",
"test_pow_example",
"test_pow_types_float32_int32",
"test_pow_types_float32_int64",
"test_pow_types_float32_uint32",
"test_pow_types_float32_uint64",
"test_pow_types_int32_float32",
"test_pow_types_int32_int32",
"test_pow_types_int64_float32",
"test_pow_types_int64_int64",
"test_prelu_broadcast",
"test_prelu_example",
"test_relu",
Expand Down
2 changes: 2 additions & 0 deletions opset.go
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@ import (
"github.com/advancedclimatesystems/gonnx/ops/mul"
"github.com/advancedclimatesystems/gonnx/ops/not"
"github.com/advancedclimatesystems/gonnx/ops/or"
"github.com/advancedclimatesystems/gonnx/ops/pow"
"github.com/advancedclimatesystems/gonnx/ops/prelu"
"github.com/advancedclimatesystems/gonnx/ops/reducemax"
"github.com/advancedclimatesystems/gonnx/ops/reducemean"
Expand Down Expand Up @@ -107,6 +108,7 @@ var operators = map[string]ops.OperatorVersions{
"Mul": mul.GetMulVersions(),
"Not": not.GetNotVersions(),
"Or": or.GetOrVersions(),
"Pow": pow.GetVersions(),
"PRelu": prelu.GetPReluVersions(),
"ReduceMax": reducemax.GetReduceMaxVersions(),
"ReduceMean": reducemean.GetVersions(),
Expand Down

0 comments on commit c97d70c

Please sign in to comment.