Skip to content

Commit

Permalink
3
Browse files Browse the repository at this point in the history
  • Loading branch information
xinyiZzz committed Sep 22, 2023
1 parent abefa95 commit fc38284
Show file tree
Hide file tree
Showing 7 changed files with 225 additions and 100 deletions.
55 changes: 55 additions & 0 deletions be/src/service/arrow_flight/auth_server_middleware.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied. See the License for the
// specific language governing permissions and limitations
// under the License.

#include "service/arrow_flight/auth_server_middleware.h"

#include "service/arrow_flight/call_header_utils.h"

namespace doris {
namespace flight {

void NoOpHeaderAuthServerMiddleware::SendingHeaders(
arrow::flight::AddCallHeaders* outgoing_headers) {
outgoing_headers->AddHeader(kAuthHeader, std::string(kBearerPrefix) + kBearerDefaultToken);
}

arrow::Status NoOpHeaderAuthServerMiddlewareFactory::StartCall(
const arrow::flight::CallInfo& info, const arrow::flight::ServerCallContext& context,
std::shared_ptr<arrow::flight::ServerMiddleware>* middleware) {
std::string username, password;
ParseBasicHeader(context.incoming_headers(), username, password);
*middleware = std::make_shared<NoOpHeaderAuthServerMiddleware>();
return arrow::Status::OK();
}

void NoOpBearerAuthServerMiddleware::SendingHeaders(
arrow::flight::AddCallHeaders* outgoing_headers) {
std::string bearer_token =
FindKeyValPrefixInCallHeaders(_incoming_headers, kAuthHeader, kBearerPrefix);
*_is_valid = (bearer_token == std::string(kBearerDefaultToken));
}

arrow::Status NoOpBearerAuthServerMiddlewareFactory::StartCall(
const arrow::flight::CallInfo& info, const arrow::flight::ServerCallContext& context,
std::shared_ptr<arrow::flight::ServerMiddleware>* middleware) {
*middleware = std::make_shared<NoOpBearerAuthServerMiddleware>(context.incoming_headers(),
&_is_valid);
return arrow::Status::OK();
}

} // namespace flight
} // namespace doris
88 changes: 88 additions & 0 deletions be/src/service/arrow_flight/auth_server_middleware.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,88 @@
// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied. See the License for the
// specific language governing permissions and limitations
// under the License.

#pragma once

#include <arrow/status.h>

#include "arrow/flight/server.h"
#include "arrow/flight/server_middleware.h"
#include "arrow/flight/types.h"

namespace doris {
namespace flight {

// Just return default bearer token.
class NoOpHeaderAuthServerMiddleware : public arrow::flight::ServerMiddleware {
public:
void SendingHeaders(arrow::flight::AddCallHeaders* outgoing_headers) override;

void CallCompleted(const arrow::Status& status) override {}

[[nodiscard]] std::string name() const override { return "NoOpHeaderAuthServerMiddleware"; }
};

// Factory for base64 header authentication.
// No actual authentication.
class NoOpHeaderAuthServerMiddlewareFactory : public arrow::flight::ServerMiddlewareFactory {
public:
NoOpHeaderAuthServerMiddlewareFactory() = default;

arrow::Status StartCall(const arrow::flight::CallInfo& info,
const arrow::flight::ServerCallContext& context,
std::shared_ptr<arrow::flight::ServerMiddleware>* middleware) override;
};

// A server middleware for validating incoming bearer header authentication.
// Just compare with default bearer token.
class NoOpBearerAuthServerMiddleware : public arrow::flight::ServerMiddleware {
public:
explicit NoOpBearerAuthServerMiddleware(const arrow::flight::CallHeaders& incoming_headers,
bool* isValid)
: _is_valid(isValid) {
_incoming_headers = incoming_headers;
}

void SendingHeaders(arrow::flight::AddCallHeaders* outgoing_headers) override;

void CallCompleted(const arrow::Status& status) override {}

[[nodiscard]] std::string name() const override { return "NoOpBearerAuthServerMiddleware"; }

private:
arrow::flight::CallHeaders _incoming_headers;
bool* _is_valid;
};

// Factory for base64 header authentication.
// No actual authentication.
class NoOpBearerAuthServerMiddlewareFactory : public arrow::flight::ServerMiddlewareFactory {
public:
NoOpBearerAuthServerMiddlewareFactory() : _is_valid(false) {}

arrow::Status StartCall(const arrow::flight::CallInfo& info,
const arrow::flight::ServerCallContext& context,
std::shared_ptr<arrow::flight::ServerMiddleware>* middleware) override;

[[nodiscard]] bool GetIsValid() const { return _is_valid; }

private:
bool _is_valid;
};

} // namespace flight
} // namespace doris
63 changes: 63 additions & 0 deletions be/src/service/arrow_flight/call_header_utils.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,63 @@
// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied. See the License for the
// specific language governing permissions and limitations
// under the License.

#pragma once

#include "arrow/flight/types.h"
#include "arrow/util/base64.h"

