diff --git a/ops/where/versions.go b/ops/where/versions.go new file mode 100644 index 0000000..5ce1463 --- /dev/null +++ b/ops/where/versions.go @@ -0,0 +1,11 @@ +package where + +import "github.com/advancedclimatesystems/gonnx/ops" + +var whereVersions = ops.OperatorVersions{ + 9: ops.NewOperatorConstructor(newWhere, 9, whereTypeConstraints), +} + +func GetVersions() ops.OperatorVersions { + return whereVersions +} diff --git a/ops/where/where.go b/ops/where/where.go new file mode 100644 index 0000000..1c56739 --- /dev/null +++ b/ops/where/where.go @@ -0,0 +1,105 @@ +package where + +import ( + "github.com/advancedclimatesystems/gonnx/onnx" + "github.com/advancedclimatesystems/gonnx/ops" + "gorgonia.org/tensor" +) + +var whereTypeConstraints = [][]tensor.Dtype{ + {tensor.Bool}, + ops.AllTypes, + ops.AllTypes, +} + +// Where represents the ONNX where operator. +type Where struct { + ops.BaseOperator +} + +// newWhere creates a new where operator. +func newWhere(version int, typeConstraints [][]tensor.Dtype) ops.Operator { + return &Where{ + BaseOperator: ops.NewBaseOperator( + version, + 3, + 3, + typeConstraints, + "where", + ), + } +} + +// Init initializes the where operator. +func (w *Where) Init(*onnx.NodeProto) error { + return nil +} + +// Apply applies the where operator. +func (w *Where) Apply(inputs []tensor.Tensor) ([]tensor.Tensor, error) { + condition := inputs[0] + + X := inputs[1] + Y := inputs[2] + + X, Y, err := ops.MultidirectionalBroadcast(X, Y) + if err != nil { + return nil, err + } + + condition, X, err = ops.MultidirectionalBroadcast(condition, X) + if err != nil { + return nil, err + } + + out, err := where(X, Y, condition) + if err != nil { + return nil, err + } + + return []tensor.Tensor{out}, err +} + +func where(X, Y, condition tensor.Tensor) (tensor.Tensor, error) { + out := tensor.New(tensor.Of(X.Dtype()), tensor.WithShape(X.Shape()...)) + + iterator := condition.Iterator() + iterator.Reset() + + for !iterator.Done() { + coords := iterator.Coord() + + conditionRaw, err := condition.At(coords...) + if err != nil { + return nil, err + } + + conditionValue, ok := conditionRaw.(bool) + if !ok { + return nil, ops.ErrCast + } + + var value any + if conditionValue { + value, err = X.At(coords...) + } else { + value, err = Y.At(coords...) + } + + if err != nil { + return nil, err + } + + err = out.SetAt(value, coords...) + if err != nil { + return nil, err + } + + _, err = iterator.Next() + if err != nil { + return nil, err + } + } + + return out, nil +} diff --git a/ops/where/where_test.go b/ops/where/where_test.go new file mode 100644 index 0000000..25d0d48 --- /dev/null +++ b/ops/where/where_test.go @@ -0,0 +1,92 @@ +package where + +import ( + "testing" + + "github.com/stretchr/testify/assert" + "gorgonia.org/tensor" +) + +func TestWhereInit(t *testing.T) { + op := whereVersions[9]() + err := op.Init(nil) + assert.Nil(t, err) +} + +func TestWhere(t *testing.T) { + tests := []struct { + version int64 + condition []bool + conditionShape []int + backing1 []float32 + backing1Shape []int + backing2 []float32 + backing2Shape []int + expectedBacking []float32 + }{ + { + 9, + []bool{true, false, true}, + []int{3}, + []float32{1, 2, 3}, + []int{3}, + []float32{4, 5, 6}, + []int{3}, + []float32{1, 5, 3}, + }, + { + 9, + []bool{true, false, true, false}, + []int{2, 2}, + []float32{1, 2, 3, 4}, + []int{2, 2}, + []float32{4, 5}, + []int{1, 2}, + []float32{1, 5, 3, 5}, + }, + { + 9, + []bool{false, true}, + []int{2}, + []float32{1, 2, 3, 4}, + []int{2, 2}, + []float32{4, 5}, + []int{1, 2}, + []float32{4, 2, 4, 4}, + }, + { + 9, + []bool{false, false, false, true, true, true}, + []int{2, 3}, + []float32{1, 2, 3, 4, 5, 6}, + []int{2, 3}, + []float32{4, 5, 6}, + []int{3}, + []float32{4, 5, 6, 4, 5, 6}, + }, + { + 9, + []bool{false, true, true, false, false, true}, + []int{2, 3}, + []float32{1, 2, 3, 4, 5, 6}, + []int{2, 3}, + []float32{4, 5, 6}, + []int{3}, + []float32{4, 2, 3, 4, 5, 6}, + }, + } + + for _, test := range tests { + inputs := []tensor.Tensor{ + tensor.New(tensor.WithShape(test.conditionShape...), tensor.WithBacking(test.condition)), + tensor.New(tensor.WithShape(test.backing1Shape...), tensor.WithBacking(test.backing1)), + tensor.New(tensor.WithShape(test.backing2Shape...), tensor.WithBacking(test.backing2)), + } + + op := whereVersions[test.version]() + + res, err := op.Apply(inputs) + assert.Nil(t, err) + assert.Equal(t, test.expectedBacking, res[0].Data()) + } +} diff --git a/ops_test.go b/ops_test.go index 9409ec1..d065549 100644 --- a/ops_test.go +++ b/ops_test.go @@ -555,6 +555,8 @@ var expectedTests = []string{ "test_unsqueeze_three_axes", "test_unsqueeze_two_axes", "test_unsqueeze_unsorted_axes", + "test_where_example", + "test_where_long_example", "test_xor_bcast3v1d", "test_xor_bcast3v2d", "test_xor_bcast4v2d", diff --git a/opset.go b/opset.go index 64a4415..ce842b1 100644 --- a/opset.go +++ b/opset.go @@ -61,6 +61,7 @@ import ( "github.com/advancedclimatesystems/gonnx/ops/tanh" "github.com/advancedclimatesystems/gonnx/ops/transpose" "github.com/advancedclimatesystems/gonnx/ops/unsqueeze" + "github.com/advancedclimatesystems/gonnx/ops/where" "github.com/advancedclimatesystems/gonnx/ops/xor" ) @@ -132,6 +133,7 @@ var operators = map[string]ops.OperatorVersions{ "Tanh": tanh.GetTanhVersions(), "Transpose": transpose.GetTransposeVersions(), "Unsqueeze": unsqueeze.GetUnsqueezeVersions(), + "Where": where.GetVersions(), "Xor": xor.GetXorVersions(), }