From fc38284c0c11f69c7e3c075932b0078240773d64 Mon Sep 17 00:00:00 2001 From: Xinyi Zou Date: Fri, 22 Sep 2023 17:01:18 +0800 Subject: [PATCH] 3 --- .../arrow_flight/auth_server_middleware.cpp | 55 ++++++++++++ .../arrow_flight/auth_server_middleware.h | 88 +++++++++++++++++++ .../service/arrow_flight/call_header_utils.h | 63 +++++++++++++ .../flight_sql_server_auth_handler.cpp | 54 ------------ .../flight_sql_server_auth_handler.h | 40 --------- .../arrow_flight/flight_sql_service.cpp | 19 ++-- .../service/arrow_flight/flight_sql_service.h | 6 +- 7 files changed, 225 insertions(+), 100 deletions(-) create mode 100644 be/src/service/arrow_flight/auth_server_middleware.cpp create mode 100644 be/src/service/arrow_flight/auth_server_middleware.h create mode 100644 be/src/service/arrow_flight/call_header_utils.h delete mode 100644 be/src/service/arrow_flight/flight_sql_server_auth_handler.cpp delete mode 100644 be/src/service/arrow_flight/flight_sql_server_auth_handler.h diff --git a/be/src/service/arrow_flight/auth_server_middleware.cpp b/be/src/service/arrow_flight/auth_server_middleware.cpp new file mode 100644 index 000000000000000..c0bf5b853b7df29 --- /dev/null +++ b/be/src/service/arrow_flight/auth_server_middleware.cpp @@ -0,0 +1,55 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#include "service/arrow_flight/auth_server_middleware.h" + +#include "service/arrow_flight/call_header_utils.h" + +namespace doris { +namespace flight { + +void NoOpHeaderAuthServerMiddleware::SendingHeaders( + arrow::flight::AddCallHeaders* outgoing_headers) { + outgoing_headers->AddHeader(kAuthHeader, std::string(kBearerPrefix) + kBearerDefaultToken); +} + +arrow::Status NoOpHeaderAuthServerMiddlewareFactory::StartCall( + const arrow::flight::CallInfo& info, const arrow::flight::ServerCallContext& context, + std::shared_ptr* middleware) { + std::string username, password; + ParseBasicHeader(context.incoming_headers(), username, password); + *middleware = std::make_shared(); + return arrow::Status::OK(); +} + +void NoOpBearerAuthServerMiddleware::SendingHeaders( + arrow::flight::AddCallHeaders* outgoing_headers) { + std::string bearer_token = + FindKeyValPrefixInCallHeaders(_incoming_headers, kAuthHeader, kBearerPrefix); + *_is_valid = (bearer_token == std::string(kBearerDefaultToken)); +} + +arrow::Status NoOpBearerAuthServerMiddlewareFactory::StartCall( + const arrow::flight::CallInfo& info, const arrow::flight::ServerCallContext& context, + std::shared_ptr* middleware) { + *middleware = std::make_shared(context.incoming_headers(), + &_is_valid); + return arrow::Status::OK(); +} + +} // namespace flight +} // namespace doris diff --git a/be/src/service/arrow_flight/auth_server_middleware.h b/be/src/service/arrow_flight/auth_server_middleware.h new file mode 100644 index 000000000000000..e5f40cf626455c6 --- /dev/null +++ b/be/src/service/arrow_flight/auth_server_middleware.h @@ -0,0 +1,88 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#pragma once + +#include + +#include "arrow/flight/server.h" +#include "arrow/flight/server_middleware.h" +#include "arrow/flight/types.h" + +namespace doris { +namespace flight { + +// Just return default bearer token. +class NoOpHeaderAuthServerMiddleware : public arrow::flight::ServerMiddleware { +public: + void SendingHeaders(arrow::flight::AddCallHeaders* outgoing_headers) override; + + void CallCompleted(const arrow::Status& status) override {} + + [[nodiscard]] std::string name() const override { return "NoOpHeaderAuthServerMiddleware"; } +}; + +// Factory for base64 header authentication. +// No actual authentication. +class NoOpHeaderAuthServerMiddlewareFactory : public arrow::flight::ServerMiddlewareFactory { +public: + NoOpHeaderAuthServerMiddlewareFactory() = default; + + arrow::Status StartCall(const arrow::flight::CallInfo& info, + const arrow::flight::ServerCallContext& context, + std::shared_ptr* middleware) override; +}; + +// A server middleware for validating incoming bearer header authentication. +// Just compare with default bearer token. +class NoOpBearerAuthServerMiddleware : public arrow::flight::ServerMiddleware { +public: + explicit NoOpBearerAuthServerMiddleware(const arrow::flight::CallHeaders& incoming_headers, + bool* isValid) + : _is_valid(isValid) { + _incoming_headers = incoming_headers; + } + + void SendingHeaders(arrow::flight::AddCallHeaders* outgoing_headers) override; + + void CallCompleted(const arrow::Status& status) override {} + + [[nodiscard]] std::string name() const override { return "NoOpBearerAuthServerMiddleware"; } + +private: + arrow::flight::CallHeaders _incoming_headers; + bool* _is_valid; +}; + +// Factory for base64 header authentication. +// No actual authentication. +class NoOpBearerAuthServerMiddlewareFactory : public arrow::flight::ServerMiddlewareFactory { +public: + NoOpBearerAuthServerMiddlewareFactory() : _is_valid(false) {} + + arrow::Status StartCall(const arrow::flight::CallInfo& info, + const arrow::flight::ServerCallContext& context, + std::shared_ptr* middleware) override; + + [[nodiscard]] bool GetIsValid() const { return _is_valid; } + +private: + bool _is_valid; +}; + +} // namespace flight +} // namespace doris diff --git a/be/src/service/arrow_flight/call_header_utils.h b/be/src/service/arrow_flight/call_header_utils.h new file mode 100644 index 000000000000000..b4e0e512beece4c --- /dev/null +++ b/be/src/service/arrow_flight/call_header_utils.h @@ -0,0 +1,63 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#pragma once + +#include "arrow/flight/types.h" +#include "arrow/util/base64.h" + +namespace doris { +namespace flight { + +const char kBearerDefaultToken[] = "bearertoken"; +const char kBasicPrefix[] = "Basic "; +const char kBearerPrefix[] = "Bearer "; +const char kAuthHeader[] = "authorization"; + +// Function to look in CallHeaders for a key that has a value starting with prefix and +// return the rest of the value after the prefix. +std::string FindKeyValPrefixInCallHeaders(const arrow::flight::CallHeaders& incoming_headers, + const std::string& key, const std::string& prefix) { + // Lambda function to compare characters without case sensitivity. + auto char_compare = [](const char& char1, const char& char2) { + return (::toupper(char1) == ::toupper(char2)); + }; + + auto iter = incoming_headers.find(key); + if (iter == incoming_headers.end()) { + return ""; + } + const std::string val(iter->second); + if (val.size() > prefix.length()) { + if (std::equal(val.begin(), val.begin() + prefix.length(), prefix.begin(), char_compare)) { + return val.substr(prefix.length()); + } + } + return ""; +} + +void ParseBasicHeader(const arrow::flight::CallHeaders& incoming_headers, std::string& username, + std::string& password) { + std::string encoded_credentials = + FindKeyValPrefixInCallHeaders(incoming_headers, kAuthHeader, kBasicPrefix); + std::stringstream decoded_stream(arrow::util::base64_decode(encoded_credentials)); + std::getline(decoded_stream, username, ':'); + std::getline(decoded_stream, password, ':'); +} + +} // namespace flight +} // namespace doris diff --git a/be/src/service/arrow_flight/flight_sql_server_auth_handler.cpp b/be/src/service/arrow_flight/flight_sql_server_auth_handler.cpp deleted file mode 100644 index 63d5ed99cb4c20d..000000000000000 --- a/be/src/service/arrow_flight/flight_sql_server_auth_handler.cpp +++ /dev/null @@ -1,54 +0,0 @@ -// Licensed to the Apache Software Foundation (ASF) under one -// or more contributor license agreements. See the NOTICE file -// distributed with this work for additional information -// regarding copyright ownership. The ASF licenses this file -// to you under the Apache License, Version 2.0 (the -// "License"); you may not use this file except in compliance -// with the License. You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, -// software distributed under the License is distributed on an -// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -// KIND, either express or implied. See the License for the -// specific language governing permissions and limitations -// under the License. - -#include "service/arrow_flight/flight_sql_server_auth_handler.h" - -#include - -namespace doris { -namespace flight { - -FlightSqlServerAuthHandler::FlightSqlServerAuthHandler() = default; - -FlightSqlServerAuthHandler::~FlightSqlServerAuthHandler() = default; - -arrow::Status FlightSqlServerAuthHandler::Authenticate( - const arrow::flight::ServerCallContext& context, arrow::flight::ServerAuthSender* outgoing, - arrow::flight::ServerAuthReader* incoming) { - std::string token; - RETURN_NOT_OK(incoming->Read(&token)); - ARROW_ASSIGN_OR_RAISE(arrow::flight::BasicAuth incoming_auth, - arrow::flight::BasicAuth::Deserialize(token)); - if (incoming_auth.username == "" || incoming_auth.password == "") { - return MakeFlightError(arrow::flight::FlightStatusCode::Unauthenticated, "Invalid token"); - } - RETURN_NOT_OK(outgoing->Write(incoming_auth.username)); - return arrow::Status::OK(); -} - -arrow::Status FlightSqlServerAuthHandler::IsValid(const arrow::flight::ServerCallContext& context, - const std::string& token, - std::string* peer_identity) { - if (token == "") { - return MakeFlightError(arrow::flight::FlightStatusCode::Unauthenticated, "Invalid token"); - } - *peer_identity = token; - return arrow::Status::OK(); -} - -} // namespace flight -} // namespace doris diff --git a/be/src/service/arrow_flight/flight_sql_server_auth_handler.h b/be/src/service/arrow_flight/flight_sql_server_auth_handler.h deleted file mode 100644 index f30c5c6c9476a74..000000000000000 --- a/be/src/service/arrow_flight/flight_sql_server_auth_handler.h +++ /dev/null @@ -1,40 +0,0 @@ -// Licensed to the Apache Software Foundation (ASF) under one -// or more contributor license agreements. See the NOTICE file -// distributed with this work for additional information -// regarding copyright ownership. The ASF licenses this file -// to you under the Apache License, Version 2.0 (the -// "License"); you may not use this file except in compliance -// with the License. You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, -// software distributed under the License is distributed on an -// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -// KIND, either express or implied. See the License for the -// specific language governing permissions and limitations -// under the License. - -#pragma once - -#include "arrow/flight/server_auth.h" -#include "arrow/flight/types.h" -#include "arrow/result.h" - -namespace doris { -namespace flight { - -class FlightSqlServerAuthHandler : public arrow::flight::ServerAuthHandler { -public: - explicit FlightSqlServerAuthHandler(); - ~FlightSqlServerAuthHandler() override; - arrow::Status Authenticate(const arrow::flight::ServerCallContext& context, - arrow::flight::ServerAuthSender* outgoing, - arrow::flight::ServerAuthReader* incoming) override; - - arrow::Status IsValid(const arrow::flight::ServerCallContext& context, const std::string& token, - std::string* peer_identity) override; -}; - -} // namespace flight -} // namespace doris diff --git a/be/src/service/arrow_flight/flight_sql_service.cpp b/be/src/service/arrow_flight/flight_sql_service.cpp index 5c99eeaadd1a5f6..719f7a466cccc7f 100644 --- a/be/src/service/arrow_flight/flight_sql_service.cpp +++ b/be/src/service/arrow_flight/flight_sql_service.cpp @@ -22,7 +22,6 @@ #include "arrow/flight/sql/server.h" #include "service/arrow_flight/arrow_flight_batch_reader.h" #include "service/arrow_flight/flight_sql_info.h" -#include "service/arrow_flight/flight_sql_server_auth_handler.h" #include "service/backend_options.h" #include "util/arrow/utils.h" #include "util/uid_util.h" @@ -75,7 +74,7 @@ class FlightSqlServer::Impl { } }; -FlightSqlServer::FlightSqlServer(std::shared_ptr impl) : impl_(std::move(impl)) {} +FlightSqlServer::FlightSqlServer(std::shared_ptr impl) : _impl(std::move(impl)) {} arrow::Result> FlightSqlServer::create() { std::shared_ptr impl = std::make_shared(); @@ -95,7 +94,7 @@ FlightSqlServer::~FlightSqlServer() { arrow::Result> FlightSqlServer::DoGetStatement( const arrow::flight::ServerCallContext& context, const arrow::flight::sql::StatementQueryTicket& command) { - return impl_->DoGetStatement(context, command); + return _impl->DoGetStatement(context, command); } Status FlightSqlServer::init(int port) { @@ -109,8 +108,18 @@ Status FlightSqlServer::init(int port) { arrow::flight::Location::ForGrpcTcp(BackendOptions::get_service_bind_address(), port) .Value(&bind_location)); arrow::flight::FlightServerOptions flight_options(bind_location); - flight_options.auth_handler = std::make_unique(); - // flight_options.auth_handler = std::make_unique(); + + // Not authenticated in BE flight server. + // After the authentication between the ADBC Client and the FE flight server is completed, + // the FE flight server will put the query id in the Ticket and send it back to the Client. + // When the Client uses the Ticket to fetch data from the BE flight server, the BE flight + // server will verify the query id, this step is equivalent to authentication. + _header_middleware = std::make_shared(); + _bearer_middleware = std::make_shared(); + flight_options.auth_handler = std::make_unique(); + flight_options.middleware.push_back({"header-auth-server", _header_middleware}); + flight_options.middleware.push_back({"bearer-auth-server", _bearer_middleware}); + RETURN_DORIS_STATUS_IF_ERROR(Init(flight_options)); LOG(INFO) << "Arrow Flight Service bind to host: " << BackendOptions::get_service_bind_address() << ", port: " << port; diff --git a/be/src/service/arrow_flight/flight_sql_service.h b/be/src/service/arrow_flight/flight_sql_service.h index 4772e98d81d114a..8f3ed088c293e99 100644 --- a/be/src/service/arrow_flight/flight_sql_service.h +++ b/be/src/service/arrow_flight/flight_sql_service.h @@ -21,6 +21,7 @@ #include "arrow/result.h" #include "common/status.h" #include "service/arrow_flight/arrow_flight_batch_reader.h" +#include "service/arrow_flight/auth_server_middleware.h" namespace doris { namespace flight { @@ -40,9 +41,12 @@ class FlightSqlServer : public arrow::flight::sql::FlightSqlServerBase { private: class Impl; - std::shared_ptr impl_; + std::shared_ptr _impl; bool _inited = false; + std::shared_ptr _header_middleware; + std::shared_ptr _bearer_middleware; + explicit FlightSqlServer(std::shared_ptr impl); };