Skip to content

Commit

Permalink
ratelimit: to supporty custom hits addend
Browse files Browse the repository at this point in the history
Signed-off-by: wangbaiping(wbpcode) <[email protected]>
  • Loading branch information
wbpcode committed Dec 10, 2024
1 parent dad9f3d commit 6944a80
Show file tree
Hide file tree
Showing 13 changed files with 725 additions and 487 deletions.
31 changes: 25 additions & 6 deletions api/envoy/config/route/v3/route_components.proto
Original file line number Diff line number Diff line change
Expand Up @@ -2169,10 +2169,10 @@ message RateLimit {
}
}

// [#not-implemented-hide:]
message HitsAddend {
// Fixed number of hits to add to the rate limit descriptor. Only one of the ``number`` or
// ``format`` fields can be set.
// Fixed number of hits to add to the rate limit descriptor.
//
// One of the ``number`` or ``format`` fields should be set but not both.
google.protobuf.UInt32Value number = 1 [(validate.rules).uint32 = {lte: 100000000}];

// Substitution format string to extract the number of hits to add to the rate limit descriptor.
Expand All @@ -2186,8 +2186,8 @@ message RateLimit {
// For example, the ``%BYTES_RECEIVED%`` format string will be replaced with the number of bytes
// received in the request.
//
// Only one of the ``number`` or ``format`` fields can be set.
string format = 2 [(validate.rules).string = {prefix: "%" suffix: "%" ignore_empty: true}];
// One of the ``number`` or ``format`` fields should be set but not both.
string format = 2 [(validate.rules).string = {prefix: "%", suffix: "%" ignore_empty: true}];
}

// Refers to the stage set in the filter. The rate limit configuration only
Expand All @@ -2197,9 +2197,19 @@ message RateLimit {
// .. note::
//
// The filter supports a range of 0 - 10 inclusively for stage numbers.
//
// .. note::
// This is not supported if the rate limit action is configured in the ``typed_per_filter_config`` like
// :ref:`VirtualHost.typed_per_filter_config<envoy_v3_api_field_config.route.v3.VirtualHost.typed_per_filter_config>` or
// :ref:`Route.typed_per_filter_config<envoy_v3_api_field_config_route.v3.Route.typed_per_filter_config>`, etc.
google.protobuf.UInt32Value stage = 1 [(validate.rules).uint32 = {lte: 10}];

// The key to be set in runtime to disable this rate limit configuration.
//
// .. note::
// This is not supported if the rate limit action is configured in the ``typed_per_filter_config`` like
// :ref:`VirtualHost.typed_per_filter_config<envoy_v3_api_field_config.route.v3.VirtualHost.typed_per_filter_config>` or
// :ref:`Route.typed_per_filter_config<envoy_v3_api_field_config_route.v3.Route.typed_per_filter_config>`, etc.
string disable_key = 2;

// A list of actions that are to be applied for this rate limit configuration.
Expand All @@ -2214,11 +2224,20 @@ message RateLimit {
// rate limit configuration. If the override value is invalid or cannot be resolved
// from metadata, no override is provided. See :ref:`rate limit override
// <config_http_filters_rate_limit_rate_limit_override>` for more information.
//
// .. note::
// This is not supported if the rate limit action is configured in the ``typed_per_filter_config`` like
// :ref:`VirtualHost.typed_per_filter_config<envoy_v3_api_field_config.route.v3.VirtualHost.typed_per_filter_config>` or
// :ref:`Route.typed_per_filter_config<envoy_v3_api_field_config_route.v3.Route.typed_per_filter_config>`, etc.
Override limit = 4;

// An optional hits addend to be appended to the descriptor produced by this rate limit
// configuration.
// [#not-implemented-hide:]
//
// .. note::
// This is only supported if the rate limit action is configured in the ``typed_per_filter_config`` like
// :ref:`VirtualHost.typed_per_filter_config<envoy_v3_api_field_config.route.v3.VirtualHost.typed_per_filter_config>` or
// :ref:`Route.typed_per_filter_config<envoy_v3_api_field_config.route.v3.Route.typed_per_filter_config>`, etc.
HitsAddend hits_addend = 5;
}

Expand Down
31 changes: 17 additions & 14 deletions envoy/ratelimit/ratelimit.h
Original file line number Diff line number Diff line change
Expand Up @@ -44,29 +44,21 @@ struct DescriptorEntry {
/**
* A single rate limit request descriptor. See ratelimit.proto.
*/
struct Descriptor {
std::vector<DescriptorEntry> entries_;
absl::optional<RateLimitOverride> limit_ = absl::nullopt;
};

/**
* A single rate limit request descriptor. See ratelimit.proto.
*/
struct LocalDescriptor {
struct DescriptorBase {
std::vector<DescriptorEntry> entries_;

friend bool operator==(const LocalDescriptor& a, const LocalDescriptor& b) {
friend bool operator==(const DescriptorBase& a, const DescriptorBase& b) {
return a.entries_ == b.entries_;
}
struct Hash {
using is_transparent = void; // NOLINT(readability-identifier-naming)
size_t operator()(const LocalDescriptor& d) const {
size_t operator()(const DescriptorBase& d) const {
return absl::Hash<std::vector<DescriptorEntry>>()(d.entries_);
}
};
struct Equal {
using is_transparent = void; // NOLINT(readability-identifier-naming)
size_t operator()(const LocalDescriptor& a, const LocalDescriptor& b) const {
size_t operator()(const DescriptorBase& a, const DescriptorBase& b) const {
return a.entries_ == b.entries_;
}
};
Expand All @@ -78,11 +70,22 @@ struct LocalDescriptor {
}

/**
* Local descriptor map.
* Descriptor map.
*/
template <class V> using Map = absl::flat_hash_map<LocalDescriptor, V, Hash, Equal>;
template <class V> using Map = absl::flat_hash_map<DescriptorBase, V, Hash, Equal>;
};

/**
* A single rate limit request descriptor. See ratelimit.proto.
* This is generated from the request based on the configured rate limit actions.
*/
struct Descriptor : public DescriptorBase {
absl::optional<RateLimitOverride> limit_ = absl::nullopt;
absl::optional<uint32_t> hits_addend_ = absl::nullopt;
};

using LocalDescriptor = DescriptorBase;

/*
* Base interface for generic rate limit descriptor producer.
*/
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -104,24 +104,24 @@ absl::optional<int64_t> TimerTokenBucket::remainingFillInterval() const {
absl::Seconds((time_after_last_fill) / 1s));
}

bool TimerTokenBucket::consume(double) {
bool TimerTokenBucket::consume(double, uint32_t to_consume) {
// Relaxed consistency is used for all operations because we don't care about ordering, just the
// final atomic correctness.
uint32_t expected_tokens = tokens_.load(std::memory_order_relaxed);
do {
// expected_tokens is either initialized above or reloaded during the CAS failure below.
if (expected_tokens == 0) {
if (expected_tokens < to_consume) {
return false;
}

// Testing hook.
parent_.synchronizer_.syncPoint("allowed_pre_cas");

// Loop while the weak CAS fails trying to subtract 1 from expected.
} while (!tokens_.compare_exchange_weak(expected_tokens, expected_tokens - 1,
// Loop while the weak CAS fails trying to subtract tokens from expected.
} while (!tokens_.compare_exchange_weak(expected_tokens, expected_tokens - to_consume,
std::memory_order_relaxed));

// We successfully decremented the counter by 1.
// We successfully decremented the counter by tokens.
return true;
}

Expand Down Expand Up @@ -163,9 +163,9 @@ AtomicTokenBucket::AtomicTokenBucket(uint32_t max_tokens, uint32_t tokens_per_fi
// Calculate the fill rate in tokens per second.
tokens_per_fill / std::chrono::duration<double>(fill_interval).count()) {}

bool AtomicTokenBucket::consume(double factor) {
bool AtomicTokenBucket::consume(double factor, uint32_t to_consume) {
ASSERT(!(factor <= 0.0 || factor > 1.0));
auto cb = [tokens = 1.0 / factor](double total) { return total < tokens ? 0.0 : tokens; };
auto cb = [tokens = to_consume / factor](double total) { return total < tokens ? 0.0 : tokens; };
return token_bucket_.consume(cb) != 0.0;
}

Expand Down Expand Up @@ -269,44 +269,49 @@ void LocalRateLimiterImpl::onFillTimer() {
fill_timer_->enableTimer(default_token_bucket_->fillInterval());
}

struct MatchResult {
std::reference_wrapper<RateLimitTokenBucket> token_bucket;
std::reference_wrapper<const RateLimit::Descriptor> request_descriptor;
};

LocalRateLimiterImpl::Result LocalRateLimiterImpl::requestAllowed(
absl::Span<const RateLimit::LocalDescriptor> request_descriptors) const {
absl::Span<const RateLimit::Descriptor> request_descriptors) const {

// In most cases the request descriptors has only few elements. We use a inlined vector to
// avoid heap allocation.
absl::InlinedVector<RateLimitTokenBucket*, 8> matched_descriptors;
absl::InlinedVector<MatchResult, 8> matched_results;

// Find all matched descriptors.
for (const auto& request_descriptor : request_descriptors) {
auto iter = descriptors_.find(request_descriptor);
if (iter != descriptors_.end()) {
matched_descriptors.push_back(iter->second.get());
matched_results.push_back(MatchResult{*iter->second, request_descriptor});
}
}

if (matched_descriptors.size() > 1) {
if (matched_results.size() > 1) {
// Sort the matched descriptors by token bucket fill rate to ensure the descriptor with the
// smallest fill rate is consumed first.
std::sort(matched_descriptors.begin(), matched_descriptors.end(),
[](const RateLimitTokenBucket* lhs, const RateLimitTokenBucket* rhs) {
return lhs->fillRate() < rhs->fillRate();
});
std::sort(matched_results.begin(), matched_results.end(), [](const auto& lhs, const auto& rhs) {
return lhs.token_bucket.get().fillRate() < rhs.token_bucket.get().fillRate();
});
}

const double share_factor =
share_provider_ != nullptr ? share_provider_->getTokensShareFactor() : 1.0;

// See if the request is forbidden by any of the matched descriptors.
for (auto descriptor : matched_descriptors) {
if (!descriptor->consume(share_factor)) {
for (auto match_result : matched_results) {
if (!match_result.token_bucket.get().consume(
share_factor, match_result.request_descriptor.get().hits_addend_.value_or(1))) {
// If the request is forbidden by a descriptor, return the result and the descriptor
// token bucket.
return {false, makeOptRefFromPtr<TokenBucketContext>(descriptor)};
return {false, makeOptRef<TokenBucketContext>(match_result.token_bucket.get())};
}
}

// See if the request is forbidden by the default token bucket.
if (matched_descriptors.empty() || always_consume_default_token_bucket_) {
if (matched_results.empty() || always_consume_default_token_bucket_) {
if (const bool result = default_token_bucket_->consume(share_factor); !result) {
// If the request is forbidden by the default token bucket, return the result and the
// default token bucket.
Expand All @@ -315,13 +320,13 @@ LocalRateLimiterImpl::Result LocalRateLimiterImpl::requestAllowed(

// If the request is allowed then return the result the token bucket. The descriptor
// token bucket will be selected as priority if it exists.
return {true, makeOptRefFromPtr<TokenBucketContext>(matched_descriptors.empty()
? default_token_bucket_.get()
: matched_descriptors[0])};
return {true, makeOptRef<TokenBucketContext>(matched_results.empty()
? *default_token_bucket_
: matched_results[0].token_bucket.get())};
};

ASSERT(!matched_descriptors.empty());
return {true, makeOptRefFromPtr<TokenBucketContext>(matched_descriptors[0])};
ASSERT(!matched_results.empty());
return {true, makeOptRef<TokenBucketContext>(matched_results[0].token_bucket.get())};
}

} // namespace LocalRateLimit
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,7 @@ class TokenBucketContext {

class RateLimitTokenBucket : public TokenBucketContext {
public:
virtual bool consume(double factor = 1.0) PURE;
virtual bool consume(double factor = 1.0, uint32_t tokens = 1) PURE;
virtual void onFillTimer(uint64_t refill_counter, double factor = 1.0) PURE;
virtual std::chrono::milliseconds fillInterval() const PURE;
virtual double fillRate() const PURE;
Expand All @@ -86,7 +86,7 @@ class TimerTokenBucket : public RateLimitTokenBucket {
LocalRateLimiterImpl& parent);

// RateLimitTokenBucket
bool consume(double factor) override;
bool consume(double factor = 1.0, uint32_t tokens = 1) override;
void onFillTimer(uint64_t refill_counter, double factor) override;
std::chrono::milliseconds fillInterval() const override { return fill_interval_; }
double fillRate() const override { return fill_rate_; }
Expand Down Expand Up @@ -115,7 +115,7 @@ class AtomicTokenBucket : public RateLimitTokenBucket {
std::chrono::milliseconds fill_interval, TimeSource& time_source);

// RateLimitTokenBucket
bool consume(double factor) override;
bool consume(double factor = 1.0, uint32_t tokens = 1) override;
void onFillTimer(uint64_t, double) override {}
std::chrono::milliseconds fillInterval() const override { return {}; }
double fillRate() const override { return token_bucket_.fillRate(); }
Expand Down Expand Up @@ -145,7 +145,7 @@ class LocalRateLimiterImpl {
ShareProviderSharedPtr shared_provider = nullptr);
~LocalRateLimiterImpl();

Result requestAllowed(absl::Span<const RateLimit::LocalDescriptor> request_descriptors) const;
Result requestAllowed(absl::Span<const RateLimit::Descriptor> request_descriptors) const;

private:
void onFillTimer();
Expand Down
2 changes: 2 additions & 0 deletions source/extensions/filters/common/ratelimit_config/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,8 @@ envoy_cc_library(
hdrs = ["ratelimit_config.h"],
deps = [
"//envoy/ratelimit:ratelimit_interface",
"//source/common/formatter:formatter_extension_lib",
"//source/common/formatter:substitution_formatter_lib",
"//source/common/router:router_ratelimit_lib",
"@com_google_absl//absl/container:inlined_vector",
"@com_google_absl//absl/strings",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,33 @@ namespace RateLimit {
RateLimitPolicy::RateLimitPolicy(const ProtoRateLimit& config,
Server::Configuration::CommonFactoryContext& context,
absl::Status& creation_status, bool no_limit) {
if (config.has_hits_addend()) {
if (!config.hits_addend().format().empty()) {
// Ensure only format or number is set.
if (config.hits_addend().has_number()) {
creation_status =
absl::InvalidArgumentError("hits_addend must contain either a format or a number");
return;
}

auto providers_or_error =
Formatter::SubstitutionFormatParser::parse(config.hits_addend().format());
SET_AND_RETURN_IF_NOT_OK(providers_or_error.status(), creation_status);
if (providers_or_error->size() != 1) {
creation_status =
absl::InvalidArgumentError("hits_addend format must contain exactly one substitution");
return;
}
hits_addend_provider_ = std::move(providers_or_error.value()[0]);
} else if (config.hits_addend().has_number()) {
hits_addend_ = config.hits_addend().number().value();
} else {
creation_status =
absl::InvalidArgumentError("hits_addend must contain either a format or a number");
return;
}
}

if (config.has_stage() || !config.disable_key().empty()) {
creation_status =
absl::InvalidArgumentError("'stage' field and 'disable_key' field are not supported");
Expand Down Expand Up @@ -101,7 +128,7 @@ void RateLimitPolicy::populateDescriptors(const Http::RequestHeaderMap& headers,
const StreamInfo::StreamInfo& stream_info,
const std::string& local_service_cluster,
RateLimitDescriptors& descriptors) const {
Envoy::RateLimit::LocalDescriptor descriptor;
Envoy::RateLimit::Descriptor descriptor;
for (const Envoy::RateLimit::DescriptorProducerPtr& action : actions_) {
Envoy::RateLimit::DescriptorEntry entry;
if (!action->populateDescriptor(entry, local_service_cluster, headers, stream_info)) {
Expand All @@ -111,6 +138,20 @@ void RateLimitPolicy::populateDescriptors(const Http::RequestHeaderMap& headers,
descriptor.entries_.emplace_back(std::move(entry));
}
}

// Populate hits_addend if set.
if (hits_addend_provider_ != nullptr) {
const ProtobufWkt::Value hits_addend =
hits_addend_provider_->formatValueWithContext({&headers}, stream_info);
if (hits_addend.has_number_value()) {
descriptor.hits_addend_ = static_cast<uint32_t>(hits_addend.number_value());
} else {
ENVOY_LOG(warn, "hits_addend must be a number");
}
} else if (hits_addend_.has_value()) {
descriptor.hits_addend_ = hits_addend_.value();
}

descriptors.emplace_back(std::move(descriptor));
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
#include "envoy/config/route/v3/route_components.pb.h"
#include "envoy/ratelimit/ratelimit.h"

#include "source/common/formatter/substitution_formatter.h"
#include "source/common/router/router_ratelimit.h"

#include "absl/container/inlined_vector.h"
Expand All @@ -14,7 +15,7 @@ namespace Common {
namespace RateLimit {

using ProtoRateLimit = envoy::config::route::v3::RateLimit;
using RateLimitDescriptors = std::vector<Envoy::RateLimit::LocalDescriptor>;
using RateLimitDescriptors = std::vector<Envoy::RateLimit::Descriptor>;

class RateLimitPolicy : Logger::Loggable<Envoy::Logger::Id::config> {
public:
Expand All @@ -28,6 +29,8 @@ class RateLimitPolicy : Logger::Loggable<Envoy::Logger::Id::config> {
RateLimitDescriptors& descriptors) const;

private:
Formatter::FormatterProviderPtr hits_addend_provider_;
absl::optional<uint32_t> hits_addend_;
std::vector<Envoy::RateLimit::DescriptorProducerPtr> actions_;
};

Expand Down
Loading

0 comments on commit 6944a80

Please sign in to comment.