Skip to content

Commit

Permalink
use status code to check if there are config chagnes
Browse files Browse the repository at this point in the history
  • Loading branch information
garmr-ulfr committed Aug 14, 2024
1 parent 9e41c49 commit 35ce476
Show file tree
Hide file tree
Showing 3 changed files with 38 additions and 39 deletions.
3 changes: 2 additions & 1 deletion services/bypass.go
Original file line number Diff line number Diff line change
Expand Up @@ -250,7 +250,8 @@ func (p *proxy) sendToBypass() (int64, error) {
return 0, err
}

io.Copy(io.Discard, resp)
io.Copy(io.Discard, resp.Body)
resp.Body.Close()
return sleep, nil
}

Expand Down
18 changes: 7 additions & 11 deletions services/config.go
Original file line number Diff line number Diff line change
Expand Up @@ -165,8 +165,7 @@ func (cs *configService) fetchConfig() (int64, error) {
cs.lastFetched = time.Now()

logger.Debug("configservice: Received config")
curConf := cs.configHandler.GetConfig()
if curConf != nil && !configIsNew(newConf) {
if newConf == nil {
op.Set("config_changed", false)
logger.Debug("configservice: Config is unchanged")
return sleep, nil
Expand Down Expand Up @@ -199,9 +198,13 @@ func (cs *configService) fetch() (*apipb.ConfigResponse, int64, error) {
if err != nil {
return nil, 0, fmt.Errorf("config request failed: %w", err)
}
defer resp.Close()

configBytes, err := io.ReadAll(resp)
if resp.StatusCode != http.StatusNoContent {
return nil, 0, nil // no config changes
}

configBytes, err := io.ReadAll(resp.Body)
resp.Body.Close()
if err != nil {
return nil, 0, fmt.Errorf("unable to read config response: %w", err)
}
Expand Down Expand Up @@ -238,10 +241,3 @@ func (cs *configService) newRequest() *apipb.ConfigRequest {

return confReq
}

// configIsNew returns true if any fields contain values.
func configIsNew(new *apipb.ConfigResponse) bool {
// We only need to check if the fields we're interested in contain values because the server
// will only send us new values if they have changed.
return new.Country != "" || new.ProToken != "" || len(new.Proxy.Proxies) > 0
}
56 changes: 29 additions & 27 deletions services/http.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,39 +12,53 @@ import (
)

const (
// retryWaitMillis is the base wait time in milliseconds between retries
retryWaitMillis = 100
maxRetryWait = 10 * time.Minute
)

// sender is a helper for sending post requests. If the request fails, sender calulates an
// exponential backoff time and return it as the sleep time.
// exponential backoff time using retryWaitMillis and return it as the sleep time.
type sender struct {
failCount int
atMaxRetryWait bool
}

// post posts data to the specified URL and returns the response body, as a ReadCloser, the sleep
// time in seconds, and any error that occurred.
// post posts data to the specified URL and returns the response, the sleep time in seconds, and any
// error that occurred.
//
// Note: it is the responsibility of the caller to read the ReadCloser to completion and close it.
// Note: if the request is successful, it is the responsibility of the caller to read the response
// body to completion and close it.
func (s *sender) post(
originURL string,
buf io.Reader,
rt http.RoundTripper,
user common.UserConfig,
) (io.ReadCloser, int64, error) {
reader, sleepVal, err := s.post(originURL, buf, rt, user)
) (*http.Response, int64, error) {
resp, err := s.doPost(originURL, buf, rt, user)
if err == nil {
s.failCount = 0
s.atMaxRetryWait = false
return reader, sleepVal, nil

if resp.StatusCode != http.StatusOK || resp.StatusCode != http.StatusNoContent {
return nil, 0, fmt.Errorf("bad response code: %v", resp.StatusCode)
}

var sleepTime int64
if sleepVal := resp.Header.Get(common.SleepHeader); sleepVal != "" {
if sleepTime, err = strconv.ParseInt(sleepVal, 10, 64); err != nil {
logger.Errorf("Could not parse sleep val: %v", err)
}
}

return resp, sleepTime, nil
}

if s.atMaxRetryWait {
// we've already reached the max wait time, so we don't need to perform the calculation again.
// we'll still increment the fail count to keep track of the number of failures
s.failCount++
return reader, int64(maxRetryWait.Seconds()), err
return nil, int64(maxRetryWait.Seconds()), err
}

wait := time.Duration(math.Pow(2, float64(s.failCount)) * float64(retryWaitMillis))
Expand All @@ -53,21 +67,21 @@ func (s *sender) post(

if wait > maxRetryWait {
s.atMaxRetryWait = true
return reader, int64(maxRetryWait.Seconds()), err
return nil, int64(maxRetryWait.Seconds()), err
}

return reader, int64(wait.Seconds()), err
return nil, int64(wait.Seconds()), err
}

func (s *sender) doPost(
originURL string,
buf io.Reader,
rt http.RoundTripper,
user common.UserConfig,
) (io.ReadCloser, int64, error) {
) (*http.Response, error) {
req, err := http.NewRequest("POST", originURL, buf)
if err != nil {
return nil, 0, fmt.Errorf("unable to create request for %s: %w", originURL, err)
return nil, fmt.Errorf("unable to create request for %s: %w", originURL, err)
}

common.AddCommonHeaders(user, req)
Expand All @@ -81,22 +95,10 @@ func (s *sender) doPost(
req.Close = true
resp, err := rt.RoundTrip(req)
if err != nil {
return nil, 0, fmt.Errorf("request to %s failed: %w", originURL, err)
}

if resp.StatusCode != 200 {
return nil, 0, fmt.Errorf("bad response code: %v", resp.StatusCode)
resp.Body.Close()
return nil, fmt.Errorf("request to %s failed: %w", originURL, err)
}

logger.Debugf("Response headers from %v:\n%v", originURL, resp.Header)

var sleepTime int64
sleepVal := resp.Header.Get(common.SleepHeader)
if sleepVal != "" {
if sleepTime, err = strconv.ParseInt(sleepVal, 10, 64); err != nil {
logger.Errorf("Could not parse sleep val: %v", err)
}
}

return resp.Body, sleepTime, nil
return resp, nil
}

0 comments on commit 35ce476

Please sign in to comment.