Skip to content

Commit

Permalink
Added Where operator (#227)
Browse files Browse the repository at this point in the history
  • Loading branch information
Swopper050 authored Dec 22, 2024
1 parent c33e5d2 commit 3ea18cc
Show file tree
Hide file tree
Showing 5 changed files with 212 additions and 0 deletions.
11 changes: 11 additions & 0 deletions ops/where/versions.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
package where

import "github.com/advancedclimatesystems/gonnx/ops"

var whereVersions = ops.OperatorVersions{
9: ops.NewOperatorConstructor(newWhere, 9, whereTypeConstraints),
}

func GetVersions() ops.OperatorVersions {
return whereVersions
}
105 changes: 105 additions & 0 deletions ops/where/where.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,105 @@
package where

import (
"github.com/advancedclimatesystems/gonnx/onnx"
"github.com/advancedclimatesystems/gonnx/ops"
"gorgonia.org/tensor"
)

var whereTypeConstraints = [][]tensor.Dtype{
{tensor.Bool},
ops.AllTypes,
ops.AllTypes,
}

// Where represents the ONNX where operator.
type Where struct {
ops.BaseOperator
}

// newWhere creates a new where operator.
func newWhere(version int, typeConstraints [][]tensor.Dtype) ops.Operator {
return &Where{
BaseOperator: ops.NewBaseOperator(
version,
3,
3,
typeConstraints,
"where",
),
}
}

// Init initializes the where operator.
func (w *Where) Init(*onnx.NodeProto) error {
return nil
}

// Apply applies the where operator.
func (w *Where) Apply(inputs []tensor.Tensor) ([]tensor.Tensor, error) {
condition := inputs[0]

X := inputs[1]
Y := inputs[2]

X, Y, err := ops.MultidirectionalBroadcast(X, Y)
if err != nil {
return nil, err
}

condition, X, err = ops.MultidirectionalBroadcast(condition, X)
if err != nil {
return nil, err
}

out, err := where(X, Y, condition)
if err != nil {
return nil, err
}

return []tensor.Tensor{out}, err
}

func where(X, Y, condition tensor.Tensor) (tensor.Tensor, error) {
out := tensor.New(tensor.Of(X.Dtype()), tensor.WithShape(X.Shape()...))

iterator := condition.Iterator()
iterator.Reset()

for !iterator.Done() {
coords := iterator.Coord()

conditionRaw, err := condition.At(coords...)
if err != nil {
return nil, err
}

conditionValue, ok := conditionRaw.(bool)
if !ok {
return nil, ops.ErrCast
}

var value any
if conditionValue {
value, err = X.At(coords...)
} else {
value, err = Y.At(coords...)
}

if err != nil {
return nil, err
}

err = out.SetAt(value, coords...)
if err != nil {
return nil, err
}

_, err = iterator.Next()
if err != nil {
return nil, err
}
}

return out, nil
}
92 changes: 92 additions & 0 deletions ops/where/where_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,92 @@
package where

import (
"testing"

"github.com/stretchr/testify/assert"
"gorgonia.org/tensor"
)

func TestWhereInit(t *testing.T) {
op := whereVersions[9]()
err := op.Init(nil)
assert.Nil(t, err)
}

func TestWhere(t *testing.T) {
tests := []struct {
version int64
condition []bool
conditionShape []int
backing1 []float32
backing1Shape []int
backing2 []float32
backing2Shape []int
expectedBacking []float32
}{
{
9,
[]bool{true, false, true},
[]int{3},
[]float32{1, 2, 3},
[]int{3},
[]float32{4, 5, 6},
[]int{3},
[]float32{1, 5, 3},
},
{
9,
[]bool{true, false, true, false},
[]int{2, 2},
[]float32{1, 2, 3, 4},
[]int{2, 2},
[]float32{4, 5},
[]int{1, 2},
[]float32{1, 5, 3, 5},
},
{
9,
[]bool{false, true},
[]int{2},
[]float32{1, 2, 3, 4},
[]int{2, 2},
[]float32{4, 5},
[]int{1, 2},
[]float32{4, 2, 4, 4},
},
{
9,
[]bool{false, false, false, true, true, true},
[]int{2, 3},
[]float32{1, 2, 3, 4, 5, 6},
[]int{2, 3},
[]float32{4, 5, 6},
[]int{3},
[]float32{4, 5, 6, 4, 5, 6},
},
{
9,
[]bool{false, true, true, false, false, true},
[]int{2, 3},
[]float32{1, 2, 3, 4, 5, 6},
[]int{2, 3},
[]float32{4, 5, 6},
[]int{3},
[]float32{4, 2, 3, 4, 5, 6},
},
}

for _, test := range tests {
inputs := []tensor.Tensor{
tensor.New(tensor.WithShape(test.conditionShape...), tensor.WithBacking(test.condition)),
tensor.New(tensor.WithShape(test.backing1Shape...), tensor.WithBacking(test.backing1)),
tensor.New(tensor.WithShape(test.backing2Shape...), tensor.WithBacking(test.backing2)),
}

op := whereVersions[test.version]()

res, err := op.Apply(inputs)
assert.Nil(t, err)
assert.Equal(t, test.expectedBacking, res[0].Data())
}
}
2 changes: 2 additions & 0 deletions ops_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -555,6 +555,8 @@ var expectedTests = []string{
"test_unsqueeze_three_axes",
"test_unsqueeze_two_axes",
"test_unsqueeze_unsorted_axes",
"test_where_example",
"test_where_long_example",
"test_xor_bcast3v1d",
"test_xor_bcast3v2d",
"test_xor_bcast4v2d",
Expand Down
2 changes: 2 additions & 0 deletions opset.go
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,7 @@ import (
"github.com/advancedclimatesystems/gonnx/ops/tanh"
"github.com/advancedclimatesystems/gonnx/ops/transpose"
"github.com/advancedclimatesystems/gonnx/ops/unsqueeze"
"github.com/advancedclimatesystems/gonnx/ops/where"
"github.com/advancedclimatesystems/gonnx/ops/xor"
)

Expand Down Expand Up @@ -132,6 +133,7 @@ var operators = map[string]ops.OperatorVersions{
"Tanh": tanh.GetTanhVersions(),
"Transpose": transpose.GetTransposeVersions(),
"Unsqueeze": unsqueeze.GetUnsqueezeVersions(),
"Where": where.GetVersions(),
"Xor": xor.GetXorVersions(),
}

Expand Down

0 comments on commit 3ea18cc

Please sign in to comment.