From d30965dd862fe067b14d0ee419f5bc470706e431 Mon Sep 17 00:00:00 2001 From: chronark Date: Thu, 8 Aug 2024 16:43:34 +0200 Subject: [PATCH] feat: synchronous origin call on cache miss --- apps/agent/pkg/ratelimit/fixed_window.go | 22 +++-- apps/agent/pkg/ratelimit/interface.go | 1 + apps/agent/services/ratelimit/commit_lease.go | 2 +- .../services/ratelimit/flush_push_pull.go | 6 +- apps/agent/services/ratelimit/ratelimit.go | 82 +++++++++++++++++++ .../v1_ratelimit_limit.accuracy.test.ts | 17 ++-- tools/artillery/keys.verifyKey.yaml | 17 ++-- 7 files changed, 116 insertions(+), 31 deletions(-) diff --git a/apps/agent/pkg/ratelimit/fixed_window.go b/apps/agent/pkg/ratelimit/fixed_window.go index c5c88e6eea..e53a28b174 100644 --- a/apps/agent/pkg/ratelimit/fixed_window.go +++ b/apps/agent/pkg/ratelimit/fixed_window.go @@ -82,16 +82,28 @@ func (r *fixedWindow) removeExpiredIdentifiers() { } } -func buildKey(identifier string, limit int64, duration time.Duration) string { +func buildKey(identifier string, duration time.Duration) string { window := time.Now().UnixMilli() / duration.Milliseconds() - return fmt.Sprintf("ratelimit:%s:%d:%d", identifier, limit, window) + return fmt.Sprintf("ratelimit:%s:%d", identifier, window) +} + +// Has returns true if there is already a record for the given identifier in the current window +func (r *fixedWindow) Has(ctx context.Context, identifier string, duration time.Duration) bool { + ctx, span := tracing.Start(ctx, "fixedWindow.Has") + defer span.End() + key := buildKey(identifier, duration) + + r.identifiersLock.RLock(ctx) + _, ok := r.identifiers[key] + r.identifiersLock.RUnlock(ctx) + return ok } func (r *fixedWindow) Take(ctx context.Context, req RatelimitRequest) RatelimitResponse { - ctx, span := tracing.Start(ctx, tracing.NewSpanName("fixedWindow.Take", req.Name)) + ctx, span := tracing.Start(ctx, "fixedWindow.Take") defer span.End() - key := buildKey(req.Identifier, req.Limit, req.Duration) + key := buildKey(req.Identifier, req.Duration) span.SetAttributes(attribute.String("key", key)) r.identifiersLock.RLock(ctx) @@ -147,7 +159,7 @@ func (r *fixedWindow) Take(ctx context.Context, req RatelimitRequest) RatelimitR func (r *fixedWindow) SetCurrent(ctx context.Context, req SetCurrentRequest) error { ctx, span := tracing.Start(ctx, "fixedWindow.SetCurrent") defer span.End() - key := buildKey(req.Identifier, req.Limit, req.Duration) + key := buildKey(req.Identifier, req.Duration) r.identifiersLock.RLock(ctx) id, ok := r.identifiers[req.Identifier] diff --git a/apps/agent/pkg/ratelimit/interface.go b/apps/agent/pkg/ratelimit/interface.go index b1a03a958a..8c73a8e072 100644 --- a/apps/agent/pkg/ratelimit/interface.go +++ b/apps/agent/pkg/ratelimit/interface.go @@ -7,6 +7,7 @@ import ( type Ratelimiter interface { Take(ctx context.Context, req RatelimitRequest) RatelimitResponse + Has(ctx context.Context, identifier string, duration time.Duration) bool SetCurrent(ctx context.Context, req SetCurrentRequest) error CommitLease(ctx context.Context, req CommitLeaseRequest) error } diff --git a/apps/agent/services/ratelimit/commit_lease.go b/apps/agent/services/ratelimit/commit_lease.go index 2a52412900..905d4aef2c 100644 --- a/apps/agent/services/ratelimit/commit_lease.go +++ b/apps/agent/services/ratelimit/commit_lease.go @@ -20,7 +20,7 @@ func (s *service) CommitLease(ctx context.Context, req *ratelimitv1.CommitLeaseR ctx, span := tracing.Start(ctx, "svc.ratelimit.CommitLease") defer span.End() - key := ratelimitNodeKey(req.Lease.Identifier, req.Lease.Limit, req.Lease.Duration) + key := ratelimitNodeKey(req.Lease.Identifier, req.Lease.Duration) origin, err := s.cluster.FindNode(key) if err != nil { diff --git a/apps/agent/services/ratelimit/flush_push_pull.go b/apps/agent/services/ratelimit/flush_push_pull.go index b0af6f6d85..afc9a02d8b 100644 --- a/apps/agent/services/ratelimit/flush_push_pull.go +++ b/apps/agent/services/ratelimit/flush_push_pull.go @@ -8,9 +8,9 @@ import ( ratelimitv1 "github.com/unkeyed/unkey/apps/agent/gen/proto/ratelimit/v1" ) -func ratelimitNodeKey(identifier string, limit int64, duration int64) string { +func ratelimitNodeKey(identifier string, duration int64) string { window := time.Now().UnixMilli() / duration - return fmt.Sprintf("ratelimit:%s:%d:%d", identifier, window, limit) + return fmt.Sprintf("ratelimit:%s:%d", identifier, window) } func (s *service) aggregateByOrigin(ctx context.Context, events []*ratelimitv1.PushPullEvent) { @@ -21,7 +21,7 @@ func (s *service) aggregateByOrigin(ctx context.Context, events []*ratelimitv1.P eventsByKey := map[string][]*ratelimitv1.PushPullEvent{} for _, e := range events { - key := ratelimitNodeKey(e.Identifier, e.Limit, e.Duration) + key := ratelimitNodeKey(e.Identifier, e.Duration) _, ok := eventsByKey[key] if !ok { eventsByKey[key] = []*ratelimitv1.PushPullEvent{} diff --git a/apps/agent/services/ratelimit/ratelimit.go b/apps/agent/services/ratelimit/ratelimit.go index c92299a618..097860b5bf 100644 --- a/apps/agent/services/ratelimit/ratelimit.go +++ b/apps/agent/services/ratelimit/ratelimit.go @@ -2,9 +2,15 @@ package ratelimit import ( "context" + "fmt" + "net/http" + "strings" "time" + "connectrpc.com/connect" + "connectrpc.com/otelconnect" ratelimitv1 "github.com/unkeyed/unkey/apps/agent/gen/proto/ratelimit/v1" + "github.com/unkeyed/unkey/apps/agent/gen/proto/ratelimit/v1/ratelimitv1connect" "github.com/unkeyed/unkey/apps/agent/pkg/ratelimit" "github.com/unkeyed/unkey/apps/agent/pkg/tracing" "go.opentelemetry.io/otel/attribute" @@ -12,6 +18,9 @@ import ( func (s *service) Ratelimit(ctx context.Context, req *ratelimitv1.RatelimitRequest) (*ratelimitv1.RatelimitResponse, error) { + ctx, span := tracing.Start(ctx, "ratelimit.Ratelimit") + defer span.End() + ratelimitReq := ratelimit.RatelimitRequest{ Name: req.Name, Identifier: req.Identifier, @@ -25,6 +34,17 @@ func (s *service) Ratelimit(ctx context.Context, req *ratelimitv1.RatelimitReque ExpiresAt: time.Now().Add(time.Duration(req.Lease.Timeout) * time.Millisecond), } } + + if !s.ratelimiter.Has(ctx, ratelimitReq.Identifier, ratelimitReq.Duration) { + originRes, err := s.ratelimitOrigin(ctx, req) + if err != nil { + s.logger.Err(err).Msg("failed to call ratelimit origin") + } + if originRes != nil { + return originRes, nil + } + } + taken := s.ratelimiter.Take(ctx, ratelimitReq) if s.batcher != nil { @@ -63,3 +83,65 @@ func (s *service) Ratelimit(ctx context.Context, req *ratelimitv1.RatelimitReque return res, nil } + +func (s *service) ratelimitOrigin(ctx context.Context, req *ratelimitv1.RatelimitRequest) (*ratelimitv1.RatelimitResponse, error) { + ctx, span := tracing.Start(ctx, "ratelimit.RatelimitOrigin") + defer span.End() + + s.logger.Info().Str("identifier", req.Identifier).Msg("no local state found, syncing with origin") + key := ratelimitNodeKey(req.Identifier, req.Duration) + peer, err := s.cluster.FindNode(key) + if err != nil { + tracing.RecordError(span, err) + s.logger.Warn().Err(err).Str("key", key).Msg("unable to find responsible nodes") + return nil, err + } + + if peer.Id == s.cluster.NodeId() { + return nil, nil + } + + s.consistencyChecker.Record(key, peer.Id) + + url := peer.RpcAddr + if !strings.Contains(url, "://") { + url = "http://" + url + } + + s.peersMu.RLock() + c, ok := s.peers[url] + s.peersMu.RUnlock() + if !ok { + interceptor, err := otelconnect.NewInterceptor(otelconnect.WithTracerProvider(tracing.GetGlobalTraceProvider())) + if err != nil { + tracing.RecordError(span, err) + s.logger.Err(err).Msg("failed to create interceptor") + return nil, err + } + c = ratelimitv1connect.NewRatelimitServiceClient(http.DefaultClient, url, connect.WithInterceptors(interceptor)) + s.peersMu.Lock() + s.peers[url] = c + s.peersMu.Unlock() + } + + connectReq := connect.NewRequest(req) + + connectReq.Header().Set("Authorization", fmt.Sprintf("Bearer %s", s.cluster.AuthToken())) + + res, err := c.Ratelimit(ctx, connectReq) + if err != nil { + tracing.RecordError(span, err) + s.logger.Err(err).Msg("failed to call ratelimit") + return nil, err + } + + s.ratelimiter.SetCurrent(ctx, ratelimit.SetCurrentRequest{ + Identifier: req.Identifier, + Limit: req.Limit, + Duration: time.Duration(req.Duration) * time.Millisecond, + Current: res.Msg.Current, + }) + + return res.Msg, nil + +} diff --git a/apps/api/src/routes/v1_ratelimit_limit.accuracy.test.ts b/apps/api/src/routes/v1_ratelimit_limit.accuracy.test.ts index d4033c7e56..d33e738345 100644 --- a/apps/api/src/routes/v1_ratelimit_limit.accuracy.test.ts +++ b/apps/api/src/routes/v1_ratelimit_limit.accuracy.test.ts @@ -25,7 +25,7 @@ const testCases: { duration: 10000, rps: 15, seconds: 120, - expected: { min: 120, max: 600 }, + expected: { min: 120, max: 150 }, }, { name: "High Rate with Short Window", @@ -33,7 +33,7 @@ const testCases: { duration: 1000, rps: 50, seconds: 60, - expected: { min: 1200, max: 3000 }, + expected: { min: 1200, max: 1500 }, }, { name: "Constant Rate Equals Limit", @@ -49,7 +49,7 @@ const testCases: { duration: 10000, rps: 100, seconds: 30, - expected: { min: 1500, max: 3000 }, + expected: { min: 1500, max: 2000 }, }, { name: "Rate Higher Than Limit", @@ -57,16 +57,8 @@ const testCases: { duration: 5000, rps: 200, seconds: 120, - expected: { min: 2400, max: 6000 }, + expected: { min: 2400, max: 3000 }, }, - // { - // name: "Long Window", - // limit: 100, - // duration: 60000, - // rps: 3, - // seconds: 120, - // expected: { min: 200, max: 400 }, - // }, ]; for (const { name, limit, duration, rps, seconds, expected } of testCases) { @@ -112,6 +104,7 @@ for (const { name, limit, duration, rps, seconds, expected } of testCases) { const passed = results.reduce((sum, res) => { return res.body.success ? sum + 1 : sum; }, 0); + console.info({ name, passed }); t.expect(passed).toBeGreaterThanOrEqual(expected.min); t.expect(passed).toBeLessThanOrEqual(expected.max); }, diff --git a/tools/artillery/keys.verifyKey.yaml b/tools/artillery/keys.verifyKey.yaml index 7f94661f38..d2c277ba57 100644 --- a/tools/artillery/keys.verifyKey.yaml +++ b/tools/artillery/keys.verifyKey.yaml @@ -1,13 +1,9 @@ config: target: https://api.unkey.dev phases: - - name: Ramp up - duration: 1m - arrivalRate: 1 - rampTo: 10 - name: Sustain - duration: 5m - arrivalRate: 10 + duration: 10m + arrivalRate: 100 payload: path: './.keys.csv' fields: @@ -34,11 +30,12 @@ scenarios: capture: - json: "$.valid" as: valid + - json: "$.code" + as: code expect: - statusCode: 200 - contentType: json - - hasProperty: valid + - hasProperty: - equals: - - "true" - - "{{ valid }}" - \ No newline at end of file + - "VALID" + - "{{ code }}"