diff --git a/compiler/kernel/expr.go b/compiler/kernel/expr.go index 0f5659f8af..aa7c50b03b 100644 --- a/compiler/kernel/expr.go +++ b/compiler/kernel/expr.go @@ -315,8 +315,13 @@ func (b *Builder) compileCall(call dag.Call) (expr.Evaluator, error) { var path field.Path // First check if call is to a user defined function, otherwise check for // builtin function. - fn, ok := b.funcs[call.Name] - if !ok { + var fn expr.Function + if e, ok := b.udfs[call.Name]; ok { + var err error + if fn, err = b.compileUDFCall(call.Name, e); err != nil { + return nil, err + } + } else { var err error fn, path, err = function.New(b.zctx(), call.Name, len(call.Args)) if err != nil { @@ -335,6 +340,22 @@ func (b *Builder) compileCall(call dag.Call) (expr.Evaluator, error) { return expr.NewCall(b.zctx(), fn, exprs), nil } +func (b *Builder) compileUDFCall(name string, body dag.Expr) (expr.Function, error) { + if fn, ok := b.compiledUDFs[name]; ok { + return fn, nil + } + fn := &expr.UDF{} + // We store compiled UDF calls here so as to avoid stack overflows on + // recursive calls. + b.compiledUDFs[name] = fn + var err error + if fn.Body, err = b.compileExpr(body); err != nil { + return nil, err + } + delete(b.compiledUDFs, name) + return fn, nil +} + func (b *Builder) compileMapCall(a *dag.MapCall) (expr.Evaluator, error) { e, err := b.compileExpr(a.Expr) if err != nil { diff --git a/compiler/kernel/op.go b/compiler/kernel/op.go index 018210b503..313f7c3419 100644 --- a/compiler/kernel/op.go +++ b/compiler/kernel/op.go @@ -4,6 +4,7 @@ import ( "context" "errors" "fmt" + "maps" "slices" "strings" "sync" @@ -50,15 +51,16 @@ import ( var ErrJoinParents = errors.New("join requires two upstream parallel query paths") type Builder struct { - rctx *runtime.Context - mctx *zed.Context - source *data.Source - readers []zio.Reader - progress *zbuf.Progress - arena *zed.Arena // For zed.Values created during compilation. - deletes *sync.Map - funcs map[string]expr.Function - resetters expr.Resetters + rctx *runtime.Context + mctx *zed.Context + source *data.Source + readers []zio.Reader + progress *zbuf.Progress + arena *zed.Arena // For zed.Values created during compilation. + deletes *sync.Map + udfs map[string]dag.Expr + compiledUDFs map[string]*expr.UDF + resetters expr.Resetters } func NewBuilder(rctx *runtime.Context, source *data.Source) *Builder { @@ -74,8 +76,9 @@ func NewBuilder(rctx *runtime.Context, source *data.Source) *Builder { RecordsRead: 0, RecordsMatched: 0, }, - arena: arena, - funcs: make(map[string]expr.Function), + arena: arena, + udfs: make(map[string]dag.Expr), + compiledUDFs: make(map[string]*expr.UDF), } } @@ -459,11 +462,14 @@ func (b *Builder) compileSeq(seq dag.Seq, parents []zbuf.Puller) ([]zbuf.Puller, } func (b *Builder) compileScope(scope *dag.Scope, parents []zbuf.Puller) ([]zbuf.Puller, error) { - // XXX We need to fix how udfs are compiled since there is currently a bug - // where aggregation expressions in udfs do not have separate state per - // invocation. - if err := b.compileFuncs(scope.Funcs); err != nil { - return nil, err + // Because there can be name collisions between a child and parent scope + // we clone the current udf map, populate the cloned map, then restore the + // old scope once the current scope has been built. + parentUDFs := b.udfs + b.udfs = maps.Clone(parentUDFs) + defer func() { b.udfs = parentUDFs }() + for _, f := range scope.Funcs { + b.udfs[f.Name] = f.Expr } return b.compileSeq(scope.Body, parents) } @@ -510,25 +516,6 @@ func (b *Builder) compileScatter(par *dag.Scatter, parents []zbuf.Puller) ([]zbu return ops, nil } -func (b *Builder) compileFuncs(fns []*dag.Func) error { - udfs := make([]*expr.UDF, 0, len(fns)) - for _, f := range fns { - if _, ok := b.funcs[f.Name]; ok { - return fmt.Errorf("internal error: func %q declared twice", f.Name) - } - u := &expr.UDF{} - b.funcs[f.Name] = u - udfs = append(udfs, u) - } - for i := range fns { - var err error - if udfs[i].Body, err = b.compileExpr(fns[i].Expr); err != nil { - return err - } - } - return nil -} - func (b *Builder) compileExprSwitch(swtch *dag.Switch, parents []zbuf.Puller) ([]zbuf.Puller, error) { parent := parents[0] if len(parents) > 1 { diff --git a/runtime/sam/expr/ztests/udf-stateful-expr.yaml b/runtime/sam/expr/ztests/udf-stateful-expr.yaml new file mode 100644 index 0000000000..b10cdd318b --- /dev/null +++ b/runtime/sam/expr/ztests/udf-stateful-expr.yaml @@ -0,0 +1,9 @@ +zed: | + func c1(): ( count() ) + func c2(): ( c1()+c1() ) + yield [c2(),c2(),c2()] + +input: 'null' + +output: | + [2(uint64),2(uint64),2(uint64)]