Skip to content

Commit

Permalink
feat: synchronous origin call on cache miss
Browse files Browse the repository at this point in the history
  • Loading branch information
chronark committed Aug 8, 2024
1 parent 86a6db9 commit d30965d
Show file tree
Hide file tree
Showing 7 changed files with 116 additions and 31 deletions.
22 changes: 17 additions & 5 deletions apps/agent/pkg/ratelimit/fixed_window.go
Original file line number Diff line number Diff line change
Expand Up @@ -82,16 +82,28 @@ func (r *fixedWindow) removeExpiredIdentifiers() {
}
}

func buildKey(identifier string, limit int64, duration time.Duration) string {
func buildKey(identifier string, duration time.Duration) string {
window := time.Now().UnixMilli() / duration.Milliseconds()
return fmt.Sprintf("ratelimit:%s:%d:%d", identifier, limit, window)
return fmt.Sprintf("ratelimit:%s:%d", identifier, window)
}

// Has returns true if there is already a record for the given identifier in the current window
func (r *fixedWindow) Has(ctx context.Context, identifier string, duration time.Duration) bool {
ctx, span := tracing.Start(ctx, "fixedWindow.Has")
defer span.End()
key := buildKey(identifier, duration)

r.identifiersLock.RLock(ctx)
_, ok := r.identifiers[key]
r.identifiersLock.RUnlock(ctx)
return ok
}

func (r *fixedWindow) Take(ctx context.Context, req RatelimitRequest) RatelimitResponse {
ctx, span := tracing.Start(ctx, tracing.NewSpanName("fixedWindow.Take", req.Name))
ctx, span := tracing.Start(ctx, "fixedWindow.Take")
defer span.End()

key := buildKey(req.Identifier, req.Limit, req.Duration)
key := buildKey(req.Identifier, req.Duration)
span.SetAttributes(attribute.String("key", key))

r.identifiersLock.RLock(ctx)
Expand Down Expand Up @@ -147,7 +159,7 @@ func (r *fixedWindow) Take(ctx context.Context, req RatelimitRequest) RatelimitR
func (r *fixedWindow) SetCurrent(ctx context.Context, req SetCurrentRequest) error {
ctx, span := tracing.Start(ctx, "fixedWindow.SetCurrent")
defer span.End()
key := buildKey(req.Identifier, req.Limit, req.Duration)
key := buildKey(req.Identifier, req.Duration)

r.identifiersLock.RLock(ctx)
id, ok := r.identifiers[req.Identifier]
Expand Down
1 change: 1 addition & 0 deletions apps/agent/pkg/ratelimit/interface.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ import (

type Ratelimiter interface {
Take(ctx context.Context, req RatelimitRequest) RatelimitResponse
Has(ctx context.Context, identifier string, duration time.Duration) bool
SetCurrent(ctx context.Context, req SetCurrentRequest) error
CommitLease(ctx context.Context, req CommitLeaseRequest) error
}
Expand Down
2 changes: 1 addition & 1 deletion apps/agent/services/ratelimit/commit_lease.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ func (s *service) CommitLease(ctx context.Context, req *ratelimitv1.CommitLeaseR
ctx, span := tracing.Start(ctx, "svc.ratelimit.CommitLease")
defer span.End()

key := ratelimitNodeKey(req.Lease.Identifier, req.Lease.Limit, req.Lease.Duration)
key := ratelimitNodeKey(req.Lease.Identifier, req.Lease.Duration)

origin, err := s.cluster.FindNode(key)
if err != nil {
Expand Down
6 changes: 3 additions & 3 deletions apps/agent/services/ratelimit/flush_push_pull.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,9 +8,9 @@ import (
ratelimitv1 "github.com/unkeyed/unkey/apps/agent/gen/proto/ratelimit/v1"
)

func ratelimitNodeKey(identifier string, limit int64, duration int64) string {
func ratelimitNodeKey(identifier string, duration int64) string {
window := time.Now().UnixMilli() / duration
return fmt.Sprintf("ratelimit:%s:%d:%d", identifier, window, limit)
return fmt.Sprintf("ratelimit:%s:%d", identifier, window)
}

func (s *service) aggregateByOrigin(ctx context.Context, events []*ratelimitv1.PushPullEvent) {
Expand All @@ -21,7 +21,7 @@ func (s *service) aggregateByOrigin(ctx context.Context, events []*ratelimitv1.P

eventsByKey := map[string][]*ratelimitv1.PushPullEvent{}
for _, e := range events {
key := ratelimitNodeKey(e.Identifier, e.Limit, e.Duration)
key := ratelimitNodeKey(e.Identifier, e.Duration)
_, ok := eventsByKey[key]
if !ok {
eventsByKey[key] = []*ratelimitv1.PushPullEvent{}
Expand Down
82 changes: 82 additions & 0 deletions apps/agent/services/ratelimit/ratelimit.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,16 +2,25 @@ package ratelimit

import (
"context"
"fmt"
"net/http"
"strings"
"time"

"connectrpc.com/connect"
"connectrpc.com/otelconnect"
ratelimitv1 "github.com/unkeyed/unkey/apps/agent/gen/proto/ratelimit/v1"
"github.com/unkeyed/unkey/apps/agent/gen/proto/ratelimit/v1/ratelimitv1connect"
"github.com/unkeyed/unkey/apps/agent/pkg/ratelimit"
"github.com/unkeyed/unkey/apps/agent/pkg/tracing"
"go.opentelemetry.io/otel/attribute"
)

func (s *service) Ratelimit(ctx context.Context, req *ratelimitv1.RatelimitRequest) (*ratelimitv1.RatelimitResponse, error) {

ctx, span := tracing.Start(ctx, "ratelimit.Ratelimit")
defer span.End()

ratelimitReq := ratelimit.RatelimitRequest{
Name: req.Name,
Identifier: req.Identifier,
Expand All @@ -25,6 +34,17 @@ func (s *service) Ratelimit(ctx context.Context, req *ratelimitv1.RatelimitReque
ExpiresAt: time.Now().Add(time.Duration(req.Lease.Timeout) * time.Millisecond),
}
}

if !s.ratelimiter.Has(ctx, ratelimitReq.Identifier, ratelimitReq.Duration) {
originRes, err := s.ratelimitOrigin(ctx, req)
if err != nil {
s.logger.Err(err).Msg("failed to call ratelimit origin")
}
if originRes != nil {
return originRes, nil
}
}

taken := s.ratelimiter.Take(ctx, ratelimitReq)

if s.batcher != nil {
Expand Down Expand Up @@ -63,3 +83,65 @@ func (s *service) Ratelimit(ctx context.Context, req *ratelimitv1.RatelimitReque
return res, nil

}

func (s *service) ratelimitOrigin(ctx context.Context, req *ratelimitv1.RatelimitRequest) (*ratelimitv1.RatelimitResponse, error) {
ctx, span := tracing.Start(ctx, "ratelimit.RatelimitOrigin")
defer span.End()

s.logger.Info().Str("identifier", req.Identifier).Msg("no local state found, syncing with origin")
key := ratelimitNodeKey(req.Identifier, req.Duration)
peer, err := s.cluster.FindNode(key)
if err != nil {
tracing.RecordError(span, err)
s.logger.Warn().Err(err).Str("key", key).Msg("unable to find responsible nodes")
return nil, err
}

if peer.Id == s.cluster.NodeId() {
return nil, nil
}

s.consistencyChecker.Record(key, peer.Id)

url := peer.RpcAddr
if !strings.Contains(url, "://") {
url = "http://" + url
}

s.peersMu.RLock()
c, ok := s.peers[url]
s.peersMu.RUnlock()
if !ok {
interceptor, err := otelconnect.NewInterceptor(otelconnect.WithTracerProvider(tracing.GetGlobalTraceProvider()))
if err != nil {
tracing.RecordError(span, err)
s.logger.Err(err).Msg("failed to create interceptor")
return nil, err
}
c = ratelimitv1connect.NewRatelimitServiceClient(http.DefaultClient, url, connect.WithInterceptors(interceptor))
s.peersMu.Lock()
s.peers[url] = c
s.peersMu.Unlock()
}

connectReq := connect.NewRequest(req)

connectReq.Header().Set("Authorization", fmt.Sprintf("Bearer %s", s.cluster.AuthToken()))

res, err := c.Ratelimit(ctx, connectReq)
if err != nil {
tracing.RecordError(span, err)
s.logger.Err(err).Msg("failed to call ratelimit")
return nil, err
}

s.ratelimiter.SetCurrent(ctx, ratelimit.SetCurrentRequest{
Identifier: req.Identifier,
Limit: req.Limit,
Duration: time.Duration(req.Duration) * time.Millisecond,
Current: res.Msg.Current,
})

return res.Msg, nil

}
17 changes: 5 additions & 12 deletions apps/api/src/routes/v1_ratelimit_limit.accuracy.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -25,15 +25,15 @@ const testCases: {
duration: 10000,
rps: 15,
seconds: 120,
expected: { min: 120, max: 600 },
expected: { min: 120, max: 150 },
},
{
name: "High Rate with Short Window",
limit: 20,
duration: 1000,
rps: 50,
seconds: 60,
expected: { min: 1200, max: 3000 },
expected: { min: 1200, max: 1500 },
},
{
name: "Constant Rate Equals Limit",
Expand All @@ -49,24 +49,16 @@ const testCases: {
duration: 10000,
rps: 100,
seconds: 30,
expected: { min: 1500, max: 3000 },
expected: { min: 1500, max: 2000 },
},
{
name: "Rate Higher Than Limit",
limit: 100,
duration: 5000,
rps: 200,
seconds: 120,
expected: { min: 2400, max: 6000 },
expected: { min: 2400, max: 3000 },
},
// {
// name: "Long Window",
// limit: 100,
// duration: 60000,
// rps: 3,
// seconds: 120,
// expected: { min: 200, max: 400 },
// },
];

for (const { name, limit, duration, rps, seconds, expected } of testCases) {
Expand Down Expand Up @@ -112,6 +104,7 @@ for (const { name, limit, duration, rps, seconds, expected } of testCases) {
const passed = results.reduce((sum, res) => {
return res.body.success ? sum + 1 : sum;
}, 0);
console.info({ name, passed });
t.expect(passed).toBeGreaterThanOrEqual(expected.min);
t.expect(passed).toBeLessThanOrEqual(expected.max);
},
Expand Down
17 changes: 7 additions & 10 deletions tools/artillery/keys.verifyKey.yaml
Original file line number Diff line number Diff line change
@@ -1,13 +1,9 @@
config:
target: https://api.unkey.dev
phases:
- name: Ramp up
duration: 1m
arrivalRate: 1
rampTo: 10
- name: Sustain
duration: 5m
arrivalRate: 10
duration: 10m
arrivalRate: 100
payload:
path: './.keys.csv'
fields:
Expand All @@ -34,11 +30,12 @@ scenarios:
capture:
- json: "$.valid"
as: valid
- json: "$.code"
as: code
expect:
- statusCode: 200
- contentType: json
- hasProperty: valid
- hasProperty:
- equals:
- "true"
- "{{ valid }}"

- "VALID"
- "{{ code }}"

0 comments on commit d30965d

Please sign in to comment.