Skip to content

Commit

Permalink
Implement Offset
Browse files Browse the repository at this point in the history
  • Loading branch information
kasiafi authored and martint committed May 10, 2019
1 parent 70e1204 commit 4d6325b
Show file tree
Hide file tree
Showing 32 changed files with 1,532 additions and 1 deletion.
13 changes: 13 additions & 0 deletions presto-main/src/main/java/io/prestosql/sql/analyzer/Analysis.java
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@
import io.prestosql.sql.tree.LambdaArgumentDeclaration;
import io.prestosql.sql.tree.Node;
import io.prestosql.sql.tree.NodeRef;
import io.prestosql.sql.tree.Offset;
import io.prestosql.sql.tree.OrderBy;
import io.prestosql.sql.tree.QuantifiedComparisonExpression;
import io.prestosql.sql.tree.Query;
Expand Down Expand Up @@ -98,6 +99,7 @@ public class Analysis
private final Map<NodeRef<Node>, List<Expression>> outputExpressions = new LinkedHashMap<>();
private final Map<NodeRef<QuerySpecification>, List<FunctionCall>> windowFunctions = new LinkedHashMap<>();
private final Map<NodeRef<OrderBy>, List<FunctionCall>> orderByWindowFunctions = new LinkedHashMap<>();
private final Map<NodeRef<Offset>, Long> offset = new LinkedHashMap<>();
private final Map<NodeRef<Node>, OptionalLong> limit = new LinkedHashMap<>();

private final Map<NodeRef<Join>, Expression> joins = new LinkedHashMap<>();
Expand Down Expand Up @@ -318,6 +320,17 @@ public List<Expression> getOrderByExpressions(Node node)
return orderByExpressions.get(NodeRef.of(node));
}

public void setOffset(Offset node, long rowCount)
{
offset.put(NodeRef.of(node), rowCount);
}

public long getOffset(Offset node)
{
checkState(offset.containsKey(NodeRef.of(node)), "missing OFFSET value for node %s", node);
return offset.get(NodeRef.of(node));
}

