Skip to content

Commit

Permalink
4
Browse files Browse the repository at this point in the history
  • Loading branch information
xinyiZzz committed Sep 25, 2023
1 parent 3132e2a commit f15f6b1
Show file tree
Hide file tree
Showing 16 changed files with 380 additions and 279 deletions.
25 changes: 18 additions & 7 deletions fe/fe-common/src/main/java/org/apache/doris/common/Config.java
Original file line number Diff line number Diff line change
Expand Up @@ -2190,13 +2190,24 @@ public class Config extends ConfigBase {
})
public static long auto_analyze_job_record_count = 20000;

@ConfField(description = {"Arrow Flight Server中所有用户token的缓存上限,超过后LRU淘汰",
@ConfField(description = {"Arrow Flight Server中所有用户token的缓存上限,超过后LRU淘汰,默认值为2000",
"The cache limit of all user tokens in Arrow Flight Server. which will be eliminated by"
+ "LRU rules after exceeding the limit."})
public static int arrow_flight_token_cache_size = -1;

@ConfField(description = {"Arrow Flight Server中用户token的存活时间,单位分钟",
"The alive time of the user token in Arrow Flight Server, unit minutes"})
public static int arrow_flight_token_alive_time = -1;
+ "LRU rules after exceeding the limit, the default value is 2000."})
public static int arrow_flight_token_cache_size = 2000;

@ConfField(description = {"Arrow Flight Server中用户token的存活时间,自上次写入后过期时间,单位分钟,默认值为4320,即3天",
"The alive time of the user token in Arrow Flight Server, expire after write, unit minutes,"
+ "the default value is 4320, which is 3 days"})
public static int arrow_flight_token_alive_time = 4320;

@ConfField(description = {"Arrow Flight Server中所有用户session的缓存上限,超过后LRU淘汰,默认值为1000",
"The cache limit of all user sessions in Arrow Flight Server. which will be eliminated by"
+ "LRU rules after exceeding the limit, the default value is 1000."})
public static int arrow_flight_session_cache_size = 1000;

@ConfField(description = {"Arrow Flight Server中用户session的存活时间,自上次访问后过期时间,单位分钟,默认值为120",
"The alive time of the user token in Arrow Flight Server, expire after access, unit minutes,"
+ "the default value is 120"})
public static int arrow_flight_session_alive_time = 120;

}
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,8 @@

import org.apache.doris.common.util.DebugUtil;
import org.apache.doris.common.util.Util;
import org.apache.doris.service.arrowflight.tokens.TokenManager;
import org.apache.doris.service.arrowflight.sessions.FlightSessionsManager;
import org.apache.doris.service.arrowflight.sessions.FlightUserSession;

import com.google.protobuf.Any;
import com.google.protobuf.ByteString;
Expand Down Expand Up @@ -73,11 +74,11 @@ public class DorisFlightSqlProducer implements FlightSqlProducer, AutoCloseable
private final Location location;
private final BufferAllocator rootAllocator = new RootAllocator();
private final SqlInfoBuilder sqlInfoBuilder;
private final TokenManager tokenManager;
private final FlightSessionsManager flightSessionsManager;

