diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/cost/Cost.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/cost/Cost.java index 18dfbad5e0ad5e..1e0e448dd69f45 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/cost/Cost.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/cost/Cost.java @@ -27,16 +27,6 @@ public interface Cost { double getValue(); - /** - * This is for calculating the cost in simplifier - */ - static Cost withRowCount(double rowCount) { - if (ConnectContext.get().getSessionVariable().getEnableNewCostModel()) { - return new CostV2(0, rowCount, 0); - } - return new CostV1(rowCount); - } - /** * return zero cost */ diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/cost/CostV1.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/cost/CostV1.java index bf1cc425999f7c..fb00bacc2877a7 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/cost/CostV1.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/cost/CostV1.java @@ -45,13 +45,6 @@ public CostV1(double cpuCost, double memoryCost, double networkCost) { + costWeight.networkWeight * networkCost; } - public CostV1(double cost) { - this.cost = cost; - this.cpuCost = 0; - this.networkCost = 0; - this.memoryCost = 0; - } - public static CostV1 infinite() { return INFINITE; } diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/jobs/joinorder/hypergraph/GraphSimplifier.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/jobs/joinorder/hypergraph/GraphSimplifier.java index c9948c21105a64..aa92d0fe0059c6 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/jobs/joinorder/hypergraph/GraphSimplifier.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/jobs/joinorder/hypergraph/GraphSimplifier.java @@ -18,32 +18,28 @@ package org.apache.doris.nereids.jobs.joinorder.hypergraph; import org.apache.doris.common.Pair; -import org.apache.doris.nereids.PlanContext; -import org.apache.doris.nereids.cost.Cost; -import org.apache.doris.nereids.cost.CostCalculator; import org.apache.doris.nereids.jobs.joinorder.hypergraph.bitmap.LongBitmap; import org.apache.doris.nereids.jobs.joinorder.hypergraph.receiver.Counter; import org.apache.doris.nereids.stats.JoinEstimation; import org.apache.doris.nereids.trees.expressions.Expression; import org.apache.doris.nereids.trees.expressions.Slot; -import org.apache.doris.nereids.trees.plans.JoinType; +import org.apache.doris.nereids.trees.plans.Plan; import org.apache.doris.nereids.trees.plans.logical.LogicalJoin; -import org.apache.doris.nereids.trees.plans.physical.PhysicalHashJoin; -import org.apache.doris.nereids.trees.plans.physical.PhysicalNestedLoopJoin; import org.apache.doris.nereids.util.JoinUtils; import org.apache.doris.statistics.Statistics; import com.google.common.base.Preconditions; -import com.google.common.collect.ImmutableList; -import com.google.common.collect.Lists; +import java.util.ArrayDeque; import java.util.ArrayList; +import java.util.BitSet; +import java.util.Collection; +import java.util.Deque; import java.util.HashMap; import java.util.List; import java.util.Optional; import java.util.PriorityQueue; import java.util.Set; -import java.util.Stack; import java.util.stream.Collectors; import javax.annotation.Nullable; @@ -70,10 +66,9 @@ public class GraphSimplifier { // because it's just used for simulating join. In fact, the graph simplifier // just generate the partial order of join operator. private final HashMap cacheStats = new HashMap<>(); - private final HashMap cacheCost = new HashMap<>(); - - private final Stack appliedSteps = new Stack<>(); - private final Stack unAppliedSteps = new Stack<>(); + private final HashMap cacheCost = new HashMap<>(); + private final Deque appliedSteps = new ArrayDeque<>(); + private final Deque unAppliedSteps = new ArrayDeque<>(); private final Set validEdges; @@ -91,7 +86,7 @@ public GraphSimplifier(HyperGraph graph) { } for (Node node : graph.getNodes()) { cacheStats.put(node.getNodeMap(), node.getGroup().getStatistics()); - cacheCost.put(node.getNodeMap(), Cost.withRowCount(node.getRowCount())); + cacheCost.put(node.getNodeMap(), node.getRowCount()); } validEdges = graph.getEdges().stream() .filter(e -> { @@ -116,6 +111,13 @@ public GraphSimplifier(HyperGraph graph) { initFirstStep(); } + private boolean isOverlap(Edge edge1, Edge edge2) { + return (LongBitmap.isOverlap(edge1.getLeftExtendedNodes(), edge2.getLeftExtendedNodes()) + && LongBitmap.isOverlap(edge1.getRightExtendedNodes(), edge2.getRightExtendedNodes())) + || (LongBitmap.isOverlap(edge1.getLeftExtendedNodes(), edge2.getRightExtendedNodes()) + && LongBitmap.isOverlap(edge1.getRightExtendedNodes(), edge2.getLeftExtendedNodes())); + } + private void initFirstStep() { extractJoinDependencies(); for (int i = 0; i < edgeSize; i += 1) { @@ -138,8 +140,10 @@ public boolean isTotalOrder() { tryGetSuperset(edge1.getLeftExtendedNodes(), edge2.getRightExtendedNodes(), superset); tryGetSuperset(edge1.getRightExtendedNodes(), edge2.getLeftExtendedNodes(), superset); tryGetSuperset(edge1.getRightExtendedNodes(), edge2.getRightExtendedNodes(), superset); - if (!circleDetector.checkCircleWithEdge(i, j) && !circleDetector.checkCircleWithEdge(j, i) - && !edge2.isSub(edge1) && !edge1.isSub(edge2) && !superset.isEmpty()) { + if (edge2.isSub(edge1) || edge1.isSub(edge2) || superset.isEmpty() || isOverlap(edge1, edge2)) { + continue; + } + if (!(circleDetector.checkCircleWithEdge(i, j) || circleDetector.checkCircleWithEdge(j, i))) { return false; } } @@ -211,7 +215,7 @@ public boolean applySimplificationStep() { appliedSteps.push(bestStep); Preconditions.checkArgument( cacheStats.containsKey(bestStep.newLeft) && cacheStats.containsKey(bestStep.newRight), - String.format("%s - %s", bestStep.newLeft, bestStep.newRight)); + "<%s - %s> must has been stats derived", bestStep.newLeft, bestStep.newRight); graph.modifyEdge(bestStep.afterIndex, bestStep.newLeft, bestStep.newRight); if (needProcessNeighbor) { processNeighbors(bestStep.afterIndex, 0, edgeSize); @@ -220,7 +224,8 @@ public boolean applySimplificationStep() { } private boolean unApplySimplificationStep() { - Preconditions.checkArgument(appliedSteps.size() > 0); + Preconditions.checkArgument(!appliedSteps.isEmpty(), + "try to unapply a simplification step but there is no step applied"); SimplificationStep bestStep = appliedSteps.pop(); unAppliedSteps.push(bestStep); graph.modifyEdge(bestStep.afterIndex, bestStep.oldLeft, bestStep.oldRight); @@ -350,8 +355,8 @@ private Optional makeSimplificationStep(int edgeIndex1, int || !cacheStats.containsKey(left2) || !cacheStats.containsKey(right2)) { return Optional.empty(); } - Pair edge1Before2; - Pair edge2Before1; + Edge edge1Before2; + Edge edge2Before1; List superBitset = new ArrayList<>(); if (tryGetSuperset(left1, left2, superBitset)) { // (common Join1 right1) Join2 right2 @@ -377,108 +382,145 @@ private Optional makeSimplificationStep(int edgeIndex1, int return Optional.empty(); } + if (edge1Before2 == null || edge2Before1 == null) { + return Optional.empty(); + } + // edge1 is not the neighborhood of edge2 SimplificationStep simplificationStep = orderJoin(edge1Before2, edge2Before1, edgeIndex1, edgeIndex2); return Optional.of(simplificationStep); } - Pair threeLeftJoin(long bitmap1, Edge edge1, long bitmap2, Edge edge2, long bitmap3) { + private Edge constructEdge(long leftNodes, Edge edge, long rightNodes) { + if (graph.getEdges().size() > 64 * 63 / 8) { + // If there are too many edges, it is advisable to return the "edge" directly + // to avoid lengthy enumeration time. + return edge; + } + BitSet validEdgesMap = graph.getEdgesInOperator(leftNodes, rightNodes); + List hashConditions = validEdgesMap.stream() + .mapToObj(i -> graph.getEdge(i).getJoin().getHashJoinConjuncts()) + .flatMap(Collection::stream) + .collect(Collectors.toList()); + List otherConditions = validEdgesMap.stream() + .mapToObj(i -> graph.getEdge(i).getJoin().getHashJoinConjuncts()) + .flatMap(Collection::stream) + .collect(Collectors.toList()); + LogicalJoin join = + edge.getJoin().withJoinConjuncts(hashConditions, otherConditions); + + Edge newEdge = new Edge( + join, + -1, edge.getLeftChildEdges(), edge.getRightChildEdges(), edge.getSubTreeNodes()); + newEdge.setLeftRequiredNodes(edge.getLeftRequiredNodes()); + newEdge.setRightRequiredNodes(edge.getRightRequiredNodes()); + newEdge.addLeftNode(leftNodes); + newEdge.addRightNode(rightNodes); + return newEdge; + } + + private void deriveStats(Edge edge, long leftBitmap, long rightBitmap) { + // The bitmap may differ from the edge's reference slots. + // Taking into account the order: edge1<{1} - {2}> edge2<{1,3} - {4}>. + // Actually, we are considering the sequence {1,3} - {2} - {4} + long bitmap = LongBitmap.newBitmapUnion(leftBitmap, rightBitmap); + if (cacheStats.containsKey(bitmap)) { + return; + } + // Note the edge in graphSimplifier contains all tree nodes + Statistics joinStats = JoinEstimation + .estimate(cacheStats.get(leftBitmap), + cacheStats.get(rightBitmap), edge.getJoin()); + cacheStats.put(bitmap, joinStats); + } + + private double calCost(Edge edge, long leftBitmap, long rightBitmap) { + long bitmap = LongBitmap.newBitmapUnion(leftBitmap, rightBitmap); + Preconditions.checkArgument(cacheStats.containsKey(leftBitmap) && cacheStats.containsKey(rightBitmap) + && cacheStats.containsKey(bitmap), + "graph simplifier meet an edge %s that have not been derived stats", edge); + LogicalJoin join = edge.getJoin(); + Statistics leftStats = cacheStats.get(leftBitmap); + Statistics rightStats = cacheStats.get(rightBitmap); + double cost; + if (JoinUtils.shouldNestedLoopJoin(join)) { + cost = cacheCost.get(leftBitmap) + cacheCost.get(rightBitmap) + + rightStats.getRowCount() + 1 / leftStats.getRowCount(); + } else { + cost = cacheCost.get(leftBitmap) + cacheCost.get(rightBitmap) + + (rightStats.getRowCount() + 1 / leftStats.getRowCount()) * 1.2; + } + + if (!cacheCost.containsKey(bitmap) || cacheCost.get(bitmap) > cost) { + cacheCost.put(bitmap, cost); + } + return cost; + } + + private @Nullable Edge threeLeftJoin(long bitmap1, Edge edge1, long bitmap2, Edge edge2, long bitmap3) { // (plan1 edge1 plan2) edge2 plan3 - // The join may have redundant table, e.g., t1,t2 join t3 join t2,t4 - // Therefore, the cost is not accurate + // if the left and right is overlapping, just return null. Preconditions.checkArgument( cacheStats.containsKey(bitmap1) && cacheStats.containsKey(bitmap2) && cacheStats.containsKey(bitmap3)); - Statistics leftStats = JoinEstimation.estimate(cacheStats.get(bitmap1), cacheStats.get(bitmap2), - edge1.getJoin()); - Statistics joinStats = JoinEstimation.estimate(leftStats, cacheStats.get(bitmap3), edge2.getJoin()); - Edge edge = new Edge( - edge2.getJoin(), -1, edge2.getLeftChildEdges(), edge2.getRightChildEdges(), edge2.getSubTreeNodes()); + + // construct new Edge long newLeft = LongBitmap.newBitmapUnion(bitmap1, bitmap2); - // To avoid overlapping the left and the right, the newLeft is calculated, Note the - // newLeft is not totally include the bitset1 and bitset2, we use circle detector to trace the dependency - newLeft = LongBitmap.andNot(newLeft, bitmap3); - edge.addLeftNodes(newLeft); - edge.addRightNode(edge2.getRightExtendedNodes()); - cacheStats.put(newLeft, leftStats); - cacheCost.put(newLeft, calCost(edge2, leftStats, cacheStats.get(bitmap1), cacheStats.get(bitmap2))); - return Pair.of(joinStats, edge); + if (LongBitmap.isOverlap(newLeft, bitmap3)) { + return null; + } + Edge newEdge = constructEdge(newLeft, edge2, bitmap3); + + deriveStats(edge1, bitmap1, bitmap2); + deriveStats(newEdge, newLeft, bitmap3); + + calCost(edge1, bitmap1, bitmap2); + + return newEdge; } - Pair threeRightJoin(long bitmap1, Edge edge1, long bitmap2, Edge edge2, long bitmap3) { + private @Nullable Edge threeRightJoin(long bitmap1, Edge edge1, long bitmap2, Edge edge2, long bitmap3) { Preconditions.checkArgument( cacheStats.containsKey(bitmap1) && cacheStats.containsKey(bitmap2) && cacheStats.containsKey(bitmap3)); // plan1 edge1 (plan2 edge2 plan3) - Statistics rightStats = JoinEstimation.estimate(cacheStats.get(bitmap2), cacheStats.get(bitmap3), - edge2.getJoin()); - Statistics joinStats = JoinEstimation.estimate(cacheStats.get(bitmap1), rightStats, edge1.getJoin()); - Edge edge = new Edge( - edge1.getJoin(), -1, edge1.getLeftChildEdges(), edge1.getRightChildEdges(), edge1.getSubTreeNodes()); - long newRight = LongBitmap.newBitmapUnion(bitmap2, bitmap3); - newRight = LongBitmap.andNot(newRight, bitmap1); - edge.addLeftNode(edge1.getLeftExtendedNodes()); - edge.addRightNode(newRight); - cacheStats.put(newRight, rightStats); - cacheCost.put(newRight, calCost(edge2, rightStats, cacheStats.get(bitmap2), cacheStats.get(bitmap3))); - return Pair.of(joinStats, edge); - } - - private Edge processMissedEdges(int edgeIndex1, int edgeIndex2, Edge edge) { - List edges = Lists.newArrayList(edge); - edges.addAll(graph.getEdges().stream() - .filter(e -> e.getIndex() != edgeIndex1 && e.getIndex() != edgeIndex2 - && LongBitmap.isSubset(e.getReferenceNodes(), edge.getReferenceNodes()) - && !LongBitmap.isSubset(e.getReferenceNodes(), edge.getLeftExtendedNodes()) - && !LongBitmap.isSubset(e.getReferenceNodes(), edge.getRightExtendedNodes())) - .collect(Collectors.toList())); - if (edges.size() > 1) { - List hashConjuncts = new ArrayList<>(); - List otherConjuncts = new ArrayList<>(); - JoinType joinType = Edge.extractJoinTypeAndConjuncts(edges, hashConjuncts, otherConjuncts); - LogicalJoin oldJoin = edge.getJoin(); - LogicalJoin newJoin = new LogicalJoin<>(joinType, hashConjuncts, - otherConjuncts, oldJoin.getHint(), oldJoin.left(), oldJoin.right()); - Edge newEdge = Edge.createTempEdge(newJoin); - newEdge.setLeftExtendedNodes(edge.getLeftExtendedNodes()); - newEdge.setRightExtendedNodes(edge.getRightExtendedNodes()); - return newEdge; - } else { - return edge; + if (LongBitmap.isOverlap(bitmap1, newRight)) { + return null; } + Edge newEdge = constructEdge(bitmap1, edge1, newRight); + + deriveStats(edge2, bitmap2, bitmap3); + deriveStats(newEdge, bitmap1, newRight); + + calCost(edge1, bitmap2, bitmap3); + return newEdge; } - private SimplificationStep orderJoin(Pair edge1Before2, - Pair edge2Before1, int edgeIndex1, int edgeIndex2) { - // TODO: Consider miss edges when construct join. - // considering - // a - // / \ - // b - c - // when constructing edge_ab before edge_bc. edge_ac should be added on top join - Cost cost1Before2 = calCost(edge1Before2.second, edge1Before2.first, - cacheStats.get(edge1Before2.second.getLeftExtendedNodes()), - cacheStats.get(edge1Before2.second.getRightExtendedNodes())); - Cost cost2Before1 = calCost(edge2Before1.second, edge2Before1.first, - cacheStats.get(edge2Before1.second.getLeftExtendedNodes()), - cacheStats.get(edge2Before1.second.getRightExtendedNodes())); + private SimplificationStep orderJoin(Edge edge1Before2, + Edge edge2Before1, int edgeIndex1, int edgeIndex2) { + double cost1Before2 = calCost(edge1Before2, + edge1Before2.getLeftExtendedNodes(), edge1Before2.getRightExtendedNodes()); + double cost2Before1 = calCost(edge2Before1, + edge2Before1.getLeftExtendedNodes(), edge2Before1.getRightExtendedNodes()); double benefit = Double.MAX_VALUE; SimplificationStep step; // Choose the plan with smaller cost and make the simplification step to replace the old edge by it. - if (cost1Before2.getValue() < cost2Before1.getValue()) { - if (cost1Before2.getValue() != 0) { - benefit = cost2Before1.getValue() / cost1Before2.getValue(); + if (cost1Before2 < cost2Before1) { + if (cost1Before2 != 0) { + benefit = cost2Before1 / cost1Before2; } // choose edge1Before2 - step = new SimplificationStep(benefit, edgeIndex1, edgeIndex2, edge1Before2.second.getLeftExtendedNodes(), - edge1Before2.second.getRightExtendedNodes(), graph.getEdge(edgeIndex2).getLeftExtendedNodes(), + step = new SimplificationStep(benefit, edgeIndex1, edgeIndex2, + edge1Before2.getLeftExtendedNodes(), + edge1Before2.getRightExtendedNodes(), + graph.getEdge(edgeIndex2).getLeftExtendedNodes(), graph.getEdge(edgeIndex2).getRightExtendedNodes()); } else { - if (cost2Before1.getValue() != 0) { - benefit = cost1Before2.getValue() / cost2Before1.getValue(); + if (cost2Before1 != 0) { + benefit = cost1Before2 / cost2Before1; } // choose edge2Before1 - step = new SimplificationStep(benefit, edgeIndex2, edgeIndex1, edge2Before1.second.getLeftExtendedNodes(), - edge2Before1.second.getRightExtendedNodes(), graph.getEdge(edgeIndex1).getLeftExtendedNodes(), + step = new SimplificationStep(benefit, edgeIndex2, edgeIndex1, edge2Before1.getLeftExtendedNodes(), + edge2Before1.getRightExtendedNodes(), graph.getEdge(edgeIndex1).getLeftExtendedNodes(), graph.getEdge(edgeIndex1).getRightExtendedNodes()); } return step; @@ -495,41 +537,6 @@ private boolean tryGetSuperset(long bitmap1, long bitmap2, List superset) return false; } - private Cost calCost(Edge edge, Statistics stats, - Statistics leftStats, Statistics rightStats) { - LogicalJoin join = edge.getJoin(); - PlanContext planContext = new PlanContext(stats, ImmutableList.of(leftStats, rightStats)); - Cost cost; - if (JoinUtils.shouldNestedLoopJoin(join)) { - PhysicalNestedLoopJoin nestedLoopJoin = new PhysicalNestedLoopJoin<>( - join.getJoinType(), - join.getHashJoinConjuncts(), - join.getOtherJoinConjuncts(), - join.getMarkJoinSlotReference(), - join.getLogicalProperties(), - join.left(), - join.right()); - cost = CostCalculator.calculateCost(nestedLoopJoin, planContext); - cost = CostCalculator.addChildCost(nestedLoopJoin, cost, cacheCost.get(edge.getLeftExtendedNodes()), 0); - cost = CostCalculator.addChildCost(nestedLoopJoin, cost, cacheCost.get(edge.getRightExtendedNodes()), 1); - } else { - PhysicalHashJoin hashJoin = new PhysicalHashJoin<>( - join.getJoinType(), - join.getHashJoinConjuncts(), - join.getOtherJoinConjuncts(), - join.getHint(), - join.getMarkJoinSlotReference(), - join.getLogicalProperties(), - join.left(), - join.right()); - cost = CostCalculator.calculateCost(hashJoin, planContext); - cost = CostCalculator.addChildCost(hashJoin, cost, cacheCost.get(edge.getLeftExtendedNodes()), 0); - cost = CostCalculator.addChildCost(hashJoin, cost, cacheCost.get(edge.getRightExtendedNodes()), 1); - } - - return cost; - } - /** * Put join dependencies into circle detector. */ @@ -594,7 +601,7 @@ static class BestSimplification implements Comparable { @Override public int compareTo(GraphSimplifier.BestSimplification o) { Preconditions.checkArgument(step.isPresent()); - return Double.compare(getBenefit(), o.getBenefit()); + return Double.compare(o.getBenefit(), getBenefit()); } public double getBenefit() { diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/jobs/joinorder/hypergraph/HyperGraph.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/jobs/joinorder/hypergraph/HyperGraph.java index ccb1aff31cf53b..23ca17825b5637 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/jobs/joinorder/hypergraph/HyperGraph.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/jobs/joinorder/hypergraph/HyperGraph.java @@ -50,6 +50,8 @@ public class HyperGraph { private final List nodes = new ArrayList<>(); private final HashSet nodeSet = new HashSet<>(); private final HashMap slotToNodeMap = new HashMap<>(); + // record all edges that can be placed on the subgraph + private final Map treeEdgesCache = new HashMap<>(); // Record the complex project expression for some subgraph // e.g. project (a + b) @@ -268,6 +270,30 @@ private Pair calculateEnds(long allNodes, Pair leftEdg return Pair.of(left, right); } + public BitSet getEdgesInOperator(long left, long right) { + BitSet operatorEdgesMap = new BitSet(); + operatorEdgesMap.or(getEdgesInTree(LongBitmap.or(left, right))); + operatorEdgesMap.andNot(getEdgesInTree(left)); + operatorEdgesMap.andNot(getEdgesInTree(right)); + return operatorEdgesMap; + } + + /** + * Returns all edges in the tree + */ + public BitSet getEdgesInTree(long treeNodesMap) { + if (!treeEdgesCache.containsKey(treeNodesMap)) { + BitSet edgesMap = new BitSet(); + for (Edge edge : edges) { + if (LongBitmap.isSubset(edge.getReferenceNodes(), treeNodesMap)) { + edgesMap.set(edge.getIndex()); + } + } + treeEdgesCache.put(treeNodesMap, edgesMap); + } + return treeEdgesCache.get(treeNodesMap); + } + private long calNodeMap(Set slots) { Preconditions.checkArgument(slots.size() != 0); long bitmap = LongBitmap.newBitmap(); diff --git a/fe/fe-core/src/test/java/org/apache/doris/nereids/jobs/joinorder/hypergraph/GraphSimplifierTest.java b/fe/fe-core/src/test/java/org/apache/doris/nereids/jobs/joinorder/hypergraph/GraphSimplifierTest.java index e2849c6d80df90..79859ccf7aa738 100644 --- a/fe/fe-core/src/test/java/org/apache/doris/nereids/jobs/joinorder/hypergraph/GraphSimplifierTest.java +++ b/fe/fe-core/src/test/java/org/apache/doris/nereids/jobs/joinorder/hypergraph/GraphSimplifierTest.java @@ -17,6 +17,7 @@ package org.apache.doris.nereids.jobs.joinorder.hypergraph; +import org.apache.doris.common.Pair; import org.apache.doris.nereids.jobs.joinorder.hypergraph.receiver.Counter; import org.apache.doris.nereids.trees.expressions.Alias; import org.apache.doris.nereids.trees.expressions.EqualTo; @@ -28,6 +29,7 @@ import org.apache.doris.nereids.util.PlanConstructor; import org.apache.doris.statistics.Statistics; +import com.google.common.collect.ImmutableList; import com.google.common.collect.Sets; import org.apache.hadoop.util.Lists; import org.junit.jupiter.api.Assertions; @@ -36,6 +38,7 @@ import java.util.ArrayList; import java.util.HashMap; +import java.util.List; class GraphSimplifierTest { private static final LogicalOlapScan scan1 = PlanConstructor.newLogicalOlapScan(0, "t1", 0); @@ -75,20 +78,32 @@ void testStarQuery() { // | // t2 HyperGraph hyperGraph = new HyperGraphBuilder() - .init(10, 30, 20, 40, 50) + .init(10, 20, 30, 40, 50) .addEdge(JoinType.INNER_JOIN, 0, 1) .addEdge(JoinType.INNER_JOIN, 0, 2) .addEdge(JoinType.INNER_JOIN, 0, 3) .addEdge(JoinType.INNER_JOIN, 0, 4) .build(); GraphSimplifier graphSimplifier = new GraphSimplifier(hyperGraph); - while (graphSimplifier.applySimplificationStep()) { + List> steps = ImmutableList.>builder() + .add(Pair.of(17L, 2L)) // 04 - 1 + .add(Pair.of(17L, 4L)) // 04 - 2 + .add(Pair.of(17L, 8L)) // 04 - 3 + .add(Pair.of(25L, 2L)) // 034 - 1 + .add(Pair.of(25L, 4L)) // 034 - 2 + .add(Pair.of(29L, 2L)) // 0234 - 1 + .build(); // 0-4-3-2-1 : big left deep tree + for (Pair step : steps) { + if (!graphSimplifier.applySimplificationStep()) { + break; + } + System.out.println(graphSimplifier.getLastAppliedSteps()); + Assertions.assertEquals(step, graphSimplifier.getLastAppliedSteps()); } Counter counter = new Counter(); SubgraphEnumerator subgraphEnumerator = new SubgraphEnumerator(counter, hyperGraph); subgraphEnumerator.enumerate(); for (int count : counter.getAllCount().values()) { - System.out.println(count); Assertions.assertTrue(count < 10); } Assertions.assertTrue(graphSimplifier.isTotalOrder()); @@ -182,24 +197,6 @@ void testHugeStar() { Assertions.assertTrue(graphSimplifier.isTotalOrder()); } - @Test - void testTime() { - int tableNum = 20; - int edgeNum = 40; - double totalTime = 0; - int times = 1; - for (int i = 0; i < times; i++) { - HyperGraph hyperGraph = new HyperGraphBuilder().randomBuildWith(tableNum, edgeNum); - double now = System.currentTimeMillis(); - GraphSimplifier graphSimplifier = new GraphSimplifier(hyperGraph); - while (graphSimplifier.applySimplificationStep()) { - } - totalTime += System.currentTimeMillis() - now; - } - System.out.printf("Simplify graph with %d nodes %d edges cost %f ms%n", tableNum, edgeNum, - totalTime / times); - } - @Test void testComplexQuery() { HyperGraph hyperGraph = new HyperGraphBuilder() @@ -235,7 +232,6 @@ void testRandomQuery() { Counter counter = new Counter(); SubgraphEnumerator subgraphEnumerator = new SubgraphEnumerator(counter, hyperGraph); subgraphEnumerator.enumerate(); - Assertions.assertTrue(graphSimplifier.isTotalOrder()); } } @@ -261,7 +257,7 @@ void benchGraphSimplifier() { int edgeNum = 64 * 63 / 2; int limit = 1000; - int times = 1; + int times = 4; double totalTime = 0; for (int i = 0; i < times; i++) { totalTime += benchGraphSimplifier(tableNum, edgeNum, limit);