Skip to content

Commit

Permalink
mv cost related Pr
Browse files Browse the repository at this point in the history
  • Loading branch information
englefly committed Jul 1, 2024
1 parent 798d9d6 commit 809641f
Show file tree
Hide file tree
Showing 6 changed files with 221 additions and 24 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -164,6 +164,18 @@ public void setSchema(List<Column> newSchema) throws IOException {
initColumnNameMap();
}

public List<Column> getPrefixKeyColumns() {
List<Column> keys = Lists.newArrayList();
for (Column col : schema) {
if (col.isKey()) {
keys.add(col);
} else {
break;
}
}
return keys;
}

public void setSchemaHash(int newSchemaHash) {
this.schemaHash = newSchemaHash;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,21 +17,27 @@

package org.apache.doris.nereids.cost;

import org.apache.doris.catalog.Column;
import org.apache.doris.catalog.KeysType;
import org.apache.doris.catalog.OlapTable;
import org.apache.doris.nereids.PlanContext;
import org.apache.doris.nereids.properties.DistributionSpec;
import org.apache.doris.nereids.properties.DistributionSpecGather;
import org.apache.doris.nereids.properties.DistributionSpecHash;
import org.apache.doris.nereids.properties.DistributionSpecReplicated;
import org.apache.doris.nereids.trees.expressions.ComparisonPredicate;
import org.apache.doris.nereids.trees.expressions.Expression;
import org.apache.doris.nereids.trees.expressions.SlotReference;
import org.apache.doris.nereids.trees.expressions.literal.Literal;
import org.apache.doris.nereids.trees.plans.Plan;
import org.apache.doris.nereids.trees.plans.algebra.OlapScan;
import org.apache.doris.nereids.trees.plans.physical.PhysicalAssertNumRows;
import org.apache.doris.nereids.trees.plans.physical.PhysicalDeferMaterializeOlapScan;
import org.apache.doris.nereids.trees.plans.physical.PhysicalDeferMaterializeTopN;
import org.apache.doris.nereids.trees.plans.physical.PhysicalDistribute;
import org.apache.doris.nereids.trees.plans.physical.PhysicalEsScan;
import org.apache.doris.nereids.trees.plans.physical.PhysicalFileScan;
import org.apache.doris.nereids.trees.plans.physical.PhysicalFilter;
import org.apache.doris.nereids.trees.plans.physical.PhysicalGenerate;
import org.apache.doris.nereids.trees.plans.physical.PhysicalHashAggregate;
import org.apache.doris.nereids.trees.plans.physical.PhysicalHashJoin;
Expand All @@ -52,8 +58,11 @@
import org.apache.doris.statistics.Statistics;

import com.google.common.base.Preconditions;
import com.google.common.collect.Sets;

import java.util.Collections;
import java.util.List;
import java.util.Set;

class CostModelV1 extends PlanVisitor<Cost, PlanContext> {

Expand Down Expand Up @@ -113,6 +122,57 @@ public Cost visitPhysicalOlapScan(PhysicalOlapScan physicalOlapScan, PlanContext
return CostV1.ofCpu(context.getSessionVariable(), rows - aggMvBonus);
}

private Set<Column> getColumnForRangePredicate(Set<Expression> expressions) {
Set<Column> columns = Sets.newHashSet();
for (Expression expr : expressions) {
if (expr instanceof ComparisonPredicate) {
ComparisonPredicate compare = (ComparisonPredicate) expr;
boolean hasLiteral = compare.left() instanceof Literal || compare.right() instanceof Literal;
boolean hasSlot = compare.left() instanceof SlotReference || compare.right() instanceof SlotReference;
if (hasSlot && hasLiteral) {
if (compare.left() instanceof SlotReference) {
if (((SlotReference) compare.left()).getColumn().isPresent()) {
columns.add(((SlotReference) compare.left()).getColumn().get());
}
} else {
if (((SlotReference) compare.right()).getColumn().isPresent()) {
columns.add(((SlotReference) compare.right()).getColumn().get());
}
}
}
}
}
return columns;
}

@Override
public Cost visitPhysicalFilter(PhysicalFilter<? extends Plan> filter, PlanContext context) {
double exprCost = expressionTreeCost(filter.getExpressions());
double filterCostFactor = 0.0001;
if (ConnectContext.get() != null) {
filterCostFactor = ConnectContext.get().getSessionVariable().filterCostFactor;
}
int prefixIndexMatched = 0;
if (filter.getGroupExpression().isPresent()) {
OlapScan olapScan = (OlapScan) filter.getGroupExpression().get().getFirstChildPlan(OlapScan.class);
if (olapScan != null) {
// check prefix index
long idxId = olapScan.getSelectedIndexId();
List<Column> keyColumns = olapScan.getTable().getIndexMetaByIndexId(idxId).getPrefixKeyColumns();
Set<Column> predicateColumns = getColumnForRangePredicate(filter.getConjuncts());
for (Column col : keyColumns) {
if (predicateColumns.contains(col)) {
prefixIndexMatched++;
} else {
break;
}
}
}
}
return CostV1.ofCpu(context.getSessionVariable(),
(filter.getConjuncts().size() - prefixIndexMatched + exprCost) * filterCostFactor);
}

@Override
public Cost visitPhysicalDeferMaterializeOlapScan(PhysicalDeferMaterializeOlapScan deferMaterializeOlapScan,
PlanContext context) {
Expand Down Expand Up @@ -141,7 +201,8 @@ public Cost visitPhysicalFileScan(PhysicalFileScan physicalFileScan, PlanContext

@Override
public Cost visitPhysicalProject(PhysicalProject<? extends Plan> physicalProject, PlanContext context) {
return CostV1.ofCpu(context.getSessionVariable(), 1);
double exprCost = expressionTreeCost(physicalProject.getProjects());
return CostV1.ofCpu(context.getSessionVariable(), exprCost + 1);
}

@Override
Expand Down Expand Up @@ -252,16 +313,29 @@ public Cost visitPhysicalDistribute(
intputRowCount * childStatistics.dataSizeFactor() * RANDOM_SHUFFLE_TO_HASH_SHUFFLE_FACTOR / beNumber);
}

private double expressionTreeCost(List<? extends Expression> expressions) {
double exprCost = 0.0;
ExpressionCostEvaluator expressionCostEvaluator = new ExpressionCostEvaluator();
for (Expression expr : expressions) {
if (!(expr instanceof SlotReference)) {
exprCost += expr.accept(expressionCostEvaluator, null);
}
}
return exprCost;
}

@Override
public Cost visitPhysicalHashAggregate(
PhysicalHashAggregate<? extends Plan> aggregate, PlanContext context) {
Statistics inputStatistics = context.getChildStatistics(0);
double exprCost = expressionTreeCost(aggregate.getExpressions());
if (aggregate.getAggPhase().isLocal()) {
return CostV1.of(context.getSessionVariable(), inputStatistics.getRowCount() / beNumber,
return CostV1.of(context.getSessionVariable(),
exprCost / 100 + inputStatistics.getRowCount() / beNumber,
inputStatistics.getRowCount() / beNumber, 0);
} else {
// global
return CostV1.of(context.getSessionVariable(), inputStatistics.getRowCount(),
return CostV1.of(context.getSessionVariable(), exprCost / 100 + inputStatistics.getRowCount(),
inputStatistics.getRowCount(), 0);
}
}
Expand Down Expand Up @@ -289,7 +363,7 @@ public Cost visitPhysicalHashJoin(

double leftRowCount = probeStats.getRowCount();
double rightRowCount = buildStats.getRowCount();
if (leftRowCount == rightRowCount
if ((long) leftRowCount == (long) rightRowCount
&& physicalHashJoin.getGroupExpression().isPresent()
&& physicalHashJoin.getGroupExpression().get().getOwnerGroup() != null
&& !physicalHashJoin.getGroupExpression().get().getOwnerGroup().isStatsReliable()) {
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,86 @@
// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied. See the License for the
// specific language governing permissions and limitations
// under the License.

package org.apache.doris.nereids.cost;

import org.apache.doris.nereids.trees.expressions.Alias;
import org.apache.doris.nereids.trees.expressions.Expression;
import org.apache.doris.nereids.trees.expressions.SlotReference;
import org.apache.doris.nereids.trees.expressions.literal.Literal;
import org.apache.doris.nereids.trees.expressions.visitor.ExpressionVisitor;
import org.apache.doris.nereids.types.ArrayType;
import org.apache.doris.nereids.types.CharType;
import org.apache.doris.nereids.types.DecimalV2Type;
import org.apache.doris.nereids.types.DecimalV3Type;
import org.apache.doris.nereids.types.MapType;
import org.apache.doris.nereids.types.StringType;
import org.apache.doris.nereids.types.StructType;
import org.apache.doris.nereids.types.VarcharType;

import com.google.common.collect.Maps;

import java.util.Map;

/**
* expression cost is calculated by
* 1. non-leaf tree node count: N
* 2. expression which contains input of stringType or complexType(array/json/struct...), add cost
*/
public class ExpressionCostEvaluator extends ExpressionVisitor<Double, Void> {
private static Map<Class, Double> dataTypeCost = Maps.newHashMap();

static {
dataTypeCost.put(DecimalV2Type.class, 1.5);
dataTypeCost.put(DecimalV3Type.class, 1.5);
dataTypeCost.put(StringType.class, 2.0);
dataTypeCost.put(CharType.class, 2.0);
dataTypeCost.put(VarcharType.class, 2.0);
dataTypeCost.put(ArrayType.class, 3.0);
dataTypeCost.put(MapType.class, 3.0);
dataTypeCost.put(StructType.class, 3.0);
}

@Override
public Double visit(Expression expr, Void context) {
double cost = 0.0;
for (Expression child : expr.children()) {
cost += child.accept(this, context);
// the more children, the more computing cost
cost += dataTypeCost.getOrDefault(child.getDataType().getClass(), 0.1);
}
return cost;
}

@Override
public Double visitSlotReference(SlotReference slot, Void context) {
return 0.0;
}

@Override
public Double visitLiteral(Literal literal, Void context) {
return 0.0;
}

@Override
public Double visitAlias(Alias alias, Void context) {
Expression child = alias.child();
if (child instanceof SlotReference) {
return 0.0;
}
return alias.child().accept(this, context);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -349,4 +349,28 @@ public String toString() {
public ObjectId getId() {
return id;
}

/**
* the first child plan of clazz
* @param clazz the operator type, like join/aggregate
* @return child operator of type clazz, if not found, return null
*/
public Plan getFirstChildPlan(Class clazz) {
for (Group childGroup : children) {
for (GroupExpression logical : childGroup.getLogicalExpressions()) {
if (clazz.isInstance(logical.getPlan())) {
return logical.getPlan();
}
}
}
// for dphyp
for (Group childGroup : children) {
for (GroupExpression physical : childGroup.getPhysicalExpressions()) {
if (clazz.isInstance(physical.getPlan())) {
return physical.getPlan();
}
}
}
return null;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -603,24 +603,24 @@ private Statistics computeAssertNumRows(AssertNumRowsElement assertNumRowsElemen

private Statistics computeFilter(Filter filter) {
Statistics stats = groupExpression.childStatistics(0);
Plan plan = tryToFindChild(groupExpression);
boolean isOnBaseTable = false;
if (plan != null) {
if (plan instanceof OlapScan) {
isOnBaseTable = true;
} else if (plan instanceof Aggregate) {
Aggregate agg = ((Aggregate<?>) plan);
List<NamedExpression> expressions = agg.getOutputExpressions();
Set<Slot> slots = expressions
.stream()
.filter(Alias.class::isInstance)
.filter(s -> ((Alias) s).child().anyMatch(AggregateFunction.class::isInstance))
.map(NamedExpression::toSlot).collect(Collectors.toSet());
Expression predicate = filter.getPredicate();
if (predicate.anyMatch(s -> slots.contains(s))) {
return new FilterEstimation(slots).estimate(filter.getPredicate(), stats);
}
} else if (plan instanceof LogicalJoin && filter instanceof LogicalFilter
if (groupExpression.getFirstChildPlan(OlapScan.class) != null) {
return new FilterEstimation(true).estimate(filter.getPredicate(), stats);
}
if (groupExpression.getFirstChildPlan(Aggregate.class) != null) {
Aggregate agg = (Aggregate<?>) groupExpression.getFirstChildPlan(Aggregate.class);
List<NamedExpression> expressions = agg.getOutputExpressions();
Set<Slot> slots = expressions
.stream()
.filter(Alias.class::isInstance)
.filter(s -> ((Alias) s).child().anyMatch(AggregateFunction.class::isInstance))
.map(NamedExpression::toSlot).collect(Collectors.toSet());
Expression predicate = filter.getPredicate();
if (predicate.anyMatch(s -> slots.contains(s))) {
return new FilterEstimation(slots).estimate(filter.getPredicate(), stats);
}
} else if (groupExpression.getFirstChildPlan(LogicalJoin.class) != null) {
LogicalJoin plan = (LogicalJoin) groupExpression.getFirstChildPlan(LogicalJoin.class);
if (filter instanceof LogicalFilter
&& filter.getConjuncts().stream().anyMatch(e -> e instanceof IsNull)) {
Statistics isNullStats = computeGeneratedIsNullStats((LogicalJoin) plan, filter);
if (isNullStats != null) {
Expand All @@ -640,8 +640,7 @@ private Statistics computeFilter(Filter filter) {
}
}
}

return new FilterEstimation(isOnBaseTable).estimate(filter.getPredicate(), stats);
return new FilterEstimation(false).estimate(filter.getPredicate(), stats);
}

private Statistics computeGeneratedIsNullStats(LogicalJoin join, Filter filter) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1193,6 +1193,8 @@ public void setEnableLeftZigZag(boolean enableLeftZigZag) {
@VariableMgr.VarAttr(name = ENABLE_NEW_COST_MODEL, needForward = true)
private boolean enableNewCostModel = false;

@VariableMgr.VarAttr(name = "filter_cost_factor", needForward = true)
public double filterCostFactor = 0.0001;
@VariableMgr.VarAttr(name = NEREIDS_STAR_SCHEMA_SUPPORT)
private boolean nereidsStarSchemaSupport = true;

Expand Down

0 comments on commit 809641f

Please sign in to comment.