public DorisFlightSqlProducer(final Location location, TokenManager tokenManager) {
public DorisFlightSqlProducer(final Location location, FlightSessionsManager flightSessionsManager) {
this.location = location;
this.tokenManager = tokenManager;
this.flightSessionsManager = flightSessionsManager;
sqlInfoBuilder = new SqlInfoBuilder();
sqlInfoBuilder.withFlightSqlServerName("DorisFE")
.withFlightSqlServerVersion("1.0")
Expand Down Expand Up @@ -107,9 +108,10 @@ public void closePreparedStatement(final ActionClosePreparedStatementRequest req
public FlightInfo getFlightInfoStatement(final CommandStatementQuery request, final CallContext context,
final FlightDescriptor descriptor) {
try {
tokenManager.validateToken(context.peerIdentity());
FlightUserSession flightUserSession = flightSessionsManager.getUserSession(context.peerIdentity());
final String query = request.getQuery();
final FlightStatementExecutor flightStatementExecutor = new FlightStatementExecutor(query);
final FlightStatementExecutor flightStatementExecutor = new FlightStatementExecutor(query,
flightUserSession.getConnectContext());

flightStatementExecutor.executeQuery();

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,12 +19,12 @@

import org.apache.doris.common.Config;
import org.apache.doris.service.arrowflight.auth2.FlightBearerTokenAuthenticator;
import org.apache.doris.service.arrowflight.auth2.FlightCookieMiddleware;
import org.apache.doris.service.arrowflight.tokens.TokenManager;
import org.apache.doris.service.arrowflight.tokens.TokenManagerImpl;
import org.apache.doris.service.arrowflight.sessions.FlightSessionsManager;
import org.apache.doris.service.arrowflight.sessions.FlightSessionsWithTokenManager;
import org.apache.doris.service.arrowflight.tokens.FlightTokenManager;
import org.apache.doris.service.arrowflight.tokens.FlightTokenManagerImpl;

import org.apache.arrow.flight.FlightServer;
import org.apache.arrow.flight.FlightServerMiddleware;
import org.apache.arrow.flight.Location;
import org.apache.arrow.memory.BufferAllocator;
import org.apache.arrow.memory.RootAllocator;
Expand All @@ -40,23 +40,21 @@ public class DorisFlightSqlService {
private static final Logger LOG = LogManager.getLogger(DorisFlightSqlService.class);
private final FlightServer flightServer;
private volatile boolean running;
private final TokenManager tokenManager;
public static final String FLIGHT_CLIENT_PROPERTIES_MIDDLEWARE = "client-properties-middleware";
public static final FlightServerMiddleware.Key<FlightCookieMiddleware> FLIGHT_CLIENT_PROPERTIES_MIDDLEWARE_KEY
= FlightServerMiddleware.Key.of(FLIGHT_CLIENT_PROPERTIES_MIDDLEWARE);
private final FlightTokenManager flightTokenManager;
private final FlightSessionsManager flightSessionsManager;

public DorisFlightSqlService(int port) {
BufferAllocator allocator = new RootAllocator();
Location location = Location.forGrpcInsecure("0.0.0.0", port);
this.tokenManager = new TokenManagerImpl(Config.arrow_flight_token_cache_size,
this.flightTokenManager = new FlightTokenManagerImpl(Config.arrow_flight_token_cache_size,
Config.arrow_flight_token_alive_time);
this.flightSessionsManager = new FlightSessionsWithTokenManager(flightTokenManager,
Config.arrow_flight_session_cache_size,
Config.arrow_flight_session_alive_time);

DorisFlightSqlProducer producer = new DorisFlightSqlProducer(location, tokenManager);
DorisFlightSqlProducer producer = new DorisFlightSqlProducer(location, flightSessionsManager);
flightServer = FlightServer.builder(allocator, location, producer)
.headerAuthenticator(new FlightBearerTokenAuthenticator(tokenManager)).build();
// .middleware(FLIGHT_CLIENT_PROPERTIES_MIDDLEWARE_KEY,
// new FlightServerCookieMiddleware.Factory())
// .authHandler(new BasicServerAuthHandler(new FlightServerBasicAuthValidator())).build();
.headerAuthenticator(new FlightBearerTokenAuthenticator(flightTokenManager)).build();
}

// start Arrow Flight SQL service, return true if success, otherwise false
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -64,9 +64,9 @@ public final class FlightStatementExecutor {
private TNetworkAddress resultInternalServiceAddr;
private ArrayList<Expr> resultOutputExprs;

public FlightStatementExecutor(final String query) {
public FlightStatementExecutor(final String query, ConnectContext connectContext) {
this.query = query;
acConnectContext = buildConnectContext();
this.acConnectContext = new AutoCloseConnectContext(connectContext);
}

public void setQueryId(TUniqueId queryId) {
Expand Down Expand Up @@ -126,21 +126,6 @@ public int hashCode() {
return Objects.hash(this);
}

public static AutoCloseConnectContext buildConnectContext() {
ConnectContext connectContext = new ConnectContext();
SessionVariable sessionVariable = connectContext.getSessionVariable();
sessionVariable.internalSession = true;
sessionVariable.setEnablePipelineEngine(false); // TODO
sessionVariable.setEnablePipelineXEngine(false); // TODO
connectContext.setEnv(Env.getCurrentEnv());
connectContext.setQualifiedUser(UserIdentity.ROOT.getQualifiedUser()); // TODO
connectContext.setCurrentUserIdentity(UserIdentity.ROOT); // TODO
connectContext.setStartTime();
connectContext.setCluster(SystemInfoService.DEFAULT_CLUSTER);
connectContext.setResultSinkType(TResultSinkType.ARROW_FLIGHT_PROTOCAL);
return new AutoCloseConnectContext(connectContext);
}

public void executeQuery() {
try {
UUID uuid = UUID.randomUUID();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,6 @@
// specific language governing permissions and limitations
// under the License.
// This file is copied from
// https://github.com/dremio/dremio-oss/blob/master/services/arrow-flight/src/main/java/com/dremio/service/flight/ServerCookieMiddleware.java
// and modified by Doris

package org.apache.doris.service.arrowflight.auth2;

Expand All @@ -28,14 +26,14 @@
* Result of Authentication.
*/
@Value.Immutable
public interface DorisAuthResult {
public interface FlightAuthResult {
String getUserName();

UserIdentity getUserIdentity();

String getRemoteIp();

static DorisAuthResult of(String userName, UserIdentity userIdentity, String remoteIp) {
static FlightAuthResult of(String userName, UserIdentity userIdentity, String remoteIp) {
return ImmutableDorisAuthResult.builder()
.userName(userName)
.userIdentity(userIdentity)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
import org.apache.doris.analysis.UserIdentity;
import org.apache.doris.catalog.Env;
import org.apache.doris.common.AuthenticationException;
import org.apache.doris.service.arrowflight.tokens.TokenManager;
import org.apache.doris.service.arrowflight.tokens.FlightTokenManager;

import com.google.common.base.Preconditions;
import com.google.common.collect.Lists;
Expand All @@ -30,48 +30,46 @@
import java.util.List;

/**
* A collection of common Dremio Flight server authentication methods.
* A collection of common Flight server authentication methods.
*/
public final class FlightAuthUtils {
private FlightAuthUtils() {
}

/**
* Authenticate against Dremio with the provided credentials.
* Authenticate against with the provided credentials.
*
* @param username Dremio username.
* @param password Dremio password.
* @param username username.
* @param password password.
* @param logger the slf4j logger for logging.
* @throws org.apache.arrow.flight.FlightRuntimeException if unable to authenticate against Dremio
* @throws org.apache.arrow.flight.FlightRuntimeException if unable to authenticate against
* with the provided credentials.
*/
public static DorisAuthResult authenticateCredentials(String username, String password, String remoteIp,
public static FlightAuthResult authenticateCredentials(String username, String password, String remoteIp,
Logger logger) {
try {
List<UserIdentity> currentUserIdentity = Lists.newArrayList();

Env.getCurrentEnv().getAuth().checkPlainPassword(username, remoteIp, password, currentUserIdentity);
Preconditions.checkState(currentUserIdentity.size() == 1);
return DorisAuthResult.of(username, currentUserIdentity.get(0), remoteIp);
return FlightAuthResult.of(username, currentUserIdentity.get(0), remoteIp);
} catch (AuthenticationException e) {
logger.error("Unable to authenticate user {}", username, e);
final String errorMessage = "Unable to authenticate user " + username + ", exception: " + e.getMessage();
throw CallStatus.UNAUTHENTICATED.withCause(e).withDescription(errorMessage).toRuntimeException();
final String errMsg = "Unable to authenticate user " + username + ", exception: " + e.getMessage();
throw CallStatus.UNAUTHENTICATED.withCause(e).withDescription(errMsg).toRuntimeException();
}
}

/**
* Create a new token with the TokenManager and create a new UserSession object associated with
* the authenticated username.
* Creates a new Bearer Token. Returns the bearer token associated with the User.
*
* @param tokenManager the TokenManager.
* @param flightTokenManager the TokenManager.
* @param username the user to create a Flight server session for.
* @param dorisAuthResult tht DorisAuthResult.
* @return the token associated with the UserSession created.
* @param flightAuthResult the FlightAuthResult.
* @return the token associated with the FlightTokenDetails created.
*/
public static String createToken(TokenManager tokenManager, String username, DorisAuthResult dorisAuthResult) {
// TODO: DX-25278: Add ClientAddress information while creating a Token in DremioFlightServerAuthValidator
return tokenManager.createToken(username, dorisAuthResult).token;
public static String createToken(FlightTokenManager flightTokenManager, String username,
FlightAuthResult flightAuthResult) {
return flightTokenManager.createToken(username, flightAuthResult).getToken();
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -15,12 +15,12 @@
// specific language governing permissions and limitations
// under the License.
// This file is copied from
// https://github.com/dremio/dremio-oss/blob/master/services/arrow-flight/src/main/java/com/dremio/service/flight/ServerCookieMiddleware.java
// https://github.com/dremio/dremio-oss/blob/master/services/arrow-flight/src/main/java/com/dremio/service/flight/auth2/DremioBearerTokenAuthenticator.java
// and modified by Doris

package org.apache.doris.service.arrowflight.auth2;

import org.apache.doris.service.arrowflight.tokens.TokenManager;
import org.apache.doris.service.arrowflight.tokens.FlightTokenManager;

import org.apache.arrow.flight.CallHeaders;
import org.apache.arrow.flight.CallStatus;
Expand All @@ -32,27 +32,28 @@
import org.apache.logging.log4j.Logger;

/**
* Dremio's custom implementation of CallHeaderAuthenticator for bearer token authentication.
* This class implements CallHeaderAuthenticator rather than BearerTokenAuthenticator. Dremio
* creates UserSession objects when the bearer token is created and requires access to the CallHeaders
* Doris's custom implementation of CallHeaderAuthenticator for bearer token authentication.
* This class implements CallHeaderAuthenticator rather than BearerTokenAuthenticator. Doris
* creates FlightTokenDetails objects when the bearer token is created and requires access to the CallHeaders
* in getAuthResultWithBearerToken.
*/

public class FlightBearerTokenAuthenticator implements CallHeaderAuthenticator {
private static final Logger LOG = LogManager.getLogger(FlightBearerTokenAuthenticator.class);

private final CallHeaderAuthenticator initialAuthenticator;
private final TokenManager tokenManager;
private final FlightTokenManager flightTokenManager;

public FlightBearerTokenAuthenticator(TokenManager tokenManager) {
this.tokenManager = tokenManager;
this.initialAuthenticator = new BasicCallHeaderAuthenticator(new FlightCredentialValidator(this.tokenManager));
public FlightBearerTokenAuthenticator(FlightTokenManager flightTokenManager) {
this.flightTokenManager = flightTokenManager;
this.initialAuthenticator = new BasicCallHeaderAuthenticator(
new FlightCredentialValidator(this.flightTokenManager));
}

/**
* If no bearer token is provided, the method initiates initial password and username
* authentication. Once authenticated, client properties are retrieved from incoming CallHeaders.
* Then it generates a token and creates a UserSession with the retrieved client properties.
* Then it generates a token and creates a FlightTokenDetails with the retrieved client properties.
* associated with it.
* <p>
* If a bearer token is provided, the method validates the provided token.
Expand Down Expand Up @@ -81,7 +82,7 @@ public AuthResult authenticate(CallHeaders incomingHeaders) {
*/
AuthResult validateBearer(String token) {
try {
tokenManager.validateToken(token);
flightTokenManager.validateToken(token);
return createAuthResultWithBearerToken(token);
} catch (IllegalArgumentException e) {
LOG.error("Bearer token validation failed.", e);
Expand All @@ -93,7 +94,7 @@ AuthResult validateBearer(String token) {
/**
* Helper method to create an AuthResult.
*
* @param token the token to create a UserSession for.
* @param token the token to create a FlightTokenDetails for.
* @return a new AuthResult with functionality to add given bearer token to the outgoing header.
*/
private AuthResult createAuthResultWithBearerToken(String token) {
Expand Down
Loading

0 comments on commit f15f6b1

Please sign in to comment.