From 59c2d1db3200015216be1299b87a21524140902b Mon Sep 17 00:00:00 2001 From: Takeshi Yoneda Date: Wed, 11 Dec 2024 21:20:18 +0000 Subject: [PATCH] impl Signed-off-by: Takeshi Yoneda --- .../filters/common/ratelimit/ratelimit.h | 2 +- .../common/ratelimit/ratelimit_impl.cc | 13 ++-- .../filters/common/ratelimit/ratelimit_impl.h | 2 +- .../filters/http/ratelimit/config.cc | 6 +- .../filters/http/ratelimit/ratelimit.cc | 59 +++++++++++------ .../filters/http/ratelimit/ratelimit.h | 66 +++++++++++++++++-- .../filters/common/ratelimit/mocks.h | 2 +- test/extensions/filters/http/ratelimit/BUILD | 1 + .../filters/http/ratelimit/ratelimit_test.cc | 4 +- 9 files changed, 116 insertions(+), 39 deletions(-) diff --git a/source/extensions/filters/common/ratelimit/ratelimit.h b/source/extensions/filters/common/ratelimit/ratelimit.h index 11267cc7db03..44a29d13da6a 100644 --- a/source/extensions/filters/common/ratelimit/ratelimit.h +++ b/source/extensions/filters/common/ratelimit/ratelimit.h @@ -90,7 +90,7 @@ class Client { */ virtual void limit(RequestCallbacks& callbacks, const std::string& domain, const std::vector& descriptors, - Tracing::Span& parent_span, const StreamInfo::StreamInfo& stream_info, + Tracing::Span& parent_span, OptRef stream_info, uint32_t hits_addend) PURE; }; diff --git a/source/extensions/filters/common/ratelimit/ratelimit_impl.cc b/source/extensions/filters/common/ratelimit/ratelimit_impl.cc index 3350e132562a..4b8b1866d005 100644 --- a/source/extensions/filters/common/ratelimit/ratelimit_impl.cc +++ b/source/extensions/filters/common/ratelimit/ratelimit_impl.cc @@ -59,18 +59,19 @@ void GrpcClientImpl::createRequest(envoy::service::ratelimit::v3::RateLimitReque void GrpcClientImpl::limit(RequestCallbacks& callbacks, const std::string& domain, const std::vector& descriptors, - Tracing::Span& parent_span, const StreamInfo::StreamInfo& stream_info, - uint32_t hits_addend) { + Tracing::Span& parent_span, + OptRef stream_info, uint32_t hits_addend) { ASSERT(callbacks_ == nullptr); callbacks_ = &callbacks; envoy::service::ratelimit::v3::RateLimitRequest request; createRequest(request, domain, descriptors, hits_addend); - request_ = - async_client_->send(service_method_, request, *this, parent_span, - Http::AsyncClient::RequestOptions().setTimeout(timeout_).setParentContext( - Http::AsyncClient::ParentContext{&stream_info})); + auto options = Http::AsyncClient::RequestOptions().setTimeout(timeout_); + if (stream_info) { + options.setParentContext(Http::AsyncClient::ParentContext{&*stream_info}); + } + request_ = async_client_->send(service_method_, request, *this, parent_span, options); } void GrpcClientImpl::onSuccess( diff --git a/source/extensions/filters/common/ratelimit/ratelimit_impl.h b/source/extensions/filters/common/ratelimit/ratelimit_impl.h index 79502ec2ef78..61a6c1c5ec88 100644 --- a/source/extensions/filters/common/ratelimit/ratelimit_impl.h +++ b/source/extensions/filters/common/ratelimit/ratelimit_impl.h @@ -57,7 +57,7 @@ class GrpcClientImpl : public Client, void cancel() override; void limit(RequestCallbacks& callbacks, const std::string& domain, const std::vector& descriptors, - Tracing::Span& parent_span, const StreamInfo::StreamInfo& stream_info, + Tracing::Span& parent_span, OptRef stream_info, uint32_t hits_addend = 0) override; // Grpc::AsyncRequestCallbacks diff --git a/source/extensions/filters/http/ratelimit/config.cc b/source/extensions/filters/http/ratelimit/config.cc index 59711a014b82..a9dfb12117ef 100644 --- a/source/extensions/filters/http/ratelimit/config.cc +++ b/source/extensions/filters/http/ratelimit/config.cc @@ -23,9 +23,9 @@ Http::FilterFactoryCb RateLimitFilterConfig::createFilterFactoryFromProtoTyped( auto& server_context = context.serverFactoryContext(); ASSERT(!proto_config.domain().empty()); - FilterConfigSharedPtr filter_config(new FilterConfig(proto_config, server_context.localInfo(), - context.scope(), server_context.runtime(), - server_context.httpContext())); + FilterConfigSharedPtr filter_config(new FilterConfig( + proto_config, server_context.localInfo(), context.scope(), server_context.runtime(), + server_context.httpContext(), server_context.threadLocal())); const std::chrono::milliseconds timeout = std::chrono::milliseconds(PROTOBUF_GET_MS_OR_DEFAULT(proto_config, timeout, 20)); diff --git a/source/extensions/filters/http/ratelimit/ratelimit.cc b/source/extensions/filters/http/ratelimit/ratelimit.cc index 89d3f6e00045..eb5e919bc636 100644 --- a/source/extensions/filters/http/ratelimit/ratelimit.cc +++ b/source/extensions/filters/http/ratelimit/ratelimit.cc @@ -1,5 +1,6 @@ #include "source/extensions/filters/http/ratelimit/ratelimit.h" +#include #include #include @@ -83,27 +84,27 @@ void Filter::initiateCall(const Http::RequestHeaderMap& headers) { } break; } - makeRateLimitRequest(); -} -void Filter::makeRateLimitRequest() { if (!descriptors_.empty()) { - const StreamInfo::UInt32Accessor* hits_addend_filter_state = - callbacks_->streamInfo().filterState()->getDataReadOnly( - HitsAddendFilterStateKey); - double hits_addend = 0; - if (hits_addend_filter_state != nullptr) { - hits_addend = hits_addend_filter_state->value(); - } - state_ = State::Calling; initiating_call_ = true; client_->limit(*this, getDomain(), descriptors_, callbacks_->activeSpan(), - callbacks_->streamInfo(), hits_addend); + callbacks_->streamInfo(), getHitAddend()); initiating_call_ = false; } } +double Filter::getHitAddend() { + const StreamInfo::UInt32Accessor* hits_addend_filter_state = + callbacks_->streamInfo().filterState()->getDataReadOnly( + HitsAddendFilterStateKey); + double hits_addend = 0; + if (hits_addend_filter_state != nullptr) { + hits_addend = hits_addend_filter_state->value(); + } + return hits_addend; +} + Http::FilterHeadersStatus Filter::decodeHeaders(Http::RequestHeaderMap& headers, bool) { if (!config_->runtime().snapshot().featureEnabled("ratelimit.http_filter_enabled", 100)) { return Http::FilterHeadersStatus::Continue; @@ -159,13 +160,20 @@ Http::FilterMetadataStatus Filter::encodeMetadata(Http::MetadataMap&) { void Filter::setEncoderFilterCallbacks(Http::StreamEncoderFilterCallbacks&) {} void Filter::onDestroy() { - if (config_->applyOnStreamDone()) { - state_ = State::OnStreamDone; - makeRateLimitRequest(); + if (state_ == State::Calling) { + state_ = State::Complete; + client_->cancel(); } else { - if (state_ == State::Calling) { - state_ = State::Complete; - client_->cancel(); + // If the filter doesn't have a outstanding limit request (made during decodeHeaders) and has descriptors, + // then we can apply the rate limit on stream done if the config allows it. + if (config_->applyOnStreamDone() && !descriptors_.empty()) { + client_->cancel(); // Clears the internal state of the client, so that we can reuse it. + client_->limit(*this, getDomain(), descriptors_, Tracing::NullSpan::instance(), absl::nullopt, + getHitAddend()); + state_ = State::PendingReuqestOnStreamDone; + // Since this filter is being destroyed, we need to keep the client alive until the request + // is complete. So we add this filter to the destroy pending list at the filter config level. + config_->addDestroyPendingFilter(shared_from_this()); } } } @@ -176,9 +184,11 @@ void Filter::complete(Filters::Common::RateLimit::LimitStatus status, Http::RequestHeaderMapPtr&& request_headers_to_add, const std::string& response_body, Filters::Common::RateLimit::DynamicMetadataPtr&& dynamic_metadata) { - if (state_ == State::OnStreamDone) { - // We have no more work to do as the rate limit request made during on completion is - // fire-and-forget. + if (state_ == State::PendingReuqestOnStreamDone) { + // Since this filter is already destroyed from HCM perspective, there's nothing to do here. + // Simply remove it from the destroy pending list which in turn will release the filter shared + // pointer. + config_->removeDestroyPendingFilter(*this); return; } state_ = State::Complete; @@ -341,6 +351,13 @@ std::string Filter::getDomain() { return config_->domain(); } +DestroyPendingFilterThreadLocal::~DestroyPendingFilterThreadLocal() { + for (const auto& filter : map_) { + filter.second->client()->cancel(); + } + map_.clear(); +} + } // namespace RateLimitFilter } // namespace HttpFilters } // namespace Extensions diff --git a/source/extensions/filters/http/ratelimit/ratelimit.h b/source/extensions/filters/http/ratelimit/ratelimit.h index f689711ae41d..ede30e08b690 100644 --- a/source/extensions/filters/http/ratelimit/ratelimit.h +++ b/source/extensions/filters/http/ratelimit/ratelimit.h @@ -35,6 +35,37 @@ enum class FilterRequestType { Internal, External, Both }; */ enum class VhRateLimitOptions { Override, Include, Ignore }; +class Filter; +using FilterSharedPtr = std::shared_ptr; + +/** + * Thread local storage for destroy pending filters. + */ +class DestroyPendingFilterThreadLocal : public ThreadLocal::ThreadLocalObject { +public: + DestroyPendingFilterThreadLocal() = default; + // When this class is destructed, it will cancel all the pending filters and release the shared + // pointers. + ~DestroyPendingFilterThreadLocal() override; + /** + * Add the filter to the destroy pending list. + * @param filter the filter to add. + */ + void addDestroyPendingFilter(FilterSharedPtr filter) { + map_.emplace(filter.get(), std::move(filter)); + } + /** + * Remove the filter from the destroy pending list. + * @param filter the const reference to the filter to remove. + */ + void removeDestroyPendingFilter(const Filter& filter) { map_.erase(&filter); } + +private: + // We use a raw pointer as the key to be able to remove filters when the callback is called where + // shared_from_this() is not available. + absl::flat_hash_map map_ = {}; +}; + /** * Global configuration for the HTTP rate limit filter. */ @@ -42,7 +73,8 @@ class FilterConfig { public: FilterConfig(const envoy::extensions::filters::http::ratelimit::v3::RateLimit& config, const LocalInfo::LocalInfo& local_info, Stats::Scope& scope, - Runtime::Loader& runtime, Http::Context& http_context) + Runtime::Loader& runtime, Http::Context& http_context, + ThreadLocal::SlotAllocator& tls_allocator) : domain_(config.domain()), stage_(static_cast(config.stage())), request_type_(config.request_type().empty() ? stringToType("both") : stringToType(config.request_type())), @@ -62,7 +94,16 @@ class FilterConfig { Envoy::Router::HeaderParser::configure(config.response_headers_to_add()), Router::HeaderParserPtr)), status_on_error_(toRatelimitServerErrorCode(config.status_on_error().code())), - apply_on_stream_done_(config.apply_on_stream_done()) {} + apply_on_stream_done_(config.apply_on_stream_done()) { + if (config.apply_on_stream_done()) { + pending_filter_slot_ = + ThreadLocal::TypedSlot::makeUnique(tls_allocator); + pending_filter_slot_->set( + [](Event::Dispatcher&) -> std::shared_ptr { + return std::make_shared(); + }); + } + } const std::string& domain() const { return domain_; } const LocalInfo::LocalInfo& localInfo() const { return local_info_; } uint64_t stage() const { return stage_; } @@ -81,6 +122,14 @@ class FilterConfig { const Router::HeaderParser& responseHeadersParser() const { return *response_headers_parser_; } Http::Code statusOnError() const { return status_on_error_; } bool applyOnStreamDone() const { return apply_on_stream_done_; } + void addDestroyPendingFilter(FilterSharedPtr filter) { + ASSERT(pending_filter_slot_); + (*pending_filter_slot_)->addDestroyPendingFilter(std::move(filter)); + } + void removeDestroyPendingFilter(const Filter& filter) { + ASSERT(pending_filter_slot_); + (*pending_filter_slot_)->removeDestroyPendingFilter(filter); + } private: static FilterRequestType stringToType(const std::string& request_type) { @@ -110,6 +159,9 @@ class FilterConfig { return Http::Code::InternalServerError; } + // Thread local storage for destory pending Filter by a hash map of shared pointers. + ThreadLocal::TypedSlotPtr pending_filter_slot_ = nullptr; + const std::string domain_; const uint64_t stage_; const FilterRequestType request_type_; @@ -154,7 +206,9 @@ class FilterConfigPerRoute : public Router::RouteSpecificFilterConfig { * HTTP rate limit filter. Depending on the route configuration, this filter calls the global * rate limiting service before allowing further filter iteration. */ -class Filter : public Http::StreamFilter, public Filters::Common::RateLimit::RequestCallbacks { +class Filter : public Http::StreamFilter, + public Filters::Common::RateLimit::RequestCallbacks, + std::enable_shared_from_this { public: Filter(FilterConfigSharedPtr config, Filters::Common::RateLimit::ClientPtr&& client) : config_(config), client_(std::move(client)) {} @@ -186,20 +240,22 @@ class Filter : public Http::StreamFilter, public Filters::Common::RateLimit::Req const std::string& response_body, Filters::Common::RateLimit::DynamicMetadataPtr&& dynamic_metadata) override; + Filters::Common::RateLimit::ClientPtr& client() { return client_; } + private: void initiateCall(const Http::RequestHeaderMap& headers); - void makeRateLimitRequest(); void populateRateLimitDescriptors(const Router::RateLimitPolicy& rate_limit_policy, std::vector& descriptors, const Http::RequestHeaderMap& headers) const; void populateResponseHeaders(Http::HeaderMap& response_headers, bool from_local_reply); void appendRequestHeaders(Http::HeaderMapPtr& request_headers_to_add); + double getHitAddend(); VhRateLimitOptions getVirtualHostRateLimitOption(const Router::RouteConstSharedPtr& route); std::string getDomain(); Http::Context& httpContext() { return config_->httpContext(); } - enum class State { NotStarted, Calling, Complete, Responded, OnStreamDone }; + enum class State { NotStarted, Calling, Complete, Responded, PendingReuqestOnStreamDone }; FilterConfigSharedPtr config_; Filters::Common::RateLimit::ClientPtr client_; diff --git a/test/extensions/filters/common/ratelimit/mocks.h b/test/extensions/filters/common/ratelimit/mocks.h index 5155d335f057..259fa3f08881 100644 --- a/test/extensions/filters/common/ratelimit/mocks.h +++ b/test/extensions/filters/common/ratelimit/mocks.h @@ -26,7 +26,7 @@ class MockClient : public Client { MOCK_METHOD(void, limit, (RequestCallbacks & callbacks, const std::string& domain, const std::vector& descriptors, - Tracing::Span& parent_span, const StreamInfo::StreamInfo& stream_info, + Tracing::Span& parent_span, OptRef stream_info, uint32_t hits_addend)); }; diff --git a/test/extensions/filters/http/ratelimit/BUILD b/test/extensions/filters/http/ratelimit/BUILD index db8ba366d5c5..be0479341c9b 100644 --- a/test/extensions/filters/http/ratelimit/BUILD +++ b/test/extensions/filters/http/ratelimit/BUILD @@ -29,6 +29,7 @@ envoy_extension_cc_test( "//test/mocks/local_info:local_info_mocks", "//test/mocks/ratelimit:ratelimit_mocks", "//test/mocks/runtime:runtime_mocks", + "//test/mocks/thread_local:thread_local_mocks", "//test/mocks/tracing:tracing_mocks", "//test/test_common:utility_lib", "@envoy_api//envoy/extensions/filters/http/ratelimit/v3:pkg_cc_proto", diff --git a/test/extensions/filters/http/ratelimit/ratelimit_test.cc b/test/extensions/filters/http/ratelimit/ratelimit_test.cc index 6caa81e8eb5f..047e190f914c 100644 --- a/test/extensions/filters/http/ratelimit/ratelimit_test.cc +++ b/test/extensions/filters/http/ratelimit/ratelimit_test.cc @@ -18,6 +18,7 @@ #include "test/mocks/local_info/mocks.h" #include "test/mocks/ratelimit/mocks.h" #include "test/mocks/runtime/mocks.h" +#include "test/mocks/thread_local/mocks.h" #include "test/mocks/tracing/mocks.h" #include "test/test_common/printers.h" #include "test/test_common/utility.h" @@ -56,7 +57,7 @@ class HttpRateLimitFilterTest : public testing::Test { TestUtility::loadFromYaml(yaml, proto_config); config_ = std::make_shared(proto_config, local_info_, *stats_store_.rootScope(), - runtime_, http_context_); + runtime_, http_context_, thread_local_); client_ = new Filters::Common::RateLimit::MockClient(); filter_ = std::make_unique(config_, Filters::Common::RateLimit::ClientPtr{client_}); @@ -134,6 +135,7 @@ class HttpRateLimitFilterTest : public testing::Test { FilterConfigSharedPtr config_; std::unique_ptr filter_; NiceMock runtime_; + NiceMock thread_local_; NiceMock route_rate_limit_; NiceMock vh_rate_limit_; std::vector descriptor_{{{{"descriptor_key", "descriptor_value"}}}};