Skip to content

Commit

Permalink
Implement ratelimiting
Browse files Browse the repository at this point in the history
  • Loading branch information
mike76-dev committed May 7, 2024
1 parent 7655156 commit 4f0ea63
Show file tree
Hide file tree
Showing 3 changed files with 105 additions and 2 deletions.
2 changes: 1 addition & 1 deletion Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ ldflags= \
-X "github.com/mike76-dev/hostscore/internal/build.NodeBinaryName=hsd" \
-X "github.com/mike76-dev/hostscore/internal/build.NodeVersion=1.1.2" \
-X "github.com/mike76-dev/hostscore/internal/build.ClientBinaryName=hsc" \
-X "github.com/mike76-dev/hostscore/internal/build.ClientVersion=1.4.0" \
-X "github.com/mike76-dev/hostscore/internal/build.ClientVersion=1.4.1" \
-X "github.com/mike76-dev/hostscore/internal/build.GitRevision=${GIT_DIRTY}${GIT_REVISION}" \
-X "github.com/mike76-dev/hostscore/internal/build.BuildTime=${BUILD_TIME}"

Expand Down
45 changes: 44 additions & 1 deletion cmd/hsc/api.go
Original file line number Diff line number Diff line change
Expand Up @@ -191,6 +191,7 @@ type portalAPI struct {
stopChan chan struct{}
averages map[string]map[string]networkAverages
nodes map[string]nodeStatus
rl *ratelimiter
}

