Skip to content

Commit

Permalink
[Fix](ShortCircuit) fix prepared statement with partial arguments pre…
Browse files Browse the repository at this point in the history
…pared (apache#45371)

We should record the placehold id map to both real Expr and the slot of
conjuncts.Otherwise the info is lost, and lead to the conjuncts updated
in wrong order(`updateScanNodeConjuncts`)
  • Loading branch information
eldenmoon committed Dec 20, 2024
1 parent 17cc76e commit b8ce691
Show file tree
Hide file tree
Showing 5 changed files with 160 additions and 33 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@
import org.apache.doris.nereids.trees.expressions.Expression;
import org.apache.doris.nereids.trees.expressions.Placeholder;
import org.apache.doris.nereids.trees.expressions.Slot;
import org.apache.doris.nereids.trees.expressions.SlotReference;
import org.apache.doris.nereids.trees.expressions.StatementScopeIdGenerator;
import org.apache.doris.nereids.trees.plans.ObjectId;
import org.apache.doris.nereids.trees.plans.PlaceholderId;
Expand Down Expand Up @@ -145,6 +146,8 @@ public enum TableFrom {
private final IdGenerator<PlaceholderId> placeHolderIdGenerator = PlaceholderId.createGenerator();
// relation id to placeholders for prepared statement, ordered by placeholder id
private final Map<PlaceholderId, Expression> idToPlaceholderRealExpr = new TreeMap<>();
// map placeholder id to comparison slot, which will used to replace conjuncts directly
private final Map<PlaceholderId, SlotReference> idToComparisonSlot = new TreeMap<>();

// collect all hash join conditions to compute node connectivity in join graph
private final List<Expression> joinFilters = new ArrayList<>();
Expand Down Expand Up @@ -448,6 +451,10 @@ public Map<PlaceholderId, Expression> getIdToPlaceholderRealExpr() {
return idToPlaceholderRealExpr;
}

public Map<PlaceholderId, SlotReference> getIdToComparisonSlot() {
return idToComparisonSlot;
}

public Map<CTEId, List<Pair<Multimap<Slot, Slot>, Group>>> getCteIdToConsumerGroup() {
return cteIdToConsumerGroup;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
import org.apache.doris.common.DdlException;
import org.apache.doris.common.Pair;
import org.apache.doris.common.util.Util;
import org.apache.doris.mysql.MysqlCommand;
import org.apache.doris.nereids.CascadesContext;
import org.apache.doris.nereids.SqlCacheContext;
import org.apache.doris.nereids.StatementContext;
Expand Down Expand Up @@ -78,6 +79,7 @@
import org.apache.doris.nereids.trees.expressions.literal.Literal;
import org.apache.doris.nereids.trees.expressions.literal.StringLiteral;
import org.apache.doris.nereids.trees.expressions.typecoercion.ImplicitCastInputTypes;
import org.apache.doris.nereids.trees.plans.PlaceholderId;
import org.apache.doris.nereids.trees.plans.Plan;
import org.apache.doris.nereids.trees.plans.logical.LogicalPlan;
import org.apache.doris.nereids.types.ArrayType;
Expand Down Expand Up @@ -546,10 +548,29 @@ public Expression visitPlaceholder(Placeholder placeholder, ExpressionRewriteCon
return visit(realExpr, context);
}

// Register prepared statement placeholder id to related slot in comparison predicate.
// Used to replace expression in ShortCircuit plan
private void registerPlaceholderIdToSlot(ComparisonPredicate cp,
ExpressionRewriteContext context, Expression left, Expression right) {
if (ConnectContext.get() != null
&& ConnectContext.get().getCommand() == MysqlCommand.COM_STMT_EXECUTE) {
// Used to replace expression in ShortCircuit plan
if (cp.right() instanceof Placeholder && left instanceof SlotReference) {
PlaceholderId id = ((Placeholder) cp.right()).getPlaceholderId();
context.cascadesContext.getStatementContext().getIdToComparisonSlot().put(id, (SlotReference) left);
} else if (cp.left() instanceof Placeholder && right instanceof SlotReference) {
PlaceholderId id = ((Placeholder) cp.left()).getPlaceholderId();
context.cascadesContext.getStatementContext().getIdToComparisonSlot().put(id, (SlotReference) right);
}
}
}

@Override
public Expression visitComparisonPredicate(ComparisonPredicate cp, ExpressionRewriteContext context) {
Expression left = cp.left().accept(this, context);
Expression right = cp.right().accept(this, context);
// Used to replace expression in ShortCircuit plan
registerPlaceholderIdToSlot(cp, context, left, right);
cp = (ComparisonPredicate) cp.withChildren(left, right);
return TypeCoercionUtils.processComparisonPredicate(cp);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,9 @@
import org.apache.doris.mysql.MysqlCommand;
import org.apache.doris.nereids.StatementContext;
import org.apache.doris.nereids.exceptions.AnalysisException;
import org.apache.doris.nereids.trees.expressions.SlotReference;
import org.apache.doris.nereids.trees.expressions.literal.Literal;
import org.apache.doris.nereids.trees.plans.PlaceholderId;
import org.apache.doris.planner.OlapScanNode;
import org.apache.doris.proto.InternalService;
import org.apache.doris.proto.InternalService.KeyTuple;
Expand All @@ -55,11 +57,11 @@
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import java.util.Map.Entry;
import java.util.concurrent.ExecutionException;
import java.util.concurrent.Future;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.TimeoutException;
import java.util.stream.Collectors;

public class PointQueryExecutor implements CoordInterface {
private static final Logger LOG = LogManager.getLogger(PointQueryExecutor.class);
Expand Down Expand Up @@ -113,33 +115,45 @@ public static void directExecuteShortCircuitQuery(StmtExecutor executor,
Preconditions.checkNotNull(preparedStmtCtx.shortCircuitQueryContext);
ShortCircuitQueryContext shortCircuitQueryContext = preparedStmtCtx.shortCircuitQueryContext.get();
// update conjuncts
List<Expr> conjunctVals = statementContext.getIdToPlaceholderRealExpr().values().stream().map(
expression -> (
(Literal) expression).toLegacyLiteral())
.collect(Collectors.toList());
if (conjunctVals.size() != preparedStmtCtx.command.placeholderCount()) {
Map<String, Expr> colNameToConjunct = Maps.newHashMap();
for (Entry<PlaceholderId, SlotReference> entry : statementContext.getIdToComparisonSlot().entrySet()) {
String colName = entry.getValue().getColumn().get().getName();
Expr conjunctVal = ((Literal) statementContext.getIdToPlaceholderRealExpr()
.get(entry.getKey())).toLegacyLiteral();
colNameToConjunct.put(colName, conjunctVal);
}
if (colNameToConjunct.size() != preparedStmtCtx.command.placeholderCount()) {
throw new AnalysisException("Mismatched conjuncts values size with prepared"
+ "statement parameters size, expected "
+ preparedStmtCtx.command.placeholderCount()
+ ", but meet " + conjunctVals.size());
+ ", but meet " + colNameToConjunct.size());
}
updateScanNodeConjuncts(shortCircuitQueryContext.scanNode, conjunctVals);
updateScanNodeConjuncts(shortCircuitQueryContext.scanNode, colNameToConjunct);
// short circuit plan and execution
executor.executeAndSendResult(false, false,
shortCircuitQueryContext.analzyedQuery, executor.getContext()
.getMysqlChannel(), null, null);
}

private static void updateScanNodeConjuncts(OlapScanNode scanNode, List<Expr> conjunctVals) {
for (int i = 0; i < conjunctVals.size(); ++i) {
BinaryPredicate binaryPredicate = (BinaryPredicate) scanNode.getConjuncts().get(i);
private static void updateScanNodeConjuncts(OlapScanNode scanNode,
Map<String, Expr> colNameToConjunct) {
for (Expr conjunct : scanNode.getConjuncts()) {
BinaryPredicate binaryPredicate = (BinaryPredicate) conjunct;
SlotRef slot = null;
int updateChildIdx = 0;
if (binaryPredicate.getChild(0) instanceof LiteralExpr) {
binaryPredicate.setChild(0, conjunctVals.get(i));
slot = (SlotRef) binaryPredicate.getChildWithoutCast(1);
} else if (binaryPredicate.getChild(1) instanceof LiteralExpr) {
binaryPredicate.setChild(1, conjunctVals.get(i));
slot = (SlotRef) binaryPredicate.getChildWithoutCast(0);
updateChildIdx = 1;
} else {
Preconditions.checkState(false, "Should contains literal in " + binaryPredicate.toSqlImpl());
}
// not a placeholder to replace
if (!colNameToConjunct.containsKey(slot.getColumnName())) {
continue;
}
binaryPredicate.setChild(updateChildIdx, colNameToConjunct.get(slot.getColumnName()));
}
}

Expand Down
31 changes: 31 additions & 0 deletions regression-test/data/point_query_p0/test_point_query.out
Original file line number Diff line number Diff line change
Expand Up @@ -163,3 +163,34 @@
-- !sql --
-10 20 aabc update val

-- !point_select --
user_guid feature sk feature_value 2021-01-01T00:00

-- !point_select --
user_guid feature sk feature_value 2021-01-01T00:00

-- !point_select --
user_guid feature sk feature_value 2021-01-01T00:00

-- !point_select --
user_guid feature sk feature_value 2021-01-01T00:00

-- !point_select --
user_guid feature sk feature_value 2021-01-01T00:00

-- !point_select --
user_guid feature sk feature_value 2021-01-01T00:00

-- !point_select --
user_guid feature sk feature_value 2021-01-01T00:00

-- !point_select --
user_guid feature sk feature_value 2021-01-01T00:00

-- !point_select --
user_guid feature sk feature_value 2021-01-01T00:00

-- !point_select --
user_guid feature sk feature_value 2021-01-01T00:00
>>>>>>> 91c475e0f4 ([Fix](ShortCircuit) fix prepared statement with partial arguments prepared (#45371))

94 changes: 74 additions & 20 deletions regression-test/suites/point_query_p0/test_point_query.groovy
Original file line number Diff line number Diff line change
Expand Up @@ -27,32 +27,30 @@ suite("test_point_query", "nonConcurrent") {
logger.info("update config: code=" + code + ", out=" + out + ", err=" + err)
}
}
def user = context.config.jdbcUser
def password = context.config.jdbcPassword
def realDb = "regression_test_serving_p0"
// Parse url
String jdbcUrl = context.config.jdbcUrl
String urlWithoutSchema = jdbcUrl.substring(jdbcUrl.indexOf("://") + 3)
def sql_ip = urlWithoutSchema.substring(0, urlWithoutSchema.indexOf(":"))
def sql_port
if (urlWithoutSchema.indexOf("/") >= 0) {
// e.g: jdbc:mysql://locahost:8080/?a=b
sql_port = urlWithoutSchema.substring(urlWithoutSchema.indexOf(":") + 1, urlWithoutSchema.indexOf("/"))
} else {
// e.g: jdbc:mysql://locahost:8080
sql_port = urlWithoutSchema.substring(urlWithoutSchema.indexOf(":") + 1)
}
// set server side prepared statement url
def prepare_url = "jdbc:mysql://" + sql_ip + ":" + sql_port + "/" + realDb + "?&useServerPrepStmts=true"
try {
set_be_config.call("disable_storage_row_cache", "false")
// nereids do not support point query now
sql "set global enable_fallback_to_original_planner = false"
sql """set global enable_nereids_planner=true"""
def user = context.config.jdbcUser
def password = context.config.jdbcPassword
def realDb = "regression_test_serving_p0"
def tableName = realDb + ".tbl_point_query"
sql "CREATE DATABASE IF NOT EXISTS ${realDb}"

// Parse url
String jdbcUrl = context.config.jdbcUrl
String urlWithoutSchema = jdbcUrl.substring(jdbcUrl.indexOf("://") + 3)
def sql_ip = urlWithoutSchema.substring(0, urlWithoutSchema.indexOf(":"))
def sql_port
if (urlWithoutSchema.indexOf("/") >= 0) {
// e.g: jdbc:mysql://locahost:8080/?a=b
sql_port = urlWithoutSchema.substring(urlWithoutSchema.indexOf(":") + 1, urlWithoutSchema.indexOf("/"))
} else {
// e.g: jdbc:mysql://locahost:8080
sql_port = urlWithoutSchema.substring(urlWithoutSchema.indexOf(":") + 1)
}
// set server side prepared statement url
def prepare_url = "jdbc:mysql://" + sql_ip + ":" + sql_port + "/" + realDb + "?&useServerPrepStmts=true"

def generateString = {len ->
def str = ""
for (int i = 0; i < len; i++) {
Expand Down Expand Up @@ -331,4 +329,60 @@ suite("test_point_query", "nonConcurrent") {
sql "update table_3821461 set value = 'update value' where col1 = -10 or col1 = 20;"
qt_sql """select * from table_3821461 where col1 = -10 and col2 = 20 and loc3 = 'aabc'"""
qt_sql """select * from table_3821461 where col2 = 20 and loc3 = 'aabc' and col1 = -10 """
}

sql "DROP TABLE IF EXISTS test_partial_prepared_statement"
sql """
CREATE TABLE `test_partial_prepared_statement` (
`user_guid` varchar(64) NOT NULL,
`feature` varchar(256) NOT NULL,
`sk` varchar(256) NOT NULL,
`feature_value` text NULL,
`data_time` datetime NOT NULL
) ENGINE=OLAP
UNIQUE KEY(`user_guid`, `feature`, `sk`)
DISTRIBUTED BY HASH(`user_guid`) BUCKETS 32
PROPERTIES (
"enable_unique_key_merge_on_write" = "true",
"light_schema_change" = "true",
"function_column.sequence_col" = "data_time",
"store_row_column" = "true",
"replication_num" = "1",
"row_store_page_size" = "16384"
);
"""
sql "insert into test_partial_prepared_statement values ('user_guid', 'feature', 'sk','feature_value', '2021-01-01 00:00:00')"
def result2 = connect(user, password, prepare_url) {
def partial_prepared_stmt = prepareStatement "select /*+ SET_VAR(enable_nereids_planner=true) */ * from regression_test_point_query_p0.test_partial_prepared_statement where sk = 'sk' and user_guid = 'user_guid' and feature = ? "
assertEquals(partial_prepared_stmt.class, com.mysql.cj.jdbc.ServerPreparedStatement);
partial_prepared_stmt.setString(1, "feature")
qe_point_select partial_prepared_stmt
qe_point_select partial_prepared_stmt

partial_prepared_stmt = prepareStatement "select /*+ SET_VAR(enable_nereids_planner=true) */ * from regression_test_point_query_p0.test_partial_prepared_statement where user_guid = ? and feature = 'feature' and sk = ?"
assertEquals(partial_prepared_stmt.class, com.mysql.cj.jdbc.ServerPreparedStatement);
partial_prepared_stmt.setString(1, "user_guid")
partial_prepared_stmt.setString(2, "sk")
qe_point_select partial_prepared_stmt
qe_point_select partial_prepared_stmt

partial_prepared_stmt = prepareStatement "select /*+ SET_VAR(enable_nereids_planner=true) */ * from regression_test_point_query_p0.test_partial_prepared_statement where ? = user_guid and sk = 'sk' and feature = 'feature' "
assertEquals(partial_prepared_stmt.class, com.mysql.cj.jdbc.ServerPreparedStatement);
partial_prepared_stmt.setString(1, "user_guid")
qe_point_select partial_prepared_stmt
qe_point_select partial_prepared_stmt

partial_prepared_stmt = prepareStatement "select /*+ SET_VAR(enable_nereids_planner=true) */ * from regression_test_point_query_p0.test_partial_prepared_statement where ? = user_guid and sk = 'sk' and feature = ? "
assertEquals(partial_prepared_stmt.class, com.mysql.cj.jdbc.ServerPreparedStatement);
partial_prepared_stmt.setString(1, "user_guid")
partial_prepared_stmt.setString(2, "feature")
qe_point_select partial_prepared_stmt
qe_point_select partial_prepared_stmt

partial_prepared_stmt = prepareStatement "select /*+ SET_VAR(enable_nereids_planner=true) */ * from regression_test_point_query_p0.test_partial_prepared_statement where sk = ? and feature = ? and 'user_guid' = user_guid"
assertEquals(partial_prepared_stmt.class, com.mysql.cj.jdbc.ServerPreparedStatement);
partial_prepared_stmt.setString(1, "sk")
partial_prepared_stmt.setString(2, "feature")
qe_point_select partial_prepared_stmt
qe_point_select partial_prepared_stmt
}
}

0 comments on commit b8ce691

Please sign in to comment.