Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Fix](ShortCircuit) fix prepared statement with partial arguments prepared #45371

Merged
merged 4 commits into from
Dec 16, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -132,6 +132,8 @@ public class StatementContext implements Closeable {
private final IdGenerator<PlaceholderId> placeHolderIdGenerator = PlaceholderId.createGenerator();
// relation id to placeholders for prepared statement, ordered by placeholder id
private final Map<PlaceholderId, Expression> idToPlaceholderRealExpr = new TreeMap<>();
// map placeholder id to comparison slot, which will used to replace conjuncts directly
private final Map<PlaceholderId, SlotReference> idToComparisonSlot = new TreeMap<>();

// collect all hash join conditions to compute node connectivity in join graph
private final List<Expression> joinFilters = new ArrayList<>();
Expand Down Expand Up @@ -399,6 +401,10 @@ public Map<PlaceholderId, Expression> getIdToPlaceholderRealExpr() {
return idToPlaceholderRealExpr;
}

public Map<PlaceholderId, SlotReference> getIdToComparisonSlot() {
return idToComparisonSlot;
}

public Map<CTEId, List<Pair<Multimap<Slot, Slot>, Group>>> getCteIdToConsumerGroup() {
return cteIdToConsumerGroup;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
import org.apache.doris.common.DdlException;
import org.apache.doris.common.Pair;
import org.apache.doris.common.util.Util;
import org.apache.doris.mysql.MysqlCommand;
import org.apache.doris.nereids.CascadesContext;
import org.apache.doris.nereids.SqlCacheContext;
import org.apache.doris.nereids.StatementContext;
Expand Down Expand Up @@ -77,6 +78,7 @@
import org.apache.doris.nereids.trees.expressions.literal.NullLiteral;
import org.apache.doris.nereids.trees.expressions.literal.StringLiteral;
import org.apache.doris.nereids.trees.expressions.typecoercion.ImplicitCastInputTypes;
import org.apache.doris.nereids.trees.plans.PlaceholderId;
import org.apache.doris.nereids.trees.plans.Plan;
import org.apache.doris.nereids.trees.plans.logical.LogicalPlan;
import org.apache.doris.nereids.types.ArrayType;
Expand Down Expand Up @@ -583,10 +585,29 @@ public Expression visitPlaceholder(Placeholder placeholder, ExpressionRewriteCon
return visit(realExpr, context);
}

// Register prepared statement placeholder id to related slot in comparison predicate.
// Used to replace expression in ShortCircuit plan
private void registerPlaceholderIdToSlot(ComparisonPredicate cp,
ExpressionRewriteContext context, Expression left, Expression right) {
if (ConnectContext.get() != null
&& ConnectContext.get().getCommand() == MysqlCommand.COM_STMT_EXECUTE) {
// Used to replace expression in ShortCircuit plan
if (cp.right() instanceof Placeholder && left instanceof SlotReference) {
PlaceholderId id = ((Placeholder) cp.right()).getPlaceholderId();
context.cascadesContext.getStatementContext().getIdToComparisonSlot().put(id, (SlotReference) left);
} else if (cp.left() instanceof Placeholder && right instanceof SlotReference) {
PlaceholderId id = ((Placeholder) cp.left()).getPlaceholderId();
context.cascadesContext.getStatementContext().getIdToComparisonSlot().put(id, (SlotReference) right);
}
}
}

@Override
public Expression visitComparisonPredicate(ComparisonPredicate cp, ExpressionRewriteContext context) {
Expression left = cp.left().accept(this, context);
Expression right = cp.right().accept(this, context);
// Used to replace expression in ShortCircuit plan
registerPlaceholderIdToSlot(cp, context, left, right);
cp = (ComparisonPredicate) cp.withChildren(left, right);
return TypeCoercionUtils.processComparisonPredicate(cp);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,9 @@
import org.apache.doris.mysql.MysqlCommand;
import org.apache.doris.nereids.StatementContext;
import org.apache.doris.nereids.exceptions.AnalysisException;
import org.apache.doris.nereids.trees.expressions.SlotReference;
import org.apache.doris.nereids.trees.expressions.literal.Literal;
import org.apache.doris.nereids.trees.plans.PlaceholderId;
import org.apache.doris.planner.OlapScanNode;
import org.apache.doris.proto.InternalService;
import org.apache.doris.proto.InternalService.KeyTuple;
Expand Down Expand Up @@ -59,12 +61,12 @@
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import java.util.Map.Entry;
import java.util.Set;
import java.util.concurrent.ExecutionException;
import java.util.concurrent.Future;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.TimeoutException;
import java.util.stream.Collectors;

public class PointQueryExecutor implements CoordInterface {
private static final Logger LOG = LogManager.getLogger(PointQueryExecutor.class);
Expand Down Expand Up @@ -142,33 +144,45 @@ public static void directExecuteShortCircuitQuery(StmtExecutor executor,
Preconditions.checkNotNull(preparedStmtCtx.shortCircuitQueryContext);
ShortCircuitQueryContext shortCircuitQueryContext = preparedStmtCtx.shortCircuitQueryContext.get();
// update conjuncts
List<Expr> conjunctVals = statementContext.getIdToPlaceholderRealExpr().values().stream().map(
expression -> (
(Literal) expression).toLegacyLiteral())
.collect(Collectors.toList());
if (conjunctVals.size() != preparedStmtCtx.command.placeholderCount()) {
Map<String, Expr> colNameToConjunct = Maps.newHashMap();
for (Entry<PlaceholderId, SlotReference> entry : statementContext.getIdToComparisonSlot().entrySet()) {
String colName = entry.getValue().getColumn().get().getName();
Expr conjunctVal = ((Literal) statementContext.getIdToPlaceholderRealExpr()
.get(entry.getKey())).toLegacyLiteral();
colNameToConjunct.put(colName, conjunctVal);
}
if (colNameToConjunct.size() != preparedStmtCtx.command.placeholderCount()) {
throw new AnalysisException("Mismatched conjuncts values size with prepared"
+ "statement parameters size, expected "
+ preparedStmtCtx.command.placeholderCount()
+ ", but meet " + conjunctVals.size());
+ ", but meet " + colNameToConjunct.size());
}
updateScanNodeConjuncts(shortCircuitQueryContext.scanNode, conjunctVals);
updateScanNodeConjuncts(shortCircuitQueryContext.scanNode, colNameToConjunct);
// short circuit plan and execution
executor.executeAndSendResult(false, false,
shortCircuitQueryContext.analzyedQuery, executor.getContext()
.getMysqlChannel(), null, null);
}

private static void updateScanNodeConjuncts(OlapScanNode scanNode, List<Expr> conjunctVals) {
for (int i = 0; i < conjunctVals.size(); ++i) {
BinaryPredicate binaryPredicate = (BinaryPredicate) scanNode.getConjuncts().get(i);
private static void updateScanNodeConjuncts(OlapScanNode scanNode,
Map<String, Expr> colNameToConjunct) {
for (Expr conjunct : scanNode.getConjuncts()) {
BinaryPredicate binaryPredicate = (BinaryPredicate) conjunct;
SlotRef slot = null;
int updateChildIdx = 0;
if (binaryPredicate.getChild(0) instanceof LiteralExpr) {
binaryPredicate.setChild(0, conjunctVals.get(i));
slot = (SlotRef) binaryPredicate.getChildWithoutCast(1);
} else if (binaryPredicate.getChild(1) instanceof LiteralExpr) {
binaryPredicate.setChild(1, conjunctVals.get(i));
slot = (SlotRef) binaryPredicate.getChildWithoutCast(0);
updateChildIdx = 1;
} else {
Preconditions.checkState(false, "Should contains literal in " + binaryPredicate.toSqlImpl());
}
// not a placeholder to replace
if (!colNameToConjunct.containsKey(slot.getColumnName())) {
continue;
}
binaryPredicate.setChild(updateChildIdx, colNameToConjunct.get(slot.getColumnName()));
}
}

Expand Down
30 changes: 30 additions & 0 deletions regression-test/data/point_query_p0/test_point_query.out
Original file line number Diff line number Diff line change
Expand Up @@ -160,3 +160,33 @@
-- !sql --
-10 20 aabc update val

-- !point_select --
user_guid feature sk feature_value 2021-01-01T00:00

-- !point_select --
user_guid feature sk feature_value 2021-01-01T00:00

-- !point_select --
user_guid feature sk feature_value 2021-01-01T00:00

-- !point_select --
user_guid feature sk feature_value 2021-01-01T00:00

-- !point_select --
user_guid feature sk feature_value 2021-01-01T00:00

-- !point_select --
user_guid feature sk feature_value 2021-01-01T00:00

-- !point_select --
user_guid feature sk feature_value 2021-01-01T00:00

-- !point_select --
user_guid feature sk feature_value 2021-01-01T00:00

-- !point_select --
user_guid feature sk feature_value 2021-01-01T00:00

-- !point_select --
user_guid feature sk feature_value 2021-01-01T00:00

92 changes: 73 additions & 19 deletions regression-test/suites/point_query_p0/test_point_query.groovy
Original file line number Diff line number Diff line change
Expand Up @@ -27,32 +27,30 @@ suite("test_point_query", "nonConcurrent") {
logger.info("update config: code=" + code + ", out=" + out + ", err=" + err)
}
}
def user = context.config.jdbcUser
def password = context.config.jdbcPassword
def realDb = "regression_test_serving_p0"
// Parse url
String jdbcUrl = context.config.jdbcUrl
String urlWithoutSchema = jdbcUrl.substring(jdbcUrl.indexOf("://") + 3)
def sql_ip = urlWithoutSchema.substring(0, urlWithoutSchema.indexOf(":"))
def sql_port
if (urlWithoutSchema.indexOf("/") >= 0) {
// e.g: jdbc:mysql://locahost:8080/?a=b
sql_port = urlWithoutSchema.substring(urlWithoutSchema.indexOf(":") + 1, urlWithoutSchema.indexOf("/"))
} else {
// e.g: jdbc:mysql://locahost:8080
sql_port = urlWithoutSchema.substring(urlWithoutSchema.indexOf(":") + 1)
}
// set server side prepared statement url
def prepare_url = "jdbc:mysql://" + sql_ip + ":" + sql_port + "/" + realDb + "?&useServerPrepStmts=true"
try {
set_be_config.call("disable_storage_row_cache", "false")
// nereids do not support point query now
sql "set global enable_fallback_to_original_planner = false"
sql """set global enable_nereids_planner=true"""
def user = context.config.jdbcUser
def password = context.config.jdbcPassword
def realDb = "regression_test_serving_p0"
def tableName = realDb + ".tbl_point_query"
sql "CREATE DATABASE IF NOT EXISTS ${realDb}"

// Parse url
String jdbcUrl = context.config.jdbcUrl
String urlWithoutSchema = jdbcUrl.substring(jdbcUrl.indexOf("://") + 3)
def sql_ip = urlWithoutSchema.substring(0, urlWithoutSchema.indexOf(":"))
def sql_port
if (urlWithoutSchema.indexOf("/") >= 0) {
// e.g: jdbc:mysql://locahost:8080/?a=b
sql_port = urlWithoutSchema.substring(urlWithoutSchema.indexOf(":") + 1, urlWithoutSchema.indexOf("/"))
} else {
// e.g: jdbc:mysql://locahost:8080
sql_port = urlWithoutSchema.substring(urlWithoutSchema.indexOf(":") + 1)
}
// set server side prepared statement url
def prepare_url = "jdbc:mysql://" + sql_ip + ":" + sql_port + "/" + realDb + "?&useServerPrepStmts=true"

def generateString = {len ->
def str = ""
for (int i = 0; i < len; i++) {
Expand Down Expand Up @@ -330,4 +328,60 @@ suite("test_point_query", "nonConcurrent") {
qt_sql "select * from table_3821461 where col1 = 10 and col2 = 20 and loc3 = 'aabc';"
sql "update table_3821461 set value = 'update value' where col1 = -10 or col1 = 20;"
qt_sql """select * from table_3821461 where col1 = -10 and col2 = 20 and loc3 = 'aabc'"""

sql "DROP TABLE IF EXISTS test_partial_prepared_statement"
sql """
CREATE TABLE `test_partial_prepared_statement` (
`user_guid` varchar(64) NOT NULL,
`feature` varchar(256) NOT NULL,
`sk` varchar(256) NOT NULL,
`feature_value` text NULL,
`data_time` datetime NOT NULL
) ENGINE=OLAP
UNIQUE KEY(`user_guid`, `feature`, `sk`)
DISTRIBUTED BY HASH(`user_guid`) BUCKETS 32
PROPERTIES (
"enable_unique_key_merge_on_write" = "true",
"light_schema_change" = "true",
"function_column.sequence_col" = "data_time",
"store_row_column" = "true",
"replication_num" = "1",
"row_store_page_size" = "16384"
);
"""
sql "insert into test_partial_prepared_statement values ('user_guid', 'feature', 'sk','feature_value', '2021-01-01 00:00:00')"
def result2 = connect(user, password, prepare_url) {
def partial_prepared_stmt = prepareStatement "select /*+ SET_VAR(enable_nereids_planner=true) */ * from regression_test_point_query_p0.test_partial_prepared_statement where sk = 'sk' and user_guid = 'user_guid' and feature = ? "
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

enable_nereids_planner is useless, could be removed

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

get

assertEquals(partial_prepared_stmt.class, com.mysql.cj.jdbc.ServerPreparedStatement);
partial_prepared_stmt.setString(1, "feature")
qe_point_select partial_prepared_stmt
qe_point_select partial_prepared_stmt

partial_prepared_stmt = prepareStatement "select /*+ SET_VAR(enable_nereids_planner=true) */ * from regression_test_point_query_p0.test_partial_prepared_statement where user_guid = ? and feature = 'feature' and sk = ?"
assertEquals(partial_prepared_stmt.class, com.mysql.cj.jdbc.ServerPreparedStatement);
partial_prepared_stmt.setString(1, "user_guid")
partial_prepared_stmt.setString(2, "sk")
qe_point_select partial_prepared_stmt
qe_point_select partial_prepared_stmt

partial_prepared_stmt = prepareStatement "select /*+ SET_VAR(enable_nereids_planner=true) */ * from regression_test_point_query_p0.test_partial_prepared_statement where ? = user_guid and sk = 'sk' and feature = 'feature' "
assertEquals(partial_prepared_stmt.class, com.mysql.cj.jdbc.ServerPreparedStatement);
partial_prepared_stmt.setString(1, "user_guid")
qe_point_select partial_prepared_stmt
qe_point_select partial_prepared_stmt

partial_prepared_stmt = prepareStatement "select /*+ SET_VAR(enable_nereids_planner=true) */ * from regression_test_point_query_p0.test_partial_prepared_statement where ? = user_guid and sk = 'sk' and feature = ? "
assertEquals(partial_prepared_stmt.class, com.mysql.cj.jdbc.ServerPreparedStatement);
partial_prepared_stmt.setString(1, "user_guid")
partial_prepared_stmt.setString(2, "feature")
qe_point_select partial_prepared_stmt
qe_point_select partial_prepared_stmt

partial_prepared_stmt = prepareStatement "select /*+ SET_VAR(enable_nereids_planner=true) */ * from regression_test_point_query_p0.test_partial_prepared_statement where sk = ? and feature = ? and 'user_guid' = user_guid"
assertEquals(partial_prepared_stmt.class, com.mysql.cj.jdbc.ServerPreparedStatement);
partial_prepared_stmt.setString(1, "sk")
partial_prepared_stmt.setString(2, "feature")
qe_point_select partial_prepared_stmt
qe_point_select partial_prepared_stmt
}
}
Loading