From ef76d4d2f66e31bf5d668bd71b14e7279befd820 Mon Sep 17 00:00:00 2001 From: Noah Treuhaft Date: Wed, 18 Dec 2024 13:16:57 -0500 Subject: [PATCH] Implement vector switch operator (#5538) --- compiler/kernel/vop.go | 62 +++++++++++- runtime/vam/op/exprswitch.go | 59 ++++++++++++ runtime/vam/op/router.go | 2 +- runtime/vam/op/swtich.go | 94 +++++++++++++++++++ .../ztests => ztests/op}/switch-chained.yaml | 2 + .../ztests => ztests/op}/switch-default.yaml | 2 + .../ztests => ztests/op}/switch-error.yaml | 4 + .../op/switch-expr-default.yaml} | 2 + .../op/switch-expr-done.yaml} | 2 + .../op/switch-expr-over.yaml} | 2 + .../op/switch-expr.yaml} | 2 + .../ztests => ztests/op}/switch-over.yaml | 2 + .../switcher/ztests => ztests/op}/switch.yaml | 2 + 13 files changed, 232 insertions(+), 5 deletions(-) create mode 100644 runtime/vam/op/exprswitch.go create mode 100644 runtime/vam/op/swtich.go rename runtime/{sam/op/switcher/ztests => ztests/op}/switch-chained.yaml (94%) rename runtime/{sam/op/switcher/ztests => ztests/op}/switch-default.yaml (95%) rename runtime/{sam/op/switcher/ztests => ztests/op}/switch-error.yaml (73%) rename runtime/{sam/op/exprswitch/ztests/switch-default.yaml => ztests/op/switch-expr-default.yaml} (95%) rename runtime/{sam/op/exprswitch/ztests/switch-done.yaml => ztests/op/switch-expr-done.yaml} (95%) rename runtime/{sam/op/exprswitch/ztests/switch-over.yaml => ztests/op/switch-expr-over.yaml} (94%) rename runtime/{sam/op/exprswitch/ztests/switch.yaml => ztests/op/switch-expr.yaml} (95%) rename runtime/{sam/op/switcher/ztests => ztests/op}/switch-over.yaml (94%) rename runtime/{sam/op/switcher/ztests => ztests/op}/switch.yaml (96%) diff --git a/compiler/kernel/vop.go b/compiler/kernel/vop.go index cb7cb4ef0b..f6add58800 100644 --- a/compiler/kernel/vop.go +++ b/compiler/kernel/vop.go @@ -4,6 +4,7 @@ import ( "errors" "fmt" + "github.com/brimdata/super" "github.com/brimdata/super/compiler/dag" "github.com/brimdata/super/compiler/optimizer" "github.com/brimdata/super/pkg/field" @@ -39,10 +40,10 @@ func (b *Builder) compileVam(o dag.Op, parents []vector.Puller) ([]vector.Puller case *dag.Scope: //return b.compileVecScope(o, parents) case *dag.Switch: - //if o.Expr != nil { - // return b.compileVamExprSwitch(o, parents) - //} - //return b.compileVecSwitch(o, parents) + if o.Expr != nil { + return b.compileVamExprSwitch(o, parents) + } + return b.compileVamSwitch(o, parents) default: var parent vector.Puller if len(parents) == 1 { @@ -114,6 +115,59 @@ func (b *Builder) compileVamScatter(scatter *dag.Scatter, parents []vector.Pulle return ops, nil } +func (b *Builder) compileVamExprSwitch(swtch *dag.Switch, parents []vector.Puller) ([]vector.Puller, error) { + parent := parents[0] + if len(parents) > 1 { + parent = vamop.NewCombine(b.rctx, parents) + } + e, err := b.compileVamExpr(swtch.Expr) + if err != nil { + return nil, err + } + s := vamop.NewExprSwitch(b.rctx, parent, e) + var exits []vector.Puller + for _, c := range swtch.Cases { + var val *super.Value + if c.Expr != nil { + val2, err := b.evalAtCompileTime(c.Expr) + if err != nil { + return nil, err + } + if val2.IsError() { + return nil, errors.New("switch case is not a constant expression") + } + val = &val2 + } + parents, err := b.compileVamSeq(c.Path, []vector.Puller{s.AddCase(val)}) + if err != nil { + return nil, err + } + exits = append(exits, parents...) + } + return exits, nil +} + +func (b *Builder) compileVamSwitch(swtch *dag.Switch, parents []vector.Puller) ([]vector.Puller, error) { + parent := parents[0] + if len(parents) > 1 { + parent = vamop.NewCombine(b.rctx, parents) + } + s := vamop.NewSwitch(b.rctx, parent) + var exits []vector.Puller + for _, c := range swtch.Cases { + e, err := b.compileVamExpr(c.Expr) + if err != nil { + return nil, fmt.Errorf("compiling switch case filter: %w", err) + } + exit, err := b.compileVamSeq(c.Path, []vector.Puller{s.AddCase(e)}) + if err != nil { + return nil, err + } + exits = append(exits, exit...) + } + return exits, nil +} + func (b *Builder) compileVamLeaf(o dag.Op, parent vector.Puller) (vector.Puller, error) { switch o := o.(type) { case *dag.Cut: diff --git a/runtime/vam/op/exprswitch.go b/runtime/vam/op/exprswitch.go new file mode 100644 index 0000000000..e9e0f3f7da --- /dev/null +++ b/runtime/vam/op/exprswitch.go @@ -0,0 +1,59 @@ +package op + +import ( + "context" + + "github.com/brimdata/super" + "github.com/brimdata/super/runtime/vam/expr" + "github.com/brimdata/super/vector" + "github.com/brimdata/super/zcode" +) + +type ExprSwitch struct { + expr expr.Evaluator + router *router + + builder zcode.Builder + cases map[string]*route + caseIndexes map[*route][]uint32 + defaultRoute *route +} + +func NewExprSwitch(ctx context.Context, parent vector.Puller, e expr.Evaluator) *ExprSwitch { + s := &ExprSwitch{expr: e, cases: map[string]*route{}, caseIndexes: map[*route][]uint32{}} + s.router = newRouter(ctx, s, parent) + return s +} + +func (s *ExprSwitch) AddCase(val *super.Value) vector.Puller { + r := s.router.addRoute() + if val == nil { + s.defaultRoute = r + } else { + s.cases[string(val.Bytes())] = r + } + return r +} + +func (s *ExprSwitch) forward(vec vector.Any) bool { + defer clear(s.caseIndexes) + exprVec := s.expr.Eval(vec) + for i := range exprVec.Len() { + s.builder.Truncate() + exprVec.Serialize(&s.builder, i) + route, ok := s.cases[string(s.builder.Bytes().Body())] + if !ok { + route = s.defaultRoute + } + if route != nil { + s.caseIndexes[route] = append(s.caseIndexes[route], i) + } + } + for route, index := range s.caseIndexes { + view := vector.NewView(vec, index) + if !route.send(view, nil) { + return false + } + } + return true +} diff --git a/runtime/vam/op/router.go b/runtime/vam/op/router.go index c7c4e6f200..d45d9f254a 100644 --- a/runtime/vam/op/router.go +++ b/runtime/vam/op/router.go @@ -25,7 +25,7 @@ func newRouter(ctx context.Context, f forwarder, parent vector.Puller) *router { return &router{ctx: ctx, forwarder: f, parent: parent} } -func (r *router) addRoute() vector.Puller { +func (r *router) addRoute() *route { route := &route{r, make(chan result), make(chan struct{}), false} r.routes = append(r.routes, route) return route diff --git a/runtime/vam/op/swtich.go b/runtime/vam/op/swtich.go new file mode 100644 index 0000000000..67fda9c23b --- /dev/null +++ b/runtime/vam/op/swtich.go @@ -0,0 +1,94 @@ +package op + +import ( + "context" + + "github.com/RoaringBitmap/roaring" + "github.com/brimdata/super" + "github.com/brimdata/super/runtime/vam/expr" + "github.com/brimdata/super/vector" +) + +type Switch struct { + router *router + cases []expr.Evaluator +} + +func NewSwitch(ctx context.Context, parent vector.Puller) *Switch { + s := &Switch{} + s.router = newRouter(ctx, s, parent) + return s +} + +func (s *Switch) AddCase(e expr.Evaluator) vector.Puller { + s.cases = append(s.cases, e) + return s.router.addRoute() +} + +func (s *Switch) forward(vec vector.Any) bool { + doneMap := roaring.New() + for i, c := range s.cases { + maskVec := c.Eval(vec) + boolMap, errMap := expr.BoolMask(maskVec) + boolMap.AndNot(doneMap) + errMap.AndNot(doneMap) + doneMap.Or(boolMap) + if !errMap.IsEmpty() { + // Clone because iteration results are undefined if the bitmap is modified. + for it := errMap.Clone().Iterator(); it.HasNext(); { + i := it.Next() + if isErrorMissing(maskVec, i) { + errMap.Remove(i) + } + } + } + var vec2 vector.Any + if errMap.IsEmpty() { + if boolMap.IsEmpty() { + continue + } + vec2 = vector.NewView(vec, boolMap.ToArray()) + } else if boolMap.IsEmpty() { + vec2 = vector.NewView(maskVec, errMap.ToArray()) + } else { + valIndex := boolMap.ToArray() + errIndex := errMap.ToArray() + tags := make([]uint32, 0, len(valIndex)+len(errIndex)) + for len(valIndex) > 0 && len(errIndex) > 0 { + if valIndex[0] < errIndex[0] { + valIndex = valIndex[1:] + tags = append(tags, 0) + } else { + errIndex = errIndex[1:] + tags = append(tags, 1) + } + } + tags = append(tags, valIndex...) + tags = append(tags, errIndex...) + valVec := vector.NewView(vec, valIndex) + errVec := vector.NewView(maskVec, errIndex) + vec2 = vector.NewDynamic(tags, []vector.Any{valVec, errVec}) + } + if !s.router.routes[i].send(vec2, nil) { + return false + } + } + return true +} + +func isErrorMissing(vec vector.Any, i uint32) bool { + vec = vector.Under(vec) + if dynVec, ok := vec.(*vector.Dynamic); ok { + vec = dynVec.Values[dynVec.Tags[i]] + i = dynVec.TagMap.Forward[i] + } + errVec, ok := vec.(*vector.Error) + if !ok { + return false + } + if errVec.Vals.Type().ID() != super.IDString { + return false + } + s, null := vector.StringValue(errVec.Vals, i) + return !null && s == string(super.Missing) +} diff --git a/runtime/sam/op/switcher/ztests/switch-chained.yaml b/runtime/ztests/op/switch-chained.yaml similarity index 94% rename from runtime/sam/op/switcher/ztests/switch-chained.yaml rename to runtime/ztests/op/switch-chained.yaml index 6e9c2ed883..1d553291e0 100644 --- a/runtime/sam/op/switcher/ztests/switch-chained.yaml +++ b/runtime/ztests/op/switch-chained.yaml @@ -12,6 +12,8 @@ zed: | case this==3 => yield 4 ) +vector: true + input: | 1 diff --git a/runtime/sam/op/switcher/ztests/switch-default.yaml b/runtime/ztests/op/switch-default.yaml similarity index 95% rename from runtime/sam/op/switcher/ztests/switch-default.yaml rename to runtime/ztests/op/switch-default.yaml index 4eb8ddb4e5..bee3578365 100644 --- a/runtime/sam/op/switcher/ztests/switch-default.yaml +++ b/runtime/ztests/op/switch-default.yaml @@ -6,6 +6,8 @@ zed: | default => count:=count() |> put a:=-1 ) |> sort a +vector: true + input: | {a:1,s:"a"} {a:2,s:"B"} diff --git a/runtime/sam/op/switcher/ztests/switch-error.yaml b/runtime/ztests/op/switch-error.yaml similarity index 73% rename from runtime/sam/op/switcher/ztests/switch-error.yaml rename to runtime/ztests/op/switch-error.yaml index 0dabdedf34..35299b2b00 100644 --- a/runtime/sam/op/switcher/ztests/switch-error.yaml +++ b/runtime/ztests/op/switch-error.yaml @@ -2,8 +2,11 @@ zed: | switch ( case a == 1 => put v:='one' case a / 0 => put v:='xxx' + case a % 0 => put v:='yyy' ) |> sort this +vector: true + input: | {a:1,s:"a"} {a:2,s:"b"} @@ -11,3 +14,4 @@ input: | output: | {a:1,s:"a",v:"one"} error("divide by zero") + error("divide by zero") diff --git a/runtime/sam/op/exprswitch/ztests/switch-default.yaml b/runtime/ztests/op/switch-expr-default.yaml similarity index 95% rename from runtime/sam/op/exprswitch/ztests/switch-default.yaml rename to runtime/ztests/op/switch-expr-default.yaml index 057797f8a0..a299e9b605 100644 --- a/runtime/sam/op/exprswitch/ztests/switch-default.yaml +++ b/runtime/ztests/op/switch-expr-default.yaml @@ -6,6 +6,8 @@ zed: | default => count:=count() |> put a:=-1 ) |> sort a +vector: true + input: | {a:1,s:"a"} {a:2,s:"B"} diff --git a/runtime/sam/op/exprswitch/ztests/switch-done.yaml b/runtime/ztests/op/switch-expr-done.yaml similarity index 95% rename from runtime/sam/op/exprswitch/ztests/switch-done.yaml rename to runtime/ztests/op/switch-expr-done.yaml index 410d65871d..a3ec18a1c3 100644 --- a/runtime/sam/op/exprswitch/ztests/switch-done.yaml +++ b/runtime/ztests/op/switch-expr-done.yaml @@ -5,6 +5,8 @@ zed: | default => pass ) |> sort b +vector: true + input: | {a:1,b:1} {a:2,b:2} diff --git a/runtime/sam/op/exprswitch/ztests/switch-over.yaml b/runtime/ztests/op/switch-expr-over.yaml similarity index 94% rename from runtime/sam/op/exprswitch/ztests/switch-over.yaml rename to runtime/ztests/op/switch-expr-over.yaml index f7464a92c2..12bdd39e6f 100644 --- a/runtime/sam/op/exprswitch/ztests/switch-over.yaml +++ b/runtime/ztests/op/switch-expr-over.yaml @@ -4,6 +4,8 @@ zed: | default => over a |> yield {b:this} ) |> sort this +vector: true + input: | {a:[1,2,3]} {a:[6,7,8,9]} diff --git a/runtime/sam/op/exprswitch/ztests/switch.yaml b/runtime/ztests/op/switch-expr.yaml similarity index 95% rename from runtime/sam/op/exprswitch/ztests/switch.yaml rename to runtime/ztests/op/switch-expr.yaml index aeca101369..5d3f7b8038 100644 --- a/runtime/sam/op/exprswitch/ztests/switch.yaml +++ b/runtime/ztests/op/switch-expr.yaml @@ -5,6 +5,8 @@ zed: | case 3 => ? null ) |> sort a +vector: true + input: | {a:1(int32),s:"a"} {a:2(int32),s:"B"} diff --git a/runtime/sam/op/switcher/ztests/switch-over.yaml b/runtime/ztests/op/switch-over.yaml similarity index 94% rename from runtime/sam/op/switcher/ztests/switch-over.yaml rename to runtime/ztests/op/switch-over.yaml index 08db1ee0f5..15e5e60dae 100644 --- a/runtime/sam/op/switcher/ztests/switch-over.yaml +++ b/runtime/ztests/op/switch-over.yaml @@ -4,6 +4,8 @@ zed: | default => over a |> yield {b:this} ) |> sort this +vector: true + input: | {a:[1,2,3]} {a:[6,7,8,9]} diff --git a/runtime/sam/op/switcher/ztests/switch.yaml b/runtime/ztests/op/switch.yaml similarity index 96% rename from runtime/sam/op/switcher/ztests/switch.yaml rename to runtime/ztests/op/switch.yaml index 835448ca87..99aec77a53 100644 --- a/runtime/sam/op/switcher/ztests/switch.yaml +++ b/runtime/ztests/op/switch.yaml @@ -6,6 +6,8 @@ zed: | case true => count:=count() |> put a:=-1 ) |> sort a +vector: true + input: | {a:1(int32),s:"a"} {a:2(int32),s:"B"}