Skip to content

Commit

Permalink
4
Browse files Browse the repository at this point in the history
  • Loading branch information
xinyiZzz committed Oct 30, 2023
1 parent efe0695 commit 267c526
Show file tree
Hide file tree
Showing 10 changed files with 328 additions and 167 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -32,21 +32,25 @@ public DummyMysqlChannel() {
}

public void setSequenceId(int sequenceId) {
throwIfArrowFlightSql();
this.sequenceId = sequenceId;
}

@Override
public String getRemoteIp() {
throwIfArrowFlightSql();
return "";
}

@Override
public String getRemoteHostPortString() {
throwIfArrowFlightSql();
return "";
}

@Override
public void close() {
throwIfArrowFlightSql();
}

@Override
Expand All @@ -56,26 +60,32 @@ protected int readAll(ByteBuffer dstBuf, boolean isHeader) throws IOException {

@Override
public ByteBuffer fetchOnePacket() throws IOException {
throwIfArrowFlightSql();
return ByteBuffer.allocate(0);
}

@Override
public void flush() throws IOException {
throwIfArrowFlightSql();
}

@Override
public void sendOnePacket(ByteBuffer packet) throws IOException {
throwIfArrowFlightSql();
}

@Override
public void sendAndFlush(ByteBuffer packet) throws IOException {
throwIfArrowFlightSql();
}

@Override
public void reset() {
throwIfArrowFlightSql();
}

public MysqlSerializer getSerializer() {
throwIfArrowFlightSql();
return serializer;
}
}
37 changes: 37 additions & 0 deletions fe/fe-core/src/main/java/org/apache/doris/mysql/MysqlChannel.java
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
import java.io.IOException;
import java.net.InetSocketAddress;
import java.nio.ByteBuffer;
import java.util.Arrays;
import javax.net.ssl.SSLEngine;
import javax.net.ssl.SSLEngineResult;
import javax.net.ssl.SSLException;
Expand Down Expand Up @@ -78,16 +79,19 @@ public class MysqlChannel {

// mysql flag CLIENT_DEPRECATE_EOF
private boolean clientDeprecatedEOF;
protected boolean isFlightSql = false;

protected MysqlChannel() {
// For DummyMysqlChannel
}

public void setClientDeprecatedEOF() {
throwIfArrowFlightSql();
clientDeprecatedEOF = true;
}

public boolean clientDeprecatedEOF() {
throwIfArrowFlightSql();
return clientDeprecatedEOF;
}

Expand Down Expand Up @@ -116,6 +120,7 @@ public MysqlChannel(StreamConnection connection) {
}

public void initSslBuffer() {
throwIfArrowFlightSql();
// allocate buffer when needed.
this.remainingBuffer = ByteBuffer.allocate(16 * 1024);
this.remainingBuffer.flip();
Expand All @@ -124,20 +129,28 @@ public void initSslBuffer() {
}

public void setSequenceId(int sequenceId) {
throwIfArrowFlightSql();
this.sequenceId = sequenceId;
}

public void setIsFlightSql(boolean isFlightSql) {
this.isFlightSql = isFlightSql;
}

public String getRemoteIp() {
throwIfArrowFlightSql();
return remoteIp;
}

public void setSslEngine(SSLEngine sslEngine) {
throwIfArrowFlightSql();
this.sslEngine = sslEngine;
decryptAppData = ByteBuffer.allocate(sslEngine.getSession().getApplicationBufferSize() * 2);
encryptNetData = ByteBuffer.allocate(sslEngine.getSession().getPacketBufferSize() * 2);
}

public void setSslMode(boolean sslMode) {
throwIfArrowFlightSql();
isSslMode = sslMode;
if (isSslMode) {
// channel in ssl mode means handshake phase has finished.
Expand All @@ -146,6 +159,7 @@ public void setSslMode(boolean sslMode) {
}

public void setSslHandshaking(boolean sslHandshaking) {
throwIfArrowFlightSql();
isSslHandshaking = sslHandshaking;
}

Expand Down Expand Up @@ -173,6 +187,7 @@ private void accSequenceId() {

// Close channel
public void close() {
throwIfArrowFlightSql();
try {
conn.close();
} catch (IOException e) {
Expand Down Expand Up @@ -237,6 +252,7 @@ protected void decryptData(ByteBuffer dstBuf, boolean isHeader) throws SSLExcept
// NOTE: all of the following code is assumed that the channel is in block mode.
// if in handshaking mode we return a packet with header otherwise without header.
public ByteBuffer fetchOnePacket() throws IOException {
throwIfArrowFlightSql();
int readLen;
ByteBuffer result = defaultBuffer;
result.clear();
Expand Down Expand Up @@ -390,6 +406,7 @@ protected ByteBuffer encryptData(ByteBuffer dstBuf) throws SSLException {
}

public void flush() throws IOException {
throwIfArrowFlightSql();
if (null == sendBuffer || sendBuffer.position() == 0) {
// Nothing to send
return;
Expand Down Expand Up @@ -437,6 +454,7 @@ private void writeBuffer(ByteBuffer buffer, boolean isSsl) throws IOException {
}

public void sendOnePacket(ByteBuffer packet) throws IOException {
throwIfArrowFlightSql();
// handshake in packet with header and has encrypted, need to send in ssl format
// ssl mode in packet no header and no encrypted, need to encrypted and add header and send in ssl format
int bufLen;
Expand Down Expand Up @@ -464,40 +482,48 @@ public void sendOnePacket(ByteBuffer packet) throws IOException {
}

public void sendAndFlush(ByteBuffer packet) throws IOException {
throwIfArrowFlightSql();
sendOnePacket(packet);
flush();
}

// Call this function before send query before
public void reset() {
throwIfArrowFlightSql();
isSend = false;
if (null != sendBuffer) {
sendBuffer.clear();
}
}

public boolean isSend() {
throwIfArrowFlightSql();
return isSend;
}

public String getRemoteHostPortString() {
throwIfArrowFlightSql();
return remoteHostPortString;
}

public void startAcceptQuery(ConnectContext connectContext, ConnectProcessor connectProcessor) {
throwIfArrowFlightSql();
conn.getSourceChannel().setReadListener(new ReadListener(connectContext, connectProcessor));
conn.getSourceChannel().resumeReads();
}

public void suspendAcceptQuery() {
throwIfArrowFlightSql();
conn.getSourceChannel().suspendReads();
}

public void resumeAcceptQuery() {
throwIfArrowFlightSql();
conn.getSourceChannel().resumeReads();
}

public void stopAcceptQuery() throws IOException {
throwIfArrowFlightSql();
conn.getSourceChannel().shutdownReads();
}

Expand Down Expand Up @@ -552,4 +578,15 @@ private boolean handleUnwrapResult(SSLEngineResult sslEngineResult) {
}
}

protected void throwIfArrowFlightSql() {
if (isFlightSql) {
try {
throw new RuntimeException("Arrow flight sql unexpected use of MysqlChannel");
} catch (Exception e) {
LOG.warn(e.getMessage() + Arrays.toString(e.getStackTrace()));
// throw e;
}
}
}

}
15 changes: 12 additions & 3 deletions fe/fe-core/src/main/java/org/apache/doris/qe/ConnectContext.java
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@
import org.apache.doris.nereids.stats.StatsErrorEstimator;
import org.apache.doris.plugin.AuditEvent.AuditEventBuilder;
import org.apache.doris.resource.Tag;
import org.apache.doris.service.arrowflight.MysqlChannelToFlightSql;
import org.apache.doris.service.arrowflight.FlightSqlChannel;
import org.apache.doris.statistics.ColumnStatistic;
import org.apache.doris.statistics.Histogram;
import org.apache.doris.system.Backend;
Expand All @@ -51,6 +51,7 @@
import org.apache.doris.transaction.TransactionEntry;
import org.apache.doris.transaction.TransactionStatus;

import com.google.common.base.Preconditions;
import com.google.common.base.Strings;
import com.google.common.collect.Lists;
import com.google.common.collect.Maps;
Expand Down Expand Up @@ -104,6 +105,7 @@ public enum ConnectType {
protected volatile String peerIdentity;
// mysql net
protected volatile MysqlChannel mysqlChannel;
protected volatile FlightSqlChannel flightSqlChannel;
// state
protected volatile QueryState state;
protected volatile long returnRows;
Expand Down Expand Up @@ -351,7 +353,9 @@ public ConnectContext(String peerIdentity) {
returnRows = 0;
isKilled = false;
sessionVariable = VariableMgr.newSessionVariable();
mysqlChannel = new MysqlChannelToFlightSql();
mysqlChannel = new DummyMysqlChannel();
mysqlChannel.setIsFlightSql(true);
flightSqlChannel = new FlightSqlChannel();
command = MysqlCommand.COM_SLEEP;
if (Config.use_fuzzy_session_variable) {
sessionVariable.initFuzzyModeVariables();
Expand Down Expand Up @@ -603,11 +607,16 @@ public MysqlChannel getMysqlChannel() {
return mysqlChannel;
}

public FlightSqlChannel getFlightSqlChannel() {
Preconditions.checkState(connectType.equals(ConnectType.ARROW_FLIGHT_SQL));
return flightSqlChannel;
}

public String getClientIP() {
if (connectType.equals(ConnectType.MYSQL)) {
return mysqlChannel.getRemoteHostPortString();
} else {
return "0.0.0.0"; // TODO
return flightSqlChannel.getRemoteHostPortString();
}
}

Expand Down
49 changes: 29 additions & 20 deletions fe/fe-core/src/main/java/org/apache/doris/qe/StmtExecutor.java
Original file line number Diff line number Diff line change
Expand Up @@ -428,7 +428,7 @@ public boolean isAnalyzeStmt() {
* isValuesOrConstantSelect: when this interface return true, original string is truncated at 1024
*
* @return parsed and analyzed statement for Stale planner.
* an unresolved LogicalPlan wrapped with a LogicalPlanAdapter for Nereids.
* an unresolved LogicalPlan wrapped with a LogicalPlanAdapter for Nereids.
*/
public StatementBase getParsedStmt() {
return parsedStmt;
Expand Down Expand Up @@ -597,7 +597,7 @@ private void parseByNereids() {
}
if (statements.size() <= originStmt.idx) {
throw new ParseException("Nereids parse failed. Parser get " + statements.size() + " statements,"
+ " but we need at least " + originStmt.idx + " statements.");
+ " but we need at least " + originStmt.idx + " statements.");
}
parsedStmt = statements.get(originStmt.idx);
}
Expand Down Expand Up @@ -636,7 +636,7 @@ private void handleQueryWithRetry(TUniqueId queryId) throws Exception {
try {
for (int i = 0; i < retryTime; i++) {
try {
//reset query id for each retry
// reset query id for each retry
if (i > 0) {
UUID uuid = UUID.randomUUID();
TUniqueId newQueryId = new TUniqueId(uuid.getMostSignificantBits(),
Expand Down Expand Up @@ -1363,7 +1363,7 @@ private void handleCacheStmt(CacheAnalyzer cacheAnalyzer, MysqlChannel channel)
// Process a select statement.
private void handleQueryStmt() throws Exception {
LOG.info("Handling query {} with query id {}",
originStmt.originStmt, DebugUtil.printId(context.queryId));
originStmt.originStmt, DebugUtil.printId(context.queryId));
// Every time set no send flag and clean all data in buffer
context.getMysqlChannel().reset();
Queriable queryStmt = (Queriable) parsedStmt;
Expand Down Expand Up @@ -2102,10 +2102,10 @@ private void handleSwitchStmt() throws AnalysisException {
private void handlePrepareStmt() throws Exception {
// register prepareStmt
LOG.debug("add prepared statement {}, isBinaryProtocol {}",
prepareStmt.getName(), prepareStmt.isBinaryProtocol());
prepareStmt.getName(), prepareStmt.isBinaryProtocol());
context.addPreparedStmt(prepareStmt.getName(),
new PrepareStmtContext(prepareStmt,
context, planner, analyzer, prepareStmt.getName()));
context, planner, analyzer, prepareStmt.getName()));
if (prepareStmt.isBinaryProtocol()) {
sendStmtPrepareOK();
}
Expand Down Expand Up @@ -2223,24 +2223,33 @@ private void sendFields(List<String> colNames, List<Type> types) throws IOExcept
}

public void sendResultSet(ResultSet resultSet) throws IOException {
context.updateReturnRows(resultSet.getResultRows().size());
// Send meta data.
sendMetaData(resultSet.getMetaData());
if (context.getConnectType().equals(ConnectType.MYSQL)) {
context.updateReturnRows(resultSet.getResultRows().size());
// Send meta data.
sendMetaData(resultSet.getMetaData());

// Send result set.
for (List<String> row : resultSet.getResultRows()) {
serializer.reset();
for (String item : row) {
if (item == null || item.equals(FeConstants.null_string)) {
serializer.writeNull();
} else {
serializer.writeLenEncodedString(item);
// Send result set.
for (List<String> row : resultSet.getResultRows()) {
serializer.reset();
for (String item : row) {
if (item == null || item.equals(FeConstants.null_string)) {
serializer.writeNull();
} else {
serializer.writeLenEncodedString(item);
}
}
context.getMysqlChannel().sendOnePacket(serializer.toByteBuffer());
}
context.getMysqlChannel().sendOnePacket(serializer.toByteBuffer());
}

context.getState().setEof();
context.getState().setEof();
} else if (context.getConnectType().equals(ConnectType.ARROW_FLIGHT_SQL)) {
context.updateReturnRows(resultSet.getResultRows().size());
context.getFlightSqlChannel()
.addResultSet(DebugUtil.printId(context.queryId()), context.getRunningQuery(), resultSet);
context.getState().setEof();
} else {
LOG.error("sendResultSet error connect type");
}
}

// Process show statement
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -94,7 +94,7 @@
import org.apache.doris.qe.QueryState;
import org.apache.doris.qe.StmtExecutor;
import org.apache.doris.qe.VariableMgr;
import org.apache.doris.service.arrowflight.FlightSQLConnectProcessor;
import org.apache.doris.service.arrowflight.FlightSqlConnectProcessor;
import org.apache.doris.statistics.ColumnStatistic;
import org.apache.doris.statistics.ResultRow;
import org.apache.doris.statistics.StatisticsCacheKey;
Expand Down Expand Up @@ -1112,7 +1112,7 @@ public TMasterOpResult forward(TMasterOpRequest params) throws TException {
if (context.getConnectType().equals(ConnectType.MYSQL)) {
processor = new MysqlConnectProcessor(context);
} else if (context.getConnectType().equals(ConnectType.ARROW_FLIGHT_SQL)) {
processor = new FlightSQLConnectProcessor(context);
processor = new FlightSqlConnectProcessor(context);
} else {
LOG.warn("unknown ConnectType: {}", context.getConnectType());
}
Expand Down
Loading

0 comments on commit 267c526

Please sign in to comment.