From 17037ec72a94e0a9b6d042c5b7cedafddce4509e Mon Sep 17 00:00:00 2001 From: Julian Hyde Date: Mon, 20 Nov 2023 23:13:20 -0800 Subject: [PATCH] [MOREL-204] Add `take` and `skip` relational operators Now that `take` is a keyword, quote uses of `List.take` function in standard basis library. Push skip and take to Calcite. Fix pushing yield into Calcite (it only worked if it was the last step). Remove call to RelJson constructor deprecated in Calcite 1.36 Fixes #204 --- docs/reference.md | 6 +- .../java/net/hydromatic/morel/ast/Ast.java | 52 +++++++ .../net/hydromatic/morel/ast/AstBuilder.java | 8 + .../java/net/hydromatic/morel/ast/Core.java | 60 ++++++++ .../net/hydromatic/morel/ast/CoreBuilder.java | 8 + .../net/hydromatic/morel/ast/FromBuilder.java | 22 +++ .../java/net/hydromatic/morel/ast/Op.java | 2 + .../net/hydromatic/morel/ast/Shuttle.java | 16 ++ .../net/hydromatic/morel/ast/Visitor.java | 16 ++ .../morel/compile/CalciteCompiler.java | 68 ++++++--- .../hydromatic/morel/compile/Compiler.java | 12 +- .../hydromatic/morel/compile/Resolver.java | 10 ++ .../morel/compile/TypeResolver.java | 16 ++ .../java/net/hydromatic/morel/eval/Codes.java | 141 ++++++++++++++---- src/main/javacc/MorelParser.jj | 12 ++ .../net/hydromatic/morel/AlgebraTest.java | 28 ++++ .../java/net/hydromatic/morel/MainTest.java | 10 ++ src/test/resources/script/builtIn.smli | 16 +- src/test/resources/script/relational.smli | 35 +++++ src/test/resources/script/wordle.smli | 2 +- 20 files changed, 478 insertions(+), 62 deletions(-) diff --git a/docs/reference.md b/docs/reference.md index ab3e31bed..b703f4067 100644 --- a/docs/reference.md +++ b/docs/reference.md @@ -38,7 +38,7 @@ Contributions are welcome! In Morel but not Standard ML: * `from` expression with `in`, `suchthat`, `join`, `where`, `group`, - `compute`, `order`, `yield` clauses + `compute`, `order`, `skip`, `take`, `yield` clauses * `union`, `except`, `intersect`, `elem`, `notelem` operators * "*lab* `=`" is optional in `exprow` * identifiers may be quoted @@ -159,7 +159,9 @@ In Standard ML but not in Morel: compute clause (a > 1) | order orderItem1 , ... , orderItemo order clause (o ≥ 1) - | yield exp + | skip exp skip clause + | take exp take clause + | yield exp yield clause groupKey → [ id = ] exp agg → [ id = ] exp [ of exp ] orderItemexp [ desc ] diff --git a/src/main/java/net/hydromatic/morel/ast/Ast.java b/src/main/java/net/hydromatic/morel/ast/Ast.java index 484737420..d74ed1f2d 100644 --- a/src/main/java/net/hydromatic/morel/ast/Ast.java +++ b/src/main/java/net/hydromatic/morel/ast/Ast.java @@ -1707,6 +1707,58 @@ public Where copy(Exp exp) { } } + /** A {@code skip} clause in a {@code from} expression. */ + public static class Skip extends FromStep { + public final Exp exp; + + Skip(Pos pos, Exp exp) { + super(pos, Op.SKIP); + this.exp = exp; + } + + @Override AstWriter unparse(AstWriter w, int left, int right) { + return w.append(" skip ").append(exp, 0, 0); + } + + @Override public AstNode accept(Shuttle shuttle) { + return shuttle.visit(this); + } + + @Override public void accept(Visitor visitor) { + visitor.visit(this); + } + + public Skip copy(Exp exp) { + return this.exp.equals(exp) ? this : new Skip(pos, exp); + } + } + + /** A {@code take} clause in a {@code from} expression. */ + public static class Take extends FromStep { + public final Exp exp; + + Take(Pos pos, Exp exp) { + super(pos, Op.TAKE); + this.exp = exp; + } + + @Override AstWriter unparse(AstWriter w, int left, int right) { + return w.append(" take ").append(exp, 0, 0); + } + + @Override public AstNode accept(Shuttle shuttle) { + return shuttle.visit(this); + } + + @Override public void accept(Visitor visitor) { + visitor.visit(this); + } + + public Take copy(Exp exp) { + return this.exp.equals(exp) ? this : new Take(pos, exp); + } + } + /** A {@code yield} clause in a {@code from} expression. */ public static class Yield extends FromStep { public final Exp exp; diff --git a/src/main/java/net/hydromatic/morel/ast/AstBuilder.java b/src/main/java/net/hydromatic/morel/ast/AstBuilder.java index ac9493c1f..b46ab652e 100644 --- a/src/main/java/net/hydromatic/morel/ast/AstBuilder.java +++ b/src/main/java/net/hydromatic/morel/ast/AstBuilder.java @@ -483,6 +483,14 @@ public Ast.FromStep where(Pos pos, Ast.Exp exp) { return new Ast.Where(pos, exp); } + public Ast.FromStep skip(Pos pos, Ast.Exp exp) { + return new Ast.Skip(pos, exp); + } + + public Ast.FromStep take(Pos pos, Ast.Exp exp) { + return new Ast.Take(pos, exp); + } + public Ast.FromStep yield(Pos pos, Ast.Exp exp) { return new Ast.Yield(pos, exp); } diff --git a/src/main/java/net/hydromatic/morel/ast/Core.java b/src/main/java/net/hydromatic/morel/ast/Core.java index 74b8799c0..b1138bdf8 100644 --- a/src/main/java/net/hydromatic/morel/ast/Core.java +++ b/src/main/java/net/hydromatic/morel/ast/Core.java @@ -1238,6 +1238,66 @@ public Where copy(Exp exp, List bindings) { } } + /** A {@code skip} clause in a {@code from} expression. */ + public static class Skip extends FromStep { + public final Exp exp; + + Skip(ImmutableList bindings, Exp exp) { + super(Op.SKIP, bindings); + this.exp = requireNonNull(exp, "exp"); + } + + @Override public Skip accept(Shuttle shuttle) { + return shuttle.visit(this); + } + + @Override public void accept(Visitor visitor) { + visitor.visit(this); + } + + @Override protected AstWriter unparse(AstWriter w, From from, int ordinal, + int left, int right) { + return w.append(" skip ").append(exp, 0, 0); + } + + public Skip copy(Exp exp, List bindings) { + return exp == this.exp + && bindings.equals(this.bindings) + ? this + : core.skip(bindings, exp); + } + } + + /** A {@code take} clause in a {@code from} expression. */ + public static class Take extends FromStep { + public final Exp exp; + + Take(ImmutableList bindings, Exp exp) { + super(Op.TAKE, bindings); + this.exp = requireNonNull(exp, "exp"); + } + + @Override public Take accept(Shuttle shuttle) { + return shuttle.visit(this); + } + + @Override public void accept(Visitor visitor) { + visitor.visit(this); + } + + @Override protected AstWriter unparse(AstWriter w, From from, int ordinal, + int left, int right) { + return w.append(" take ").append(exp, 0, 0); + } + + public Take copy(Exp exp, List bindings) { + return exp == this.exp + && bindings.equals(this.bindings) + ? this + : core.take(bindings, exp); + } + } + /** An {@code order} clause in a {@code from} expression. */ public static class Order extends FromStep { public final ImmutableList orderItems; diff --git a/src/main/java/net/hydromatic/morel/ast/CoreBuilder.java b/src/main/java/net/hydromatic/morel/ast/CoreBuilder.java index fb29a98ff..d99faffb1 100644 --- a/src/main/java/net/hydromatic/morel/ast/CoreBuilder.java +++ b/src/main/java/net/hydromatic/morel/ast/CoreBuilder.java @@ -521,6 +521,14 @@ public Core.Where where(List bindings, Core.Exp exp) { return new Core.Where(ImmutableList.copyOf(bindings), exp); } + public Core.Skip skip(List bindings, Core.Exp exp) { + return new Core.Skip(ImmutableList.copyOf(bindings), exp); + } + + public Core.Take take(List bindings, Core.Exp exp) { + return new Core.Take(ImmutableList.copyOf(bindings), exp); + } + public Core.Yield yield_(List bindings, Core.Exp exp) { return new Core.Yield(ImmutableList.copyOf(bindings), exp); } diff --git a/src/main/java/net/hydromatic/morel/ast/FromBuilder.java b/src/main/java/net/hydromatic/morel/ast/FromBuilder.java index 69430d350..2d1149c47 100644 --- a/src/main/java/net/hydromatic/morel/ast/FromBuilder.java +++ b/src/main/java/net/hydromatic/morel/ast/FromBuilder.java @@ -29,6 +29,7 @@ import org.apache.calcite.util.Util; import org.checkerframework.checker.nullness.qual.Nullable; +import java.math.BigDecimal; import java.util.ArrayList; import java.util.LinkedHashMap; import java.util.List; @@ -192,6 +193,19 @@ public FromBuilder where(Core.Exp condition) { return addStep(core.where(bindings, condition)); } + public FromBuilder skip(Core.Exp count) { + if (count.op == Op.INT_LITERAL + && ((Core.Literal) count).value.equals(BigDecimal.ZERO)) { + // skip "skip 0" + return this; + } + return addStep(core.skip(bindings, count)); + } + + public FromBuilder take(Core.Exp count) { + return addStep(core.take(bindings, count)); + } + public FromBuilder group(SortedMap groupExps, SortedMap aggregates) { return addStep(core.group(groupExps, aggregates)); @@ -378,6 +392,14 @@ private class StepHandler extends Visitor { where(where.exp); } + @Override protected void visit(Core.Skip skip) { + skip(skip.exp); + } + + @Override protected void visit(Core.Take take) { + take(take.exp); + } + @Override protected void visit(Core.Yield yield) { yield_(false, yield.bindings, yield.exp); } diff --git a/src/main/java/net/hydromatic/morel/ast/Op.java b/src/main/java/net/hydromatic/morel/ast/Op.java index 033e29e59..03bddabb7 100644 --- a/src/main/java/net/hydromatic/morel/ast/Op.java +++ b/src/main/java/net/hydromatic/morel/ast/Op.java @@ -137,6 +137,8 @@ public enum Op { COMPUTE, ORDER, ORDER_ITEM, + SKIP, + TAKE, YIELD, AGGREGATE, IF; diff --git a/src/main/java/net/hydromatic/morel/ast/Shuttle.java b/src/main/java/net/hydromatic/morel/ast/Shuttle.java index 470a2436a..4182d1280 100644 --- a/src/main/java/net/hydromatic/morel/ast/Shuttle.java +++ b/src/main/java/net/hydromatic/morel/ast/Shuttle.java @@ -240,6 +240,14 @@ protected AstNode visit(Ast.Where where) { return ast.where(where.pos, where.exp.accept(this)); } + protected AstNode visit(Ast.Skip skip) { + return ast.skip(skip.pos, skip.exp.accept(this)); + } + + protected AstNode visit(Ast.Take take) { + return ast.take(take.pos, take.exp.accept(this)); + } + protected AstNode visit(Ast.Yield yield) { return ast.yield(yield.pos, yield.exp.accept(this)); } @@ -395,6 +403,14 @@ protected Core.Where visit(Core.Where where) { return where.copy(where.exp.accept(this), where.bindings); } + protected Core.Skip visit(Core.Skip skip) { + return skip.copy(skip.exp.accept(this), skip.bindings); + } + + protected Core.Take visit(Core.Take take) { + return take.copy(take.exp.accept(this), take.bindings); + } + protected Core.Group visit(Core.Group group) { return group.copy(visitSortedMap(group.groupExps), visitSortedMap(group.aggregates)); diff --git a/src/main/java/net/hydromatic/morel/ast/Visitor.java b/src/main/java/net/hydromatic/morel/ast/Visitor.java index 5a2fbd39a..f12e91361 100644 --- a/src/main/java/net/hydromatic/morel/ast/Visitor.java +++ b/src/main/java/net/hydromatic/morel/ast/Visitor.java @@ -206,6 +206,14 @@ protected void visit(Ast.Where where) { where.exp.accept(this); } + protected void visit(Ast.Skip skip) { + skip.exp.accept(this); + } + + protected void visit(Ast.Take take) { + take.exp.accept(this); + } + protected void visit(Ast.Yield yield) { yield.exp.accept(this); } @@ -326,6 +334,14 @@ protected void visit(Core.Where where) { where.exp.accept(this); } + protected void visit(Core.Skip skip) { + skip.exp.accept(this); + } + + protected void visit(Core.Take take) { + take.exp.accept(this); + } + protected void visit(Core.NonRecValDecl valDecl) { valDecl.pat.accept(this); valDecl.exp.accept(this); diff --git a/src/main/java/net/hydromatic/morel/compile/CalciteCompiler.java b/src/main/java/net/hydromatic/morel/compile/CalciteCompiler.java index 6a30e9b35..838a64864 100644 --- a/src/main/java/net/hydromatic/morel/compile/CalciteCompiler.java +++ b/src/main/java/net/hydromatic/morel/compile/CalciteCompiler.java @@ -260,7 +260,7 @@ Code toRel4(Environment env, Code code, Type type) { } else { for (Core.Exp arg : args) { cx.relBuilder.values(new String[]{"T"}, true); - yield_(cx, arg); + yield_(cx, ImmutableList.of(), arg); } cx.relBuilder.union(true, args.size()); } @@ -303,9 +303,9 @@ Code toRel4(Environment env, Code code, Type type) { return false; } final JsonBuilder jsonBuilder = new JsonBuilder(); + final RelJson relJson = RelJson.create().withJsonBuilder(jsonBuilder); final String jsonRowType = - jsonBuilder.toJsonString( - new RelJson(jsonBuilder).toJson(rowType)); + jsonBuilder.toJsonString(relJson.toJson(rowType)); final String morelCode = apply.toString(); ThreadLocals.let(CalciteFunctions.THREAD_ENV, new CalciteFunctions.Context(new Session(ImmutableMap.of()), cx.env, @@ -369,7 +369,7 @@ private static void harmonizeRowTypes(RelBuilder relBuilder, int inputCount) { } } - @Override protected RelCode compileFrom(Context cx, Core.From from) { + @Override protected Code compileFrom(Context cx, Core.From from) { final Code code = super.compileFrom(cx, from); return new RelCode() { @Override public Describer describe(Describer describer) { @@ -400,7 +400,7 @@ private static void harmonizeRowTypes(RelBuilder relBuilder, int inputCount) { || getLast(from.steps).op != Op.YIELD) { final Core.Exp implicitYieldExp = core.implicitYieldExp(typeSystem, from.steps); - cx = yield_(cx, implicitYieldExp); + cx = yield_(cx, ImmutableList.of(), implicitYieldExp); } return true; } @@ -411,6 +411,10 @@ private RelContext step(RelContext cx, int i, Core.FromStep fromStep) { return join(cx, i, (Core.Scan) fromStep); case WHERE: return where(cx, (Core.Where) fromStep); + case SKIP: + return skip(cx, (Core.Skip) fromStep); + case TAKE: + return take(cx, (Core.Take) fromStep); case ORDER: return order(cx, (Core.Order) fromStep); case GROUP: @@ -425,25 +429,28 @@ private RelContext step(RelContext cx, int i, Core.FromStep fromStep) { } private RelContext yield_(RelContext cx, Core.Yield yield) { - return yield_(cx, yield.exp); + return yield_(cx, yield.bindings, yield.exp); } - private RelContext yield_(RelContext cx, Core.Exp exp) { + private RelContext yield_(RelContext cx, List bindings, + Core.Exp exp) { final Core.Tuple tuple; switch (exp.op) { case ID: final Core.Id id = (Core.Id) exp; tuple = toRecord(cx, id); if (tuple != null) { - return yield_(cx, tuple); + return yield_(cx, bindings, tuple); } break; case TUPLE: tuple = (Core.Tuple) exp; + final List names = + ImmutableList.copyOf(tuple.type().argNameTypes().keySet()); cx.relBuilder.project(transform(tuple.args, e -> translate(cx, e)), - ImmutableList.copyOf(tuple.type().argNameTypes().keySet())); - return cx; + names); + return getRelContext(cx, cx.env.bindAll(bindings), names); } RexNode rex = translate(cx, exp); cx.relBuilder.project(rex); @@ -489,7 +496,7 @@ record = toRecord(cx, id); if (cx.map.containsKey(id.idPat.name)) { // Not a record, so must be a scalar. It is represented in Calcite // as a record with one field. - final VarData fn = cx.map.get(id.idPat.name); + final VarData fn = requireNonNull(cx.map.get(id.idPat.name)); return fn.apply(cx.relBuilder); } break; @@ -611,9 +618,9 @@ private RexNode morelScalar(RelContext cx, Core.Exp exp) { final RelDataType calciteType = Converters.toCalciteType(exp.type, typeFactory); final JsonBuilder jsonBuilder = new JsonBuilder(); + final RelJson relJson = RelJson.create().withJsonBuilder(jsonBuilder); final String jsonType = - jsonBuilder.toJsonString( - new RelJson(jsonBuilder).toJson(calciteType)); + jsonBuilder.toJsonString(relJson.toJson(calciteType)); final String morelCode = exp.toString(); return cx.relBuilder.getRexBuilder().makeCall(calciteType, CalciteFunctions.SCALAR_OPERATOR, @@ -695,6 +702,26 @@ private RelContext where(RelContext cx, Core.Where where) { return cx; } + private RelContext skip(RelContext cx, Core.Skip skip) { + if (skip.exp.op != Op.INT_LITERAL) { + throw new AssertionError("skip requires literal: " + skip.exp); + } + int offset = ((Core.Literal) skip.exp).unwrap(Integer.class); + int fetch = -1; // per Calcite: "negative means no limit" + cx.relBuilder.limit(offset, fetch); + return cx; + } + + private RelContext take(RelContext cx, Core.Take take) { + if (take.exp.op != Op.INT_LITERAL) { + throw new AssertionError("take requires literal: " + take.exp); + } + int offset = 0; + int fetch = ((Core.Literal) take.exp).unwrap(Integer.class); + cx.relBuilder.limit(offset, fetch); + return cx; + } + private RelContext order(RelContext cx, Core.Order order) { final List exps = new ArrayList<>(); order.orderItems.forEach(i -> { @@ -709,7 +736,6 @@ private RelContext order(RelContext cx, Core.Order order) { } private RelContext group(RelContext cx, Core.Group group) { - final SortedMap map = new TreeMap<>(); final List bindings = new ArrayList<>(); final List nodes = new ArrayList<>(); final List names = new ArrayList<>(); @@ -734,20 +760,24 @@ private RelContext group(RelContext cx, Core.Group group) { // Create an Aggregate operator. cx.relBuilder.aggregate(groupKey, aggregateCalls); + return getRelContext(cx, cx.env.bindAll(bindings), names); + } + private static RelContext getRelContext(RelContext cx, Environment env, + List names) { // Permute the fields so that they are sorted by name, per Morel records. final List sortedNames = Ordering.natural().immutableSortedCopy(names); cx.relBuilder.rename(names) .project(cx.relBuilder.fields(sortedNames)); final RelDataType rowType = cx.relBuilder.peek().getRowType(); - sortedNames.forEach(name -> { - final int i = map.size(); - map.put(name, new VarData(PrimitiveType.UNIT, i, rowType)); - }); + final SortedMap map = new TreeMap<>(); + sortedNames.forEach(name -> + map.put(name, + new VarData(PrimitiveType.UNIT, map.size(), rowType))); // Return a context containing a variable for each output field. - return new RelContext(cx.env.bindAll(bindings), cx, cx.relBuilder, + return new RelContext(env, cx, cx.relBuilder, ImmutableSortedMap.copyOfSorted(map), 1); } diff --git a/src/main/java/net/hydromatic/morel/compile/Compiler.java b/src/main/java/net/hydromatic/morel/compile/Compiler.java index 4dc633f5b..d7f4d91ce 100644 --- a/src/main/java/net/hydromatic/morel/compile/Compiler.java +++ b/src/main/java/net/hydromatic/morel/compile/Compiler.java @@ -151,7 +151,7 @@ void apply(Consumer outLines, Consumer outBindings, } /** Compilation context. */ - static class Context { + public static class Context { final Environment env; Context(Environment env) { @@ -388,6 +388,16 @@ && getOnlyElement(bindings).id.type.equals(elementType)) { final Code filterCode = compile(cx, where.exp); return () -> Codes.whereRowSink(filterCode, nextFactory.get()); + case SKIP: + final Core.Skip skip = (Core.Skip) firstStep; + final Code skipCode = compile(cx, skip.exp); + return () -> Codes.skipRowSink(skipCode, nextFactory.get()); + + case TAKE: + final Core.Take take = (Core.Take) firstStep; + final Code takeCode = compile(cx, take.exp); + return () -> Codes.takeRowSink(takeCode, nextFactory.get()); + case YIELD: final Core.Yield yield = (Core.Yield) firstStep; if (steps.size() == 1) { diff --git a/src/main/java/net/hydromatic/morel/compile/Resolver.java b/src/main/java/net/hydromatic/morel/compile/Resolver.java index 2318d7762..5dff91189 100644 --- a/src/main/java/net/hydromatic/morel/compile/Resolver.java +++ b/src/main/java/net/hydromatic/morel/compile/Resolver.java @@ -919,6 +919,16 @@ Core.Exp run(Ast.From from) { fromBuilder.where(r.toCore(where.exp)); } + @Override protected void visit(Ast.Skip skip) { + final Resolver r = withEnv(env); // do not use 'from' bindings + fromBuilder.skip(r.toCore(skip.exp)); + } + + @Override protected void visit(Ast.Take take) { + final Resolver r = withEnv(env); // do not use 'from' bindings + fromBuilder.take(r.toCore(take.exp)); + } + @Override protected void visit(Ast.Yield yield) { final Resolver r = withEnv(fromBuilder.bindings()); fromBuilder.yield_(r.toCore(yield.exp)); diff --git a/src/main/java/net/hydromatic/morel/compile/TypeResolver.java b/src/main/java/net/hydromatic/morel/compile/TypeResolver.java index 3180d370e..cd70681a7 100644 --- a/src/main/java/net/hydromatic/morel/compile/TypeResolver.java +++ b/src/main/java/net/hydromatic/morel/compile/TypeResolver.java @@ -546,6 +546,22 @@ private Pair deduceStepType(TypeEnv env, fromSteps.add(where.copy(filter2)); return Pair.of(env2, v); + case SKIP: + final Ast.Skip skip = (Ast.Skip) step; + final Unifier.Variable v11 = unifier.variable(); + final Ast.Exp skipCount = deduceType(env2, skip.exp, v11); + equiv(v11, toTerm(PrimitiveType.INT)); + fromSteps.add(skip.copy(skipCount)); + return Pair.of(env2, v); + + case TAKE: + final Ast.Take take = (Ast.Take) step; + final Unifier.Variable v12 = unifier.variable(); + final Ast.Exp takeCount = deduceType(env2, take.exp, v12); + equiv(v12, toTerm(PrimitiveType.INT)); + fromSteps.add(take.copy(takeCount)); + return Pair.of(env2, v); + case YIELD: final Ast.Yield yield = (Ast.Yield) step; final Unifier.Variable v6 = unifier.variable(); diff --git a/src/main/java/net/hydromatic/morel/eval/Codes.java b/src/main/java/net/hydromatic/morel/eval/Codes.java index 7533d3d86..1a65ebef3 100644 --- a/src/main/java/net/hydromatic/morel/eval/Codes.java +++ b/src/main/java/net/hydromatic/morel/eval/Codes.java @@ -507,6 +507,7 @@ public static Code from(Supplier rowSinkFactory) { @Override public Object eval(EvalEnv env) { final RowSink rowSink = rowSinkFactory.get(); + rowSink.start(env); rowSink.accept(env); return rowSink.result(env); } @@ -524,6 +525,16 @@ public static RowSink whereRowSink(Code filterCode, RowSink rowSink) { return new WhereRowSink(filterCode, rowSink); } + /** Creates a {@link RowSink} for a {@code skip} clause. */ + public static RowSink skipRowSink(Code filterCode, RowSink rowSink) { + return new SkipRowSink(filterCode, rowSink); + } + + /** Creates a {@link RowSink} for a {@code take} clause. */ + public static RowSink takeRowSink(Code filterCode, RowSink rowSink) { + return new TakeRowSink(filterCode, rowSink); + } + /** Creates a {@link RowSink} for a {@code order} clause. */ public static RowSink orderRowSink( Iterable> codes, @@ -3042,26 +3053,47 @@ public Object eval(EvalEnv env) { /** Accepts rows produced by a supplier as part of a {@code from} clause. */ public interface RowSink extends Describable { + void start(EvalEnv env); void accept(EvalEnv env); List result(EvalEnv env); } + /** Abstract implementation for row sinks that have one successor. */ + abstract static class BaseRowSink implements RowSink { + final RowSink rowSink; + + BaseRowSink(RowSink rowSink) { + this.rowSink = requireNonNull(rowSink); + } + + @Override public void start(EvalEnv env) { + rowSink.start(env); + } + + @Override public void accept(EvalEnv env) { + rowSink.accept(env); + } + + @Override public List result(EvalEnv env) { + return rowSink.result(env); + } + } + /** Implementation of {@link RowSink} for a {@code join} clause. */ - static class ScanRowSink implements RowSink { + static class ScanRowSink extends BaseRowSink { final Op op; // inner, left, right, full private final Core.Pat pat; private final Code code; final Code conditionCode; - final RowSink rowSink; ScanRowSink(Op op, Core.Pat pat, Code code, Code conditionCode, RowSink rowSink) { + super(rowSink); checkArgument(op == Op.INNER_JOIN); this.op = op; this.pat = pat; this.code = code; this.conditionCode = conditionCode; - this.rowSink = rowSink; } @Override public Describer describe(Describer describer) { @@ -3078,7 +3110,7 @@ private static boolean isConstantTrue(Code code) { && Objects.equals(code.eval(null), true); } - public void accept(EvalEnv env) { + @Override public void accept(EvalEnv env) { final MutableEvalEnv mutableEvalEnv = env.bindMutablePat(pat); final Iterable elements = (Iterable) code.eval(env); for (Object element : elements) { @@ -3090,20 +3122,15 @@ public void accept(EvalEnv env) { } } } - - public List result(EvalEnv env) { - return rowSink.result(env); - } } /** Implementation of {@link RowSink} for a {@code where} clause. */ - static class WhereRowSink implements RowSink { + static class WhereRowSink extends BaseRowSink { final Code filterCode; - final RowSink rowSink; WhereRowSink(Code filterCode, RowSink rowSink) { + super(rowSink); this.filterCode = filterCode; - this.rowSink = rowSink; } @Override public Describer describe(Describer describer) { @@ -3112,38 +3139,92 @@ static class WhereRowSink implements RowSink { .arg("sink", rowSink)); } - public void accept(EvalEnv env) { + @Override public void accept(EvalEnv env) { if ((Boolean) filterCode.eval(env)) { rowSink.accept(env); } } + } - public List result(EvalEnv env) { - return rowSink.result(env); + /** Implementation of {@link RowSink} for a {@code skip} clause. */ + static class SkipRowSink extends BaseRowSink { + final Code skipCode; + int skip; + + SkipRowSink(Code skipCode, RowSink rowSink) { + super(rowSink); + this.skipCode = skipCode; + } + + @Override public Describer describe(Describer describer) { + return describer.start("skip", d -> + d.arg("count", skipCode) + .arg("sink", rowSink)); + } + + @Override public void start(EvalEnv env) { + skip = (Integer) skipCode.eval(env); + super.start(env); + } + + @Override public void accept(EvalEnv env) { + if (skip > 0) { + --skip; + } else { + rowSink.accept(env); + } + } + } + + /** Implementation of {@link RowSink} for a {@code take} clause. */ + static class TakeRowSink extends BaseRowSink { + final Code takeCode; + int take; + + TakeRowSink(Code takeCode, RowSink rowSink) { + super(rowSink); + this.takeCode = takeCode; + } + + @Override public Describer describe(Describer describer) { + return describer.start("take", d -> + d.arg("count", takeCode) + .arg("sink", rowSink)); + } + + @Override public void start(EvalEnv env) { + take = (Integer) takeCode.eval(env); + super.start(env); + } + + @Override public void accept(EvalEnv env) { + if (take > 0) { + --take; + rowSink.accept(env); + } } } /** Implementation of {@link RowSink} for a {@code group} clause. */ - private static class GroupRowSink implements RowSink { + private static class GroupRowSink extends BaseRowSink { final Code keyCode; final ImmutableList inNames; final ImmutableList keyNames; /** group names followed by aggregate names */ final ImmutableList outNames; final ImmutableList aggregateCodes; - final RowSink rowSink; final ListMultimap map = ArrayListMultimap.create(); final Object[] values; GroupRowSink(Code keyCode, ImmutableList aggregateCodes, ImmutableList inNames, ImmutableList keyNames, ImmutableList outNames, RowSink rowSink) { + super(rowSink); this.keyCode = requireNonNull(keyCode); this.aggregateCodes = requireNonNull(aggregateCodes); this.inNames = requireNonNull(inNames); this.keyNames = requireNonNull(keyNames); this.outNames = requireNonNull(outNames); - this.rowSink = requireNonNull(rowSink); this.values = inNames.size() == 1 ? null : new Object[inNames.size()]; checkArgument(isPrefix(keyNames, outNames)); } @@ -3161,7 +3242,7 @@ private static boolean isPrefix(List list0, List list1) { }); } - public void accept(EvalEnv env) { + @Override public void accept(EvalEnv env) { if (inNames.size() == 1) { map.put(keyCode.eval(env), env.getOpt(inNames.get(0))); } else { @@ -3172,7 +3253,7 @@ public void accept(EvalEnv env) { } } - public List result(final EvalEnv env) { + @Override public List result(final EvalEnv env) { // Derive env2, the environment for our consumer. It consists of our input // environment plus output names. EvalEnv env2 = env; @@ -3213,18 +3294,17 @@ public List result(final EvalEnv env) { } /** Implementation of {@link RowSink} for an {@code order} clause. */ - static class OrderRowSink implements RowSink { + static class OrderRowSink extends BaseRowSink { final ImmutablePairList codes; final ImmutableList names; - final RowSink rowSink; final List rows = new ArrayList<>(); final Object[] values; OrderRowSink(ImmutablePairList codes, ImmutableList names, RowSink rowSink) { + super(rowSink); this.codes = codes; this.names = names; - this.rowSink = rowSink; this.values = names.size() == 1 ? null : new Object[names.size()]; } @@ -3235,7 +3315,7 @@ static class OrderRowSink implements RowSink { }); } - public void accept(EvalEnv env) { + @Override public void accept(EvalEnv env) { if (values == null) { rows.add(env.getOpt(names.get(0))); } else { @@ -3246,7 +3326,7 @@ public void accept(EvalEnv env) { } } - public List result(final EvalEnv env) { + @Override public List result(final EvalEnv env) { final MutableEvalEnv leftEnv = env.bindMutableArray(names); final MutableEvalEnv rightEnv = env.bindMutableArray(names); rows.sort((left, right) -> { @@ -3277,17 +3357,16 @@ public List result(final EvalEnv env) { * step is allowed to generate expressions that are not records. Non-record * expressions (e.g. {@code int} expressions) do not have a name, and * therefore the value cannot be passed via the {@link EvalEnv}. */ - private static class YieldRowSink implements RowSink { + private static class YieldRowSink extends BaseRowSink { private final ImmutableList names; private final ImmutableList codes; - private final RowSink rowSink; private final Object[] values; YieldRowSink(ImmutableList names, ImmutableList codes, RowSink rowSink) { + super(rowSink); this.names = names; this.codes = codes; - this.rowSink = rowSink; this.values = names.size() == 1 ? null : new Object[names.size()]; } @@ -3310,10 +3389,6 @@ private static class YieldRowSink implements RowSink { } rowSink.accept(env2); } - - @Override public List result(EvalEnv env) { - return rowSink.result(env); - } } /** Implementation of {@link RowSink} that the last step of a {@code from} @@ -3330,6 +3405,10 @@ private static class CollectRowSink implements RowSink { return describer.start("collect", d -> d.arg("", code)); } + @Override public void start(EvalEnv env) { + list.clear(); + } + @Override public void accept(EvalEnv env) { list.add(code.eval(env)); } diff --git a/src/main/javacc/MorelParser.jj b/src/main/javacc/MorelParser.jj index b2d6ad9c6..7ddcb8cf3 100644 --- a/src/main/javacc/MorelParser.jj +++ b/src/main/javacc/MorelParser.jj @@ -357,6 +357,8 @@ void fromStep(List steps) : final Pair patExp; final Exp condition; final Exp filterExp; + final Exp skipExp; + final Exp takeExp; final Exp yieldExp; final PairList groupExps; final List aggregates; @@ -406,6 +408,14 @@ void fromStep(List steps) : { span = Span.of(pos()); } orderItems = orderItemCommaList() { steps.add(ast.order(span.end(this), orderItems)); } +| + { span = Span.of(pos()); } skipExp = expression() { + steps.add(ast.skip(span.end(this), skipExp)); + } +| + { span = Span.of(pos()); } takeExp = expression() { + steps.add(ast.take(span.end(this), takeExp)); + } | { span = Span.of(pos()); } yieldExp = expression() { steps.add(ast.yield(span.end(this), yieldExp)); @@ -1418,6 +1428,8 @@ AstNode statementEof() : | < JOIN: "join" > | < ON: "on" > | < ORDER: "order" > +| < SKIP_: "skip" > +| < TAKE: "take" > | < WHERE: "where" > | < YIELD: "yield" > } diff --git a/src/test/java/net/hydromatic/morel/AlgebraTest.java b/src/test/java/net/hydromatic/morel/AlgebraTest.java index 1a8e1c3d0..d91996f04 100644 --- a/src/test/java/net/hydromatic/morel/AlgebraTest.java +++ b/src/test/java/net/hydromatic/morel/AlgebraTest.java @@ -68,6 +68,27 @@ public class AlgebraTest { 10)); } + @Test void testScottOrder() { + final String ml = "from e in scott.emp\n" + + " yield {e.empno, e.deptno}\n" + + " order empno desc\n" + + " skip 2 take 4"; + // When fixed, + // [CALCITE-6128] RelBuilder.sortLimit should compose offset and fetch + // will yield a plan with one fewer LogicalSort + final String plan = "LogicalSort(fetch=[4])\n" + + " LogicalSort(sort0=[$1], dir0=[DESC], offset=[2])\n" + + " LogicalProject(deptno=[$7], empno=[$0])\n" + + " JdbcTableScan(table=[[scott, EMP]])\n"; + ml(ml) + .withBinding("scott", BuiltInDataSet.SCOTT) + .assertType("{deptno:int, empno:int} list") + .assertCalcite(is(plan)) + .assertEvalIter( + equalsOrdered(list(30, 7900), list(20, 7876), list(30, 7844), + list(10, 7839))); + } + @Test void testScottJoin() { final String ml = "let\n" + " val emps = #emp scott\n" @@ -153,6 +174,13 @@ public class AlgebraTest { + "group r.b compute sb = sum of r.b,\n" + " mb = min of r.b, a = count\n" + "yield {a, a2 = a + b, sb}", + "from e in scott.emp\n" + + "yield {e.ename, x = e.deptno * 2}", + "from e in scott.emp\n" + + "order e.ename", + "from e in scott.emp\n" + + "order e.ename desc\n" + + "take 3", "from e in scott.emp,\n" + " d in scott.dept\n" + "where e.deptno = d.deptno\n" diff --git a/src/test/java/net/hydromatic/morel/MainTest.java b/src/test/java/net/hydromatic/morel/MainTest.java index 9e64e7ae9..9035501a1 100644 --- a/src/test/java/net/hydromatic/morel/MainTest.java +++ b/src/test/java/net/hydromatic/morel/MainTest.java @@ -1752,6 +1752,12 @@ private static List node(Object... args) { .assertParse("from e in emps" + " group id = #id e compute count = count" + " join d in depts where false"); + ml("from e in emps skip 1 take 2").assertParseSame(); + ml("from e in emps order e.empno take 2") + .assertParse("from e in emps order #empno e take 2"); + ml("from e in emps order e.empno take 2 skip 3 skip 1+1 take 2") + .assertParse("from e in emps order #empno e take 2 skip 3 skip 1 + 1 " + + "take 2"); ml("fn f => from i in [1, 2, 3] where f i") .assertParseSame() .assertType("(int -> bool) -> int list"); @@ -2399,10 +2405,14 @@ private static List node(Object... args) { @Test void testFromOrderYield() { final String ml = "from r in [{a=1,b=2},{a=1,b=0},{a=2,b=1}]\n" + " order r.a desc, r.b\n" + + " skip 0\n" + + " take 4 + 6\n" + " yield {r.a, b10 = r.b * 10}"; final String expected = "from r in" + " [{a = 1, b = 2}, {a = 1, b = 0}, {a = 2, b = 1}]" + " order #a r desc, #b r" + + " skip 0" + + " take 4 + 6" + " yield {a = #a r, b10 = #b r * 10}"; ml(ml).assertParse(expected) .assertType(hasMoniker("{a:int, b10:int} list")) diff --git a/src/test/resources/script/builtIn.smli b/src/test/resources/script/builtIn.smli index 4ef77491c..fc2989452 100644 --- a/src/test/resources/script/builtIn.smli +++ b/src/test/resources/script/builtIn.smli @@ -697,20 +697,20 @@ Sys.plan (); > : string (*) val take : 'a list * int -> 'a list -List.take; +List.`take`; > val it = fn : 'a list * int -> 'a list -List.take ([1,2,3], 0); +List.`take` ([1,2,3], 0); > val it = [] : int list -List.take ([1,2,3], 1); +List.`take` ([1,2,3], 1); > val it = [1] : int list -List.take ([1,2,3], 3); +List.`take` ([1,2,3], 3); > val it = [1,2,3] : int list -List.take ([1,2,3], 4); +List.`take` ([1,2,3], 4); > uncaught exception Subscript [subscript out of bounds] -> raised at: stdIn:1.1-1.23 -List.take ([1,2,3], ~1); +> raised at: stdIn:1.1-1.25 +List.`take` ([1,2,3], ~1); > uncaught exception Subscript [subscript out of bounds] -> raised at: stdIn:1.1-1.24 +> raised at: stdIn:1.1-1.26 Sys.plan (); > val it = > "apply2(fnValue List.take, tuple(constant(1), constant(2), constant(3)), constant(-1))" diff --git a/src/test/resources/script/relational.smli b/src/test/resources/script/relational.smli index 29c0ac293..69cf69146 100644 --- a/src/test/resources/script/relational.smli +++ b/src/test/resources/script/relational.smli @@ -175,6 +175,41 @@ from e in emps order deptno desc; > val it = [30,30,20,10] : int list +from e in emps + yield {e.deptno} + order deptno desc + skip 1; +> val it = [30,20,10] : int list + +from e in emps + yield {e.deptno} + order deptno desc + skip 1 + take 2; +> val it = [30,20] : int list + +from e in emps + yield {e.deptno} + order deptno desc + take 2; +> val it = [30,30] : int list + +(*) Pass 'take' and 'skip' via function arguments +let + fun earlyEmps n = + from e in emps + yield {e.id, e.deptno} + order id + skip n - 2 + take n +in + (earlyEmps 2, earlyEmps 3) +end; +> val it = +> ([{deptno=10,id=100},{deptno=20,id=101}], +> [{deptno=20,id=101},{deptno=30,id=102},{deptno=30,id=103}]) +> : {deptno:int, id:int} list * {deptno:int, id:int} list + (*) 'yield' followed by 'order' from e in emps yield {e.deptno, x = e.deptno, e.name} diff --git a/src/test/resources/script/wordle.smli b/src/test/resources/script/wordle.smli index 703f8269b..ddd602996 100644 --- a/src/test/resources/script/wordle.smli +++ b/src/test/resources/script/wordle.smli @@ -131,7 +131,7 @@ fun bestGuesses words = > val bestGuesses = fn : string list -> {f:int, w:string} list (*) Run on a sample of 500 words; faster than the full set of 12,972 words. -val sampleWords = if slow then words else List.take (words, 200); +val sampleWords = if slow then words else List.`take` (words, 200); > val sampleWords = > ["aahed","aalii","aargh","aarti","abaca","abaci","aback","abacs","abaft", > "abaka","abamp","aband","abase","abash","abask","abate","abaya","abbas",