Skip to content

Commit

Permalink
[Fix](ShortCircuit) fix prepared statement with partial arguments pre…
Browse files Browse the repository at this point in the history
…pared

We should record the placehold id map to both real Expr and the slot of conjuncts.Otherwise the info is lost
  • Loading branch information
eldenmoon committed Dec 12, 2024
1 parent d4adf92 commit b3422a0
Show file tree
Hide file tree
Showing 3 changed files with 50 additions and 13 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -132,6 +132,8 @@ public class StatementContext implements Closeable {
private final IdGenerator<PlaceholderId> placeHolderIdGenerator = PlaceholderId.createGenerator();
// relation id to placeholders for prepared statement, ordered by placeholder id
private final Map<PlaceholderId, Expression> idToPlaceholderRealExpr = new TreeMap<>();
// map placeholder id to comparison slot, which will used to replace conjuncts directly
private final Map<PlaceholderId, SlotReference> idToComparisonSlot = new TreeMap<>();

// collect all hash join conditions to compute node connectivity in join graph
private final List<Expression> joinFilters = new ArrayList<>();
Expand Down Expand Up @@ -399,6 +401,10 @@ public Map<PlaceholderId, Expression> getIdToPlaceholderRealExpr() {
return idToPlaceholderRealExpr;
}

public Map<PlaceholderId, SlotReference> getIdToComparisonSlot() {
return idToComparisonSlot;
}

public Map<CTEId, List<Pair<Multimap<Slot, Slot>, Group>>> getCteIdToConsumerGroup() {
return cteIdToConsumerGroup;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,7 @@
import org.apache.doris.nereids.trees.expressions.literal.NullLiteral;
import org.apache.doris.nereids.trees.expressions.literal.StringLiteral;
import org.apache.doris.nereids.trees.expressions.typecoercion.ImplicitCastInputTypes;
import org.apache.doris.nereids.trees.plans.PlaceholderId;
import org.apache.doris.nereids.trees.plans.Plan;
import org.apache.doris.nereids.trees.plans.logical.LogicalPlan;
import org.apache.doris.nereids.types.ArrayType;
Expand Down Expand Up @@ -583,10 +584,26 @@ public Expression visitPlaceholder(Placeholder placeholder, ExpressionRewriteCon
return visit(realExpr, context);
}

// Register placeholder id to related slot in comparison predicate.
// Used to replace expression in ShortCircuit plan
private void registerPlaceholdIdToSlot(ComparisonPredicate cp,
ExpressionRewriteContext context, Expression left, Expression right) {
// Used to replace expression in ShortCircuit plan
if (cp.right() instanceof Placeholder && left instanceof SlotReference) {
PlaceholderId id = ((Placeholder) cp.right()).getPlaceholderId();
context.cascadesContext.getStatementContext().getIdToComparisonSlot().put(id, (SlotReference) left);
} else if (cp.left() instanceof Placeholder && right instanceof SlotReference) {
PlaceholderId id = ((Placeholder) cp.left()).getPlaceholderId();
context.cascadesContext.getStatementContext().getIdToComparisonSlot().put(id, (SlotReference) right);
}
}

@Override
public Expression visitComparisonPredicate(ComparisonPredicate cp, ExpressionRewriteContext context) {
Expression left = cp.left().accept(this, context);
Expression right = cp.right().accept(this, context);
// Used to replace expression in ShortCircuit plan
registerPlaceholdIdToSlot(cp, context, left, right);
cp = (ComparisonPredicate) cp.withChildren(left, right);
return TypeCoercionUtils.processComparisonPredicate(cp);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,9 @@
import org.apache.doris.mysql.MysqlCommand;
import org.apache.doris.nereids.StatementContext;
import org.apache.doris.nereids.exceptions.AnalysisException;
import org.apache.doris.nereids.trees.expressions.SlotReference;
import org.apache.doris.nereids.trees.expressions.literal.Literal;
import org.apache.doris.nereids.trees.plans.PlaceholderId;
import org.apache.doris.planner.OlapScanNode;
import org.apache.doris.proto.InternalService;
import org.apache.doris.proto.InternalService.KeyTuple;
Expand Down Expand Up @@ -59,12 +61,12 @@
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import java.util.Map.Entry;
import java.util.Set;
import java.util.concurrent.ExecutionException;
import java.util.concurrent.Future;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.TimeoutException;
import java.util.stream.Collectors;

public class PointQueryExecutor implements CoordInterface {
private static final Logger LOG = LogManager.getLogger(PointQueryExecutor.class);
Expand Down Expand Up @@ -142,33 +144,45 @@ public static void directExecuteShortCircuitQuery(StmtExecutor executor,
Preconditions.checkNotNull(preparedStmtCtx.shortCircuitQueryContext);
ShortCircuitQueryContext shortCircuitQueryContext = preparedStmtCtx.shortCircuitQueryContext.get();
// update conjuncts
List<Expr> conjunctVals = statementContext.getIdToPlaceholderRealExpr().values().stream().map(
expression -> (
(Literal) expression).toLegacyLiteral())
.collect(Collectors.toList());
if (conjunctVals.size() != preparedStmtCtx.command.placeholderCount()) {
Map<String, Expr> colNameToConjunct = Maps.newHashMap();
for (Entry<PlaceholderId, SlotReference> entry : statementContext.getIdToComparisonSlot().entrySet()) {
String colName = entry.getValue().getColumn().get().getName();
Expr conjunctVal = ((Literal) statementContext.getIdToPlaceholderRealExpr()
.get(entry.getKey())).toLegacyLiteral();
colNameToConjunct.put(colName, conjunctVal);
}
if (colNameToConjunct.size() != preparedStmtCtx.command.placeholderCount()) {
throw new AnalysisException("Mismatched conjuncts values size with prepared"
+ "statement parameters size, expected "
+ preparedStmtCtx.command.placeholderCount()
+ ", but meet " + conjunctVals.size());
+ ", but meet " + colNameToConjunct.size());
}
updateScanNodeConjuncts(shortCircuitQueryContext.scanNode, conjunctVals);
updateScanNodeConjuncts(shortCircuitQueryContext.scanNode, colNameToConjunct);
// short circuit plan and execution
executor.executeAndSendResult(false, false,
shortCircuitQueryContext.analzyedQuery, executor.getContext()
.getMysqlChannel(), null, null);
}

private static void updateScanNodeConjuncts(OlapScanNode scanNode, List<Expr> conjunctVals) {
for (int i = 0; i < conjunctVals.size(); ++i) {
BinaryPredicate binaryPredicate = (BinaryPredicate) scanNode.getConjuncts().get(i);
private static void updateScanNodeConjuncts(OlapScanNode scanNode,
Map<String, Expr> colNameToConjunct) {
for (Expr conjunct : scanNode.getConjuncts()) {
BinaryPredicate binaryPredicate = (BinaryPredicate) conjunct;
SlotRef slot = null;
int updateChildIdx = 0;
if (binaryPredicate.getChild(0) instanceof LiteralExpr) {
binaryPredicate.setChild(0, conjunctVals.get(i));
slot = (SlotRef) binaryPredicate.getChild(1);
} else if (binaryPredicate.getChild(1) instanceof LiteralExpr) {
binaryPredicate.setChild(1, conjunctVals.get(i));
slot = (SlotRef) binaryPredicate.getChild(0);
updateChildIdx = 1;
} else {
Preconditions.checkState(false, "Should contains literal in " + binaryPredicate.toSqlImpl());
}
// not a placeholder to replace
if (!colNameToConjunct.containsKey(slot.getColumnName())) {
continue;
}
binaryPredicate.setChild(updateChildIdx, colNameToConjunct.get(slot.getColumnName()));
}
}

Expand Down

0 comments on commit b3422a0

Please sign in to comment.