func newAPI(s *jsonStore, db *sql.DB, token string, logger *zap.Logger, cache *responseCache) (*portalAPI, error) {
Expand All @@ -210,6 +211,8 @@ func newAPI(s *jsonStore, db *sql.DB, token string, logger *zap.Logger, cache *r
api.hosts["mainnet"] = make(map[types.PublicKey]*portalHost)
api.hosts["zen"] = make(map[types.PublicKey]*portalHost)

api.rl = newRatelimiter(api.stopChan)

err := api.load()
if err != nil {
return nil, err
Expand Down Expand Up @@ -370,6 +373,10 @@ func (api *portalAPI) buildHTTPRoutes() {
}

func (api *portalAPI) hostsHostHandler(w http.ResponseWriter, req *http.Request, _ httprouter.Params) {
if api.rl.limitExceeded(getRemoteHost(req)) {
writeError(w, "too many requests", http.StatusTooManyRequests)
return
}
network := strings.ToLower(req.FormValue("network"))
if network == "" {
network = "mainnet"
Expand Down Expand Up @@ -406,6 +413,10 @@ func (api *portalAPI) hostsHostHandler(w http.ResponseWriter, req *http.Request,
}

func (api *portalAPI) hostsHandler(w http.ResponseWriter, req *http.Request, _ httprouter.Params) {
if api.rl.limitExceeded(getRemoteHost(req)) {
writeError(w, "too many requests", http.StatusTooManyRequests)
return
}
network := strings.ToLower(req.FormValue("network"))
if network == "" {
network = "mainnet"
Expand Down Expand Up @@ -513,6 +524,10 @@ func (api *portalAPI) hostsHandler(w http.ResponseWriter, req *http.Request, _ h
}

func (api *portalAPI) hostsKeysHandler(w http.ResponseWriter, req *http.Request, _ httprouter.Params) {
if api.rl.limitExceeded(getRemoteHost(req)) {
writeError(w, "too many requests", http.StatusTooManyRequests)
return
}
err := req.ParseForm()
if err != nil {
writeError(w, "unable to parse request", http.StatusBadRequest)
Expand Down Expand Up @@ -678,6 +693,10 @@ func (api *portalAPI) hostsKeysHandler(w http.ResponseWriter, req *http.Request,
}

func (api *portalAPI) hostsScansHandler(w http.ResponseWriter, req *http.Request, _ httprouter.Params) {
if api.rl.limitExceeded(getRemoteHost(req)) {
writeError(w, "too many requests", http.StatusTooManyRequests)
return
}
network := strings.ToLower(req.FormValue("network"))
if network == "" {
network = "mainnet"
Expand Down Expand Up @@ -752,6 +771,10 @@ func (api *portalAPI) hostsScansHandler(w http.ResponseWriter, req *http.Request
}

func (api *portalAPI) hostsBenchmarksHandler(w http.ResponseWriter, req *http.Request, _ httprouter.Params) {
if api.rl.limitExceeded(getRemoteHost(req)) {
writeError(w, "too many requests", http.StatusTooManyRequests)
return
}
network := strings.ToLower(req.FormValue("network"))
if network == "" {
network = "mainnet"
Expand Down Expand Up @@ -835,14 +858,22 @@ func balanceStatus(balance types.Currency) string {
return "ok"
}

func (api *portalAPI) serviceStatusHandler(w http.ResponseWriter, _ *http.Request, _ httprouter.Params) {
func (api *portalAPI) serviceStatusHandler(w http.ResponseWriter, req *http.Request, _ httprouter.Params) {
if api.rl.limitExceeded(getRemoteHost(req)) {
writeError(w, "too many requests", http.StatusTooManyRequests)
return
}
writeJSON(w, statusResponse{
Version: build.ClientVersion,
Nodes: api.nodes,
})
}

func (api *portalAPI) networkHostsHandler(w http.ResponseWriter, req *http.Request, _ httprouter.Params) {
if api.rl.limitExceeded(getRemoteHost(req)) {
writeError(w, "too many requests", http.StatusTooManyRequests)
return
}
network := strings.ToLower(req.FormValue("network"))
if network == "" {
network = "mainnet"
Expand All @@ -864,6 +895,10 @@ func (api *portalAPI) networkHostsHandler(w http.ResponseWriter, req *http.Reque
}

func (api *portalAPI) hostsChangesHandler(w http.ResponseWriter, req *http.Request, _ httprouter.Params) {
if api.rl.limitExceeded(getRemoteHost(req)) {
writeError(w, "too many requests", http.StatusTooManyRequests)
return
}
network := strings.ToLower(req.FormValue("network"))
if network == "" {
network = "mainnet"
Expand Down Expand Up @@ -923,6 +958,10 @@ func (api *portalAPI) hostsChangesHandler(w http.ResponseWriter, req *http.Reque
}

func (api *portalAPI) networkAveragesHandler(w http.ResponseWriter, req *http.Request, _ httprouter.Params) {
if api.rl.limitExceeded(getRemoteHost(req)) {
writeError(w, "too many requests", http.StatusTooManyRequests)
return
}
network := strings.ToLower(req.FormValue("network"))
if network == "" {
network = "mainnet"
Expand All @@ -935,6 +974,10 @@ func (api *portalAPI) networkAveragesHandler(w http.ResponseWriter, req *http.Re
}

func (api *portalAPI) networkCountriesHandler(w http.ResponseWriter, req *http.Request, _ httprouter.Params) {
if api.rl.limitExceeded(getRemoteHost(req)) {
writeError(w, "too many requests", http.StatusTooManyRequests)
return
}
network := strings.ToLower(req.FormValue("network"))
if network == "" {
network = "mainnet"
Expand Down
60 changes: 60 additions & 0 deletions cmd/hsc/ratelimit.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
package main

import (
"net"
"net/http"
"sync"
"time"
)

const maxRequestsPerSecond = 10

// ratelimiter keeps the API request stats and determines whether
// to allow the request or not.
type ratelimiter struct {
requests map[string]int
mu sync.Mutex
}

func newRatelimiter(stopChan chan struct{}) *ratelimiter {
rl := &ratelimiter{
requests: make(map[string]int),
}

ticker := time.Tick(time.Second)
go func() {
for range ticker {
select {
case <-stopChan:
return
default:
}
rl.mu.Lock()
rl.requests = make(map[string]int)
rl.mu.Unlock()
}
}()

return rl
}

// limitExceeded returns true if there are too many requests from the given host.
func (rl *ratelimiter) limitExceeded(addr string) bool {
rl.mu.Lock()
defer rl.mu.Unlock()

rl.requests[addr]++
return rl.requests[addr] > maxRequestsPerSecond
}

// getRemoteHost returns the address of the remote host.
func getRemoteHost(r *http.Request) (host string) {
host, _, _ = net.SplitHostPort(r.RemoteAddr)
if host == "127.0.0.1" || host == "localhost" {
xff := r.Header.Values("X-Forwarded-For")
if len(xff) > 0 {
host = xff[0]
}
}
return
}

0 comments on commit 4f0ea63

Please sign in to comment.