diff --git a/Makefile b/Makefile index 84577c7..1de773f 100644 --- a/Makefile +++ b/Makefile @@ -11,7 +11,7 @@ ldflags= \ -X "github.com/mike76-dev/hostscore/internal/build.NodeBinaryName=hsd" \ -X "github.com/mike76-dev/hostscore/internal/build.NodeVersion=1.1.2" \ -X "github.com/mike76-dev/hostscore/internal/build.ClientBinaryName=hsc" \ --X "github.com/mike76-dev/hostscore/internal/build.ClientVersion=1.4.0" \ +-X "github.com/mike76-dev/hostscore/internal/build.ClientVersion=1.4.1" \ -X "github.com/mike76-dev/hostscore/internal/build.GitRevision=${GIT_DIRTY}${GIT_REVISION}" \ -X "github.com/mike76-dev/hostscore/internal/build.BuildTime=${BUILD_TIME}" diff --git a/cmd/hsc/api.go b/cmd/hsc/api.go index 18f51b2..2ff5ded 100644 --- a/cmd/hsc/api.go +++ b/cmd/hsc/api.go @@ -191,6 +191,7 @@ type portalAPI struct { stopChan chan struct{} averages map[string]map[string]networkAverages nodes map[string]nodeStatus + rl *ratelimiter } func newAPI(s *jsonStore, db *sql.DB, token string, logger *zap.Logger, cache *responseCache) (*portalAPI, error) { @@ -210,6 +211,8 @@ func newAPI(s *jsonStore, db *sql.DB, token string, logger *zap.Logger, cache *r api.hosts["mainnet"] = make(map[types.PublicKey]*portalHost) api.hosts["zen"] = make(map[types.PublicKey]*portalHost) + api.rl = newRatelimiter(api.stopChan) + err := api.load() if err != nil { return nil, err @@ -370,6 +373,10 @@ func (api *portalAPI) buildHTTPRoutes() { } func (api *portalAPI) hostsHostHandler(w http.ResponseWriter, req *http.Request, _ httprouter.Params) { + if api.rl.limitExceeded(getRemoteHost(req)) { + writeError(w, "too many requests", http.StatusTooManyRequests) + return + } network := strings.ToLower(req.FormValue("network")) if network == "" { network = "mainnet" @@ -406,6 +413,10 @@ func (api *portalAPI) hostsHostHandler(w http.ResponseWriter, req *http.Request, } func (api *portalAPI) hostsHandler(w http.ResponseWriter, req *http.Request, _ httprouter.Params) { + if api.rl.limitExceeded(getRemoteHost(req)) { + writeError(w, "too many requests", http.StatusTooManyRequests) + return + } network := strings.ToLower(req.FormValue("network")) if network == "" { network = "mainnet" @@ -513,6 +524,10 @@ func (api *portalAPI) hostsHandler(w http.ResponseWriter, req *http.Request, _ h } func (api *portalAPI) hostsKeysHandler(w http.ResponseWriter, req *http.Request, _ httprouter.Params) { + if api.rl.limitExceeded(getRemoteHost(req)) { + writeError(w, "too many requests", http.StatusTooManyRequests) + return + } err := req.ParseForm() if err != nil { writeError(w, "unable to parse request", http.StatusBadRequest) @@ -678,6 +693,10 @@ func (api *portalAPI) hostsKeysHandler(w http.ResponseWriter, req *http.Request, } func (api *portalAPI) hostsScansHandler(w http.ResponseWriter, req *http.Request, _ httprouter.Params) { + if api.rl.limitExceeded(getRemoteHost(req)) { + writeError(w, "too many requests", http.StatusTooManyRequests) + return + } network := strings.ToLower(req.FormValue("network")) if network == "" { network = "mainnet" @@ -752,6 +771,10 @@ func (api *portalAPI) hostsScansHandler(w http.ResponseWriter, req *http.Request } func (api *portalAPI) hostsBenchmarksHandler(w http.ResponseWriter, req *http.Request, _ httprouter.Params) { + if api.rl.limitExceeded(getRemoteHost(req)) { + writeError(w, "too many requests", http.StatusTooManyRequests) + return + } network := strings.ToLower(req.FormValue("network")) if network == "" { network = "mainnet" @@ -835,7 +858,11 @@ func balanceStatus(balance types.Currency) string { return "ok" } -func (api *portalAPI) serviceStatusHandler(w http.ResponseWriter, _ *http.Request, _ httprouter.Params) { +func (api *portalAPI) serviceStatusHandler(w http.ResponseWriter, req *http.Request, _ httprouter.Params) { + if api.rl.limitExceeded(getRemoteHost(req)) { + writeError(w, "too many requests", http.StatusTooManyRequests) + return + } writeJSON(w, statusResponse{ Version: build.ClientVersion, Nodes: api.nodes, @@ -843,6 +870,10 @@ func (api *portalAPI) serviceStatusHandler(w http.ResponseWriter, _ *http.Reques } func (api *portalAPI) networkHostsHandler(w http.ResponseWriter, req *http.Request, _ httprouter.Params) { + if api.rl.limitExceeded(getRemoteHost(req)) { + writeError(w, "too many requests", http.StatusTooManyRequests) + return + } network := strings.ToLower(req.FormValue("network")) if network == "" { network = "mainnet" @@ -864,6 +895,10 @@ func (api *portalAPI) networkHostsHandler(w http.ResponseWriter, req *http.Reque } func (api *portalAPI) hostsChangesHandler(w http.ResponseWriter, req *http.Request, _ httprouter.Params) { + if api.rl.limitExceeded(getRemoteHost(req)) { + writeError(w, "too many requests", http.StatusTooManyRequests) + return + } network := strings.ToLower(req.FormValue("network")) if network == "" { network = "mainnet" @@ -923,6 +958,10 @@ func (api *portalAPI) hostsChangesHandler(w http.ResponseWriter, req *http.Reque } func (api *portalAPI) networkAveragesHandler(w http.ResponseWriter, req *http.Request, _ httprouter.Params) { + if api.rl.limitExceeded(getRemoteHost(req)) { + writeError(w, "too many requests", http.StatusTooManyRequests) + return + } network := strings.ToLower(req.FormValue("network")) if network == "" { network = "mainnet" @@ -935,6 +974,10 @@ func (api *portalAPI) networkAveragesHandler(w http.ResponseWriter, req *http.Re } func (api *portalAPI) networkCountriesHandler(w http.ResponseWriter, req *http.Request, _ httprouter.Params) { + if api.rl.limitExceeded(getRemoteHost(req)) { + writeError(w, "too many requests", http.StatusTooManyRequests) + return + } network := strings.ToLower(req.FormValue("network")) if network == "" { network = "mainnet" diff --git a/cmd/hsc/ratelimit.go b/cmd/hsc/ratelimit.go new file mode 100644 index 0000000..67a5486 --- /dev/null +++ b/cmd/hsc/ratelimit.go @@ -0,0 +1,60 @@ +package main + +import ( + "net" + "net/http" + "sync" + "time" +) + +const maxRequestsPerSecond = 10 + +// ratelimiter keeps the API request stats and determines whether +// to allow the request or not. +type ratelimiter struct { + requests map[string]int + mu sync.Mutex +} + +func newRatelimiter(stopChan chan struct{}) *ratelimiter { + rl := &ratelimiter{ + requests: make(map[string]int), + } + + ticker := time.Tick(time.Second) + go func() { + for range ticker { + select { + case <-stopChan: + return + default: + } + rl.mu.Lock() + rl.requests = make(map[string]int) + rl.mu.Unlock() + } + }() + + return rl +} + +// limitExceeded returns true if there are too many requests from the given host. +func (rl *ratelimiter) limitExceeded(addr string) bool { + rl.mu.Lock() + defer rl.mu.Unlock() + + rl.requests[addr]++ + return rl.requests[addr] > maxRequestsPerSecond +} + +// getRemoteHost returns the address of the remote host. +func getRemoteHost(r *http.Request) (host string) { + host, _, _ = net.SplitHostPort(r.RemoteAddr) + if host == "127.0.0.1" || host == "localhost" { + xff := r.Header.Values("X-Forwarded-For") + if len(xff) > 0 { + host = xff[0] + } + } + return +}