diff --git a/compiler/ast/dag/expr.go b/compiler/ast/dag/expr.go index a735d4133a..5c0df48cd8 100644 --- a/compiler/ast/dag/expr.go +++ b/compiler/ast/dag/expr.go @@ -62,6 +62,11 @@ type ( Kind string `json:"kind" unpack:""` Value string `json:"value"` } + MapCall struct { + Kind string `json:"kind" unpack:""` + Expr Expr `json:"expr"` + Inner Expr `json:"inner"` + } MapExpr struct { Kind string `json:"kind" unpack:""` Entries []Entry `json:"entries"` @@ -121,6 +126,7 @@ func (*Conditional) ExprDAG() {} func (*Dot) ExprDAG() {} func (*Func) ExprDAG() {} func (*Literal) ExprDAG() {} +func (*MapCall) ExprDAG() {} func (*MapExpr) ExprDAG() {} func (*OverExpr) ExprDAG() {} func (*RecordExpr) ExprDAG() {} diff --git a/compiler/ast/dag/unpack.go b/compiler/ast/dag/unpack.go index 0a03034870..cf781d61c2 100644 --- a/compiler/ast/dag/unpack.go +++ b/compiler/ast/dag/unpack.go @@ -31,6 +31,7 @@ var unpacker = unpack.New( Lister{}, Literal{}, Load{}, + MapCall{}, MapExpr{}, Merge{}, Over{}, diff --git a/compiler/kernel/expr.go b/compiler/kernel/expr.go index 3ae2ee0bc8..399911b3b7 100644 --- a/compiler/kernel/expr.go +++ b/compiler/kernel/expr.go @@ -83,6 +83,8 @@ func (b *Builder) compileExpr(e dag.Expr) (expr.Evaluator, error) { return b.compileArrayExpr(e) case *dag.SetExpr: return b.compileSetExpr(e) + case *dag.MapCall: + return b.compileMapCall(e) case *dag.MapExpr: return b.compileMapExpr(e) case *dag.Agg: @@ -309,6 +311,18 @@ func (b *Builder) compileCall(call dag.Call) (expr.Evaluator, error) { return expr.NewCall(b.zctx(), fn, exprs), nil } +func (b *Builder) compileMapCall(a *dag.MapCall) (expr.Evaluator, error) { + e, err := b.compileExpr(a.Expr) + if err != nil { + return nil, err + } + inner, err := b.compileExpr(a.Inner) + if err != nil { + return nil, err + } + return expr.NewMapCall(b.zctx(), e, inner), nil +} + func (b *Builder) compileShaper(node dag.Call, tf expr.ShaperTransform) (expr.Evaluator, error) { args := node.Args field, err := b.compileExpr(args[0]) diff --git a/compiler/semantic/expr.go b/compiler/semantic/expr.go index bf705ce23e..89546c0bce 100644 --- a/compiler/semantic/expr.go +++ b/compiler/semantic/expr.go @@ -506,6 +506,28 @@ func (a *analyzer) semCall(call *ast.Call) (dag.Expr, error) { if nargs == 1 { exprs = append([]dag.Expr{&dag.This{Kind: "This"}}, exprs...) } + case name == "map": + if err := function.CheckArgCount(nargs, 2, 2); err != nil { + return nil, fmt.Errorf("%s(): %w", name, err) + } + id, ok := call.Args[1].(*ast.ID) + if !ok { + return nil, fmt.Errorf("%s(): second argument must be the identifier of a function", name) + } + // Validate that the called func takes a single argument + inner, err := a.semCall(&ast.Call{ + Kind: "Call", + Name: id.Name, + Args: []ast.Expr{&ast.ID{Kind: "ID", Name: "this"}}, + }) + if err != nil { + return nil, err + } + return &dag.MapCall{ + Kind: "MapCall", + Expr: exprs[0], + Inner: inner, + }, nil default: if _, _, err = function.New(a.zctx, name, nargs); err != nil { return nil, fmt.Errorf("%s(): %w", name, err) diff --git a/docs/language/aggregates/map.md b/docs/language/aggregates/collect_map.md similarity index 63% rename from docs/language/aggregates/map.md rename to docs/language/aggregates/collect_map.md index b10d6bf9a1..9bae505f5c 100644 --- a/docs/language/aggregates/map.md +++ b/docs/language/aggregates/collect_map.md @@ -1,16 +1,16 @@ ### Aggregate Function -  **map** — aggregate map values into a single map +  **collect_map** — aggregate map values into a single map ### Synopsis ``` -map(|{any:any}|) -> |{any:any}| +collect_map(|{any:any}|) -> |{any:any}| ``` ### Description -The _map_ aggregate function combines map inputs into a single map output. -If _map_ receives multiple values for the same key, the last value received is +The _collect_map_ aggregate function combines map inputs into a single map output. +If _collect_map_ receives multiple values for the same key, the last value received is retained. If the input keys or values vary in type, the return type will be a map of union of those types. @@ -18,7 +18,7 @@ of union of those types. Combine a sequence of records into a map: ```mdtest-command -echo '{stock:"APPL",price:145.03} {stock:"GOOG",price:87.07}' | zq -z 'map(|{stock:price}|)' - +echo '{stock:"APPL",price:145.03} {stock:"GOOG",price:87.07}' | zq -z 'collect_map(|{stock:price}|)' - ``` => ```mdtest-output @@ -27,7 +27,7 @@ echo '{stock:"APPL",price:145.03} {stock:"GOOG",price:87.07}' | zq -z 'map(|{sto Continuous collection over a simple sequence: ```mdtest-command -echo '|{"APPL":145.03}| |{"GOOG":87.07}| |{"APPL":150.13}|' | zq -z 'yield map(this)' - +echo '|{"APPL":145.03}| |{"GOOG":87.07}| |{"APPL":150.13}|' | zq -z 'yield collect_map(this)' - ``` => ```mdtest-output diff --git a/runtime/expr/agg/agg.go b/runtime/expr/agg/agg.go index c491aa3d53..2cecb7a639 100644 --- a/runtime/expr/agg/agg.go +++ b/runtime/expr/agg/agg.go @@ -53,9 +53,9 @@ func NewPattern(op string, hasarg bool) (Pattern, error) { pattern = func() Function { return newMathReducer(anymath.Add) } - case "map": + case "collect_map": pattern = func() Function { - return newMap() + return newCollectMap() } case "min": pattern = func() Function { diff --git a/runtime/expr/agg/map.go b/runtime/expr/agg/map.go index a4440e8e5c..3098fa6386 100644 --- a/runtime/expr/agg/map.go +++ b/runtime/expr/agg/map.go @@ -6,13 +6,13 @@ import ( "golang.org/x/exp/slices" ) -type Map struct { +type CollectMap struct { entries map[string]mapEntry scratch []byte } -func newMap() *Map { - return &Map{entries: make(map[string]mapEntry)} +func newCollectMap() *CollectMap { + return &CollectMap{entries: make(map[string]mapEntry)} } var _ Function = (*Collect)(nil) @@ -22,7 +22,7 @@ type mapEntry struct { val *zed.Value } -func (m *Map) Consume(val *zed.Value) { +func (m *CollectMap) Consume(val *zed.Value) { if val.IsNull() { return } @@ -43,11 +43,11 @@ func (m *Map) Consume(val *zed.Value) { } } -func (m *Map) ConsumeAsPartial(val *zed.Value) { +func (m *CollectMap) ConsumeAsPartial(val *zed.Value) { m.Consume(val) } -func (m *Map) Result(zctx *zed.Context) *zed.Value { +func (m *CollectMap) Result(zctx *zed.Context) *zed.Value { if len(m.entries) == 0 { return zed.Null } @@ -71,7 +71,7 @@ func (m *Map) Result(zctx *zed.Context) *zed.Value { return zed.NewValue(typ, b) } -func (m *Map) ResultAsPartial(zctx *zed.Context) *zed.Value { +func (m *CollectMap) ResultAsPartial(zctx *zed.Context) *zed.Value { return m.Result(zctx) } diff --git a/runtime/expr/agg/ztests/map-union.yaml b/runtime/expr/agg/ztests/collect-map-union.yaml similarity index 93% rename from runtime/expr/agg/ztests/map-union.yaml rename to runtime/expr/agg/ztests/collect-map-union.yaml index 325f992123..52c2935c81 100644 --- a/runtime/expr/agg/ztests/map-union.yaml +++ b/runtime/expr/agg/ztests/collect-map-union.yaml @@ -1,4 +1,4 @@ -zed: "map(|{k:v}|)" +zed: "collect_map(|{k:v}|)" output-flags: -pretty 2 diff --git a/runtime/expr/agg/ztests/map.yaml b/runtime/expr/agg/ztests/collect-map.yaml similarity index 89% rename from runtime/expr/agg/ztests/map.yaml rename to runtime/expr/agg/ztests/collect-map.yaml index 33f5a0e47a..0b1f0c6714 100644 --- a/runtime/expr/agg/ztests/map.yaml +++ b/runtime/expr/agg/ztests/collect-map.yaml @@ -1,4 +1,4 @@ -zed: "map(|{stock: price}|)" +zed: "collect_map(|{stock: price}|)" output-flags: -pretty 2 diff --git a/runtime/expr/map.go b/runtime/expr/map.go new file mode 100644 index 0000000000..3f40ef946d --- /dev/null +++ b/runtime/expr/map.go @@ -0,0 +1,71 @@ +package expr + +import ( + "github.com/brimdata/zed" + "github.com/brimdata/zed/zcode" +) + +type mapCall struct { + builder zcode.Builder + eval Evaluator + inner Evaluator + zctx *zed.Context + + // vals is used to reduce allocations + vals []zed.Value + // types is used to reduce allocations + types []zed.Type +} + +func NewMapCall(zctx *zed.Context, e, inner Evaluator) Evaluator { + return &mapCall{eval: e, inner: inner, zctx: zctx} +} + +func (a *mapCall) Eval(ectx Context, in *zed.Value) *zed.Value { + v := a.eval.Eval(ectx, in) + if v.IsError() { + return v + } + elems, err := v.Elements() + if err != nil { + return ectx.CopyValue(*a.zctx.WrapError(err.Error(), in)) + } + if len(elems) == 0 { + return v + } + a.vals = a.vals[:0] + a.types = a.types[:0] + for _, elem := range elems { + out := a.inner.Eval(ectx, &elem) + a.vals = append(a.vals, *out) + a.types = append(a.types, out.Type) + } + inner := a.innerType(a.types) + bytes := a.buildVal(inner, a.vals) + if _, ok := zed.TypeUnder(in.Type).(*zed.TypeSet); ok { + return ectx.NewValue(a.zctx.LookupTypeSet(inner), zed.NormalizeSet(bytes)) + } + return ectx.NewValue(a.zctx.LookupTypeArray(inner), bytes) +} + +func (a *mapCall) buildVal(inner zed.Type, vals []zed.Value) []byte { + a.builder.Reset() + if union, ok := inner.(*zed.TypeUnion); ok { + for _, val := range a.vals { + zed.BuildUnion(&a.builder, union.TagOf(val.Type), val.Bytes()) + } + } else { + for _, val := range a.vals { + a.builder.Append(val.Bytes()) + } + } + return a.builder.Bytes() +} + +func (a *mapCall) innerType(types []zed.Type) zed.Type { + types = zed.UniqueTypes(types) + if len(types) == 1 { + return types[0] + } + return a.zctx.LookupTypeUnion(types) +} diff --git a/runtime/expr/ztests/map.yaml b/runtime/expr/ztests/map.yaml new file mode 100644 index 0000000000..de501a45c5 --- /dev/null +++ b/runtime/expr/ztests/map.yaml @@ -0,0 +1,8 @@ +script: | + echo '{a:["foo","bar","baz"]}' | zq -z 'a := map(a,upper)' - + echo '["1","2","3"]' | zq -z 'yield map(this,int64)' - +outputs: + - name: stdout + data: | + {a:["FOO","BAR","BAZ"]} + [1,2,3]