Skip to content

Commit

Permalink
impl
Browse files Browse the repository at this point in the history
Signed-off-by: Takeshi Yoneda <[email protected]>
  • Loading branch information
mathetake committed Dec 11, 2024
1 parent 6598d44 commit 59c2d1d
Show file tree
Hide file tree
Showing 9 changed files with 116 additions and 39 deletions.
2 changes: 1 addition & 1 deletion source/extensions/filters/common/ratelimit/ratelimit.h
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,7 @@ class Client {
*/
virtual void limit(RequestCallbacks& callbacks, const std::string& domain,
const std::vector<Envoy::RateLimit::Descriptor>& descriptors,
Tracing::Span& parent_span, const StreamInfo::StreamInfo& stream_info,
Tracing::Span& parent_span, OptRef<const StreamInfo::StreamInfo> stream_info,
uint32_t hits_addend) PURE;
};

Expand Down
13 changes: 7 additions & 6 deletions source/extensions/filters/common/ratelimit/ratelimit_impl.cc
Original file line number Diff line number Diff line change
Expand Up @@ -59,18 +59,19 @@ void GrpcClientImpl::createRequest(envoy::service::ratelimit::v3::RateLimitReque

void GrpcClientImpl::limit(RequestCallbacks& callbacks, const std::string& domain,
const std::vector<Envoy::RateLimit::Descriptor>& descriptors,
Tracing::Span& parent_span, const StreamInfo::StreamInfo& stream_info,
uint32_t hits_addend) {
Tracing::Span& parent_span,
OptRef<const StreamInfo::StreamInfo> stream_info, uint32_t hits_addend) {
ASSERT(callbacks_ == nullptr);
callbacks_ = &callbacks;

envoy::service::ratelimit::v3::RateLimitRequest request;
createRequest(request, domain, descriptors, hits_addend);

request_ =
async_client_->send(service_method_, request, *this, parent_span,
Http::AsyncClient::RequestOptions().setTimeout(timeout_).setParentContext(
Http::AsyncClient::ParentContext{&stream_info}));
auto options = Http::AsyncClient::RequestOptions().setTimeout(timeout_);
if (stream_info) {
options.setParentContext(Http::AsyncClient::ParentContext{&*stream_info});
}
request_ = async_client_->send(service_method_, request, *this, parent_span, options);
}

void GrpcClientImpl::onSuccess(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@ class GrpcClientImpl : public Client,
void cancel() override;
void limit(RequestCallbacks& callbacks, const std::string& domain,
const std::vector<Envoy::RateLimit::Descriptor>& descriptors,
Tracing::Span& parent_span, const StreamInfo::StreamInfo& stream_info,
Tracing::Span& parent_span, OptRef<const StreamInfo::StreamInfo> stream_info,
uint32_t hits_addend = 0) override;

// Grpc::AsyncRequestCallbacks
Expand Down
6 changes: 3 additions & 3 deletions source/extensions/filters/http/ratelimit/config.cc
Original file line number Diff line number Diff line change
Expand Up @@ -23,9 +23,9 @@ Http::FilterFactoryCb RateLimitFilterConfig::createFilterFactoryFromProtoTyped(
auto& server_context = context.serverFactoryContext();

ASSERT(!proto_config.domain().empty());
FilterConfigSharedPtr filter_config(new FilterConfig(proto_config, server_context.localInfo(),
context.scope(), server_context.runtime(),
server_context.httpContext()));
FilterConfigSharedPtr filter_config(new FilterConfig(
proto_config, server_context.localInfo(), context.scope(), server_context.runtime(),
server_context.httpContext(), server_context.threadLocal()));
const std::chrono::milliseconds timeout =
std::chrono::milliseconds(PROTOBUF_GET_MS_OR_DEFAULT(proto_config, timeout, 20));

Expand Down
59 changes: 38 additions & 21 deletions source/extensions/filters/http/ratelimit/ratelimit.cc
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
#include "source/extensions/filters/http/ratelimit/ratelimit.h"

#include <memory>
#include <string>
#include <vector>

Expand Down Expand Up @@ -83,27 +84,27 @@ void Filter::initiateCall(const Http::RequestHeaderMap& headers) {
}
break;
}
makeRateLimitRequest();
}

void Filter::makeRateLimitRequest() {
if (!descriptors_.empty()) {
const StreamInfo::UInt32Accessor* hits_addend_filter_state =
callbacks_->streamInfo().filterState()->getDataReadOnly<StreamInfo::UInt32Accessor>(
HitsAddendFilterStateKey);
double hits_addend = 0;
if (hits_addend_filter_state != nullptr) {
hits_addend = hits_addend_filter_state->value();
}

state_ = State::Calling;
initiating_call_ = true;
client_->limit(*this, getDomain(), descriptors_, callbacks_->activeSpan(),
callbacks_->streamInfo(), hits_addend);
callbacks_->streamInfo(), getHitAddend());
initiating_call_ = false;
}
}

