Skip to content

Commit

Permalink
refactor EndpointsRateLimitInfo
Browse files Browse the repository at this point in the history
  • Loading branch information
jeffy-mathew committed Aug 20, 2024
1 parent c0df1bf commit 7c03309
Show file tree
Hide file tree
Showing 5 changed files with 161 additions and 147 deletions.
2 changes: 1 addition & 1 deletion gateway/policy.go
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ type DBAccessDefinition struct {
Limit *user.APILimit `json:"limit"`

// Endpoints contains endpoint rate limit settings.
Endpoints []user.Endpoint `json:"endpoints,omitempty"`
Endpoints user.Endpoints `json:"endpoints,omitempty"`
}

func (d *DBAccessDefinition) ToRegularAD() user.AccessDefinition {
Expand Down
41 changes: 4 additions & 37 deletions gateway/session_manager.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,11 +5,8 @@ import (
"errors"
"fmt"
"net/http"
"strings"
"time"

"github.com/TykTechnologies/tyk/regexp"

"github.com/TykTechnologies/drl"
"github.com/TykTechnologies/leakybucket"
"github.com/TykTechnologies/leakybucket/memorycache"
Expand Down Expand Up @@ -243,11 +240,11 @@ func (l *SessionLimiter) ForwardMessage(r *http.Request, session *user.SessionSt
)

if len(accessDef.Endpoints) > 0 {
endpointRLInfo, doEndpointRL := getEndpointRateLimitInfo(r.Method, r.URL.Path, accessDef.Endpoints)
endpointRLInfo, doEndpointRL := accessDef.Endpoints.RateLimitInfo(r.Method, r.URL.Path)
if doEndpointRL {
apiLimit.Rate = endpointRLInfo.rate
apiLimit.Per = endpointRLInfo.per
endpointRLKeySuffix = endpointRLInfo.keySuffix
apiLimit.Rate = endpointRLInfo.Rate
apiLimit.Per = endpointRLInfo.Per
endpointRLKeySuffix = endpointRLInfo.KeySuffix
}
}

Expand Down Expand Up @@ -442,33 +439,3 @@ func GetAccessDefinitionByAPIIDOrSession(session *user.SessionState, api *APISpe

return accessDef, allowanceScope, nil
}

type endpointRateLimitInfo struct {
keySuffix string
rate float64
per float64
}

func getEndpointRateLimitInfo(method string, path string, endpoints []user.Endpoint) (*endpointRateLimitInfo, bool) {
for _, endpoint := range endpoints {
asRegex, err := regexp.Compile(endpoint.Path)
if err != nil {
return nil, false
}

match := asRegex.MatchString(path)
if match {
for _, endpointMethod := range endpoint.Methods {
if strings.ToUpper(endpointMethod.Name) == strings.ToUpper(method) {
return &endpointRateLimitInfo{
keySuffix: storage.HashStr(fmt.Sprintf("%s:%s", endpointMethod.Name, endpoint.Path)),
rate: endpointMethod.Limit.Rate,
per: endpointMethod.Limit.Per,
}, true
}
}
}
}

return nil, false
}
106 changes: 0 additions & 106 deletions gateway/session_manager_test.go
Original file line number Diff line number Diff line change
@@ -1,11 +1,8 @@
package gateway

import (
"net/http"
"testing"

"github.com/TykTechnologies/tyk/storage"

"github.com/stretchr/testify/assert"

"github.com/TykTechnologies/tyk/apidef"
Expand Down Expand Up @@ -143,106 +140,3 @@ func TestGetAccessDefinitionByAPIIDOrSession(t *testing.T) {
assert.NoError(t, err)
})
}

func TestGetEndpointRateLimitInfo(t *testing.T) {
tests := []struct {
name string
method string
path string
endpoints []user.Endpoint
expected *endpointRateLimitInfo
found bool
}{
{
name: "Matching endpoint and method",
method: http.MethodGet,
path: "/api/v1/users",
endpoints: []user.Endpoint{
{
Path: "/api/v1/users",
Methods: []user.EndpointMethod{
{Name: "GET", Limit: user.RateLimit{Rate: 100, Per: 60}},
},
},
},
expected: &endpointRateLimitInfo{
keySuffix: storage.HashStr("GET:/api/v1/users"),
rate: 100,
per: 60,
},
found: true,
},
{
name: "Matching endpoint, non-matching method",
path: "/api/v1/users",
method: http.MethodPost,
endpoints: []user.Endpoint{
{
Path: "/api/v1/users",
Methods: []user.EndpointMethod{
{Name: "GET", Limit: user.RateLimit{Rate: 100, Per: 60}},
},
},
},
expected: nil,
found: false,
},
{
name: "Non-matching endpoint",
method: http.MethodGet,
path: "/api/v1/products",
endpoints: []user.Endpoint{
{
Path: "/api/v1/users",
Methods: []user.EndpointMethod{
{Name: "GET", Limit: user.RateLimit{Rate: 100, Per: 60}},
},
},
},
expected: nil,
found: false,
},
{
name: "Regex path matching",
path: "/api/v1/users/123",
method: http.MethodGet,
endpoints: []user.Endpoint{
{
Path: "/api/v1/users/[0-9]+",
Methods: []user.EndpointMethod{
{Name: "GET", Limit: user.RateLimit{Rate: 50, Per: 30}},
},
},
},
expected: &endpointRateLimitInfo{
keySuffix: storage.HashStr("GET:/api/v1/users/[0-9]+"),
rate: 50,
per: 30,
},
found: true,
},
{
name: "Invalid regex path",
path: "/api/v1/users",
method: http.MethodGet,
endpoints: []user.Endpoint{
{
Path: "[invalid regex",
Methods: []user.EndpointMethod{
{Name: "GET", Limit: user.RateLimit{Rate: 100, Per: 60}},
},
},
},
expected: nil,
found: false,
},
}

for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result, found := getEndpointRateLimitInfo(tt.method, tt.path, tt.endpoints)
assert.Equal(t, tt.found, found)
assert.Equal(t, tt.expected, result)
})
}
}
53 changes: 50 additions & 3 deletions user/session.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,12 @@ package user
import (
"crypto/md5"
"fmt"
"strings"
"time"

"github.com/TykTechnologies/tyk/regexp"
"github.com/TykTechnologies/tyk/storage"

"github.com/TykTechnologies/graphql-go-tools/pkg/graphql"

"github.com/TykTechnologies/tyk/apidef"
Expand Down Expand Up @@ -110,7 +114,7 @@ type AccessDefinition struct {

AllowanceScope string `json:"allowance_scope" msg:"allowance_scope"`

Endpoints []Endpoint `json:"endpoints,omitempty" msg:"endpoints,omitempty"`
Endpoints Endpoints `json:"endpoints,omitempty" msg:"endpoints,omitempty"`
}

// IsEmpty checks if APILimit is empty.
Expand Down Expand Up @@ -181,11 +185,19 @@ type Monitor struct {
TriggerLimits []float64 `json:"trigger_limits" msg:"trigger_limits"`
}

// Endpoints is a collection of Endpoint.
type Endpoints []Endpoint

// Endpoint holds the configuration for endpoint rate limiting.
type Endpoint struct {
Path string `json:"path,omitempty" msg:"path"`
Methods []EndpointMethod `json:"methods,omitempty" msg:"methods"`
Path string `json:"path,omitempty" msg:"path"`
Methods EndpointMethods `json:"methods,omitempty" msg:"methods"`
}

// EndpointMethods is a collection of EndpointMethod.
type EndpointMethods []EndpointMethod

// EndpointMethod holds the configuration on endpoint method level.
type EndpointMethod struct {
Name string `json:"name,omitempty" msg:"name,omitempty"`
Limit RateLimit `json:"limit,omitempty" msg:"limit,omitempty"`
Expand Down Expand Up @@ -464,3 +476,38 @@ func (s *SessionState) GetQuotaLimitByAPIID(apiID string) (int64, int64, int64,
func (s *SessionState) IsBasicAuth() bool {
return s.BasicAuthData.Password != ""
}

// EndpointRateLimitInfo holds the information to process endpoint rate limits.
type EndpointRateLimitInfo struct {
// KeySuffix is the suffix to use for the storage key.
KeySuffix string
// Rate is the allowance.
Rate float64
// Per is the rate limiting interval.
Per float64
}

// RateLimitInfo returns EndpointRateLimitInfo for endpoint rate limiting.
func (es Endpoints) RateLimitInfo(method string, path string) (*EndpointRateLimitInfo, bool) {
for _, endpoint := range es {
asRegex, err := regexp.Compile(endpoint.Path)
if err != nil {
return nil, false
}

match := asRegex.MatchString(path)
if match {
for _, endpointMethod := range endpoint.Methods {
if strings.ToUpper(endpointMethod.Name) == strings.ToUpper(method) {
return &EndpointRateLimitInfo{
KeySuffix: storage.HashStr(fmt.Sprintf("%s:%s", endpointMethod.Name, endpoint.Path)),
Rate: endpointMethod.Limit.Rate,
Per: endpointMethod.Limit.Per,
}, true
}
}
}
}

return nil, false
}
106 changes: 106 additions & 0 deletions user/session_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,13 @@ package user

import (
"encoding/json"
"net/http"
"reflect"
"testing"
"time"

"github.com/TykTechnologies/tyk/storage"

"github.com/TykTechnologies/tyk/apidef"

"github.com/stretchr/testify/assert"
Expand Down Expand Up @@ -373,3 +376,106 @@ func TestAPILimit_Clone(t *testing.T) {
})
}
}

func TestEndpoints_RateLimitInfo(t *testing.T) {
tests := []struct {
name string
method string
path string
endpoints Endpoints
expected *EndpointRateLimitInfo
found bool
}{
{
name: "Matching endpoint and method",
method: http.MethodGet,
path: "/api/v1/users",
endpoints: Endpoints{
{
Path: "/api/v1/users",
Methods: []EndpointMethod{
{Name: "GET", Limit: RateLimit{Rate: 100, Per: 60}},
},
},
},
expected: &EndpointRateLimitInfo{
KeySuffix: storage.HashStr("GET:/api/v1/users"),
Rate: 100,
Per: 60,
},
found: true,
},
{
name: "Matching endpoint, non-matching method",
path: "/api/v1/users",
method: http.MethodPost,
endpoints: []Endpoint{
{
Path: "/api/v1/users",
Methods: []EndpointMethod{
{Name: "GET", Limit: RateLimit{Rate: 100, Per: 60}},
},
},
},
expected: nil,
found: false,
},
{
name: "Non-matching endpoint",
method: http.MethodGet,
path: "/api/v1/products",
endpoints: []Endpoint{
{
Path: "/api/v1/users",
Methods: []EndpointMethod{
{Name: "GET", Limit: RateLimit{Rate: 100, Per: 60}},
},
},
},
expected: nil,
found: false,
},
{
name: "Regex path matching",
path: "/api/v1/users/123",
method: http.MethodGet,
endpoints: []Endpoint{
{
Path: "/api/v1/users/[0-9]+",
Methods: []EndpointMethod{
{Name: "GET", Limit: RateLimit{Rate: 50, Per: 30}},
},
},
},
expected: &EndpointRateLimitInfo{
KeySuffix: storage.HashStr("GET:/api/v1/users/[0-9]+"),
Rate: 50,
Per: 30,
},
found: true,
},
{
name: "Invalid regex path",
path: "/api/v1/users",
method: http.MethodGet,
endpoints: []Endpoint{
{
Path: "[invalid regex",
Methods: []EndpointMethod{
{Name: "GET", Limit: RateLimit{Rate: 100, Per: 60}},
},
},
},
expected: nil,
found: false,
},
}

for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result, found := tt.endpoints.RateLimitInfo(tt.method, tt.path)
assert.Equal(t, tt.found, found)
assert.Equal(t, tt.expected, result)
})
}
}

0 comments on commit 7c03309

Please sign in to comment.