diff --git a/docs/reference.md b/docs/reference.md index ab3e31be..b703f406 100644 --- a/docs/reference.md +++ b/docs/reference.md @@ -38,7 +38,7 @@ Contributions are welcome! In Morel but not Standard ML: * `from` expression with `in`, `suchthat`, `join`, `where`, `group`, - `compute`, `order`, `yield` clauses + `compute`, `order`, `skip`, `take`, `yield` clauses * `union`, `except`, `intersect`, `elem`, `notelem` operators * "*lab* `=`" is optional in `exprow` * identifiers may be quoted @@ -159,7 +159,9 @@ In Standard ML but not in Morel: compute clause (a > 1) | order orderItem1 , ... , orderItemo order clause (o ≥ 1) - | yield exp + | skip exp skip clause + | take exp take clause + | yield exp yield clause groupKey → [ id = ] exp agg → [ id = ] exp [ of exp ] orderItemexp [ desc ] diff --git a/src/main/java/net/hydromatic/morel/ast/Ast.java b/src/main/java/net/hydromatic/morel/ast/Ast.java index 48473742..d74ed1f2 100644 --- a/src/main/java/net/hydromatic/morel/ast/Ast.java +++ b/src/main/java/net/hydromatic/morel/ast/Ast.java @@ -1707,6 +1707,58 @@ public Where copy(Exp exp) { } } + /** A {@code skip} clause in a {@code from} expression. */ + public static class Skip extends FromStep { + public final Exp exp; + + Skip(Pos pos, Exp exp) { + super(pos, Op.SKIP); + this.exp = exp; + } + + @Override AstWriter unparse(AstWriter w, int left, int right) { + return w.append(" skip ").append(exp, 0, 0); + } + + @Override public AstNode accept(Shuttle shuttle) { + return shuttle.visit(this); + } + + @Override public void accept(Visitor visitor) { + visitor.visit(this); + } + + public Skip copy(Exp exp) { + return this.exp.equals(exp) ? this : new Skip(pos, exp); + } + } + + /** A {@code take} clause in a {@code from} expression. */ + public static class Take extends FromStep { + public final Exp exp; + + Take(Pos pos, Exp exp) { + super(pos, Op.TAKE); + this.exp = exp; + } + + @Override AstWriter unparse(AstWriter w, int left, int right) { + return w.append(" take ").append(exp, 0, 0); + } + + @Override public AstNode accept(Shuttle shuttle) { + return shuttle.visit(this); + } + + @Override public void accept(Visitor visitor) { + visitor.visit(this); + } + + public Take copy(Exp exp) { + return this.exp.equals(exp) ? this : new Take(pos, exp); + } + } + /** A {@code yield} clause in a {@code from} expression. */ public static class Yield extends FromStep { public final Exp exp; diff --git a/src/main/java/net/hydromatic/morel/ast/AstBuilder.java b/src/main/java/net/hydromatic/morel/ast/AstBuilder.java index ac9493c1..b46ab652 100644 --- a/src/main/java/net/hydromatic/morel/ast/AstBuilder.java +++ b/src/main/java/net/hydromatic/morel/ast/AstBuilder.java @@ -483,6 +483,14 @@ public Ast.FromStep where(Pos pos, Ast.Exp exp) { return new Ast.Where(pos, exp); } + public Ast.FromStep skip(Pos pos, Ast.Exp exp) { + return new Ast.Skip(pos, exp); + } + + public Ast.FromStep take(Pos pos, Ast.Exp exp) { + return new Ast.Take(pos, exp); + } + public Ast.FromStep yield(Pos pos, Ast.Exp exp) { return new Ast.Yield(pos, exp); } diff --git a/src/main/java/net/hydromatic/morel/ast/Core.java b/src/main/java/net/hydromatic/morel/ast/Core.java index 74b8799c..b1138bdf 100644 --- a/src/main/java/net/hydromatic/morel/ast/Core.java +++ b/src/main/java/net/hydromatic/morel/ast/Core.java @@ -1238,6 +1238,66 @@ public Where copy(Exp exp, List bindings) { } } + /** A {@code skip} clause in a {@code from} expression. */ + public static class Skip extends FromStep { + public final Exp exp; + + Skip(ImmutableList bindings, Exp exp) { + super(Op.SKIP, bindings); + this.exp = requireNonNull(exp, "exp"); + } + + @Override public Skip accept(Shuttle shuttle) { + return shuttle.visit(this); + } + + @Override public void accept(Visitor visitor) { + visitor.visit(this); + } + + @Override protected AstWriter unparse(AstWriter w, From from, int ordinal, + int left, int right) { + return w.append(" skip ").append(exp, 0, 0); + } + + public Skip copy(Exp exp, List bindings) { + return exp == this.exp + && bindings.equals(this.bindings) + ? this + : core.skip(bindings, exp); + } + } + + /** A {@code take} clause in a {@code from} expression. */ + public static class Take extends FromStep { + public final Exp exp; + + Take(ImmutableList bindings, Exp exp) { + super(Op.TAKE, bindings); + this.exp = requireNonNull(exp, "exp"); + } + + @Override public Take accept(Shuttle shuttle) { + return shuttle.visit(this); + } + + @Override public void accept(Visitor visitor) { + visitor.visit(this); + } + + @Override protected AstWriter unparse(AstWriter w, From from, int ordinal, + int left, int right) { + return w.append(" take ").append(exp, 0, 0); + } + + public Take copy(Exp exp, List bindings) { + return exp == this.exp + && bindings.equals(this.bindings) + ? this + : core.take(bindings, exp); + } + } + /** An {@code order} clause in a {@code from} expression. */ public static class Order extends FromStep { public final ImmutableList orderItems; diff --git a/src/main/java/net/hydromatic/morel/ast/CoreBuilder.java b/src/main/java/net/hydromatic/morel/ast/CoreBuilder.java index fb29a98f..d99faffb 100644 --- a/src/main/java/net/hydromatic/morel/ast/CoreBuilder.java +++ b/src/main/java/net/hydromatic/morel/ast/CoreBuilder.java @@ -521,6 +521,14 @@ public Core.Where where(List bindings, Core.Exp exp) { return new Core.Where(ImmutableList.copyOf(bindings), exp); } + public Core.Skip skip(List bindings, Core.Exp exp) { + return new Core.Skip(ImmutableList.copyOf(bindings), exp); + } + + public Core.Take take(List bindings, Core.Exp exp) { + return new Core.Take(ImmutableList.copyOf(bindings), exp); + } + public Core.Yield yield_(List bindings, Core.Exp exp) { return new Core.Yield(ImmutableList.copyOf(bindings), exp); } diff --git a/src/main/java/net/hydromatic/morel/ast/FromBuilder.java b/src/main/java/net/hydromatic/morel/ast/FromBuilder.java index 69430d35..2d1149c4 100644 --- a/src/main/java/net/hydromatic/morel/ast/FromBuilder.java +++ b/src/main/java/net/hydromatic/morel/ast/FromBuilder.java @@ -29,6 +29,7 @@ import org.apache.calcite.util.Util; import org.checkerframework.checker.nullness.qual.Nullable; +import java.math.BigDecimal; import java.util.ArrayList; import java.util.LinkedHashMap; import java.util.List; @@ -192,6 +193,19 @@ public FromBuilder where(Core.Exp condition) { return addStep(core.where(bindings, condition)); } + public FromBuilder skip(Core.Exp count) { + if (count.op == Op.INT_LITERAL + && ((Core.Literal) count).value.equals(BigDecimal.ZERO)) { + // skip "skip 0" + return this; + } + return addStep(core.skip(bindings, count)); + } + + public FromBuilder take(Core.Exp count) { + return addStep(core.take(bindings, count)); + } + public FromBuilder group(SortedMap groupExps, SortedMap aggregates) { return addStep(core.group(groupExps, aggregates)); @@ -378,6 +392,14 @@ private class StepHandler extends Visitor { where(where.exp); } + @Override protected void visit(Core.Skip skip) { + skip(skip.exp); + } + + @Override protected void visit(Core.Take take) { + take(take.exp); + } + @Override protected void visit(Core.Yield yield) { yield_(false, yield.bindings, yield.exp); } diff --git a/src/main/java/net/hydromatic/morel/ast/Op.java b/src/main/java/net/hydromatic/morel/ast/Op.java index 033e29e5..03bddabb 100644 --- a/src/main/java/net/hydromatic/morel/ast/Op.java +++ b/src/main/java/net/hydromatic/morel/ast/Op.java @@ -137,6 +137,8 @@ public enum Op { COMPUTE, ORDER, ORDER_ITEM, + SKIP, + TAKE, YIELD, AGGREGATE, IF; diff --git a/src/main/java/net/hydromatic/morel/ast/Shuttle.java b/src/main/java/net/hydromatic/morel/ast/Shuttle.java index 470a2436..4182d128 100644 --- a/src/main/java/net/hydromatic/morel/ast/Shuttle.java +++ b/src/main/java/net/hydromatic/morel/ast/Shuttle.java @@ -240,6 +240,14 @@ protected AstNode visit(Ast.Where where) { return ast.where(where.pos, where.exp.accept(this)); } + protected AstNode visit(Ast.Skip skip) { + return ast.skip(skip.pos, skip.exp.accept(this)); + } + + protected AstNode visit(Ast.Take take) { + return ast.take(take.pos, take.exp.accept(this)); + } + protected AstNode visit(Ast.Yield yield) { return ast.yield(yield.pos, yield.exp.accept(this)); } @@ -395,6 +403,14 @@ protected Core.Where visit(Core.Where where) { return where.copy(where.exp.accept(this), where.bindings); } + protected Core.Skip visit(Core.Skip skip) { + return skip.copy(skip.exp.accept(this), skip.bindings); + } + + protected Core.Take visit(Core.Take take) { + return take.copy(take.exp.accept(this), take.bindings); + } + protected Core.Group visit(Core.Group group) { return group.copy(visitSortedMap(group.groupExps), visitSortedMap(group.aggregates)); diff --git a/src/main/java/net/hydromatic/morel/ast/Visitor.java b/src/main/java/net/hydromatic/morel/ast/Visitor.java index 5a2fbd39..f12e9136 100644 --- a/src/main/java/net/hydromatic/morel/ast/Visitor.java +++ b/src/main/java/net/hydromatic/morel/ast/Visitor.java @@ -206,6 +206,14 @@ protected void visit(Ast.Where where) { where.exp.accept(this); } + protected void visit(Ast.Skip skip) { + skip.exp.accept(this); + } + + protected void visit(Ast.Take take) { + take.exp.accept(this); + } + protected void visit(Ast.Yield yield) { yield.exp.accept(this); } @@ -326,6 +334,14 @@ protected void visit(Core.Where where) { where.exp.accept(this); } + protected void visit(Core.Skip skip) { + skip.exp.accept(this); + } + + protected void visit(Core.Take take) { + take.exp.accept(this); + } + protected void visit(Core.NonRecValDecl valDecl) { valDecl.pat.accept(this); valDecl.exp.accept(this); diff --git a/src/main/java/net/hydromatic/morel/compile/CalciteCompiler.java b/src/main/java/net/hydromatic/morel/compile/CalciteCompiler.java index 6a30e9b3..838a6486 100644 --- a/src/main/java/net/hydromatic/morel/compile/CalciteCompiler.java +++ b/src/main/java/net/hydromatic/morel/compile/CalciteCompiler.java @@ -260,7 +260,7 @@ Code toRel4(Environment env, Code code, Type type) { } else { for (Core.Exp arg : args) { cx.relBuilder.values(new String[]{"T"}, true); - yield_(cx, arg); + yield_(cx, ImmutableList.of(), arg); } cx.relBuilder.union(true, args.size()); } @@ -303,9 +303,9 @@ Code toRel4(Environment env, Code code, Type type) { return false; } final JsonBuilder jsonBuilder = new JsonBuilder(); + final RelJson relJson = RelJson.create().withJsonBuilder(jsonBuilder); final String jsonRowType = - jsonBuilder.toJsonString( - new RelJson(jsonBuilder).toJson(rowType)); + jsonBuilder.toJsonString(relJson.toJson(rowType)); final String morelCode = apply.toString(); ThreadLocals.let(CalciteFunctions.THREAD_ENV, new CalciteFunctions.Context(new Session(ImmutableMap.of()), cx.env, @@ -369,7 +369,7 @@ private static void harmonizeRowTypes(RelBuilder relBuilder, int inputCount) { } } - @Override protected RelCode compileFrom(Context cx, Core.From from) { + @Override protected Code compileFrom(Context cx, Core.From from) { final Code code = super.compileFrom(cx, from); return new RelCode() { @Override public Describer describe(Describer describer) { @@ -400,7 +400,7 @@ private static void harmonizeRowTypes(RelBuilder relBuilder, int inputCount) { || getLast(from.steps).op != Op.YIELD) { final Core.Exp implicitYieldExp = core.implicitYieldExp(typeSystem, from.steps); - cx = yield_(cx, implicitYieldExp); + cx = yield_(cx, ImmutableList.of(), implicitYieldExp); } return true; } @@ -411,6 +411,10 @@ private RelContext step(RelContext cx, int i, Core.FromStep fromStep) { return join(cx, i, (Core.Scan) fromStep); case WHERE: return where(cx, (Core.Where) fromStep); + case SKIP: + return skip(cx, (Core.Skip) fromStep); + case TAKE: + return take(cx, (Core.Take) fromStep); case ORDER: return order(cx, (Core.Order) fromStep); case GROUP: @@ -425,25 +429,28 @@ private RelContext step(RelContext cx, int i, Core.FromStep fromStep) { } private RelContext yield_(RelContext cx, Core.Yield yield) { - return yield_(cx, yield.exp); + return yield_(cx, yield.bindings, yield.exp); } - private RelContext yield_(RelContext cx, Core.Exp exp) { + private RelContext yield_(RelContext cx, List bindings, + Core.Exp exp) { final Core.Tuple tuple; switch (exp.op) { case ID: final Core.Id id = (Core.Id) exp; tuple = toRecord(cx, id); if (tuple != null) { - return yield_(cx, tuple); + return yield_(cx, bindings, tuple); } break; case TUPLE: tuple = (Core.Tuple) exp; + final List names = + ImmutableList.copyOf(tuple.type().argNameTypes().keySet()); cx.relBuilder.project(transform(tuple.args, e -> translate(cx, e)), - ImmutableList.copyOf(tuple.type().argNameTypes().keySet())); - return cx; + names); + return getRelContext(cx, cx.env.bindAll(bindings), names); } RexNode rex = translate(cx, exp); cx.relBuilder.project(rex); @@ -489,7 +496,7 @@ record = toRecord(cx, id); if (cx.map.containsKey(id.idPat.name)) { // Not a record, so must be a scalar. It is represented in Calcite // as a record with one field. - final VarData fn = cx.map.get(id.idPat.name); + final VarData fn = requireNonNull(cx.map.get(id.idPat.name)); return fn.apply(cx.relBuilder); } break; @@ -611,9 +618,9 @@ private RexNode morelScalar(RelContext cx, Core.Exp exp) { final RelDataType calciteType = Converters.toCalciteType(exp.type, typeFactory); final JsonBuilder jsonBuilder = new JsonBuilder(); + final RelJson relJson = RelJson.create().withJsonBuilder(jsonBuilder); final String jsonType = - jsonBuilder.toJsonString( - new RelJson(jsonBuilder).toJson(calciteType)); + jsonBuilder.toJsonString(relJson.toJson(calciteType)); final String morelCode = exp.toString(); return cx.relBuilder.getRexBuilder().makeCall(calciteType, CalciteFunctions.SCALAR_OPERATOR, @@ -695,6 +702,26 @@ private RelContext where(RelContext cx, Core.Where where) { return cx; } + private RelContext skip(RelContext cx, Core.Skip skip) { + if (skip.exp.op != Op.INT_LITERAL) { + throw new AssertionError("skip requires literal: " + skip.exp); + } + int offset = ((Core.Literal) skip.exp).unwrap(Integer.class); + int fetch = -1; // per Calcite: "negative means no limit" + cx.relBuilder.limit(offset, fetch); + return cx; + } + + private RelContext take(RelContext cx, Core.Take take) { + if (take.exp.op != Op.INT_LITERAL) { + throw new AssertionError("take requires literal: " + take.exp); + } + int offset = 0; + int fetch = ((Core.Literal) take.exp).unwrap(Integer.class); + cx.relBuilder.limit(offset, fetch); + return cx; + } + private RelContext order(RelContext cx, Core.Order order) { final List exps = new ArrayList<>(); order.orderItems.forEach(i -> { @@ -709,7 +736,6 @@ private RelContext order(RelContext cx, Core.Order order) { } private RelContext group(RelContext cx, Core.Group group) { - final SortedMap map = new TreeMap<>(); final List bindings = new ArrayList<>(); final List nodes = new ArrayList<>(); final List names = new ArrayList<>(); @@ -734,20 +760,24 @@ private RelContext group(RelContext cx, Core.Group group) { // Create an Aggregate operator. cx.relBuilder.aggregate(groupKey, aggregateCalls); + return getRelContext(cx, cx.env.bindAll(bindings), names); + } + private static RelContext getRelContext(RelContext cx, Environment env, + List names) { // Permute the fields so that they are sorted by name, per Morel records. final List sortedNames = Ordering.natural().immutableSortedCopy(names); cx.relBuilder.rename(names) .project(cx.relBuilder.fields(sortedNames)); final RelDataType rowType = cx.relBuilder.peek().getRowType(); - sortedNames.forEach(name -> { - final int i = map.size(); - map.put(name, new VarData(PrimitiveType.UNIT, i, rowType)); - }); + final SortedMap map = new TreeMap<>(); + sortedNames.forEach(name -> + map.put(name, + new VarData(PrimitiveType.UNIT, map.size(), rowType))); // Return a context containing a variable for each output field. - return new RelContext(cx.env.bindAll(bindings), cx, cx.relBuilder, + return new RelContext(env, cx, cx.relBuilder, ImmutableSortedMap.copyOfSorted(map), 1); } diff --git a/src/main/java/net/hydromatic/morel/compile/Compiler.java b/src/main/java/net/hydromatic/morel/compile/Compiler.java index 4dc633f5..d7f4d91c 100644 --- a/src/main/java/net/hydromatic/morel/compile/Compiler.java +++ b/src/main/java/net/hydromatic/morel/compile/Compiler.java @@ -151,7 +151,7 @@ void apply(Consumer outLines, Consumer outBindings, } /** Compilation context. */ - static class Context { + public static class Context { final Environment env; Context(Environment env) { @@ -388,6 +388,16 @@ && getOnlyElement(bindings).id.type.equals(elementType)) { final Code filterCode = compile(cx, where.exp); return () -> Codes.whereRowSink(filterCode, nextFactory.get()); + case SKIP: + final Core.Skip skip = (Core.Skip) firstStep; + final Code skipCode = compile(cx, skip.exp); + return () -> Codes.skipRowSink(skipCode, nextFactory.get()); + + case TAKE: + final Core.Take take = (Core.Take) firstStep; + final Code takeCode = compile(cx, take.exp); + return () -> Codes.takeRowSink(takeCode, nextFactory.get()); + case YIELD: final Core.Yield yield = (Core.Yield) firstStep; if (steps.size() == 1) { diff --git a/src/main/java/net/hydromatic/morel/compile/Resolver.java b/src/main/java/net/hydromatic/morel/compile/Resolver.java index 2318d776..5dff9118 100644 --- a/src/main/java/net/hydromatic/morel/compile/Resolver.java +++ b/src/main/java/net/hydromatic/morel/compile/Resolver.java @@ -919,6 +919,16 @@ Core.Exp run(Ast.From from) { fromBuilder.where(r.toCore(where.exp)); } + @Override protected void visit(Ast.Skip skip) { + final Resolver r = withEnv(env); // do not use 'from' bindings + fromBuilder.skip(r.toCore(skip.exp)); + } + + @Override protected void visit(Ast.Take take) { + final Resolver r = withEnv(env); // do not use 'from' bindings + fromBuilder.take(r.toCore(take.exp)); + } + @Override protected void visit(Ast.Yield yield) { final Resolver r = withEnv(fromBuilder.bindings()); fromBuilder.yield_(r.toCore(yield.exp)); diff --git a/src/main/java/net/hydromatic/morel/compile/TypeResolver.java b/src/main/java/net/hydromatic/morel/compile/TypeResolver.java index 3180d370..cd70681a 100644 --- a/src/main/java/net/hydromatic/morel/compile/TypeResolver.java +++ b/src/main/java/net/hydromatic/morel/compile/TypeResolver.java @@ -546,6 +546,22 @@ private Pair deduceStepType(TypeEnv env, fromSteps.add(where.copy(filter2)); return Pair.of(env2, v); + case SKIP: + final Ast.Skip skip = (Ast.Skip) step; + final Unifier.Variable v11 = unifier.variable(); + final Ast.Exp skipCount = deduceType(env2, skip.exp, v11); + equiv(v11, toTerm(PrimitiveType.INT)); + fromSteps.add(skip.copy(skipCount)); + return Pair.of(env2, v); + + case TAKE: + final Ast.Take take = (Ast.Take) step; + final Unifier.Variable v12 = unifier.variable(); + final Ast.Exp takeCount = deduceType(env2, take.exp, v12); + equiv(v12, toTerm(PrimitiveType.INT)); + fromSteps.add(take.copy(takeCount)); + return Pair.of(env2, v); + case YIELD: final Ast.Yield yield = (Ast.Yield) step; final Unifier.Variable v6 = unifier.variable(); diff --git a/src/main/java/net/hydromatic/morel/eval/Codes.java b/src/main/java/net/hydromatic/morel/eval/Codes.java index 7533d3d8..1a65ebef 100644 --- a/src/main/java/net/hydromatic/morel/eval/Codes.java +++ b/src/main/java/net/hydromatic/morel/eval/Codes.java @@ -507,6 +507,7 @@ public static Code from(Supplier rowSinkFactory) { @Override public Object eval(EvalEnv env) { final RowSink rowSink = rowSinkFactory.get(); + rowSink.start(env); rowSink.accept(env); return rowSink.result(env); } @@ -524,6 +525,16 @@ public static RowSink whereRowSink(Code filterCode, RowSink rowSink) { return new WhereRowSink(filterCode, rowSink); } + /** Creates a {@link RowSink} for a {@code skip} clause. */ + public static RowSink skipRowSink(Code filterCode, RowSink rowSink) { + return new SkipRowSink(filterCode, rowSink); + } + + /** Creates a {@link RowSink} for a {@code take} clause. */ + public static RowSink takeRowSink(Code filterCode, RowSink rowSink) { + return new TakeRowSink(filterCode, rowSink); + } + /** Creates a {@link RowSink} for a {@code order} clause. */ public static RowSink orderRowSink( Iterable> codes, @@ -3042,26 +3053,47 @@ public Object eval(EvalEnv env) { /** Accepts rows produced by a supplier as part of a {@code from} clause. */ public interface RowSink extends Describable { + void start(EvalEnv env); void accept(EvalEnv env); List result(EvalEnv env); } + /** Abstract implementation for row sinks that have one successor. */ + abstract static class BaseRowSink implements RowSink { + final RowSink rowSink; + + BaseRowSink(RowSink rowSink) { + this.rowSink = requireNonNull(rowSink); + } + + @Override public void start(EvalEnv env) { + rowSink.start(env); + } + + @Override public void accept(EvalEnv env) { + rowSink.accept(env); + } + + @Override public List result(EvalEnv env) { + return rowSink.result(env); + } + } + /** Implementation of {@link RowSink} for a {@code join} clause. */ - static class ScanRowSink implements RowSink { + static class ScanRowSink extends BaseRowSink { final Op op; // inner, left, right, full private final Core.Pat pat; private final Code code; final Code conditionCode; - final RowSink rowSink; ScanRowSink(Op op, Core.Pat pat, Code code, Code conditionCode, RowSink rowSink) { + super(rowSink); checkArgument(op == Op.INNER_JOIN); this.op = op; this.pat = pat; this.code = code; this.conditionCode = conditionCode; - this.rowSink = rowSink; } @Override public Describer describe(Describer describer) { @@ -3078,7 +3110,7 @@ private static boolean isConstantTrue(Code code) { && Objects.equals(code.eval(null), true); } - public void accept(EvalEnv env) { + @Override public void accept(EvalEnv env) { final MutableEvalEnv mutableEvalEnv = env.bindMutablePat(pat); final Iterable elements = (Iterable) code.eval(env); for (Object element : elements) { @@ -3090,20 +3122,15 @@ public void accept(EvalEnv env) { } } } - - public List result(EvalEnv env) { - return rowSink.result(env); - } } /** Implementation of {@link RowSink} for a {@code where} clause. */ - static class WhereRowSink implements RowSink { + static class WhereRowSink extends BaseRowSink { final Code filterCode; - final RowSink rowSink; WhereRowSink(Code filterCode, RowSink rowSink) { + super(rowSink); this.filterCode = filterCode; - this.rowSink = rowSink; } @Override public Describer describe(Describer describer) { @@ -3112,38 +3139,92 @@ static class WhereRowSink implements RowSink { .arg("sink", rowSink)); } - public void accept(EvalEnv env) { + @Override public void accept(EvalEnv env) { if ((Boolean) filterCode.eval(env)) { rowSink.accept(env); } } + } - public List result(EvalEnv env) { - return rowSink.result(env); + /** Implementation of {@link RowSink} for a {@code skip} clause. */ + static class SkipRowSink extends BaseRowSink { + final Code skipCode; + int skip; + + SkipRowSink(Code skipCode, RowSink rowSink) { + super(rowSink); + this.skipCode = skipCode; + } + + @Override public Describer describe(Describer describer) { + return describer.start("skip", d -> + d.arg("count", skipCode) + .arg("sink", rowSink)); + } + + @Override public void start(EvalEnv env) { + skip = (Integer) skipCode.eval(env); + super.start(env); + } + + @Override public void accept(EvalEnv env) { + if (skip > 0) { + --skip; + } else { + rowSink.accept(env); + } + } + } + + /** Implementation of {@link RowSink} for a {@code take} clause. */ + static class TakeRowSink extends BaseRowSink { + final Code takeCode; + int take; + + TakeRowSink(Code takeCode, RowSink rowSink) { + super(rowSink); + this.takeCode = takeCode; + } + + @Override public Describer describe(Describer describer) { + return describer.start("take", d -> + d.arg("count", takeCode) + .arg("sink", rowSink)); + } + + @Override public void start(EvalEnv env) { + take = (Integer) takeCode.eval(env); + super.start(env); + } + + @Override public void accept(EvalEnv env) { + if (take > 0) { + --take; + rowSink.accept(env); + } } } /** Implementation of {@link RowSink} for a {@code group} clause. */ - private static class GroupRowSink implements RowSink { + private static class GroupRowSink extends BaseRowSink { final Code keyCode; final ImmutableList inNames; final ImmutableList keyNames; /** group names followed by aggregate names */ final ImmutableList outNames; final ImmutableList aggregateCodes; - final RowSink rowSink; final ListMultimap map = ArrayListMultimap.create(); final Object[] values; GroupRowSink(Code keyCode, ImmutableList aggregateCodes, ImmutableList inNames, ImmutableList keyNames, ImmutableList outNames, RowSink rowSink) { + super(rowSink); this.keyCode = requireNonNull(keyCode); this.aggregateCodes = requireNonNull(aggregateCodes); this.inNames = requireNonNull(inNames); this.keyNames = requireNonNull(keyNames); this.outNames = requireNonNull(outNames); - this.rowSink = requireNonNull(rowSink); this.values = inNames.size() == 1 ? null : new Object[inNames.size()]; checkArgument(isPrefix(keyNames, outNames)); } @@ -3161,7 +3242,7 @@ private static boolean isPrefix(List list0, List list1) { }); } - public void accept(EvalEnv env) { + @Override public void accept(EvalEnv env) { if (inNames.size() == 1) { map.put(keyCode.eval(env), env.getOpt(inNames.get(0))); } else { @@ -3172,7 +3253,7 @@ public void accept(EvalEnv env) { } } - public List result(final EvalEnv env) { + @Override public List result(final EvalEnv env) { // Derive env2, the environment for our consumer. It consists of our input // environment plus output names. EvalEnv env2 = env; @@ -3213,18 +3294,17 @@ public List result(final EvalEnv env) { } /** Implementation of {@link RowSink} for an {@code order} clause. */ - static class OrderRowSink implements RowSink { + static class OrderRowSink extends BaseRowSink { final ImmutablePairList codes; final ImmutableList names; - final RowSink rowSink; final List rows = new ArrayList<>(); final Object[] values; OrderRowSink(ImmutablePairList codes, ImmutableList names, RowSink rowSink) { + super(rowSink); this.codes = codes; this.names = names; - this.rowSink = rowSink; this.values = names.size() == 1 ? null : new Object[names.size()]; } @@ -3235,7 +3315,7 @@ static class OrderRowSink implements RowSink { }); } - public void accept(EvalEnv env) { + @Override public void accept(EvalEnv env) { if (values == null) { rows.add(env.getOpt(names.get(0))); } else { @@ -3246,7 +3326,7 @@ public void accept(EvalEnv env) { } } - public List result(final EvalEnv env) { + @Override public List result(final EvalEnv env) { final MutableEvalEnv leftEnv = env.bindMutableArray(names); final MutableEvalEnv rightEnv = env.bindMutableArray(names); rows.sort((left, right) -> { @@ -3277,17 +3357,16 @@ public List result(final EvalEnv env) { * step is allowed to generate expressions that are not records. Non-record * expressions (e.g. {@code int} expressions) do not have a name, and * therefore the value cannot be passed via the {@link EvalEnv}. */ - private static class YieldRowSink implements RowSink { + private static class YieldRowSink extends BaseRowSink { private final ImmutableList names; private final ImmutableList codes; - private final RowSink rowSink; private final Object[] values; YieldRowSink(ImmutableList names, ImmutableList codes, RowSink rowSink) { + super(rowSink); this.names = names; this.codes = codes; - this.rowSink = rowSink; this.values = names.size() == 1 ? null : new Object[names.size()]; } @@ -3310,10 +3389,6 @@ private static class YieldRowSink implements RowSink { } rowSink.accept(env2); } - - @Override public List result(EvalEnv env) { - return rowSink.result(env); - } } /** Implementation of {@link RowSink} that the last step of a {@code from} @@ -3330,6 +3405,10 @@ private static class CollectRowSink implements RowSink { return describer.start("collect", d -> d.arg("", code)); } + @Override public void start(EvalEnv env) { + list.clear(); + } + @Override public void accept(EvalEnv env) { list.add(code.eval(env)); } diff --git a/src/main/javacc/MorelParser.jj b/src/main/javacc/MorelParser.jj index b2d6ad9c..7ddcb8cf 100644 --- a/src/main/javacc/MorelParser.jj +++ b/src/main/javacc/MorelParser.jj @@ -357,6 +357,8 @@ void fromStep(List steps) : final Pair patExp; final Exp condition; final Exp filterExp; + final Exp skipExp; + final Exp takeExp; final Exp yieldExp; final PairList groupExps; final List aggregates; @@ -406,6 +408,14 @@ void fromStep(List steps) : { span = Span.of(pos()); } orderItems = orderItemCommaList() { steps.add(ast.order(span.end(this), orderItems)); } +| + { span = Span.of(pos()); } skipExp = expression() { + steps.add(ast.skip(span.end(this), skipExp)); + } +| + { span = Span.of(pos()); } takeExp = expression() { + steps.add(ast.take(span.end(this), takeExp)); + } | { span = Span.of(pos()); } yieldExp = expression() { steps.add(ast.yield(span.end(this), yieldExp)); @@ -1418,6 +1428,8 @@ AstNode statementEof() : | < JOIN: "join" > | < ON: "on" > | < ORDER: "order" > +| < SKIP_: "skip" > +| < TAKE: "take" > | < WHERE: "where" > | < YIELD: "yield" > } diff --git a/src/test/java/net/hydromatic/morel/AlgebraTest.java b/src/test/java/net/hydromatic/morel/AlgebraTest.java index 1a8e1c3d..d91996f0 100644 --- a/src/test/java/net/hydromatic/morel/AlgebraTest.java +++ b/src/test/java/net/hydromatic/morel/AlgebraTest.java @@ -68,6 +68,27 @@ public class AlgebraTest { 10)); } + @Test void testScottOrder() { + final String ml = "from e in scott.emp\n" + + " yield {e.empno, e.deptno}\n" + + " order empno desc\n" + + " skip 2 take 4"; + // When fixed, + // [CALCITE-6128] RelBuilder.sortLimit should compose offset and fetch + // will yield a plan with one fewer LogicalSort + final String plan = "LogicalSort(fetch=[4])\n" + + " LogicalSort(sort0=[$1], dir0=[DESC], offset=[2])\n" + + " LogicalProject(deptno=[$7], empno=[$0])\n" + + " JdbcTableScan(table=[[scott, EMP]])\n"; + ml(ml) + .withBinding("scott", BuiltInDataSet.SCOTT) + .assertType("{deptno:int, empno:int} list") + .assertCalcite(is(plan)) + .assertEvalIter( + equalsOrdered(list(30, 7900), list(20, 7876), list(30, 7844), + list(10, 7839))); + } + @Test void testScottJoin() { final String ml = "let\n" + " val emps = #emp scott\n" @@ -153,6 +174,13 @@ public class AlgebraTest { + "group r.b compute sb = sum of r.b,\n" + " mb = min of r.b, a = count\n" + "yield {a, a2 = a + b, sb}", + "from e in scott.emp\n" + + "yield {e.ename, x = e.deptno * 2}", + "from e in scott.emp\n" + + "order e.ename", + "from e in scott.emp\n" + + "order e.ename desc\n" + + "take 3", "from e in scott.emp,\n" + " d in scott.dept\n" + "where e.deptno = d.deptno\n" diff --git a/src/test/java/net/hydromatic/morel/MainTest.java b/src/test/java/net/hydromatic/morel/MainTest.java index 9e64e7ae..9035501a 100644 --- a/src/test/java/net/hydromatic/morel/MainTest.java +++ b/src/test/java/net/hydromatic/morel/MainTest.java @@ -1752,6 +1752,12 @@ private static List node(Object... args) { .assertParse("from e in emps" + " group id = #id e compute count = count" + " join d in depts where false"); + ml("from e in emps skip 1 take 2").assertParseSame(); + ml("from e in emps order e.empno take 2") + .assertParse("from e in emps order #empno e take 2"); + ml("from e in emps order e.empno take 2 skip 3 skip 1+1 take 2") + .assertParse("from e in emps order #empno e take 2 skip 3 skip 1 + 1 " + + "take 2"); ml("fn f => from i in [1, 2, 3] where f i") .assertParseSame() .assertType("(int -> bool) -> int list"); @@ -2399,10 +2405,14 @@ private static List node(Object... args) { @Test void testFromOrderYield() { final String ml = "from r in [{a=1,b=2},{a=1,b=0},{a=2,b=1}]\n" + " order r.a desc, r.b\n" + + " skip 0\n" + + " take 4 + 6\n" + " yield {r.a, b10 = r.b * 10}"; final String expected = "from r in" + " [{a = 1, b = 2}, {a = 1, b = 0}, {a = 2, b = 1}]" + " order #a r desc, #b r" + + " skip 0" + + " take 4 + 6" + " yield {a = #a r, b10 = #b r * 10}"; ml(ml).assertParse(expected) .assertType(hasMoniker("{a:int, b10:int} list")) diff --git a/src/test/resources/script/builtIn.smli b/src/test/resources/script/builtIn.smli index 4ef77491..fc298945 100644 --- a/src/test/resources/script/builtIn.smli +++ b/src/test/resources/script/builtIn.smli @@ -697,20 +697,20 @@ Sys.plan (); > : string (*) val take : 'a list * int -> 'a list -List.take; +List.`take`; > val it = fn : 'a list * int -> 'a list -List.take ([1,2,3], 0); +List.`take` ([1,2,3], 0); > val it = [] : int list -List.take ([1,2,3], 1); +List.`take` ([1,2,3], 1); > val it = [1] : int list -List.take ([1,2,3], 3); +List.`take` ([1,2,3], 3); > val it = [1,2,3] : int list -List.take ([1,2,3], 4); +List.`take` ([1,2,3], 4); > uncaught exception Subscript [subscript out of bounds] -> raised at: stdIn:1.1-1.23 -List.take ([1,2,3], ~1); +> raised at: stdIn:1.1-1.25 +List.`take` ([1,2,3], ~1); > uncaught exception Subscript [subscript out of bounds] -> raised at: stdIn:1.1-1.24 +> raised at: stdIn:1.1-1.26 Sys.plan (); > val it = > "apply2(fnValue List.take, tuple(constant(1), constant(2), constant(3)), constant(-1))" diff --git a/src/test/resources/script/relational.smli b/src/test/resources/script/relational.smli index 29c0ac29..69cf6914 100644 --- a/src/test/resources/script/relational.smli +++ b/src/test/resources/script/relational.smli @@ -175,6 +175,41 @@ from e in emps order deptno desc; > val it = [30,30,20,10] : int list +from e in emps + yield {e.deptno} + order deptno desc + skip 1; +> val it = [30,20,10] : int list + +from e in emps + yield {e.deptno} + order deptno desc + skip 1 + take 2; +> val it = [30,20] : int list + +from e in emps + yield {e.deptno} + order deptno desc + take 2; +> val it = [30,30] : int list + +(*) Pass 'take' and 'skip' via function arguments +let + fun earlyEmps n = + from e in emps + yield {e.id, e.deptno} + order id + skip n - 2 + take n +in + (earlyEmps 2, earlyEmps 3) +end; +> val it = +> ([{deptno=10,id=100},{deptno=20,id=101}], +> [{deptno=20,id=101},{deptno=30,id=102},{deptno=30,id=103}]) +> : {deptno:int, id:int} list * {deptno:int, id:int} list + (*) 'yield' followed by 'order' from e in emps yield {e.deptno, x = e.deptno, e.name} diff --git a/src/test/resources/script/wordle.smli b/src/test/resources/script/wordle.smli index 703f8269..ddd60299 100644 --- a/src/test/resources/script/wordle.smli +++ b/src/test/resources/script/wordle.smli @@ -131,7 +131,7 @@ fun bestGuesses words = > val bestGuesses = fn : string list -> {f:int, w:string} list (*) Run on a sample of 500 words; faster than the full set of 12,972 words. -val sampleWords = if slow then words else List.take (words, 200); +val sampleWords = if slow then words else List.`take` (words, 200); > val sampleWords = > ["aahed","aalii","aargh","aarti","abaca","abaci","aback","abacs","abaft", > "abaka","abamp","aband","abase","abash","abask","abate","abaya","abbas",