double Filter::getHitAddend() {
const StreamInfo::UInt32Accessor* hits_addend_filter_state =
callbacks_->streamInfo().filterState()->getDataReadOnly<StreamInfo::UInt32Accessor>(
HitsAddendFilterStateKey);
double hits_addend = 0;
if (hits_addend_filter_state != nullptr) {
hits_addend = hits_addend_filter_state->value();
}
return hits_addend;
}

Http::FilterHeadersStatus Filter::decodeHeaders(Http::RequestHeaderMap& headers, bool) {
if (!config_->runtime().snapshot().featureEnabled("ratelimit.http_filter_enabled", 100)) {
return Http::FilterHeadersStatus::Continue;
Expand Down Expand Up @@ -159,13 +160,20 @@ Http::FilterMetadataStatus Filter::encodeMetadata(Http::MetadataMap&) {
void Filter::setEncoderFilterCallbacks(Http::StreamEncoderFilterCallbacks&) {}

void Filter::onDestroy() {
if (config_->applyOnStreamDone()) {
state_ = State::OnStreamDone;
makeRateLimitRequest();
if (state_ == State::Calling) {
state_ = State::Complete;
client_->cancel();
} else {
if (state_ == State::Calling) {
state_ = State::Complete;
client_->cancel();
// If the filter doesn't have a outstanding limit request (made during decodeHeaders) and has descriptors,
// then we can apply the rate limit on stream done if the config allows it.
if (config_->applyOnStreamDone() && !descriptors_.empty()) {
client_->cancel(); // Clears the internal state of the client, so that we can reuse it.
client_->limit(*this, getDomain(), descriptors_, Tracing::NullSpan::instance(), absl::nullopt,
getHitAddend());
state_ = State::PendingReuqestOnStreamDone;
// Since this filter is being destroyed, we need to keep the client alive until the request
// is complete. So we add this filter to the destroy pending list at the filter config level.
config_->addDestroyPendingFilter(shared_from_this());
}
}
}
Expand All @@ -176,9 +184,11 @@ void Filter::complete(Filters::Common::RateLimit::LimitStatus status,
Http::RequestHeaderMapPtr&& request_headers_to_add,
const std::string& response_body,
Filters::Common::RateLimit::DynamicMetadataPtr&& dynamic_metadata) {
if (state_ == State::OnStreamDone) {
// We have no more work to do as the rate limit request made during on completion is
// fire-and-forget.
if (state_ == State::PendingReuqestOnStreamDone) {
// Since this filter is already destroyed from HCM perspective, there's nothing to do here.
// Simply remove it from the destroy pending list which in turn will release the filter shared
// pointer.
config_->removeDestroyPendingFilter(*this);
return;
}
state_ = State::Complete;
Expand Down Expand Up @@ -341,6 +351,13 @@ std::string Filter::getDomain() {
return config_->domain();
}

DestroyPendingFilterThreadLocal::~DestroyPendingFilterThreadLocal() {
for (const auto& filter : map_) {
filter.second->client()->cancel();
}
map_.clear();
}

} // namespace RateLimitFilter
} // namespace HttpFilters
} // namespace Extensions
Expand Down
66 changes: 61 additions & 5 deletions source/extensions/filters/http/ratelimit/ratelimit.h
Original file line number Diff line number Diff line change
Expand Up @@ -35,14 +35,46 @@ enum class FilterRequestType { Internal, External, Both };
*/
enum class VhRateLimitOptions { Override, Include, Ignore };

class Filter;
using FilterSharedPtr = std::shared_ptr<Filter>;

/**
* Thread local storage for destroy pending filters.
*/
class DestroyPendingFilterThreadLocal : public ThreadLocal::ThreadLocalObject {
public:
DestroyPendingFilterThreadLocal() = default;
// When this class is destructed, it will cancel all the pending filters and release the shared
// pointers.
~DestroyPendingFilterThreadLocal() override;
/**
* Add the filter to the destroy pending list.
* @param filter the filter to add.
*/
void addDestroyPendingFilter(FilterSharedPtr filter) {
map_.emplace(filter.get(), std::move(filter));
}
/**
* Remove the filter from the destroy pending list.
* @param filter the const reference to the filter to remove.
*/
void removeDestroyPendingFilter(const Filter& filter) { map_.erase(&filter); }

private:
// We use a raw pointer as the key to be able to remove filters when the callback is called where
// shared_from_this() is not available.
absl::flat_hash_map<Filter*, FilterSharedPtr> map_ = {};
};

/**
* Global configuration for the HTTP rate limit filter.
*/
class FilterConfig {
public:
FilterConfig(const envoy::extensions::filters::http::ratelimit::v3::RateLimit& config,
const LocalInfo::LocalInfo& local_info, Stats::Scope& scope,
Runtime::Loader& runtime, Http::Context& http_context)
Runtime::Loader& runtime, Http::Context& http_context,
ThreadLocal::SlotAllocator& tls_allocator)
: domain_(config.domain()), stage_(static_cast<uint64_t>(config.stage())),
request_type_(config.request_type().empty() ? stringToType("both")
: stringToType(config.request_type())),
Expand All @@ -62,7 +94,16 @@ class FilterConfig {
Envoy::Router::HeaderParser::configure(config.response_headers_to_add()),
Router::HeaderParserPtr)),
status_on_error_(toRatelimitServerErrorCode(config.status_on_error().code())),
apply_on_stream_done_(config.apply_on_stream_done()) {}
apply_on_stream_done_(config.apply_on_stream_done()) {
if (config.apply_on_stream_done()) {
pending_filter_slot_ =
ThreadLocal::TypedSlot<DestroyPendingFilterThreadLocal>::makeUnique(tls_allocator);
pending_filter_slot_->set(
[](Event::Dispatcher&) -> std::shared_ptr<DestroyPendingFilterThreadLocal> {
return std::make_shared<DestroyPendingFilterThreadLocal>();
});
}
}
const std::string& domain() const { return domain_; }
const LocalInfo::LocalInfo& localInfo() const { return local_info_; }
uint64_t stage() const { return stage_; }
Expand All @@ -81,6 +122,14 @@ class FilterConfig {
const Router::HeaderParser& responseHeadersParser() const { return *response_headers_parser_; }
Http::Code statusOnError() const { return status_on_error_; }
bool applyOnStreamDone() const { return apply_on_stream_done_; }
void addDestroyPendingFilter(FilterSharedPtr filter) {
ASSERT(pending_filter_slot_);
(*pending_filter_slot_)->addDestroyPendingFilter(std::move(filter));
}
void removeDestroyPendingFilter(const Filter& filter) {
ASSERT(pending_filter_slot_);
(*pending_filter_slot_)->removeDestroyPendingFilter(filter);
}

private:
static FilterRequestType stringToType(const std::string& request_type) {
Expand Down Expand Up @@ -110,6 +159,9 @@ class FilterConfig {
return Http::Code::InternalServerError;
}

// Thread local storage for destory pending Filter by a hash map of shared pointers.
ThreadLocal::TypedSlotPtr<DestroyPendingFilterThreadLocal> pending_filter_slot_ = nullptr;

const std::string domain_;
const uint64_t stage_;
const FilterRequestType request_type_;
Expand Down Expand Up @@ -154,7 +206,9 @@ class FilterConfigPerRoute : public Router::RouteSpecificFilterConfig {
* HTTP rate limit filter. Depending on the route configuration, this filter calls the global
* rate limiting service before allowing further filter iteration.
*/
class Filter : public Http::StreamFilter, public Filters::Common::RateLimit::RequestCallbacks {
class Filter : public Http::StreamFilter,
public Filters::Common::RateLimit::RequestCallbacks,
std::enable_shared_from_this<Filter> {
public:
Filter(FilterConfigSharedPtr config, Filters::Common::RateLimit::ClientPtr&& client)
: config_(config), client_(std::move(client)) {}
Expand Down Expand Up @@ -186,20 +240,22 @@ class Filter : public Http::StreamFilter, public Filters::Common::RateLimit::Req
const std::string& response_body,
Filters::Common::RateLimit::DynamicMetadataPtr&& dynamic_metadata) override;

Filters::Common::RateLimit::ClientPtr& client() { return client_; }

private:
void initiateCall(const Http::RequestHeaderMap& headers);
void makeRateLimitRequest();
void populateRateLimitDescriptors(const Router::RateLimitPolicy& rate_limit_policy,
std::vector<Envoy::RateLimit::Descriptor>& descriptors,
const Http::RequestHeaderMap& headers) const;
void populateResponseHeaders(Http::HeaderMap& response_headers, bool from_local_reply);
void appendRequestHeaders(Http::HeaderMapPtr& request_headers_to_add);
double getHitAddend();
VhRateLimitOptions getVirtualHostRateLimitOption(const Router::RouteConstSharedPtr& route);
std::string getDomain();

Http::Context& httpContext() { return config_->httpContext(); }

enum class State { NotStarted, Calling, Complete, Responded, OnStreamDone };
enum class State { NotStarted, Calling, Complete, Responded, PendingReuqestOnStreamDone };

FilterConfigSharedPtr config_;
Filters::Common::RateLimit::ClientPtr client_;
Expand Down
2 changes: 1 addition & 1 deletion test/extensions/filters/common/ratelimit/mocks.h
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ class MockClient : public Client {
MOCK_METHOD(void, limit,
(RequestCallbacks & callbacks, const std::string& domain,
const std::vector<Envoy::RateLimit::Descriptor>& descriptors,
Tracing::Span& parent_span, const StreamInfo::StreamInfo& stream_info,
Tracing::Span& parent_span, OptRef<const StreamInfo::StreamInfo> stream_info,
uint32_t hits_addend));
};

Expand Down
1 change: 1 addition & 0 deletions test/extensions/filters/http/ratelimit/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ envoy_extension_cc_test(
"//test/mocks/local_info:local_info_mocks",
"//test/mocks/ratelimit:ratelimit_mocks",
"//test/mocks/runtime:runtime_mocks",
"//test/mocks/thread_local:thread_local_mocks",
"//test/mocks/tracing:tracing_mocks",
"//test/test_common:utility_lib",
"@envoy_api//envoy/extensions/filters/http/ratelimit/v3:pkg_cc_proto",
Expand Down
4 changes: 3 additions & 1 deletion test/extensions/filters/http/ratelimit/ratelimit_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
#include "test/mocks/local_info/mocks.h"
#include "test/mocks/ratelimit/mocks.h"
#include "test/mocks/runtime/mocks.h"
#include "test/mocks/thread_local/mocks.h"
#include "test/mocks/tracing/mocks.h"
#include "test/test_common/printers.h"
#include "test/test_common/utility.h"
Expand Down Expand Up @@ -56,7 +57,7 @@ class HttpRateLimitFilterTest : public testing::Test {
TestUtility::loadFromYaml(yaml, proto_config);

config_ = std::make_shared<FilterConfig>(proto_config, local_info_, *stats_store_.rootScope(),
runtime_, http_context_);
runtime_, http_context_, thread_local_);

client_ = new Filters::Common::RateLimit::MockClient();
filter_ = std::make_unique<Filter>(config_, Filters::Common::RateLimit::ClientPtr{client_});
Expand Down Expand Up @@ -134,6 +135,7 @@ class HttpRateLimitFilterTest : public testing::Test {
FilterConfigSharedPtr config_;
std::unique_ptr<Filter> filter_;
NiceMock<Runtime::MockLoader> runtime_;
NiceMock<ThreadLocal::MockInstance> thread_local_;
NiceMock<Router::MockRateLimitPolicyEntry> route_rate_limit_;
NiceMock<Router::MockRateLimitPolicyEntry> vh_rate_limit_;
std::vector<RateLimit::Descriptor> descriptor_{{{{"descriptor_key", "descriptor_value"}}}};
Expand Down

0 comments on commit 59c2d1d

Please sign in to comment.