Skip to content

Commit

Permalink
Fix stateful expressions in UDFs (#5093)
Browse files Browse the repository at this point in the history
This commit fixes an issue with using aggregation expressions user-defined
functions where there wasn't a separate state per textual invocation.

Closes #5092
  • Loading branch information
mattnibs authored Jun 19, 2024
1 parent d217ed8 commit 9c3f5fd
Show file tree
Hide file tree
Showing 3 changed files with 54 additions and 37 deletions.
25 changes: 23 additions & 2 deletions compiler/kernel/expr.go
Original file line number Diff line number Diff line change
Expand Up @@ -315,8 +315,13 @@ func (b *Builder) compileCall(call dag.Call) (expr.Evaluator, error) {
var path field.Path
// First check if call is to a user defined function, otherwise check for
// builtin function.
fn, ok := b.funcs[call.Name]
if !ok {
var fn expr.Function
if e, ok := b.udfs[call.Name]; ok {
var err error
if fn, err = b.compileUDFCall(call.Name, e); err != nil {
return nil, err
}
} else {
var err error
fn, path, err = function.New(b.zctx(), call.Name, len(call.Args))
if err != nil {
Expand All @@ -335,6 +340,22 @@ func (b *Builder) compileCall(call dag.Call) (expr.Evaluator, error) {
return expr.NewCall(b.zctx(), fn, exprs), nil
}

func (b *Builder) compileUDFCall(name string, body dag.Expr) (expr.Function, error) {
if fn, ok := b.compiledUDFs[name]; ok {
return fn, nil
}
fn := &expr.UDF{}
// We store compiled UDF calls here so as to avoid stack overflows on
// recursive calls.
b.compiledUDFs[name] = fn
var err error
if fn.Body, err = b.compileExpr(body); err != nil {
return nil, err
}
delete(b.compiledUDFs, name)
return fn, nil
}

func (b *Builder) compileMapCall(a *dag.MapCall) (expr.Evaluator, error) {
e, err := b.compileExpr(a.Expr)
if err != nil {
Expand Down
57 changes: 22 additions & 35 deletions compiler/kernel/op.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ import (
"context"
"errors"
"fmt"
"maps"
"slices"
"strings"
"sync"
Expand Down Expand Up @@ -50,15 +51,16 @@ import (
var ErrJoinParents = errors.New("join requires two upstream parallel query paths")

type Builder struct {
rctx *runtime.Context
mctx *zed.Context
source *data.Source
readers []zio.Reader
progress *zbuf.Progress
arena *zed.Arena // For zed.Values created during compilation.
deletes *sync.Map
funcs map[string]expr.Function
resetters expr.Resetters
rctx *runtime.Context
mctx *zed.Context
source *data.Source
readers []zio.Reader
progress *zbuf.Progress
arena *zed.Arena // For zed.Values created during compilation.
deletes *sync.Map
udfs map[string]dag.Expr
compiledUDFs map[string]*expr.UDF
resetters expr.Resetters
}

func NewBuilder(rctx *runtime.Context, source *data.Source) *Builder {
Expand All @@ -74,8 +76,9 @@ func NewBuilder(rctx *runtime.Context, source *data.Source) *Builder {
RecordsRead: 0,
RecordsMatched: 0,
},
arena: arena,
funcs: make(map[string]expr.Function),
arena: arena,
udfs: make(map[string]dag.Expr),
compiledUDFs: make(map[string]*expr.UDF),
}
}

Expand Down Expand Up @@ -459,11 +462,14 @@ func (b *Builder) compileSeq(seq dag.Seq, parents []zbuf.Puller) ([]zbuf.Puller,
}

func (b *Builder) compileScope(scope *dag.Scope, parents []zbuf.Puller) ([]zbuf.Puller, error) {
// XXX We need to fix how udfs are compiled since there is currently a bug
// where aggregation expressions in udfs do not have separate state per
// invocation.
if err := b.compileFuncs(scope.Funcs); err != nil {
return nil, err
// Because there can be name collisions between a child and parent scope
// we clone the current udf map, populate the cloned map, then restore the
// old scope once the current scope has been built.
parentUDFs := b.udfs
b.udfs = maps.Clone(parentUDFs)
defer func() { b.udfs = parentUDFs }()
for _, f := range scope.Funcs {
b.udfs[f.Name] = f.Expr
}
return b.compileSeq(scope.Body, parents)
}
Expand Down Expand Up @@ -510,25 +516,6 @@ func (b *Builder) compileScatter(par *dag.Scatter, parents []zbuf.Puller) ([]zbu
return ops, nil
}

func (b *Builder) compileFuncs(fns []*dag.Func) error {
udfs := make([]*expr.UDF, 0, len(fns))
for _, f := range fns {
if _, ok := b.funcs[f.Name]; ok {
return fmt.Errorf("internal error: func %q declared twice", f.Name)
}
u := &expr.UDF{}
b.funcs[f.Name] = u
udfs = append(udfs, u)
}
for i := range fns {
var err error
if udfs[i].Body, err = b.compileExpr(fns[i].Expr); err != nil {
return err
}
}
return nil
}

func (b *Builder) compileExprSwitch(swtch *dag.Switch, parents []zbuf.Puller) ([]zbuf.Puller, error) {
parent := parents[0]
if len(parents) > 1 {
Expand Down
9 changes: 9 additions & 0 deletions runtime/sam/expr/ztests/udf-stateful-expr.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
zed: |
func c1(): ( count() )
func c2(): ( c1()+c1() )
yield [c2(),c2(),c2()]
input: 'null'

output: |
[2(uint64),2(uint64),2(uint64)]

0 comments on commit 9c3f5fd

Please sign in to comment.