public void setLimit(Node node, OptionalLong rowCount)
{
limit.put(NodeRef.of(node), rowCount);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -100,6 +100,7 @@ public enum SemanticErrorCode

TOO_MANY_GROUPING_SETS,

INVALID_OFFSET_ROW_COUNT,
INVALID_FETCH_FIRST_ROW_COUNT,
INVALID_LIMIT_ROW_COUNT,
}
Original file line number Diff line number Diff line change
Expand Up @@ -182,6 +182,7 @@
import static io.prestosql.sql.analyzer.SemanticErrorCode.DUPLICATE_RELATION;
import static io.prestosql.sql.analyzer.SemanticErrorCode.INVALID_FETCH_FIRST_ROW_COUNT;
import static io.prestosql.sql.analyzer.SemanticErrorCode.INVALID_LIMIT_ROW_COUNT;
import static io.prestosql.sql.analyzer.SemanticErrorCode.INVALID_OFFSET_ROW_COUNT;
import static io.prestosql.sql.analyzer.SemanticErrorCode.INVALID_ORDINAL;
import static io.prestosql.sql.analyzer.SemanticErrorCode.INVALID_PROCEDURE_ARGUMENTS;
import static io.prestosql.sql.analyzer.SemanticErrorCode.INVALID_WINDOW_FRAME;
Expand Down Expand Up @@ -2135,7 +2136,17 @@ private List<Expression> analyzeOrderBy(Node node, List<SortItem> sortItems, Sco

private void analyzeOffset(Offset node)
{
throw new SemanticException(NOT_SUPPORTED, node, "OFFSET not yet implemented");
long rowCount;
try {
rowCount = Long.parseLong(node.getRowCount());
}
catch (NumberFormatException e) {
throw new SemanticException(INVALID_OFFSET_ROW_COUNT, node, "Invalid OFFSET row count: %s", node.getRowCount());
}
if (rowCount < 0) {
throw new SemanticException(INVALID_OFFSET_ROW_COUNT, node, "OFFSET row count must be greater or equal to 0 (actual value: %s)", rowCount);
}
analysis.setOffset(node, rowCount);
}

private void analyzeLimit(Node node)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,11 @@
import io.prestosql.sql.planner.iterative.rule.GatherAndMergeWindows;
import io.prestosql.sql.planner.iterative.rule.ImplementBernoulliSampleAsFilter;
import io.prestosql.sql.planner.iterative.rule.ImplementFilteredAggregations;
import io.prestosql.sql.planner.iterative.rule.ImplementOffsetOverOther;
import io.prestosql.sql.planner.iterative.rule.ImplementOffsetOverProjectSort;
import io.prestosql.sql.planner.iterative.rule.ImplementOffsetOverProjectTopN;
import io.prestosql.sql.planner.iterative.rule.ImplementOffsetOverSort;
import io.prestosql.sql.planner.iterative.rule.ImplementOffsetOverTopN;
import io.prestosql.sql.planner.iterative.rule.InlineProjections;
import io.prestosql.sql.planner.iterative.rule.MergeFilters;
import io.prestosql.sql.planner.iterative.rule.MergeLimitOverProjectWithSort;
Expand Down Expand Up @@ -78,9 +83,11 @@
import io.prestosql.sql.planner.iterative.rule.PushAggregationThroughOuterJoin;
import io.prestosql.sql.planner.iterative.rule.PushLimitIntoTableScan;
import io.prestosql.sql.planner.iterative.rule.PushLimitThroughMarkDistinct;
import io.prestosql.sql.planner.iterative.rule.PushLimitThroughOffset;
import io.prestosql.sql.planner.iterative.rule.PushLimitThroughOuterJoin;
import io.prestosql.sql.planner.iterative.rule.PushLimitThroughProject;
import io.prestosql.sql.planner.iterative.rule.PushLimitThroughSemiJoin;
import io.prestosql.sql.planner.iterative.rule.PushOffsetThroughProject;
import io.prestosql.sql.planner.iterative.rule.PushPartialAggregationThroughExchange;
import io.prestosql.sql.planner.iterative.rule.PushPartialAggregationThroughJoin;
import io.prestosql.sql.planner.iterative.rule.PushPredicateIntoTableScan;
Expand Down Expand Up @@ -289,6 +296,8 @@ public PlanOptimizers(
new EvaluateZeroLimit(),
new EvaluateZeroTopN(),
new EvaluateZeroSample(),
new PushOffsetThroughProject(),
new PushLimitThroughOffset(),
new PushLimitThroughProject(),
new MergeLimits(),
new MergeLimitWithSort(),
Expand All @@ -309,6 +318,16 @@ public PlanOptimizers(
new PruneOrderByInAggregation(metadata.getFunctionRegistry()),
new RewriteSpatialPartitioningAggregation(metadata)))
.build()),
new IterativeOptimizer(
ruleStats,
statsCalculator,
estimatedExchangesCostCalculator,
ImmutableSet.of(
new ImplementOffsetOverTopN(),
new ImplementOffsetOverProjectTopN(),
new ImplementOffsetOverSort(),
new ImplementOffsetOverProjectSort(),
new ImplementOffsetOverOther())),
simplifyOptimizer,
new UnaliasSymbolReferences(),
new IterativeOptimizer(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@
import io.prestosql.sql.planner.plan.FilterNode;
import io.prestosql.sql.planner.plan.GroupIdNode;
import io.prestosql.sql.planner.plan.LimitNode;
import io.prestosql.sql.planner.plan.OffsetNode;
import io.prestosql.sql.planner.plan.PlanNode;
import io.prestosql.sql.planner.plan.ProjectNode;
import io.prestosql.sql.planner.plan.SortNode;
Expand All @@ -55,6 +56,7 @@
import io.prestosql.sql.tree.LambdaExpression;
import io.prestosql.sql.tree.Node;
import io.prestosql.sql.tree.NodeRef;
import io.prestosql.sql.tree.Offset;
import io.prestosql.sql.tree.OrderBy;
import io.prestosql.sql.tree.Query;
import io.prestosql.sql.tree.QuerySpecification;
Expand Down Expand Up @@ -133,6 +135,7 @@ public RelationPlan plan(Query query)

builder = sort(builder, query.getOrderBy(), analysis.getOrderByExpressions(query));
builder = project(builder, analysis.getOutputExpressions(query));
builder = offset(builder, query.getOffset());
builder = limit(builder, query.getLimit());

return new RelationPlan(
Expand Down Expand Up @@ -183,6 +186,7 @@ public RelationPlan plan(QuerySpecification node)
builder = distinct(builder, node);
builder = sort(builder, node.getOrderBy(), analysis.getOrderByExpressions(node));
builder = project(builder, outputs);
builder = offset(builder, node.getOffset());
builder = limit(builder, node.getLimit());

return new RelationPlan(
Expand Down Expand Up @@ -881,6 +885,19 @@ private PlanBuilder sort(PlanBuilder subPlan, Optional<OrderBy> orderBy, List<Ex
new OrderingScheme(orderBySymbols.build(), orderings)));
}

private PlanBuilder offset(PlanBuilder subPlan, Optional<Offset> offset)
{
if (!offset.isPresent()) {
return subPlan;
}

return subPlan.withNewRoot(
new OffsetNode(
idAllocator.getNextId(),
subPlan.getRoot(),
analysis.getOffset(offset.get())));
}

private PlanBuilder limit(PlanBuilder subPlan, Optional<Node> limit)
{
if (limit.isPresent() && analysis.getLimit(limit.get()).isPresent()) {
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,91 @@
/*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package io.prestosql.sql.planner.iterative.rule;

import com.google.common.collect.ImmutableList;
import io.prestosql.matching.Captures;
import io.prestosql.matching.Pattern;
import io.prestosql.sql.planner.Symbol;
import io.prestosql.sql.planner.iterative.Rule;
import io.prestosql.sql.planner.plan.Assignments;
import io.prestosql.sql.planner.plan.FilterNode;
import io.prestosql.sql.planner.plan.OffsetNode;
import io.prestosql.sql.planner.plan.ProjectNode;
import io.prestosql.sql.planner.plan.RowNumberNode;
import io.prestosql.sql.planner.plan.SortNode;
import io.prestosql.sql.planner.plan.TopNNode;
import io.prestosql.sql.tree.ComparisonExpression;
import io.prestosql.sql.tree.GenericLiteral;

import java.util.Optional;

import static io.prestosql.spi.type.BigintType.BIGINT;
import static io.prestosql.sql.planner.plan.Patterns.offset;
import static io.prestosql.sql.planner.plan.Patterns.source;

/**
* Transforms:
* <pre>
* - Offset (row count = x)
* - Source (other than Sort, TopN)
* </pre>
* Into:
* <pre>
* - Project (prune rowNumber symbol)
* - Filter (rowNumber > x)
* - RowNumber
* - Source
* </pre>
*/
public class ImplementOffsetOverOther
implements Rule<OffsetNode>
{
private static final Pattern<OffsetNode> PATTERN = offset()
.with(source().matching(node -> !(node instanceof TopNNode || node instanceof SortNode)));

@Override
public Pattern<OffsetNode> getPattern()
{
return PATTERN;
}

@Override
public Result apply(OffsetNode parent, Captures captures, Context context)
{
Symbol rowNumberSymbol = context.getSymbolAllocator().newSymbol("row_number", BIGINT);

RowNumberNode rowNumberNode = new RowNumberNode(
context.getIdAllocator().getNextId(),
parent.getSource(),
ImmutableList.of(),
rowNumberSymbol,
Optional.empty(),
Optional.empty());

FilterNode filterNode = new FilterNode(
context.getIdAllocator().getNextId(),
rowNumberNode,
new ComparisonExpression(
ComparisonExpression.Operator.GREATER_THAN,
rowNumberSymbol.toSymbolReference(),
new GenericLiteral("BIGINT", Long.toString(parent.getCount()))));

ProjectNode projectNode = new ProjectNode(
context.getIdAllocator().getNextId(),
filterNode,
Assignments.identity(parent.getOutputSymbols()));

return Result.ofPlanNode(projectNode);
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,109 @@
/*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package io.prestosql.sql.planner.iterative.rule;

import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableMap;
import io.prestosql.matching.Capture;
import io.prestosql.matching.Captures;
import io.prestosql.matching.Pattern;
import io.prestosql.spi.block.SortOrder;
import io.prestosql.sql.planner.OrderingScheme;
import io.prestosql.sql.planner.Symbol;
import io.prestosql.sql.planner.iterative.Rule;
import io.prestosql.sql.planner.plan.FilterNode;
import io.prestosql.sql.planner.plan.OffsetNode;
import io.prestosql.sql.planner.plan.ProjectNode;
import io.prestosql.sql.planner.plan.RowNumberNode;
import io.prestosql.sql.planner.plan.SortNode;
import io.prestosql.sql.tree.ComparisonExpression;
import io.prestosql.sql.tree.GenericLiteral;

import java.util.Optional;

import static io.prestosql.matching.Capture.newCapture;
import static io.prestosql.spi.type.BigintType.BIGINT;
import static io.prestosql.sql.planner.plan.Patterns.offset;
import static io.prestosql.sql.planner.plan.Patterns.project;
import static io.prestosql.sql.planner.plan.Patterns.sort;
import static io.prestosql.sql.planner.plan.Patterns.source;

/**
* Transforms:
* <pre>
* - Offset (row count = x)
* - Project (prunes sort symbols no longer useful)
* - Sort (order by a, b, c)
* </pre>
* Into:
* <pre>
* - Project (prunes rowNumber symbol and sort symbols no longer useful)
* - Sort (order by rowNumber)
* - Filter (rowNumber > x)
* - RowNumber
* - Sort (order by a, b, c)
* </pre>
*/
public class ImplementOffsetOverProjectSort
implements Rule<OffsetNode>
{
private static final Capture<ProjectNode> PROJECT = newCapture();
private static final Capture<SortNode> SORT = newCapture();

private static final Pattern<OffsetNode> PATTERN = offset()
.with(source().matching(
project().capturedAs(PROJECT).matching(ProjectNode::isIdentity)
.with(source().matching(
sort().capturedAs(SORT)))));

@Override
public Pattern<OffsetNode> getPattern()
{
return PATTERN;
}

@Override
public Result apply(OffsetNode parent, Captures captures, Context context)
{
ProjectNode project = captures.get(PROJECT);
SortNode sort = captures.get(SORT);

Symbol rowNumberSymbol = context.getSymbolAllocator().newSymbol("row_number", BIGINT);

RowNumberNode rowNumberNode = new RowNumberNode(
context.getIdAllocator().getNextId(),
sort,
ImmutableList.of(),
rowNumberSymbol,
Optional.empty(),
Optional.empty());

FilterNode filterNode = new FilterNode(
context.getIdAllocator().getNextId(),
rowNumberNode,
new ComparisonExpression(
ComparisonExpression.Operator.GREATER_THAN,
rowNumberSymbol.toSymbolReference(),
new GenericLiteral("BIGINT", Long.toString(parent.getCount()))));

SortNode sortNode = new SortNode(
context.getIdAllocator().getNextId(),
filterNode,
new OrderingScheme(ImmutableList.of(rowNumberSymbol), ImmutableMap.of(rowNumberSymbol, SortOrder.ASC_NULLS_FIRST)));

ProjectNode projectNode = (ProjectNode) project.replaceChildren(ImmutableList.of(sortNode));

return Result.ofPlanNode(projectNode);
}
}
Loading

0 comments on commit 4d6325b

Please sign in to comment.