From 4d6325b661635ef61e52d420ea1bcf5006389c2f Mon Sep 17 00:00:00 2001 From: kasiafi <30203062+kasiafi@users.noreply.github.com> Date: Tue, 7 May 2019 23:58:40 +0200 Subject: [PATCH] Implement Offset --- .../io/prestosql/sql/analyzer/Analysis.java | 13 +++ .../sql/analyzer/SemanticErrorCode.java | 1 + .../sql/analyzer/StatementAnalyzer.java | 13 ++- .../prestosql/sql/planner/PlanOptimizers.java | 19 +++ .../prestosql/sql/planner/QueryPlanner.java | 17 +++ .../rule/ImplementOffsetOverOther.java | 91 +++++++++++++++ .../rule/ImplementOffsetOverProjectSort.java | 109 +++++++++++++++++ .../rule/ImplementOffsetOverProjectTopN.java | 110 ++++++++++++++++++ .../rule/ImplementOffsetOverSort.java | 107 +++++++++++++++++ .../rule/ImplementOffsetOverTopN.java | 108 +++++++++++++++++ .../rule/PushLimitThroughOffset.java | 78 +++++++++++++ .../rule/PushOffsetThroughProject.java | 64 ++++++++++ .../optimizations/QueryCardinalityUtil.java | 17 +++ .../UnaliasSymbolReferences.java | 7 ++ .../sql/planner/plan/OffsetNode.java | 86 ++++++++++++++ .../prestosql/sql/planner/plan/Patterns.java | 5 + .../sql/planner/plan/PlanVisitor.java | 5 + .../sanity/ValidateDependenciesChecker.java | 10 ++ .../prestosql/sql/analyzer/TestAnalyzer.java | 7 ++ .../sql/planner/TestLogicalPlanner.java | 83 +++++++++++++ .../sql/planner/assertions/OffsetMatcher.java | 51 ++++++++ .../planner/assertions/PlanMatchPattern.java | 6 + .../assertions/RowNumberSymbolMatcher.java | 47 ++++++++ .../rule/TestImplementOffsetOverOther.java | 54 +++++++++ .../TestImplementOffsetOverProjectSort.java | 66 +++++++++++ .../TestImplementOffsetOverProjectTopN.java | 69 +++++++++++ .../rule/TestImplementOffsetOverSort.java | 63 ++++++++++ .../rule/TestImplementOffsetOverTopN.java | 66 +++++++++++ .../rule/TestPushLimitThroughOffset.java | 50 ++++++++ .../rule/TestPushOffsetThroughProject.java | 62 ++++++++++ .../iterative/rule/test/PlanBuilder.java | 6 + .../prestosql/tests/AbstractTestQueries.java | 43 +++++++ 32 files changed, 1532 insertions(+), 1 deletion(-) create mode 100644 presto-main/src/main/java/io/prestosql/sql/planner/iterative/rule/ImplementOffsetOverOther.java create mode 100644 presto-main/src/main/java/io/prestosql/sql/planner/iterative/rule/ImplementOffsetOverProjectSort.java create mode 100644 presto-main/src/main/java/io/prestosql/sql/planner/iterative/rule/ImplementOffsetOverProjectTopN.java create mode 100644 presto-main/src/main/java/io/prestosql/sql/planner/iterative/rule/ImplementOffsetOverSort.java create mode 100644 presto-main/src/main/java/io/prestosql/sql/planner/iterative/rule/ImplementOffsetOverTopN.java create mode 100644 presto-main/src/main/java/io/prestosql/sql/planner/iterative/rule/PushLimitThroughOffset.java create mode 100644 presto-main/src/main/java/io/prestosql/sql/planner/iterative/rule/PushOffsetThroughProject.java create mode 100644 presto-main/src/main/java/io/prestosql/sql/planner/plan/OffsetNode.java create mode 100644 presto-main/src/test/java/io/prestosql/sql/planner/assertions/OffsetMatcher.java create mode 100644 presto-main/src/test/java/io/prestosql/sql/planner/assertions/RowNumberSymbolMatcher.java create mode 100644 presto-main/src/test/java/io/prestosql/sql/planner/iterative/rule/TestImplementOffsetOverOther.java create mode 100644 presto-main/src/test/java/io/prestosql/sql/planner/iterative/rule/TestImplementOffsetOverProjectSort.java create mode 100644 presto-main/src/test/java/io/prestosql/sql/planner/iterative/rule/TestImplementOffsetOverProjectTopN.java create mode 100644 presto-main/src/test/java/io/prestosql/sql/planner/iterative/rule/TestImplementOffsetOverSort.java create mode 100644 presto-main/src/test/java/io/prestosql/sql/planner/iterative/rule/TestImplementOffsetOverTopN.java create mode 100644 presto-main/src/test/java/io/prestosql/sql/planner/iterative/rule/TestPushLimitThroughOffset.java create mode 100644 presto-main/src/test/java/io/prestosql/sql/planner/iterative/rule/TestPushOffsetThroughProject.java diff --git a/presto-main/src/main/java/io/prestosql/sql/analyzer/Analysis.java b/presto-main/src/main/java/io/prestosql/sql/analyzer/Analysis.java index deec6a8bd6dd..4ecd9c864ba6 100644 --- a/presto-main/src/main/java/io/prestosql/sql/analyzer/Analysis.java +++ b/presto-main/src/main/java/io/prestosql/sql/analyzer/Analysis.java @@ -35,6 +35,7 @@ import io.prestosql.sql.tree.LambdaArgumentDeclaration; import io.prestosql.sql.tree.Node; import io.prestosql.sql.tree.NodeRef; +import io.prestosql.sql.tree.Offset; import io.prestosql.sql.tree.OrderBy; import io.prestosql.sql.tree.QuantifiedComparisonExpression; import io.prestosql.sql.tree.Query; @@ -98,6 +99,7 @@ public class Analysis private final Map, List> outputExpressions = new LinkedHashMap<>(); private final Map, List> windowFunctions = new LinkedHashMap<>(); private final Map, List> orderByWindowFunctions = new LinkedHashMap<>(); + private final Map, Long> offset = new LinkedHashMap<>(); private final Map, OptionalLong> limit = new LinkedHashMap<>(); private final Map, Expression> joins = new LinkedHashMap<>(); @@ -318,6 +320,17 @@ public List getOrderByExpressions(Node node) return orderByExpressions.get(NodeRef.of(node)); } + public void setOffset(Offset node, long rowCount) + { + offset.put(NodeRef.of(node), rowCount); + } + + public long getOffset(Offset node) + { + checkState(offset.containsKey(NodeRef.of(node)), "missing OFFSET value for node %s", node); + return offset.get(NodeRef.of(node)); + } + public void setLimit(Node node, OptionalLong rowCount) { limit.put(NodeRef.of(node), rowCount); diff --git a/presto-main/src/main/java/io/prestosql/sql/analyzer/SemanticErrorCode.java b/presto-main/src/main/java/io/prestosql/sql/analyzer/SemanticErrorCode.java index 658c1aedc90f..6d6a1d08086e 100644 --- a/presto-main/src/main/java/io/prestosql/sql/analyzer/SemanticErrorCode.java +++ b/presto-main/src/main/java/io/prestosql/sql/analyzer/SemanticErrorCode.java @@ -100,6 +100,7 @@ public enum SemanticErrorCode TOO_MANY_GROUPING_SETS, + INVALID_OFFSET_ROW_COUNT, INVALID_FETCH_FIRST_ROW_COUNT, INVALID_LIMIT_ROW_COUNT, } diff --git a/presto-main/src/main/java/io/prestosql/sql/analyzer/StatementAnalyzer.java b/presto-main/src/main/java/io/prestosql/sql/analyzer/StatementAnalyzer.java index 4610f084fb73..6d2ba7f92161 100644 --- a/presto-main/src/main/java/io/prestosql/sql/analyzer/StatementAnalyzer.java +++ b/presto-main/src/main/java/io/prestosql/sql/analyzer/StatementAnalyzer.java @@ -182,6 +182,7 @@ import static io.prestosql.sql.analyzer.SemanticErrorCode.DUPLICATE_RELATION; import static io.prestosql.sql.analyzer.SemanticErrorCode.INVALID_FETCH_FIRST_ROW_COUNT; import static io.prestosql.sql.analyzer.SemanticErrorCode.INVALID_LIMIT_ROW_COUNT; +import static io.prestosql.sql.analyzer.SemanticErrorCode.INVALID_OFFSET_ROW_COUNT; import static io.prestosql.sql.analyzer.SemanticErrorCode.INVALID_ORDINAL; import static io.prestosql.sql.analyzer.SemanticErrorCode.INVALID_PROCEDURE_ARGUMENTS; import static io.prestosql.sql.analyzer.SemanticErrorCode.INVALID_WINDOW_FRAME; @@ -2135,7 +2136,17 @@ private List analyzeOrderBy(Node node, List sortItems, Sco private void analyzeOffset(Offset node) { - throw new SemanticException(NOT_SUPPORTED, node, "OFFSET not yet implemented"); + long rowCount; + try { + rowCount = Long.parseLong(node.getRowCount()); + } + catch (NumberFormatException e) { + throw new SemanticException(INVALID_OFFSET_ROW_COUNT, node, "Invalid OFFSET row count: %s", node.getRowCount()); + } + if (rowCount < 0) { + throw new SemanticException(INVALID_OFFSET_ROW_COUNT, node, "OFFSET row count must be greater or equal to 0 (actual value: %s)", rowCount); + } + analysis.setOffset(node, rowCount); } private void analyzeLimit(Node node) diff --git a/presto-main/src/main/java/io/prestosql/sql/planner/PlanOptimizers.java b/presto-main/src/main/java/io/prestosql/sql/planner/PlanOptimizers.java index bccc1dcffe94..52c726217981 100644 --- a/presto-main/src/main/java/io/prestosql/sql/planner/PlanOptimizers.java +++ b/presto-main/src/main/java/io/prestosql/sql/planner/PlanOptimizers.java @@ -48,6 +48,11 @@ import io.prestosql.sql.planner.iterative.rule.GatherAndMergeWindows; import io.prestosql.sql.planner.iterative.rule.ImplementBernoulliSampleAsFilter; import io.prestosql.sql.planner.iterative.rule.ImplementFilteredAggregations; +import io.prestosql.sql.planner.iterative.rule.ImplementOffsetOverOther; +import io.prestosql.sql.planner.iterative.rule.ImplementOffsetOverProjectSort; +import io.prestosql.sql.planner.iterative.rule.ImplementOffsetOverProjectTopN; +import io.prestosql.sql.planner.iterative.rule.ImplementOffsetOverSort; +import io.prestosql.sql.planner.iterative.rule.ImplementOffsetOverTopN; import io.prestosql.sql.planner.iterative.rule.InlineProjections; import io.prestosql.sql.planner.iterative.rule.MergeFilters; import io.prestosql.sql.planner.iterative.rule.MergeLimitOverProjectWithSort; @@ -78,9 +83,11 @@ import io.prestosql.sql.planner.iterative.rule.PushAggregationThroughOuterJoin; import io.prestosql.sql.planner.iterative.rule.PushLimitIntoTableScan; import io.prestosql.sql.planner.iterative.rule.PushLimitThroughMarkDistinct; +import io.prestosql.sql.planner.iterative.rule.PushLimitThroughOffset; import io.prestosql.sql.planner.iterative.rule.PushLimitThroughOuterJoin; import io.prestosql.sql.planner.iterative.rule.PushLimitThroughProject; import io.prestosql.sql.planner.iterative.rule.PushLimitThroughSemiJoin; +import io.prestosql.sql.planner.iterative.rule.PushOffsetThroughProject; import io.prestosql.sql.planner.iterative.rule.PushPartialAggregationThroughExchange; import io.prestosql.sql.planner.iterative.rule.PushPartialAggregationThroughJoin; import io.prestosql.sql.planner.iterative.rule.PushPredicateIntoTableScan; @@ -289,6 +296,8 @@ public PlanOptimizers( new EvaluateZeroLimit(), new EvaluateZeroTopN(), new EvaluateZeroSample(), + new PushOffsetThroughProject(), + new PushLimitThroughOffset(), new PushLimitThroughProject(), new MergeLimits(), new MergeLimitWithSort(), @@ -309,6 +318,16 @@ public PlanOptimizers( new PruneOrderByInAggregation(metadata.getFunctionRegistry()), new RewriteSpatialPartitioningAggregation(metadata))) .build()), + new IterativeOptimizer( + ruleStats, + statsCalculator, + estimatedExchangesCostCalculator, + ImmutableSet.of( + new ImplementOffsetOverTopN(), + new ImplementOffsetOverProjectTopN(), + new ImplementOffsetOverSort(), + new ImplementOffsetOverProjectSort(), + new ImplementOffsetOverOther())), simplifyOptimizer, new UnaliasSymbolReferences(), new IterativeOptimizer( diff --git a/presto-main/src/main/java/io/prestosql/sql/planner/QueryPlanner.java b/presto-main/src/main/java/io/prestosql/sql/planner/QueryPlanner.java index 4290e728e514..7e341230b34b 100644 --- a/presto-main/src/main/java/io/prestosql/sql/planner/QueryPlanner.java +++ b/presto-main/src/main/java/io/prestosql/sql/planner/QueryPlanner.java @@ -37,6 +37,7 @@ import io.prestosql.sql.planner.plan.FilterNode; import io.prestosql.sql.planner.plan.GroupIdNode; import io.prestosql.sql.planner.plan.LimitNode; +import io.prestosql.sql.planner.plan.OffsetNode; import io.prestosql.sql.planner.plan.PlanNode; import io.prestosql.sql.planner.plan.ProjectNode; import io.prestosql.sql.planner.plan.SortNode; @@ -55,6 +56,7 @@ import io.prestosql.sql.tree.LambdaExpression; import io.prestosql.sql.tree.Node; import io.prestosql.sql.tree.NodeRef; +import io.prestosql.sql.tree.Offset; import io.prestosql.sql.tree.OrderBy; import io.prestosql.sql.tree.Query; import io.prestosql.sql.tree.QuerySpecification; @@ -133,6 +135,7 @@ public RelationPlan plan(Query query) builder = sort(builder, query.getOrderBy(), analysis.getOrderByExpressions(query)); builder = project(builder, analysis.getOutputExpressions(query)); + builder = offset(builder, query.getOffset()); builder = limit(builder, query.getLimit()); return new RelationPlan( @@ -183,6 +186,7 @@ public RelationPlan plan(QuerySpecification node) builder = distinct(builder, node); builder = sort(builder, node.getOrderBy(), analysis.getOrderByExpressions(node)); builder = project(builder, outputs); + builder = offset(builder, node.getOffset()); builder = limit(builder, node.getLimit()); return new RelationPlan( @@ -881,6 +885,19 @@ private PlanBuilder sort(PlanBuilder subPlan, Optional orderBy, List offset) + { + if (!offset.isPresent()) { + return subPlan; + } + + return subPlan.withNewRoot( + new OffsetNode( + idAllocator.getNextId(), + subPlan.getRoot(), + analysis.getOffset(offset.get()))); + } + private PlanBuilder limit(PlanBuilder subPlan, Optional limit) { if (limit.isPresent() && analysis.getLimit(limit.get()).isPresent()) { diff --git a/presto-main/src/main/java/io/prestosql/sql/planner/iterative/rule/ImplementOffsetOverOther.java b/presto-main/src/main/java/io/prestosql/sql/planner/iterative/rule/ImplementOffsetOverOther.java new file mode 100644 index 000000000000..87c2608104e8 --- /dev/null +++ b/presto-main/src/main/java/io/prestosql/sql/planner/iterative/rule/ImplementOffsetOverOther.java @@ -0,0 +1,91 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.prestosql.sql.planner.iterative.rule; + +import com.google.common.collect.ImmutableList; +import io.prestosql.matching.Captures; +import io.prestosql.matching.Pattern; +import io.prestosql.sql.planner.Symbol; +import io.prestosql.sql.planner.iterative.Rule; +import io.prestosql.sql.planner.plan.Assignments; +import io.prestosql.sql.planner.plan.FilterNode; +import io.prestosql.sql.planner.plan.OffsetNode; +import io.prestosql.sql.planner.plan.ProjectNode; +import io.prestosql.sql.planner.plan.RowNumberNode; +import io.prestosql.sql.planner.plan.SortNode; +import io.prestosql.sql.planner.plan.TopNNode; +import io.prestosql.sql.tree.ComparisonExpression; +import io.prestosql.sql.tree.GenericLiteral; + +import java.util.Optional; + +import static io.prestosql.spi.type.BigintType.BIGINT; +import static io.prestosql.sql.planner.plan.Patterns.offset; +import static io.prestosql.sql.planner.plan.Patterns.source; + +/** + * Transforms: + *
+ * - Offset (row count = x)
+ *    - Source (other than Sort, TopN)
+ * 
+ * Into: + *
+ * - Project (prune rowNumber symbol)
+ *    - Filter (rowNumber > x)
+ *       - RowNumber
+ *          - Source
+ * 
+ */ +public class ImplementOffsetOverOther + implements Rule +{ + private static final Pattern PATTERN = offset() + .with(source().matching(node -> !(node instanceof TopNNode || node instanceof SortNode))); + + @Override + public Pattern getPattern() + { + return PATTERN; + } + + @Override + public Result apply(OffsetNode parent, Captures captures, Context context) + { + Symbol rowNumberSymbol = context.getSymbolAllocator().newSymbol("row_number", BIGINT); + + RowNumberNode rowNumberNode = new RowNumberNode( + context.getIdAllocator().getNextId(), + parent.getSource(), + ImmutableList.of(), + rowNumberSymbol, + Optional.empty(), + Optional.empty()); + + FilterNode filterNode = new FilterNode( + context.getIdAllocator().getNextId(), + rowNumberNode, + new ComparisonExpression( + ComparisonExpression.Operator.GREATER_THAN, + rowNumberSymbol.toSymbolReference(), + new GenericLiteral("BIGINT", Long.toString(parent.getCount())))); + + ProjectNode projectNode = new ProjectNode( + context.getIdAllocator().getNextId(), + filterNode, + Assignments.identity(parent.getOutputSymbols())); + + return Result.ofPlanNode(projectNode); + } +} diff --git a/presto-main/src/main/java/io/prestosql/sql/planner/iterative/rule/ImplementOffsetOverProjectSort.java b/presto-main/src/main/java/io/prestosql/sql/planner/iterative/rule/ImplementOffsetOverProjectSort.java new file mode 100644 index 000000000000..2fcc7162774c --- /dev/null +++ b/presto-main/src/main/java/io/prestosql/sql/planner/iterative/rule/ImplementOffsetOverProjectSort.java @@ -0,0 +1,109 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.prestosql.sql.planner.iterative.rule; + +import com.google.common.collect.ImmutableList; +import com.google.common.collect.ImmutableMap; +import io.prestosql.matching.Capture; +import io.prestosql.matching.Captures; +import io.prestosql.matching.Pattern; +import io.prestosql.spi.block.SortOrder; +import io.prestosql.sql.planner.OrderingScheme; +import io.prestosql.sql.planner.Symbol; +import io.prestosql.sql.planner.iterative.Rule; +import io.prestosql.sql.planner.plan.FilterNode; +import io.prestosql.sql.planner.plan.OffsetNode; +import io.prestosql.sql.planner.plan.ProjectNode; +import io.prestosql.sql.planner.plan.RowNumberNode; +import io.prestosql.sql.planner.plan.SortNode; +import io.prestosql.sql.tree.ComparisonExpression; +import io.prestosql.sql.tree.GenericLiteral; + +import java.util.Optional; + +import static io.prestosql.matching.Capture.newCapture; +import static io.prestosql.spi.type.BigintType.BIGINT; +import static io.prestosql.sql.planner.plan.Patterns.offset; +import static io.prestosql.sql.planner.plan.Patterns.project; +import static io.prestosql.sql.planner.plan.Patterns.sort; +import static io.prestosql.sql.planner.plan.Patterns.source; + +/** + * Transforms: + *
+ * - Offset (row count = x)
+ *    - Project (prunes sort symbols no longer useful)
+ *       - Sort (order by a, b, c)
+ * 
+ * Into: + *
+ * - Project (prunes rowNumber symbol and sort symbols no longer useful)
+ *    - Sort (order by rowNumber)
+ *       - Filter (rowNumber > x)
+ *          - RowNumber
+ *             - Sort (order by a, b, c)
+ * 
+ */ +public class ImplementOffsetOverProjectSort + implements Rule +{ + private static final Capture PROJECT = newCapture(); + private static final Capture SORT = newCapture(); + + private static final Pattern PATTERN = offset() + .with(source().matching( + project().capturedAs(PROJECT).matching(ProjectNode::isIdentity) + .with(source().matching( + sort().capturedAs(SORT))))); + + @Override + public Pattern getPattern() + { + return PATTERN; + } + + @Override + public Result apply(OffsetNode parent, Captures captures, Context context) + { + ProjectNode project = captures.get(PROJECT); + SortNode sort = captures.get(SORT); + + Symbol rowNumberSymbol = context.getSymbolAllocator().newSymbol("row_number", BIGINT); + + RowNumberNode rowNumberNode = new RowNumberNode( + context.getIdAllocator().getNextId(), + sort, + ImmutableList.of(), + rowNumberSymbol, + Optional.empty(), + Optional.empty()); + + FilterNode filterNode = new FilterNode( + context.getIdAllocator().getNextId(), + rowNumberNode, + new ComparisonExpression( + ComparisonExpression.Operator.GREATER_THAN, + rowNumberSymbol.toSymbolReference(), + new GenericLiteral("BIGINT", Long.toString(parent.getCount())))); + + SortNode sortNode = new SortNode( + context.getIdAllocator().getNextId(), + filterNode, + new OrderingScheme(ImmutableList.of(rowNumberSymbol), ImmutableMap.of(rowNumberSymbol, SortOrder.ASC_NULLS_FIRST))); + + ProjectNode projectNode = (ProjectNode) project.replaceChildren(ImmutableList.of(sortNode)); + + return Result.ofPlanNode(projectNode); + } +} diff --git a/presto-main/src/main/java/io/prestosql/sql/planner/iterative/rule/ImplementOffsetOverProjectTopN.java b/presto-main/src/main/java/io/prestosql/sql/planner/iterative/rule/ImplementOffsetOverProjectTopN.java new file mode 100644 index 000000000000..7de6abe9515a --- /dev/null +++ b/presto-main/src/main/java/io/prestosql/sql/planner/iterative/rule/ImplementOffsetOverProjectTopN.java @@ -0,0 +1,110 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.prestosql.sql.planner.iterative.rule; + +import com.google.common.collect.ImmutableList; +import com.google.common.collect.ImmutableMap; +import io.prestosql.matching.Capture; +import io.prestosql.matching.Captures; +import io.prestosql.matching.Pattern; +import io.prestosql.spi.block.SortOrder; +import io.prestosql.sql.planner.OrderingScheme; +import io.prestosql.sql.planner.Symbol; +import io.prestosql.sql.planner.iterative.Rule; +import io.prestosql.sql.planner.plan.FilterNode; +import io.prestosql.sql.planner.plan.OffsetNode; +import io.prestosql.sql.planner.plan.ProjectNode; +import io.prestosql.sql.planner.plan.RowNumberNode; +import io.prestosql.sql.planner.plan.SortNode; +import io.prestosql.sql.planner.plan.TopNNode; +import io.prestosql.sql.tree.ComparisonExpression; +import io.prestosql.sql.tree.GenericLiteral; + +import java.util.Optional; + +import static io.prestosql.matching.Capture.newCapture; +import static io.prestosql.spi.type.BigintType.BIGINT; +import static io.prestosql.sql.planner.plan.Patterns.offset; +import static io.prestosql.sql.planner.plan.Patterns.project; +import static io.prestosql.sql.planner.plan.Patterns.source; +import static io.prestosql.sql.planner.plan.Patterns.topN; + +/** + * Transforms: + *
+ * - Offset (row count = x)
+ *    - Project (prunes sort symbols no longer useful)
+ *       - TopN (order by a, b, c)
+ * 
+ * Into: + *
+ * - Project (prunes rowNumber symbol and sort symbols no longer useful)
+ *    - Sort (order by rowNumber)
+ *       - Filter (rowNumber > x)
+ *          - RowNumber
+ *             - TopN (order by a, b, c)
+ * 
+ */ +public class ImplementOffsetOverProjectTopN + implements Rule +{ + private static final Capture PROJECT = newCapture(); + private static final Capture TOPN = newCapture(); + + private static final Pattern PATTERN = offset() + .with(source().matching( + project().capturedAs(PROJECT).matching(ProjectNode::isIdentity) + .with(source().matching( + topN().capturedAs(TOPN))))); + + @Override + public Pattern getPattern() + { + return PATTERN; + } + + @Override + public Result apply(OffsetNode parent, Captures captures, Context context) + { + ProjectNode project = captures.get(PROJECT); + TopNNode topN = captures.get(TOPN); + + Symbol rowNumberSymbol = context.getSymbolAllocator().newSymbol("row_number", BIGINT); + + RowNumberNode rowNumberNode = new RowNumberNode( + context.getIdAllocator().getNextId(), + topN, + ImmutableList.of(), + rowNumberSymbol, + Optional.empty(), + Optional.empty()); + + FilterNode filterNode = new FilterNode( + context.getIdAllocator().getNextId(), + rowNumberNode, + new ComparisonExpression( + ComparisonExpression.Operator.GREATER_THAN, + rowNumberSymbol.toSymbolReference(), + new GenericLiteral("BIGINT", Long.toString(parent.getCount())))); + + SortNode sortNode = new SortNode( + context.getIdAllocator().getNextId(), + filterNode, + new OrderingScheme(ImmutableList.of(rowNumberSymbol), ImmutableMap.of(rowNumberSymbol, SortOrder.ASC_NULLS_FIRST))); + + ProjectNode projectNode = (ProjectNode) project.replaceChildren(ImmutableList.of(sortNode)); + + return Result.ofPlanNode(projectNode); + } +} diff --git a/presto-main/src/main/java/io/prestosql/sql/planner/iterative/rule/ImplementOffsetOverSort.java b/presto-main/src/main/java/io/prestosql/sql/planner/iterative/rule/ImplementOffsetOverSort.java new file mode 100644 index 000000000000..cf5b695f9b21 --- /dev/null +++ b/presto-main/src/main/java/io/prestosql/sql/planner/iterative/rule/ImplementOffsetOverSort.java @@ -0,0 +1,107 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.prestosql.sql.planner.iterative.rule; + +import com.google.common.collect.ImmutableList; +import com.google.common.collect.ImmutableMap; +import io.prestosql.matching.Capture; +import io.prestosql.matching.Captures; +import io.prestosql.matching.Pattern; +import io.prestosql.spi.block.SortOrder; +import io.prestosql.sql.planner.OrderingScheme; +import io.prestosql.sql.planner.Symbol; +import io.prestosql.sql.planner.iterative.Rule; +import io.prestosql.sql.planner.plan.Assignments; +import io.prestosql.sql.planner.plan.FilterNode; +import io.prestosql.sql.planner.plan.OffsetNode; +import io.prestosql.sql.planner.plan.ProjectNode; +import io.prestosql.sql.planner.plan.RowNumberNode; +import io.prestosql.sql.planner.plan.SortNode; +import io.prestosql.sql.tree.ComparisonExpression; +import io.prestosql.sql.tree.GenericLiteral; + +import java.util.Optional; + +import static io.prestosql.matching.Capture.newCapture; +import static io.prestosql.spi.type.BigintType.BIGINT; +import static io.prestosql.sql.planner.plan.Patterns.offset; +import static io.prestosql.sql.planner.plan.Patterns.sort; +import static io.prestosql.sql.planner.plan.Patterns.source; + +/** + * Transforms: + *
+ * - Offset (row count = x)
+ *    - Sort (order by a, b, c)
+ * 
+ * Into: + *
+ * - Project (prune rowNumber symbol)
+ *    - Sort (order by rowNumber)
+ *       - Filter (rowNumber > x)
+ *          - RowNumber
+ *             - Sort (order by a, b, c)
+ * 
+ */ +public class ImplementOffsetOverSort + implements Rule +{ + private static final Capture SORT = newCapture(); + + private static final Pattern PATTERN = offset() + .with(source().matching( + sort().capturedAs(SORT))); + + @Override + public Pattern getPattern() + { + return PATTERN; + } + + @Override + public Result apply(OffsetNode parent, Captures captures, Context context) + { + SortNode sort = captures.get(SORT); + + Symbol rowNumberSymbol = context.getSymbolAllocator().newSymbol("row_number", BIGINT); + + RowNumberNode rowNumberNode = new RowNumberNode( + context.getIdAllocator().getNextId(), + sort, + ImmutableList.of(), + rowNumberSymbol, + Optional.empty(), + Optional.empty()); + + FilterNode filterNode = new FilterNode( + context.getIdAllocator().getNextId(), + rowNumberNode, + new ComparisonExpression( + ComparisonExpression.Operator.GREATER_THAN, + rowNumberSymbol.toSymbolReference(), + new GenericLiteral("BIGINT", Long.toString(parent.getCount())))); + + SortNode sortNode = new SortNode( + context.getIdAllocator().getNextId(), + filterNode, + new OrderingScheme(ImmutableList.of(rowNumberSymbol), ImmutableMap.of(rowNumberSymbol, SortOrder.ASC_NULLS_FIRST))); + + ProjectNode projectNode = new ProjectNode( + context.getIdAllocator().getNextId(), + sortNode, + Assignments.identity(sort.getOutputSymbols())); + + return Result.ofPlanNode(projectNode); + } +} diff --git a/presto-main/src/main/java/io/prestosql/sql/planner/iterative/rule/ImplementOffsetOverTopN.java b/presto-main/src/main/java/io/prestosql/sql/planner/iterative/rule/ImplementOffsetOverTopN.java new file mode 100644 index 000000000000..03f0f8b3c6f7 --- /dev/null +++ b/presto-main/src/main/java/io/prestosql/sql/planner/iterative/rule/ImplementOffsetOverTopN.java @@ -0,0 +1,108 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.prestosql.sql.planner.iterative.rule; + +import com.google.common.collect.ImmutableList; +import com.google.common.collect.ImmutableMap; +import io.prestosql.matching.Capture; +import io.prestosql.matching.Captures; +import io.prestosql.matching.Pattern; +import io.prestosql.spi.block.SortOrder; +import io.prestosql.sql.planner.OrderingScheme; +import io.prestosql.sql.planner.Symbol; +import io.prestosql.sql.planner.iterative.Rule; +import io.prestosql.sql.planner.plan.Assignments; +import io.prestosql.sql.planner.plan.FilterNode; +import io.prestosql.sql.planner.plan.OffsetNode; +import io.prestosql.sql.planner.plan.ProjectNode; +import io.prestosql.sql.planner.plan.RowNumberNode; +import io.prestosql.sql.planner.plan.SortNode; +import io.prestosql.sql.planner.plan.TopNNode; +import io.prestosql.sql.tree.ComparisonExpression; +import io.prestosql.sql.tree.GenericLiteral; + +import java.util.Optional; + +import static io.prestosql.matching.Capture.newCapture; +import static io.prestosql.spi.type.BigintType.BIGINT; +import static io.prestosql.sql.planner.plan.Patterns.offset; +import static io.prestosql.sql.planner.plan.Patterns.source; +import static io.prestosql.sql.planner.plan.Patterns.topN; + +/** + * Transforms: + *
+ * - Offset (row count = x)
+ *    - TopN (order by a, b, c)
+ * 
+ * Into: + *
+ * - Project (prune rowNumber symbol)
+ *    - Sort (order by rowNumber)
+ *       - Filter (rowNumber > x)
+ *          - RowNumber
+ *             - TopN (order by a, b, c)
+ * 
+ */ +public class ImplementOffsetOverTopN + implements Rule +{ + private static final Capture TOPN = newCapture(); + + private static final Pattern PATTERN = offset() + .with(source().matching( + topN().capturedAs(TOPN))); + + @Override + public Pattern getPattern() + { + return PATTERN; + } + + @Override + public Result apply(OffsetNode parent, Captures captures, Context context) + { + TopNNode topN = captures.get(TOPN); + + Symbol rowNumberSymbol = context.getSymbolAllocator().newSymbol("row_number", BIGINT); + + RowNumberNode rowNumberNode = new RowNumberNode( + context.getIdAllocator().getNextId(), + topN, + ImmutableList.of(), + rowNumberSymbol, + Optional.empty(), + Optional.empty()); + + FilterNode filterNode = new FilterNode( + context.getIdAllocator().getNextId(), + rowNumberNode, + new ComparisonExpression( + ComparisonExpression.Operator.GREATER_THAN, + rowNumberSymbol.toSymbolReference(), + new GenericLiteral("BIGINT", Long.toString(parent.getCount())))); + + SortNode sortNode = new SortNode( + context.getIdAllocator().getNextId(), + filterNode, + new OrderingScheme(ImmutableList.of(rowNumberSymbol), ImmutableMap.of(rowNumberSymbol, SortOrder.ASC_NULLS_FIRST))); + + ProjectNode projectNode = new ProjectNode( + context.getIdAllocator().getNextId(), + sortNode, + Assignments.identity(topN.getOutputSymbols())); + + return Result.ofPlanNode(projectNode); + } +} diff --git a/presto-main/src/main/java/io/prestosql/sql/planner/iterative/rule/PushLimitThroughOffset.java b/presto-main/src/main/java/io/prestosql/sql/planner/iterative/rule/PushLimitThroughOffset.java new file mode 100644 index 000000000000..4c3d0da9291f --- /dev/null +++ b/presto-main/src/main/java/io/prestosql/sql/planner/iterative/rule/PushLimitThroughOffset.java @@ -0,0 +1,78 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.prestosql.sql.planner.iterative.rule; + +import com.google.common.collect.ImmutableList; +import io.prestosql.matching.Capture; +import io.prestosql.matching.Captures; +import io.prestosql.matching.Pattern; +import io.prestosql.sql.planner.iterative.Rule; +import io.prestosql.sql.planner.plan.LimitNode; +import io.prestosql.sql.planner.plan.OffsetNode; + +import static io.prestosql.matching.Capture.newCapture; +import static io.prestosql.sql.planner.plan.Patterns.limit; +import static io.prestosql.sql.planner.plan.Patterns.offset; +import static io.prestosql.sql.planner.plan.Patterns.source; +import static java.lang.Math.addExact; + +/** + * Transforms: + *
+ * - Limit (row count x)
+ *    - Offset (row count y)
+ * 
+ * Into: + *
+ * - Offset (row count y)
+ *    - Limit (row count x+y)
+ * 
+ */ +public class PushLimitThroughOffset + implements Rule +{ + private static final Capture CHILD = newCapture(); + + private static final Pattern PATTERN = limit() + .with(source().matching( + offset().capturedAs(CHILD))); + + @Override + public Pattern getPattern() + { + return PATTERN; + } + + @Override + public Result apply(LimitNode parent, Captures captures, Context context) + { + OffsetNode child = captures.get(CHILD); + + long count; + try { + count = addExact(parent.getCount(), child.getCount()); + } + catch (ArithmeticException e) { + return Result.empty(); + } + + return Result.ofPlanNode( + child.replaceChildren(ImmutableList.of( + new LimitNode( + parent.getId(), + child.getSource(), + count, + parent.isPartial())))); + } +} diff --git a/presto-main/src/main/java/io/prestosql/sql/planner/iterative/rule/PushOffsetThroughProject.java b/presto-main/src/main/java/io/prestosql/sql/planner/iterative/rule/PushOffsetThroughProject.java new file mode 100644 index 000000000000..9e0154aff94d --- /dev/null +++ b/presto-main/src/main/java/io/prestosql/sql/planner/iterative/rule/PushOffsetThroughProject.java @@ -0,0 +1,64 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.prestosql.sql.planner.iterative.rule; + +import io.prestosql.matching.Capture; +import io.prestosql.matching.Captures; +import io.prestosql.matching.Pattern; +import io.prestosql.sql.planner.iterative.Rule; +import io.prestosql.sql.planner.plan.OffsetNode; +import io.prestosql.sql.planner.plan.ProjectNode; + +import static io.prestosql.matching.Capture.newCapture; +import static io.prestosql.sql.planner.iterative.rule.Util.transpose; +import static io.prestosql.sql.planner.plan.Patterns.offset; +import static io.prestosql.sql.planner.plan.Patterns.project; +import static io.prestosql.sql.planner.plan.Patterns.source; + +/** + * Transforms: + *
+ * - Offset
+ *    - Project (non identity)
+ * 
+ * Into: + *
+ * - Project (non identity)
+ *    - Offset
+ * 
+ */ +public class PushOffsetThroughProject + implements Rule +{ + private static final Capture CHILD = newCapture(); + + private static final Pattern PATTERN = offset() + .with(source().matching( + project() + // do not push offset through identity projection which could be there for column pruning purposes + .matching(projectNode -> !projectNode.isIdentity()) + .capturedAs(CHILD))); + + @Override + public Pattern getPattern() + { + return PATTERN; + } + + @Override + public Result apply(OffsetNode parent, Captures captures, Context context) + { + return Result.ofPlanNode(transpose(parent, captures.get(CHILD))); + } +} diff --git a/presto-main/src/main/java/io/prestosql/sql/planner/optimizations/QueryCardinalityUtil.java b/presto-main/src/main/java/io/prestosql/sql/planner/optimizations/QueryCardinalityUtil.java index bbc7819e4610..4b631047f03f 100644 --- a/presto-main/src/main/java/io/prestosql/sql/planner/optimizations/QueryCardinalityUtil.java +++ b/presto-main/src/main/java/io/prestosql/sql/planner/optimizations/QueryCardinalityUtil.java @@ -21,6 +21,7 @@ import io.prestosql.sql.planner.plan.ExchangeNode; import io.prestosql.sql.planner.plan.FilterNode; import io.prestosql.sql.planner.plan.LimitNode; +import io.prestosql.sql.planner.plan.OffsetNode; import io.prestosql.sql.planner.plan.PlanNode; import io.prestosql.sql.planner.plan.PlanVisitor; import io.prestosql.sql.planner.plan.ProjectNode; @@ -29,6 +30,7 @@ import static com.google.common.collect.Iterables.getOnlyElement; import static io.prestosql.sql.planner.iterative.Lookup.noLookup; +import static java.lang.Math.max; import static java.lang.Math.min; import static java.util.Objects.requireNonNull; @@ -141,6 +143,21 @@ public Range visitValues(ValuesNode node, Void context) return Range.singleton((long) node.getRows().size()); } + @Override + public Range visitOffset(OffsetNode node, Void context) + { + Range sourceCardinalityRange = node.getSource().accept(this, null); + + long lower = max(sourceCardinalityRange.lowerEndpoint() - node.getCount(), 0L); + + if (sourceCardinalityRange.hasUpperBound()) { + return Range.closed(lower, max(sourceCardinalityRange.upperEndpoint() - node.getCount(), 0L)); + } + else { + return Range.atLeast(lower); + } + } + @Override public Range visitLimit(LimitNode node, Void context) { diff --git a/presto-main/src/main/java/io/prestosql/sql/planner/optimizations/UnaliasSymbolReferences.java b/presto-main/src/main/java/io/prestosql/sql/planner/optimizations/UnaliasSymbolReferences.java index 61fdd78d8002..25c771d177f1 100644 --- a/presto-main/src/main/java/io/prestosql/sql/planner/optimizations/UnaliasSymbolReferences.java +++ b/presto-main/src/main/java/io/prestosql/sql/planner/optimizations/UnaliasSymbolReferences.java @@ -50,6 +50,7 @@ import io.prestosql.sql.planner.plan.LateralJoinNode; import io.prestosql.sql.planner.plan.LimitNode; import io.prestosql.sql.planner.plan.MarkDistinctNode; +import io.prestosql.sql.planner.plan.OffsetNode; import io.prestosql.sql.planner.plan.OutputNode; import io.prestosql.sql.planner.plan.PlanNode; import io.prestosql.sql.planner.plan.ProjectNode; @@ -324,6 +325,12 @@ public PlanNode visitRemoteSource(RemoteSourceNode node, RewriteContext co node.getExchangeType()); } + @Override + public PlanNode visitOffset(OffsetNode node, RewriteContext context) + { + return context.defaultRewrite(node); + } + @Override public PlanNode visitLimit(LimitNode node, RewriteContext context) { diff --git a/presto-main/src/main/java/io/prestosql/sql/planner/plan/OffsetNode.java b/presto-main/src/main/java/io/prestosql/sql/planner/plan/OffsetNode.java new file mode 100644 index 000000000000..446842e1ce24 --- /dev/null +++ b/presto-main/src/main/java/io/prestosql/sql/planner/plan/OffsetNode.java @@ -0,0 +1,86 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.prestosql.sql.planner.plan; + +import com.fasterxml.jackson.annotation.JsonCreator; +import com.fasterxml.jackson.annotation.JsonProperty; +import com.google.common.collect.ImmutableList; +import com.google.common.collect.Iterables; +import io.prestosql.sql.planner.Symbol; + +import javax.annotation.concurrent.Immutable; + +import java.util.List; + +import static com.google.common.base.Preconditions.checkArgument; +import static java.util.Objects.requireNonNull; + +@Immutable +public class OffsetNode + extends PlanNode +{ + private final PlanNode source; + private final long count; + + @JsonCreator + public OffsetNode( + @JsonProperty("id") PlanNodeId id, + @JsonProperty("source") PlanNode source, + @JsonProperty("count") long count) + { + super(id); + + requireNonNull(source, "source is null"); + checkArgument(count >= 0, "count must be greater than or equal to zero"); + + this.source = source; + this.count = count; + } + + @Override + public List getSources() + { + return ImmutableList.of(source); + } + + @JsonProperty + public PlanNode getSource() + { + return source; + } + + @JsonProperty + public long getCount() + { + return count; + } + + @Override + public List getOutputSymbols() + { + return source.getOutputSymbols(); + } + + @Override + public R accept(PlanVisitor visitor, C context) + { + return visitor.visitOffset(this, context); + } + + @Override + public PlanNode replaceChildren(List newChildren) + { + return new OffsetNode(getId(), Iterables.getOnlyElement(newChildren), count); + } +} diff --git a/presto-main/src/main/java/io/prestosql/sql/planner/plan/Patterns.java b/presto-main/src/main/java/io/prestosql/sql/planner/plan/Patterns.java index fb1b14de8626..c4a873300458 100644 --- a/presto-main/src/main/java/io/prestosql/sql/planner/plan/Patterns.java +++ b/presto-main/src/main/java/io/prestosql/sql/planner/plan/Patterns.java @@ -87,6 +87,11 @@ public static Pattern lateralJoin() return typeOf(LateralJoinNode.class); } + public static Pattern offset() + { + return typeOf(OffsetNode.class); + } + public static Pattern limit() { return typeOf(LimitNode.class); diff --git a/presto-main/src/main/java/io/prestosql/sql/planner/plan/PlanVisitor.java b/presto-main/src/main/java/io/prestosql/sql/planner/plan/PlanVisitor.java index 91c56f6af997..ad700f87485b 100644 --- a/presto-main/src/main/java/io/prestosql/sql/planner/plan/PlanVisitor.java +++ b/presto-main/src/main/java/io/prestosql/sql/planner/plan/PlanVisitor.java @@ -49,6 +49,11 @@ public R visitOutput(OutputNode node, C context) return visitPlan(node, context); } + public R visitOffset(OffsetNode node, C context) + { + return visitPlan(node, context); + } + public R visitLimit(LimitNode node, C context) { return visitPlan(node, context); diff --git a/presto-main/src/main/java/io/prestosql/sql/planner/sanity/ValidateDependenciesChecker.java b/presto-main/src/main/java/io/prestosql/sql/planner/sanity/ValidateDependenciesChecker.java index 421c1583e9c0..d7e05f50e2cf 100644 --- a/presto-main/src/main/java/io/prestosql/sql/planner/sanity/ValidateDependenciesChecker.java +++ b/presto-main/src/main/java/io/prestosql/sql/planner/sanity/ValidateDependenciesChecker.java @@ -42,6 +42,7 @@ import io.prestosql.sql.planner.plan.LimitNode; import io.prestosql.sql.planner.plan.MarkDistinctNode; import io.prestosql.sql.planner.plan.MetadataDeleteNode; +import io.prestosql.sql.planner.plan.OffsetNode; import io.prestosql.sql.planner.plan.OutputNode; import io.prestosql.sql.planner.plan.PlanNode; import io.prestosql.sql.planner.plan.PlanVisitor; @@ -304,6 +305,15 @@ public Void visitOutput(OutputNode node, Set boundSymbols) return null; } + @Override + public Void visitOffset(OffsetNode node, Set boundSymbols) + { + PlanNode source = node.getSource(); + source.accept(this, boundSymbols); // visit child + + return null; + } + @Override public Void visitLimit(LimitNode node, Set boundSymbols) { diff --git a/presto-main/src/test/java/io/prestosql/sql/analyzer/TestAnalyzer.java b/presto-main/src/test/java/io/prestosql/sql/analyzer/TestAnalyzer.java index 8500e0ffd19e..29ce3cc95e20 100644 --- a/presto-main/src/test/java/io/prestosql/sql/analyzer/TestAnalyzer.java +++ b/presto-main/src/test/java/io/prestosql/sql/analyzer/TestAnalyzer.java @@ -86,6 +86,7 @@ import static io.prestosql.sql.analyzer.SemanticErrorCode.INVALID_FETCH_FIRST_ROW_COUNT; import static io.prestosql.sql.analyzer.SemanticErrorCode.INVALID_LIMIT_ROW_COUNT; import static io.prestosql.sql.analyzer.SemanticErrorCode.INVALID_LITERAL; +import static io.prestosql.sql.analyzer.SemanticErrorCode.INVALID_OFFSET_ROW_COUNT; import static io.prestosql.sql.analyzer.SemanticErrorCode.INVALID_ORDINAL; import static io.prestosql.sql.analyzer.SemanticErrorCode.INVALID_PARAMETER_USAGE; import static io.prestosql.sql.analyzer.SemanticErrorCode.INVALID_PROCEDURE_ARGUMENTS; @@ -331,6 +332,12 @@ public void testOrderByNonComparable() assertFails(TYPE_MISMATCH, "SELECT x FROM (SELECT approx_set(1) x) ORDER BY x"); } + @Test + public void testOffsetInvalidRowCount() + { + assertFails(INVALID_OFFSET_ROW_COUNT, "SELECT * FROM t1 OFFSET 987654321098765432109876543210 ROWS"); + } + @Test public void testFetchFirstInvalidRowCount() { diff --git a/presto-main/src/test/java/io/prestosql/sql/planner/TestLogicalPlanner.java b/presto-main/src/test/java/io/prestosql/sql/planner/TestLogicalPlanner.java index a1ff18e28936..8c770f1654a8 100644 --- a/presto-main/src/test/java/io/prestosql/sql/planner/TestLogicalPlanner.java +++ b/presto-main/src/test/java/io/prestosql/sql/planner/TestLogicalPlanner.java @@ -18,7 +18,9 @@ import io.prestosql.Session; import io.prestosql.sql.analyzer.FeaturesConfig.JoinDistributionType; import io.prestosql.sql.planner.assertions.BasePlanTest; +import io.prestosql.sql.planner.assertions.ExpressionMatcher; import io.prestosql.sql.planner.assertions.PlanMatchPattern; +import io.prestosql.sql.planner.assertions.RowNumberSymbolMatcher; import io.prestosql.sql.planner.optimizations.AddLocalExchanges; import io.prestosql.sql.planner.optimizations.CheckSubqueryNodesAreRewritten; import io.prestosql.sql.planner.optimizations.PlanOptimizer; @@ -74,9 +76,11 @@ import static io.prestosql.sql.planner.assertions.PlanMatchPattern.node; import static io.prestosql.sql.planner.assertions.PlanMatchPattern.output; import static io.prestosql.sql.planner.assertions.PlanMatchPattern.project; +import static io.prestosql.sql.planner.assertions.PlanMatchPattern.rowNumber; import static io.prestosql.sql.planner.assertions.PlanMatchPattern.semiJoin; import static io.prestosql.sql.planner.assertions.PlanMatchPattern.singleGroupingSet; import static io.prestosql.sql.planner.assertions.PlanMatchPattern.sort; +import static io.prestosql.sql.planner.assertions.PlanMatchPattern.strictProject; import static io.prestosql.sql.planner.assertions.PlanMatchPattern.strictTableScan; import static io.prestosql.sql.planner.assertions.PlanMatchPattern.tableScan; import static io.prestosql.sql.planner.assertions.PlanMatchPattern.topN; @@ -94,6 +98,7 @@ import static io.prestosql.sql.planner.plan.JoinNode.DistributionType.REPLICATED; import static io.prestosql.sql.planner.plan.JoinNode.Type.INNER; import static io.prestosql.sql.planner.plan.JoinNode.Type.LEFT; +import static io.prestosql.sql.tree.SortItem.NullOrdering.FIRST; import static io.prestosql.sql.tree.SortItem.NullOrdering.LAST; import static io.prestosql.sql.tree.SortItem.Ordering.ASCENDING; import static io.prestosql.sql.tree.SortItem.Ordering.DESCENDING; @@ -860,4 +865,82 @@ public void testFetch() any( tableScan("nation"))))); } + + @Test + public void testOffset() + { + assertPlan( + "SELECT name FROM nation OFFSET 2 ROWS", + any( + strictProject( + ImmutableMap.of("name", new ExpressionMatcher("name")), + filter( + "(row_num > BIGINT '2')", + rowNumber( + pattern -> pattern + .partitionBy(ImmutableList.of()), + any( + tableScan("nation", ImmutableMap.of("NAME", "name")))) + .withAlias("row_num", new RowNumberSymbolMatcher()))))); + + assertPlan( + "SELECT name FROM nation ORDER BY regionkey OFFSET 2 ROWS", + any( + strictProject( + ImmutableMap.of("name", new ExpressionMatcher("name")), + any( + sort( + ImmutableList.of(sort("row_num", ASCENDING, FIRST)), + any( + filter( + "(row_num > BIGINT '2')", + rowNumber( + pattern -> pattern + .partitionBy(ImmutableList.of()), + any( + sort( + ImmutableList.of(sort("regionkey", ASCENDING, LAST)), + any( + tableScan("nation", ImmutableMap.of("NAME", "name", "REGIONKEY", "regionkey")))))) + .withAlias("row_num", new RowNumberSymbolMatcher())))))))); + + assertPlan( + "SELECT name FROM nation ORDER BY regionkey OFFSET 2 ROWS FETCH NEXT 5 ROWS ONLY", + any( + strictProject( + ImmutableMap.of("name", new ExpressionMatcher("name")), + any( + sort( + ImmutableList.of(sort("row_num", ASCENDING, FIRST)), + any( + filter( + "(row_num > BIGINT '2')", + rowNumber( + pattern -> pattern + .partitionBy(ImmutableList.of()), + topN( + 7, + ImmutableList.of(sort("regionkey", ASCENDING, LAST)), + TopNNode.Step.FINAL, + anyTree( + tableScan("nation", ImmutableMap.of("NAME", "name", "REGIONKEY", "regionkey"))))) + .withAlias("row_num", new RowNumberSymbolMatcher())))))))); + + assertPlan( + "SELECT name FROM nation OFFSET 2 ROWS FETCH NEXT 5 ROWS ONLY", + any( + strictProject( + ImmutableMap.of("name", new ExpressionMatcher("name")), + filter( + "(row_num > BIGINT '2')", + rowNumber( + pattern -> pattern + .partitionBy(ImmutableList.of()), + limit( + 7, + false, + any( + tableScan("nation", ImmutableMap.of("NAME", "name"))))) + .withAlias("row_num", new RowNumberSymbolMatcher()))))); + } } diff --git a/presto-main/src/test/java/io/prestosql/sql/planner/assertions/OffsetMatcher.java b/presto-main/src/test/java/io/prestosql/sql/planner/assertions/OffsetMatcher.java new file mode 100644 index 000000000000..f0641b039338 --- /dev/null +++ b/presto-main/src/test/java/io/prestosql/sql/planner/assertions/OffsetMatcher.java @@ -0,0 +1,51 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.prestosql.sql.planner.assertions; + +import io.prestosql.Session; +import io.prestosql.cost.StatsProvider; +import io.prestosql.metadata.Metadata; +import io.prestosql.sql.planner.plan.OffsetNode; +import io.prestosql.sql.planner.plan.PlanNode; + +import static com.google.common.base.Preconditions.checkState; + +public class OffsetMatcher + implements Matcher +{ + private final long rowCount; + + public OffsetMatcher(long rowCount) + { + this.rowCount = rowCount; + } + + @Override + public boolean shapeMatches(PlanNode node) + { + if (!(node instanceof OffsetNode)) { + return false; + } + + OffsetNode offsetNode = (OffsetNode) node; + return offsetNode.getCount() == rowCount; + } + + @Override + public MatchResult detailMatches(PlanNode node, StatsProvider stats, Session session, Metadata metadata, SymbolAliases symbolAliases) + { + checkState(shapeMatches(node)); + return MatchResult.match(); + } +} diff --git a/presto-main/src/test/java/io/prestosql/sql/planner/assertions/PlanMatchPattern.java b/presto-main/src/test/java/io/prestosql/sql/planner/assertions/PlanMatchPattern.java index 2c6479d2f2bb..03193a0f8133 100644 --- a/presto-main/src/test/java/io/prestosql/sql/planner/assertions/PlanMatchPattern.java +++ b/presto-main/src/test/java/io/prestosql/sql/planner/assertions/PlanMatchPattern.java @@ -42,6 +42,7 @@ import io.prestosql.sql.planner.plan.LateralJoinNode; import io.prestosql.sql.planner.plan.LimitNode; import io.prestosql.sql.planner.plan.MarkDistinctNode; +import io.prestosql.sql.planner.plan.OffsetNode; import io.prestosql.sql.planner.plan.OutputNode; import io.prestosql.sql.planner.plan.PlanNode; import io.prestosql.sql.planner.plan.ProjectNode; @@ -525,6 +526,11 @@ public static PlanMatchPattern values(List aliases) return values(aliases, Optional.empty()); } + public static PlanMatchPattern offset(long rowCount, PlanMatchPattern source) + { + return node(OffsetNode.class, source).with(new OffsetMatcher(rowCount)); + } + public static PlanMatchPattern limit(long limit, PlanMatchPattern source) { return limit(limit, false, source); diff --git a/presto-main/src/test/java/io/prestosql/sql/planner/assertions/RowNumberSymbolMatcher.java b/presto-main/src/test/java/io/prestosql/sql/planner/assertions/RowNumberSymbolMatcher.java new file mode 100644 index 000000000000..35d1fe28fd4b --- /dev/null +++ b/presto-main/src/test/java/io/prestosql/sql/planner/assertions/RowNumberSymbolMatcher.java @@ -0,0 +1,47 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.prestosql.sql.planner.assertions; + +import io.prestosql.Session; +import io.prestosql.metadata.Metadata; +import io.prestosql.sql.planner.Symbol; +import io.prestosql.sql.planner.plan.PlanNode; +import io.prestosql.sql.planner.plan.RowNumberNode; + +import java.util.Optional; + +import static com.google.common.base.MoreObjects.toStringHelper; + +public class RowNumberSymbolMatcher + implements RvalueMatcher +{ + @Override + public Optional getAssignedSymbol(PlanNode node, Session session, Metadata metadata, SymbolAliases symbolAliases) + { + if (!(node instanceof RowNumberNode)) { + return Optional.empty(); + } + + RowNumberNode rowNumberNode = (RowNumberNode) node; + + return Optional.of(rowNumberNode.getRowNumberSymbol()); + } + + @Override + public String toString() + { + return toStringHelper(this) + .toString(); + } +} diff --git a/presto-main/src/test/java/io/prestosql/sql/planner/iterative/rule/TestImplementOffsetOverOther.java b/presto-main/src/test/java/io/prestosql/sql/planner/iterative/rule/TestImplementOffsetOverOther.java new file mode 100644 index 000000000000..27c1049a5da3 --- /dev/null +++ b/presto-main/src/test/java/io/prestosql/sql/planner/iterative/rule/TestImplementOffsetOverOther.java @@ -0,0 +1,54 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.prestosql.sql.planner.iterative.rule; + +import com.google.common.collect.ImmutableList; +import com.google.common.collect.ImmutableMap; +import io.prestosql.sql.planner.Symbol; +import io.prestosql.sql.planner.assertions.ExpressionMatcher; +import io.prestosql.sql.planner.assertions.RowNumberSymbolMatcher; +import io.prestosql.sql.planner.iterative.rule.test.BaseRuleTest; +import org.testng.annotations.Test; + +import static io.prestosql.sql.planner.assertions.PlanMatchPattern.filter; +import static io.prestosql.sql.planner.assertions.PlanMatchPattern.rowNumber; +import static io.prestosql.sql.planner.assertions.PlanMatchPattern.strictProject; +import static io.prestosql.sql.planner.assertions.PlanMatchPattern.values; + +public class TestImplementOffsetOverOther + extends BaseRuleTest +{ + @Test + public void testReplaceOffsetOverOther() + { + tester().assertThat(new ImplementOffsetOverOther()) + .on(p -> { + Symbol a = p.symbol("a"); + Symbol b = p.symbol("b"); + return p.offset( + 2, + p.values(a, b)); + }) + .matches( + strictProject( + ImmutableMap.of("a", new ExpressionMatcher("a"), "b", new ExpressionMatcher("b")), + filter( + "(row_num > BIGINT '2')", + rowNumber( + pattern -> pattern + .partitionBy(ImmutableList.of()), + values("a", "b")) + .withAlias("row_num", new RowNumberSymbolMatcher())))); + } +} diff --git a/presto-main/src/test/java/io/prestosql/sql/planner/iterative/rule/TestImplementOffsetOverProjectSort.java b/presto-main/src/test/java/io/prestosql/sql/planner/iterative/rule/TestImplementOffsetOverProjectSort.java new file mode 100644 index 000000000000..9fbd324e81c7 --- /dev/null +++ b/presto-main/src/test/java/io/prestosql/sql/planner/iterative/rule/TestImplementOffsetOverProjectSort.java @@ -0,0 +1,66 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.prestosql.sql.planner.iterative.rule; + +import com.google.common.collect.ImmutableList; +import com.google.common.collect.ImmutableMap; +import io.prestosql.sql.planner.Symbol; +import io.prestosql.sql.planner.assertions.ExpressionMatcher; +import io.prestosql.sql.planner.assertions.RowNumberSymbolMatcher; +import io.prestosql.sql.planner.iterative.rule.test.BaseRuleTest; +import io.prestosql.sql.planner.plan.Assignments; +import org.testng.annotations.Test; + +import static io.prestosql.sql.planner.assertions.PlanMatchPattern.filter; +import static io.prestosql.sql.planner.assertions.PlanMatchPattern.rowNumber; +import static io.prestosql.sql.planner.assertions.PlanMatchPattern.sort; +import static io.prestosql.sql.planner.assertions.PlanMatchPattern.strictProject; +import static io.prestosql.sql.planner.assertions.PlanMatchPattern.values; +import static io.prestosql.sql.tree.SortItem.NullOrdering.FIRST; +import static io.prestosql.sql.tree.SortItem.Ordering.ASCENDING; + +public class TestImplementOffsetOverProjectSort + extends BaseRuleTest +{ + @Test + public void testReplaceOffsetOverProjectSort() + { + tester().assertThat(new ImplementOffsetOverProjectSort()) + .on(p -> { + Symbol a = p.symbol("a"); + Symbol b = p.symbol("b"); + return p.offset( + 2, + p.project( + Assignments.identity(b), + p.sort( + ImmutableList.of(a), + p.values(a, b)))); + }) + .matches( + strictProject( + ImmutableMap.of("b", new ExpressionMatcher("b")), + sort( + ImmutableList.of(sort("row_num", ASCENDING, FIRST)), + filter( + "(row_num > BIGINT '2')", + rowNumber( + pattern -> pattern + .partitionBy(ImmutableList.of()), + sort( + ImmutableList.of(sort("a", ASCENDING, FIRST)), + values("a", "b"))) + .withAlias("row_num", new RowNumberSymbolMatcher()))))); + } +} diff --git a/presto-main/src/test/java/io/prestosql/sql/planner/iterative/rule/TestImplementOffsetOverProjectTopN.java b/presto-main/src/test/java/io/prestosql/sql/planner/iterative/rule/TestImplementOffsetOverProjectTopN.java new file mode 100644 index 000000000000..0e3e5ce6558f --- /dev/null +++ b/presto-main/src/test/java/io/prestosql/sql/planner/iterative/rule/TestImplementOffsetOverProjectTopN.java @@ -0,0 +1,69 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.prestosql.sql.planner.iterative.rule; + +import com.google.common.collect.ImmutableList; +import com.google.common.collect.ImmutableMap; +import io.prestosql.sql.planner.Symbol; +import io.prestosql.sql.planner.assertions.ExpressionMatcher; +import io.prestosql.sql.planner.assertions.RowNumberSymbolMatcher; +import io.prestosql.sql.planner.iterative.rule.test.BaseRuleTest; +import io.prestosql.sql.planner.plan.Assignments; +import org.testng.annotations.Test; + +import static io.prestosql.sql.planner.assertions.PlanMatchPattern.filter; +import static io.prestosql.sql.planner.assertions.PlanMatchPattern.rowNumber; +import static io.prestosql.sql.planner.assertions.PlanMatchPattern.sort; +import static io.prestosql.sql.planner.assertions.PlanMatchPattern.strictProject; +import static io.prestosql.sql.planner.assertions.PlanMatchPattern.topN; +import static io.prestosql.sql.planner.assertions.PlanMatchPattern.values; +import static io.prestosql.sql.tree.SortItem.NullOrdering.FIRST; +import static io.prestosql.sql.tree.SortItem.Ordering.ASCENDING; + +public class TestImplementOffsetOverProjectTopN + extends BaseRuleTest +{ + @Test + public void testReplaceOffsetOverProjectTopN() + { + tester().assertThat(new ImplementOffsetOverProjectTopN()) + .on(p -> { + Symbol a = p.symbol("a"); + Symbol b = p.symbol("b"); + return p.offset( + 2, + p.project( + Assignments.identity(b), + p.topN( + 5, + ImmutableList.of(a), + p.values(a, b)))); + }) + .matches( + strictProject( + ImmutableMap.of("b", new ExpressionMatcher("b")), + sort( + ImmutableList.of(sort("row_num", ASCENDING, FIRST)), + filter( + "(row_num > BIGINT '2')", + rowNumber( + pattern -> pattern + .partitionBy(ImmutableList.of()), + topN( + 5, + ImmutableList.of(sort("a", ASCENDING, FIRST)), + values("a", "b"))) + .withAlias("row_num", new RowNumberSymbolMatcher()))))); + } +} diff --git a/presto-main/src/test/java/io/prestosql/sql/planner/iterative/rule/TestImplementOffsetOverSort.java b/presto-main/src/test/java/io/prestosql/sql/planner/iterative/rule/TestImplementOffsetOverSort.java new file mode 100644 index 000000000000..e2a1c261ec8b --- /dev/null +++ b/presto-main/src/test/java/io/prestosql/sql/planner/iterative/rule/TestImplementOffsetOverSort.java @@ -0,0 +1,63 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.prestosql.sql.planner.iterative.rule; + +import com.google.common.collect.ImmutableList; +import com.google.common.collect.ImmutableMap; +import io.prestosql.sql.planner.Symbol; +import io.prestosql.sql.planner.assertions.ExpressionMatcher; +import io.prestosql.sql.planner.assertions.RowNumberSymbolMatcher; +import io.prestosql.sql.planner.iterative.rule.test.BaseRuleTest; +import org.testng.annotations.Test; + +import static io.prestosql.sql.planner.assertions.PlanMatchPattern.filter; +import static io.prestosql.sql.planner.assertions.PlanMatchPattern.rowNumber; +import static io.prestosql.sql.planner.assertions.PlanMatchPattern.sort; +import static io.prestosql.sql.planner.assertions.PlanMatchPattern.strictProject; +import static io.prestosql.sql.planner.assertions.PlanMatchPattern.values; +import static io.prestosql.sql.tree.SortItem.NullOrdering.FIRST; +import static io.prestosql.sql.tree.SortItem.Ordering.ASCENDING; + +public class TestImplementOffsetOverSort + extends BaseRuleTest +{ + @Test + public void testReplaceOffsetOverSort() + { + tester().assertThat(new ImplementOffsetOverSort()) + .on(p -> { + Symbol a = p.symbol("a"); + Symbol b = p.symbol("b"); + return p.offset( + 2, + p.sort( + ImmutableList.of(a), + p.values(a, b))); + }) + .matches( + strictProject( + ImmutableMap.of("a", new ExpressionMatcher("a"), "b", new ExpressionMatcher("b")), + sort( + ImmutableList.of(sort("row_num", ASCENDING, FIRST)), + filter( + "(row_num > BIGINT '2')", + rowNumber( + pattern -> pattern + .partitionBy(ImmutableList.of()), + sort( + ImmutableList.of(sort("a", ASCENDING, FIRST)), + values("a", "b"))) + .withAlias("row_num", new RowNumberSymbolMatcher()))))); + } +} diff --git a/presto-main/src/test/java/io/prestosql/sql/planner/iterative/rule/TestImplementOffsetOverTopN.java b/presto-main/src/test/java/io/prestosql/sql/planner/iterative/rule/TestImplementOffsetOverTopN.java new file mode 100644 index 000000000000..399eec970860 --- /dev/null +++ b/presto-main/src/test/java/io/prestosql/sql/planner/iterative/rule/TestImplementOffsetOverTopN.java @@ -0,0 +1,66 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.prestosql.sql.planner.iterative.rule; + +import com.google.common.collect.ImmutableList; +import com.google.common.collect.ImmutableMap; +import io.prestosql.sql.planner.Symbol; +import io.prestosql.sql.planner.assertions.ExpressionMatcher; +import io.prestosql.sql.planner.assertions.RowNumberSymbolMatcher; +import io.prestosql.sql.planner.iterative.rule.test.BaseRuleTest; +import org.testng.annotations.Test; + +import static io.prestosql.sql.planner.assertions.PlanMatchPattern.filter; +import static io.prestosql.sql.planner.assertions.PlanMatchPattern.rowNumber; +import static io.prestosql.sql.planner.assertions.PlanMatchPattern.sort; +import static io.prestosql.sql.planner.assertions.PlanMatchPattern.strictProject; +import static io.prestosql.sql.planner.assertions.PlanMatchPattern.topN; +import static io.prestosql.sql.planner.assertions.PlanMatchPattern.values; +import static io.prestosql.sql.tree.SortItem.NullOrdering.FIRST; +import static io.prestosql.sql.tree.SortItem.Ordering.ASCENDING; + +public class TestImplementOffsetOverTopN + extends BaseRuleTest +{ + @Test + public void testReplaceOffsetOverTopN() + { + tester().assertThat(new ImplementOffsetOverTopN()) + .on(p -> { + Symbol a = p.symbol("a"); + Symbol b = p.symbol("b"); + return p.offset( + 2, + p.topN( + 5, + ImmutableList.of(a), + p.values(a, b))); + }) + .matches( + strictProject( + ImmutableMap.of("a", new ExpressionMatcher("a"), "b", new ExpressionMatcher("b")), + sort( + ImmutableList.of(sort("row_num", ASCENDING, FIRST)), + filter( + "(row_num > BIGINT '2')", + rowNumber( + pattern -> pattern + .partitionBy(ImmutableList.of()), + topN( + 5, + ImmutableList.of(sort("a", ASCENDING, FIRST)), + values("a", "b"))) + .withAlias("row_num", new RowNumberSymbolMatcher()))))); + } +} diff --git a/presto-main/src/test/java/io/prestosql/sql/planner/iterative/rule/TestPushLimitThroughOffset.java b/presto-main/src/test/java/io/prestosql/sql/planner/iterative/rule/TestPushLimitThroughOffset.java new file mode 100644 index 000000000000..1f1c3c0acafa --- /dev/null +++ b/presto-main/src/test/java/io/prestosql/sql/planner/iterative/rule/TestPushLimitThroughOffset.java @@ -0,0 +1,50 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.prestosql.sql.planner.iterative.rule; + +import io.prestosql.sql.planner.iterative.rule.test.BaseRuleTest; +import org.testng.annotations.Test; + +import static io.prestosql.sql.planner.assertions.PlanMatchPattern.limit; +import static io.prestosql.sql.planner.assertions.PlanMatchPattern.offset; +import static io.prestosql.sql.planner.assertions.PlanMatchPattern.values; + +public class TestPushLimitThroughOffset + extends BaseRuleTest +{ + @Test + public void testPushdownLimitThroughOffset() + { + tester().assertThat(new PushLimitThroughOffset()) + .on(p -> p.limit( + 2, + p.offset(5, p.values()))) + .matches( + offset( + 5, + limit(7, values()))); + } + + @Test + public void doNotPushdownWhenRowCountOverflowsLong() + { + tester().assertThat(new PushLimitThroughOffset()) + .on(p -> { + return p.limit( + Long.MAX_VALUE, + p.offset(Long.MAX_VALUE, p.values())); + }) + .doesNotFire(); + } +} diff --git a/presto-main/src/test/java/io/prestosql/sql/planner/iterative/rule/TestPushOffsetThroughProject.java b/presto-main/src/test/java/io/prestosql/sql/planner/iterative/rule/TestPushOffsetThroughProject.java new file mode 100644 index 000000000000..b586be020d4d --- /dev/null +++ b/presto-main/src/test/java/io/prestosql/sql/planner/iterative/rule/TestPushOffsetThroughProject.java @@ -0,0 +1,62 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.prestosql.sql.planner.iterative.rule; + +import com.google.common.collect.ImmutableMap; +import io.prestosql.sql.planner.Symbol; +import io.prestosql.sql.planner.iterative.rule.test.BaseRuleTest; +import io.prestosql.sql.planner.plan.Assignments; +import org.testng.annotations.Test; + +import static io.prestosql.sql.planner.assertions.PlanMatchPattern.expression; +import static io.prestosql.sql.planner.assertions.PlanMatchPattern.offset; +import static io.prestosql.sql.planner.assertions.PlanMatchPattern.strictProject; +import static io.prestosql.sql.planner.assertions.PlanMatchPattern.values; +import static io.prestosql.sql.tree.BooleanLiteral.TRUE_LITERAL; + +public class TestPushOffsetThroughProject + extends BaseRuleTest +{ + @Test + public void testPushdownOffsetNonIdentityProjection() + { + tester().assertThat(new PushOffsetThroughProject()) + .on(p -> { + Symbol a = p.symbol("a"); + return p.offset( + 5, + p.project( + Assignments.of(a, TRUE_LITERAL), + p.values())); + }) + .matches( + strictProject( + ImmutableMap.of("b", expression("true")), + offset(5, values()))); + } + + @Test + public void testDoNotPushdownOffsetThroughIdentityProjection() + { + tester().assertThat(new PushOffsetThroughProject()) + .on(p -> { + Symbol a = p.symbol("a"); + return p.offset( + 5, + p.project( + Assignments.of(a, a.toSymbolReference()), + p.values(a))); + }).doesNotFire(); + } +} diff --git a/presto-main/src/test/java/io/prestosql/sql/planner/iterative/rule/test/PlanBuilder.java b/presto-main/src/test/java/io/prestosql/sql/planner/iterative/rule/test/PlanBuilder.java index 3c68e03103b2..531a6bd37254 100644 --- a/presto-main/src/test/java/io/prestosql/sql/planner/iterative/rule/test/PlanBuilder.java +++ b/presto-main/src/test/java/io/prestosql/sql/planner/iterative/rule/test/PlanBuilder.java @@ -57,6 +57,7 @@ import io.prestosql.sql.planner.plan.LateralJoinNode; import io.prestosql.sql.planner.plan.LimitNode; import io.prestosql.sql.planner.plan.MarkDistinctNode; +import io.prestosql.sql.planner.plan.OffsetNode; import io.prestosql.sql.planner.plan.OutputNode; import io.prestosql.sql.planner.plan.PlanFragmentId; import io.prestosql.sql.planner.plan.PlanNode; @@ -209,6 +210,11 @@ public SortNode sort(List orderBy, PlanNode source) Maps.toMap(orderBy, Functions.constant(SortOrder.ASC_NULLS_FIRST)))); } + public OffsetNode offset(long rowCount, PlanNode source) + { + return new OffsetNode(idAllocator.getNextId(), source, rowCount); + } + public LimitNode limit(long limit, PlanNode source) { return new LimitNode(idAllocator.getNextId(), source, limit, false); diff --git a/presto-tests/src/main/java/io/prestosql/tests/AbstractTestQueries.java b/presto-tests/src/main/java/io/prestosql/tests/AbstractTestQueries.java index 2445fd242a06..f3e9468fd5f2 100644 --- a/presto-tests/src/main/java/io/prestosql/tests/AbstractTestQueries.java +++ b/presto-tests/src/main/java/io/prestosql/tests/AbstractTestQueries.java @@ -842,6 +842,49 @@ public void testLimitAll() assertQuery("SELECT custkey, totalprice FROM orders LIMIT ALL", "SELECT custkey, totalprice FROM orders"); } + @Test + public void testOffset() + { + String values = "(VALUES ('A', 3), ('D', 2), ('C', 1), ('B', 4)) AS t(x, y)"; + + MaterializedResult actual = computeActual("SELECT x FROM " + values + " OFFSET 2 ROWS"); + MaterializedResult all = computeExpected("SELECT x FROM " + values, actual.getTypes()); + + assertEquals(actual.getMaterializedRows().size(), 2); + assertNotEquals(actual.getMaterializedRows().get(0), actual.getMaterializedRows().get(1)); + assertContains(all, actual); + } + + @Test + public void testOffsetWithFetch() + { + String values = "(VALUES ('A', 3), ('D', 2), ('C', 1), ('B', 4)) AS t(x, y)"; + + MaterializedResult actual = computeActual("SELECT x FROM " + values + " OFFSET 2 ROWS FETCH NEXT ROW ONLY"); + MaterializedResult all = computeExpected("SELECT x FROM " + values, actual.getTypes()); + + assertEquals(actual.getMaterializedRows().size(), 1); + assertContains(all, actual); + } + + @Test + public void testOffsetWithOrderBy() + { + String values = "(VALUES ('A', 3), ('D', 2), ('C', 1), ('B', 4)) AS t(x, y)"; + + assertQuery("SELECT x FROM " + values + " ORDER BY y OFFSET 2 ROWS", "VALUES 'A', 'B'"); + assertQuery("SELECT x FROM " + values + " ORDER BY y OFFSET 2 ROWS FETCH NEXT 1 ROW ONLY", "VALUES 'A'"); + } + + @Test + public void testOffsetEmptyResult() + { + assertQueryReturnsEmptyResult("SELECT name FROM nation OFFSET 100 ROWS"); + assertQueryReturnsEmptyResult("SELECT name FROM nation ORDER BY regionkey OFFSET 100 ROWS"); + assertQueryReturnsEmptyResult("SELECT name FROM nation OFFSET 100 ROWS LIMIT 20"); + assertQueryReturnsEmptyResult("SELECT name FROM nation ORDER BY regionkey OFFSET 100 ROWS LIMIT 20"); + } + @Test public void testRepeatedAggregations() {