diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/StatementContext.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/StatementContext.java index 5f1c4ea7272101..c9f3264f8bb978 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/StatementContext.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/StatementContext.java @@ -37,6 +37,7 @@ import org.apache.doris.nereids.trees.expressions.Expression; import org.apache.doris.nereids.trees.expressions.Placeholder; import org.apache.doris.nereids.trees.expressions.Slot; +import org.apache.doris.nereids.trees.expressions.SlotReference; import org.apache.doris.nereids.trees.expressions.StatementScopeIdGenerator; import org.apache.doris.nereids.trees.plans.ObjectId; import org.apache.doris.nereids.trees.plans.PlaceholderId; @@ -145,6 +146,8 @@ public enum TableFrom { private final IdGenerator placeHolderIdGenerator = PlaceholderId.createGenerator(); // relation id to placeholders for prepared statement, ordered by placeholder id private final Map idToPlaceholderRealExpr = new TreeMap<>(); + // map placeholder id to comparison slot, which will used to replace conjuncts directly + private final Map idToComparisonSlot = new TreeMap<>(); // collect all hash join conditions to compute node connectivity in join graph private final List joinFilters = new ArrayList<>(); @@ -448,6 +451,10 @@ public Map getIdToPlaceholderRealExpr() { return idToPlaceholderRealExpr; } + public Map getIdToComparisonSlot() { + return idToComparisonSlot; + } + public Map, Group>>> getCteIdToConsumerGroup() { return cteIdToConsumerGroup; } diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/analysis/ExpressionAnalyzer.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/analysis/ExpressionAnalyzer.java index 5bc586d2d42d55..c06d9c769511b2 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/analysis/ExpressionAnalyzer.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/analysis/ExpressionAnalyzer.java @@ -24,6 +24,7 @@ import org.apache.doris.common.DdlException; import org.apache.doris.common.Pair; import org.apache.doris.common.util.Util; +import org.apache.doris.mysql.MysqlCommand; import org.apache.doris.nereids.CascadesContext; import org.apache.doris.nereids.SqlCacheContext; import org.apache.doris.nereids.StatementContext; @@ -78,6 +79,7 @@ import org.apache.doris.nereids.trees.expressions.literal.Literal; import org.apache.doris.nereids.trees.expressions.literal.StringLiteral; import org.apache.doris.nereids.trees.expressions.typecoercion.ImplicitCastInputTypes; +import org.apache.doris.nereids.trees.plans.PlaceholderId; import org.apache.doris.nereids.trees.plans.Plan; import org.apache.doris.nereids.trees.plans.logical.LogicalPlan; import org.apache.doris.nereids.types.ArrayType; @@ -546,10 +548,29 @@ public Expression visitPlaceholder(Placeholder placeholder, ExpressionRewriteCon return visit(realExpr, context); } + // Register prepared statement placeholder id to related slot in comparison predicate. + // Used to replace expression in ShortCircuit plan + private void registerPlaceholderIdToSlot(ComparisonPredicate cp, + ExpressionRewriteContext context, Expression left, Expression right) { + if (ConnectContext.get() != null + && ConnectContext.get().getCommand() == MysqlCommand.COM_STMT_EXECUTE) { + // Used to replace expression in ShortCircuit plan + if (cp.right() instanceof Placeholder && left instanceof SlotReference) { + PlaceholderId id = ((Placeholder) cp.right()).getPlaceholderId(); + context.cascadesContext.getStatementContext().getIdToComparisonSlot().put(id, (SlotReference) left); + } else if (cp.left() instanceof Placeholder && right instanceof SlotReference) { + PlaceholderId id = ((Placeholder) cp.left()).getPlaceholderId(); + context.cascadesContext.getStatementContext().getIdToComparisonSlot().put(id, (SlotReference) right); + } + } + } + @Override public Expression visitComparisonPredicate(ComparisonPredicate cp, ExpressionRewriteContext context) { Expression left = cp.left().accept(this, context); Expression right = cp.right().accept(this, context); + // Used to replace expression in ShortCircuit plan + registerPlaceholderIdToSlot(cp, context, left, right); cp = (ComparisonPredicate) cp.withChildren(left, right); return TypeCoercionUtils.processComparisonPredicate(cp); } diff --git a/fe/fe-core/src/main/java/org/apache/doris/qe/PointQueryExecutor.java b/fe/fe-core/src/main/java/org/apache/doris/qe/PointQueryExecutor.java index 72a5a8e66a8a98..471655a3e79562 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/qe/PointQueryExecutor.java +++ b/fe/fe-core/src/main/java/org/apache/doris/qe/PointQueryExecutor.java @@ -29,7 +29,9 @@ import org.apache.doris.mysql.MysqlCommand; import org.apache.doris.nereids.StatementContext; import org.apache.doris.nereids.exceptions.AnalysisException; +import org.apache.doris.nereids.trees.expressions.SlotReference; import org.apache.doris.nereids.trees.expressions.literal.Literal; +import org.apache.doris.nereids.trees.plans.PlaceholderId; import org.apache.doris.planner.OlapScanNode; import org.apache.doris.proto.InternalService; import org.apache.doris.proto.InternalService.KeyTuple; @@ -55,11 +57,11 @@ import java.util.Iterator; import java.util.List; import java.util.Map; +import java.util.Map.Entry; import java.util.concurrent.ExecutionException; import java.util.concurrent.Future; import java.util.concurrent.TimeUnit; import java.util.concurrent.TimeoutException; -import java.util.stream.Collectors; public class PointQueryExecutor implements CoordInterface { private static final Logger LOG = LogManager.getLogger(PointQueryExecutor.class); @@ -113,33 +115,45 @@ public static void directExecuteShortCircuitQuery(StmtExecutor executor, Preconditions.checkNotNull(preparedStmtCtx.shortCircuitQueryContext); ShortCircuitQueryContext shortCircuitQueryContext = preparedStmtCtx.shortCircuitQueryContext.get(); // update conjuncts - List conjunctVals = statementContext.getIdToPlaceholderRealExpr().values().stream().map( - expression -> ( - (Literal) expression).toLegacyLiteral()) - .collect(Collectors.toList()); - if (conjunctVals.size() != preparedStmtCtx.command.placeholderCount()) { + Map colNameToConjunct = Maps.newHashMap(); + for (Entry entry : statementContext.getIdToComparisonSlot().entrySet()) { + String colName = entry.getValue().getColumn().get().getName(); + Expr conjunctVal = ((Literal) statementContext.getIdToPlaceholderRealExpr() + .get(entry.getKey())).toLegacyLiteral(); + colNameToConjunct.put(colName, conjunctVal); + } + if (colNameToConjunct.size() != preparedStmtCtx.command.placeholderCount()) { throw new AnalysisException("Mismatched conjuncts values size with prepared" + "statement parameters size, expected " + preparedStmtCtx.command.placeholderCount() - + ", but meet " + conjunctVals.size()); + + ", but meet " + colNameToConjunct.size()); } - updateScanNodeConjuncts(shortCircuitQueryContext.scanNode, conjunctVals); + updateScanNodeConjuncts(shortCircuitQueryContext.scanNode, colNameToConjunct); // short circuit plan and execution executor.executeAndSendResult(false, false, shortCircuitQueryContext.analzyedQuery, executor.getContext() .getMysqlChannel(), null, null); } - private static void updateScanNodeConjuncts(OlapScanNode scanNode, List conjunctVals) { - for (int i = 0; i < conjunctVals.size(); ++i) { - BinaryPredicate binaryPredicate = (BinaryPredicate) scanNode.getConjuncts().get(i); + private static void updateScanNodeConjuncts(OlapScanNode scanNode, + Map colNameToConjunct) { + for (Expr conjunct : scanNode.getConjuncts()) { + BinaryPredicate binaryPredicate = (BinaryPredicate) conjunct; + SlotRef slot = null; + int updateChildIdx = 0; if (binaryPredicate.getChild(0) instanceof LiteralExpr) { - binaryPredicate.setChild(0, conjunctVals.get(i)); + slot = (SlotRef) binaryPredicate.getChildWithoutCast(1); } else if (binaryPredicate.getChild(1) instanceof LiteralExpr) { - binaryPredicate.setChild(1, conjunctVals.get(i)); + slot = (SlotRef) binaryPredicate.getChildWithoutCast(0); + updateChildIdx = 1; } else { Preconditions.checkState(false, "Should contains literal in " + binaryPredicate.toSqlImpl()); } + // not a placeholder to replace + if (!colNameToConjunct.containsKey(slot.getColumnName())) { + continue; + } + binaryPredicate.setChild(updateChildIdx, colNameToConjunct.get(slot.getColumnName())); } } diff --git a/regression-test/data/point_query_p0/test_point_query.out b/regression-test/data/point_query_p0/test_point_query.out index 3003a098a5b826..7245020673b4a9 100644 --- a/regression-test/data/point_query_p0/test_point_query.out +++ b/regression-test/data/point_query_p0/test_point_query.out @@ -163,3 +163,34 @@ -- !sql -- -10 20 aabc update val +-- !point_select -- +user_guid feature sk feature_value 2021-01-01T00:00 + +-- !point_select -- +user_guid feature sk feature_value 2021-01-01T00:00 + +-- !point_select -- +user_guid feature sk feature_value 2021-01-01T00:00 + +-- !point_select -- +user_guid feature sk feature_value 2021-01-01T00:00 + +-- !point_select -- +user_guid feature sk feature_value 2021-01-01T00:00 + +-- !point_select -- +user_guid feature sk feature_value 2021-01-01T00:00 + +-- !point_select -- +user_guid feature sk feature_value 2021-01-01T00:00 + +-- !point_select -- +user_guid feature sk feature_value 2021-01-01T00:00 + +-- !point_select -- +user_guid feature sk feature_value 2021-01-01T00:00 + +-- !point_select -- +user_guid feature sk feature_value 2021-01-01T00:00 +>>>>>>> 91c475e0f4 ([Fix](ShortCircuit) fix prepared statement with partial arguments prepared (#45371)) + diff --git a/regression-test/suites/point_query_p0/test_point_query.groovy b/regression-test/suites/point_query_p0/test_point_query.groovy index d9b34803d4fcdb..f552a86ed54d69 100644 --- a/regression-test/suites/point_query_p0/test_point_query.groovy +++ b/regression-test/suites/point_query_p0/test_point_query.groovy @@ -27,32 +27,30 @@ suite("test_point_query", "nonConcurrent") { logger.info("update config: code=" + code + ", out=" + out + ", err=" + err) } } + def user = context.config.jdbcUser + def password = context.config.jdbcPassword + def realDb = "regression_test_serving_p0" + // Parse url + String jdbcUrl = context.config.jdbcUrl + String urlWithoutSchema = jdbcUrl.substring(jdbcUrl.indexOf("://") + 3) + def sql_ip = urlWithoutSchema.substring(0, urlWithoutSchema.indexOf(":")) + def sql_port + if (urlWithoutSchema.indexOf("/") >= 0) { + // e.g: jdbc:mysql://locahost:8080/?a=b + sql_port = urlWithoutSchema.substring(urlWithoutSchema.indexOf(":") + 1, urlWithoutSchema.indexOf("/")) + } else { + // e.g: jdbc:mysql://locahost:8080 + sql_port = urlWithoutSchema.substring(urlWithoutSchema.indexOf(":") + 1) + } + // set server side prepared statement url + def prepare_url = "jdbc:mysql://" + sql_ip + ":" + sql_port + "/" + realDb + "?&useServerPrepStmts=true" try { set_be_config.call("disable_storage_row_cache", "false") - // nereids do not support point query now sql "set global enable_fallback_to_original_planner = false" sql """set global enable_nereids_planner=true""" - def user = context.config.jdbcUser - def password = context.config.jdbcPassword - def realDb = "regression_test_serving_p0" def tableName = realDb + ".tbl_point_query" sql "CREATE DATABASE IF NOT EXISTS ${realDb}" - // Parse url - String jdbcUrl = context.config.jdbcUrl - String urlWithoutSchema = jdbcUrl.substring(jdbcUrl.indexOf("://") + 3) - def sql_ip = urlWithoutSchema.substring(0, urlWithoutSchema.indexOf(":")) - def sql_port - if (urlWithoutSchema.indexOf("/") >= 0) { - // e.g: jdbc:mysql://locahost:8080/?a=b - sql_port = urlWithoutSchema.substring(urlWithoutSchema.indexOf(":") + 1, urlWithoutSchema.indexOf("/")) - } else { - // e.g: jdbc:mysql://locahost:8080 - sql_port = urlWithoutSchema.substring(urlWithoutSchema.indexOf(":") + 1) - } - // set server side prepared statement url - def prepare_url = "jdbc:mysql://" + sql_ip + ":" + sql_port + "/" + realDb + "?&useServerPrepStmts=true" - def generateString = {len -> def str = "" for (int i = 0; i < len; i++) { @@ -331,4 +329,60 @@ suite("test_point_query", "nonConcurrent") { sql "update table_3821461 set value = 'update value' where col1 = -10 or col1 = 20;" qt_sql """select * from table_3821461 where col1 = -10 and col2 = 20 and loc3 = 'aabc'""" qt_sql """select * from table_3821461 where col2 = 20 and loc3 = 'aabc' and col1 = -10 """ -} \ No newline at end of file + + sql "DROP TABLE IF EXISTS test_partial_prepared_statement" + sql """ + CREATE TABLE `test_partial_prepared_statement` ( + `user_guid` varchar(64) NOT NULL, + `feature` varchar(256) NOT NULL, + `sk` varchar(256) NOT NULL, + `feature_value` text NULL, + `data_time` datetime NOT NULL + ) ENGINE=OLAP + UNIQUE KEY(`user_guid`, `feature`, `sk`) + DISTRIBUTED BY HASH(`user_guid`) BUCKETS 32 + PROPERTIES ( + "enable_unique_key_merge_on_write" = "true", + "light_schema_change" = "true", + "function_column.sequence_col" = "data_time", + "store_row_column" = "true", + "replication_num" = "1", + "row_store_page_size" = "16384" + ); + """ + sql "insert into test_partial_prepared_statement values ('user_guid', 'feature', 'sk','feature_value', '2021-01-01 00:00:00')" + def result2 = connect(user, password, prepare_url) { + def partial_prepared_stmt = prepareStatement "select /*+ SET_VAR(enable_nereids_planner=true) */ * from regression_test_point_query_p0.test_partial_prepared_statement where sk = 'sk' and user_guid = 'user_guid' and feature = ? " + assertEquals(partial_prepared_stmt.class, com.mysql.cj.jdbc.ServerPreparedStatement); + partial_prepared_stmt.setString(1, "feature") + qe_point_select partial_prepared_stmt + qe_point_select partial_prepared_stmt + + partial_prepared_stmt = prepareStatement "select /*+ SET_VAR(enable_nereids_planner=true) */ * from regression_test_point_query_p0.test_partial_prepared_statement where user_guid = ? and feature = 'feature' and sk = ?" + assertEquals(partial_prepared_stmt.class, com.mysql.cj.jdbc.ServerPreparedStatement); + partial_prepared_stmt.setString(1, "user_guid") + partial_prepared_stmt.setString(2, "sk") + qe_point_select partial_prepared_stmt + qe_point_select partial_prepared_stmt + + partial_prepared_stmt = prepareStatement "select /*+ SET_VAR(enable_nereids_planner=true) */ * from regression_test_point_query_p0.test_partial_prepared_statement where ? = user_guid and sk = 'sk' and feature = 'feature' " + assertEquals(partial_prepared_stmt.class, com.mysql.cj.jdbc.ServerPreparedStatement); + partial_prepared_stmt.setString(1, "user_guid") + qe_point_select partial_prepared_stmt + qe_point_select partial_prepared_stmt + + partial_prepared_stmt = prepareStatement "select /*+ SET_VAR(enable_nereids_planner=true) */ * from regression_test_point_query_p0.test_partial_prepared_statement where ? = user_guid and sk = 'sk' and feature = ? " + assertEquals(partial_prepared_stmt.class, com.mysql.cj.jdbc.ServerPreparedStatement); + partial_prepared_stmt.setString(1, "user_guid") + partial_prepared_stmt.setString(2, "feature") + qe_point_select partial_prepared_stmt + qe_point_select partial_prepared_stmt + + partial_prepared_stmt = prepareStatement "select /*+ SET_VAR(enable_nereids_planner=true) */ * from regression_test_point_query_p0.test_partial_prepared_statement where sk = ? and feature = ? and 'user_guid' = user_guid" + assertEquals(partial_prepared_stmt.class, com.mysql.cj.jdbc.ServerPreparedStatement); + partial_prepared_stmt.setString(1, "sk") + partial_prepared_stmt.setString(2, "feature") + qe_point_select partial_prepared_stmt + qe_point_select partial_prepared_stmt + } +}