Skip to content

Commit

Permalink
add http tests
Browse files Browse the repository at this point in the history
  • Loading branch information
garmr-ulfr committed Aug 14, 2024
1 parent 35ce476 commit 535cc4f
Show file tree
Hide file tree
Showing 2 changed files with 88 additions and 11 deletions.
22 changes: 11 additions & 11 deletions services/http.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,13 +12,13 @@ import (
)

const (
// retryWaitMillis is the base wait time in milliseconds between retries
retryWaitMillis = 100
maxRetryWait = 10 * time.Minute
// retryWaitSeconds is the base wait time in seconds between retries
retryWaitSeconds = 5 * time.Second
maxRetryWait = 10 * time.Minute
)

// sender is a helper for sending post requests. If the request fails, sender calulates an
// exponential backoff time using retryWaitMillis and return it as the sleep time.
// exponential backoff time using retryWaitSeconds and return it as the sleep time.
type sender struct {
failCount int
atMaxRetryWait bool
Expand All @@ -37,13 +37,14 @@ func (s *sender) post(
) (*http.Response, int64, error) {
resp, err := s.doPost(originURL, buf, rt, user)
if err == nil {
if resp.StatusCode != http.StatusOK && resp.StatusCode != http.StatusNoContent {
err = fmt.Errorf("bad response code: %v", resp.StatusCode)
goto backoff
}

s.failCount = 0
s.atMaxRetryWait = false

if resp.StatusCode != http.StatusOK || resp.StatusCode != http.StatusNoContent {
return nil, 0, fmt.Errorf("bad response code: %v", resp.StatusCode)
}

var sleepTime int64
if sleepVal := resp.Header.Get(common.SleepHeader); sleepVal != "" {
if sleepTime, err = strconv.ParseInt(sleepVal, 10, 64); err != nil {
Expand All @@ -54,15 +55,15 @@ func (s *sender) post(
return resp, sleepTime, nil
}

backoff:
if s.atMaxRetryWait {
// we've already reached the max wait time, so we don't need to perform the calculation again.
// we'll still increment the fail count to keep track of the number of failures
s.failCount++
return nil, int64(maxRetryWait.Seconds()), err
}

wait := time.Duration(math.Pow(2, float64(s.failCount)) * float64(retryWaitMillis))
wait *= time.Millisecond
wait := time.Duration(math.Pow(2, float64(s.failCount))) * retryWaitSeconds
s.failCount++

if wait > maxRetryWait {
Expand Down Expand Up @@ -95,7 +96,6 @@ func (s *sender) doPost(
req.Close = true
resp, err := rt.RoundTrip(req)
if err != nil {
resp.Body.Close()
return nil, fmt.Errorf("request to %s failed: %w", originURL, err)
}

Expand Down
77 changes: 77 additions & 0 deletions services/http_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,77 @@
package services

import (
"math"
mrand "math/rand/v2"
"net/http"
"strconv"
"testing"
"time"

"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"

"github.com/getlantern/flashlight/v7/common"
)

func TestPost(t *testing.T) {
sdr := &sender{}
rt := &mockRoundTripper{
status: http.StatusOK,
sleep: mrand.IntN(10),
}
user := common.NullUserConfig{}
_, sleep, err := sdr.post("http://example.com", nil, rt, user)
require.NoError(t, err)

assert.Equal(t, rt.sleep, int(sleep), "response sleep value does not match")

testBackoff := func(t *testing.T, rt *mockRoundTripper) {
sdr := &sender{}
for i := 0; i < 5; i++ {
wait := time.Duration(math.Pow(2, float64(i))) * retryWaitSeconds
want := int64(wait.Seconds())
_, sleep, err = sdr.post("http://example.com", nil, rt, user)
assert.Equal(t, want, sleep, "returned sleep value does not follow an exponential backoff")
}
}

t.Run("backoff on error", func(t *testing.T) {
rt = &mockRoundTripper{shouldErr: true}
testBackoff(t, rt)
})

t.Run("backoff on bad StatusCode", func(t *testing.T) {
rt = &mockRoundTripper{status: http.StatusBadRequest}
testBackoff(t, rt)
})
}

func TestDoPost(t *testing.T) {
sdr := &sender{}
rt := &mockRoundTripper{status: http.StatusOK}
_, err := sdr.doPost("http://example.com", nil, rt, common.NullUserConfig{})
assert.NoError(t, err)
assert.True(t, rt.req.Close, "request.Close should be set to true before calling RoundTrip")
}

type mockRoundTripper struct {
req *http.Request
status int
sleep int
shouldErr bool
}

func (m *mockRoundTripper) RoundTrip(req *http.Request) (*http.Response, error) {
m.req = req
if m.shouldErr {
return nil, assert.AnError
}

header := http.Header{}
header.Add(common.SleepHeader, strconv.Itoa(m.sleep))
return &http.Response{
StatusCode: m.status,
Header: header,
}, nil
}

0 comments on commit 535cc4f

Please sign in to comment.