namespace doris {
namespace flight {

const char kBearerDefaultToken[] = "bearertoken";
const char kBasicPrefix[] = "Basic ";
const char kBearerPrefix[] = "Bearer ";
const char kAuthHeader[] = "authorization";

// Function to look in CallHeaders for a key that has a value starting with prefix and
// return the rest of the value after the prefix.
std::string FindKeyValPrefixInCallHeaders(const arrow::flight::CallHeaders& incoming_headers,
const std::string& key, const std::string& prefix) {
// Lambda function to compare characters without case sensitivity.
auto char_compare = [](const char& char1, const char& char2) {
return (::toupper(char1) == ::toupper(char2));
};

auto iter = incoming_headers.find(key);
if (iter == incoming_headers.end()) {
return "";
}
const std::string val(iter->second);
if (val.size() > prefix.length()) {
if (std::equal(val.begin(), val.begin() + prefix.length(), prefix.begin(), char_compare)) {
return val.substr(prefix.length());
}
}
return "";
}

void ParseBasicHeader(const arrow::flight::CallHeaders& incoming_headers, std::string& username,
std::string& password) {
std::string encoded_credentials =
FindKeyValPrefixInCallHeaders(incoming_headers, kAuthHeader, kBasicPrefix);
std::stringstream decoded_stream(arrow::util::base64_decode(encoded_credentials));
std::getline(decoded_stream, username, ':');
std::getline(decoded_stream, password, ':');
}

} // namespace flight
} // namespace doris
54 changes: 0 additions & 54 deletions be/src/service/arrow_flight/flight_sql_server_auth_handler.cpp

This file was deleted.

40 changes: 0 additions & 40 deletions be/src/service/arrow_flight/flight_sql_server_auth_handler.h

This file was deleted.

19 changes: 14 additions & 5 deletions be/src/service/arrow_flight/flight_sql_service.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,6 @@
#include "arrow/flight/sql/server.h"
#include "service/arrow_flight/arrow_flight_batch_reader.h"
#include "service/arrow_flight/flight_sql_info.h"
#include "service/arrow_flight/flight_sql_server_auth_handler.h"
#include "service/backend_options.h"
#include "util/arrow/utils.h"
#include "util/uid_util.h"
Expand Down Expand Up @@ -75,7 +74,7 @@ class FlightSqlServer::Impl {
}
};

FlightSqlServer::FlightSqlServer(std::shared_ptr<Impl> impl) : impl_(std::move(impl)) {}
FlightSqlServer::FlightSqlServer(std::shared_ptr<Impl> impl) : _impl(std::move(impl)) {}

arrow::Result<std::shared_ptr<FlightSqlServer>> FlightSqlServer::create() {
std::shared_ptr<Impl> impl = std::make_shared<Impl>();
Expand All @@ -95,7 +94,7 @@ FlightSqlServer::~FlightSqlServer() {
arrow::Result<std::unique_ptr<arrow::flight::FlightDataStream>> FlightSqlServer::DoGetStatement(
const arrow::flight::ServerCallContext& context,
const arrow::flight::sql::StatementQueryTicket& command) {
return impl_->DoGetStatement(context, command);
return _impl->DoGetStatement(context, command);
}

Status FlightSqlServer::init(int port) {
Expand All @@ -109,8 +108,18 @@ Status FlightSqlServer::init(int port) {
arrow::flight::Location::ForGrpcTcp(BackendOptions::get_service_bind_address(), port)
.Value(&bind_location));
arrow::flight::FlightServerOptions flight_options(bind_location);
flight_options.auth_handler = std::make_unique<FlightSqlServerAuthHandler>();
// flight_options.auth_handler = std::make_unique<arrow::flight::NoOpAuthHandler>();

// Not authenticated in BE flight server.
// After the authentication between the ADBC Client and the FE flight server is completed,
// the FE flight server will put the query id in the Ticket and send it back to the Client.
// When the Client uses the Ticket to fetch data from the BE flight server, the BE flight
// server will verify the query id, this step is equivalent to authentication.
_header_middleware = std::make_shared<NoOpHeaderAuthServerMiddlewareFactory>();
_bearer_middleware = std::make_shared<NoOpBearerAuthServerMiddlewareFactory>();
flight_options.auth_handler = std::make_unique<arrow::flight::NoOpAuthHandler>();
flight_options.middleware.push_back({"header-auth-server", _header_middleware});
flight_options.middleware.push_back({"bearer-auth-server", _bearer_middleware});

RETURN_DORIS_STATUS_IF_ERROR(Init(flight_options));
LOG(INFO) << "Arrow Flight Service bind to host: " << BackendOptions::get_service_bind_address()
<< ", port: " << port;
Expand Down
6 changes: 5 additions & 1 deletion be/src/service/arrow_flight/flight_sql_service.h
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
#include "arrow/result.h"
#include "common/status.h"
#include "service/arrow_flight/arrow_flight_batch_reader.h"
#include "service/arrow_flight/auth_server_middleware.h"

namespace doris {
namespace flight {
Expand All @@ -40,9 +41,12 @@ class FlightSqlServer : public arrow::flight::sql::FlightSqlServerBase {

private:
class Impl;
std::shared_ptr<Impl> impl_;
std::shared_ptr<Impl> _impl;
bool _inited = false;

std::shared_ptr<HeaderAuthServerMiddlewareFactory> _header_middleware;
std::shared_ptr<BearerAuthServerMiddlewareFactory> _bearer_middleware;

explicit FlightSqlServer(std::shared_ptr<Impl> impl);
};

Expand Down

0 comments on commit fc38284

Please sign in to comment.