From 661e55aedad401457757003500064cb7b7d19a1c Mon Sep 17 00:00:00 2001 From: Oleg Kovalov Date: Wed, 13 Sep 2023 19:36:41 +0200 Subject: [PATCH] New client API --- hedged.go | 71 +++++++++++++++++++++++++++++++++++++++++++------- hedged_test.go | 4 +-- 2 files changed, 64 insertions(+), 11 deletions(-) diff --git a/hedged.go b/hedged.go index 24e03f5..333403e 100644 --- a/hedged.go +++ b/hedged.go @@ -12,6 +12,55 @@ import ( const infiniteTimeout = 30 * 24 * time.Hour // domain specific infinite +type Client struct { + rt http.RoundTripper + stats *Stats +} + +type Config struct { + Transport http.RoundTripper + Upto int + Delay time.Duration + Foo FooFn +} + +type FooFn func() (upto int, delay time.Duration) + +func New(cfg Config) (*Client, error) { + switch { + case cfg.Delay < 0: + return nil, errors.New("hedgedhttp: timeout cannot be negative") + case cfg.Upto < 0: + return nil, errors.New("hedgedhttp: upto cannot be negative") + } + if cfg.Transport == nil { + cfg.Transport = http.DefaultTransport + } + + rt, stats, err := NewRoundTripperAndStats(cfg.Delay, cfg.Upto, cfg.Transport) + if err != nil { + return nil, err + } + + c := &Client{ + rt: rt, + stats: stats, + } + return c, nil +} + +func (c *Client) Stats() *Stats { + return c.stats +} + +func (c *Client) Do(req *http.Request) (*http.Response, error) { + return c.rt.RoundTrip(req) +} + +func (c *Client) RoundTrip(req *http.Request) (*http.Response, error) { + return c.rt.RoundTrip(req) +} + // NewClient returns a new http.Client which implements hedged requests pattern. // Given Client starts a new request after a timeout from previous request. // Starts no more than upto requests. @@ -63,8 +112,8 @@ func NewRoundTripperAndStats(timeout time.Duration, upto int, rt http.RoundTripp switch { case timeout < 0: return nil, nil, errors.New("hedgedhttp: timeout cannot be negative") - case upto < 1: - return nil, nil, errors.New("hedgedhttp: upto must be greater than 0") + case upto < 0: + return nil, nil, errors.New("hedgedhttp: upto cannot be negative") } if rt == nil { @@ -85,6 +134,7 @@ func NewRoundTripperAndStats(timeout time.Duration, upto int, rt http.RoundTripp } type hedgedTransport struct { + foo FooFn rt http.RoundTripper timeout time.Duration upto int @@ -94,15 +144,18 @@ type hedgedTransport struct { func (ht *hedgedTransport) RoundTrip(req *http.Request) (*http.Response, error) { mainCtx := req.Context() - timeout := ht.timeout + upto, timeout := ht.upto, ht.timeout + if ht.foo != nil { + upto, timeout = ht.foo() + } errOverall := &MultiError{} - resultCh := make(chan indexedResp, ht.upto) - errorCh := make(chan error, ht.upto) + resultCh := make(chan indexedResp, upto) + errorCh := make(chan error, upto) ht.metrics.requestedRoundTripsInc() resultIdx := -1 - cancels := make([]func(), ht.upto) + cancels := make([]func(), upto) defer runInPool(func() { for i, cancel := range cancels { @@ -113,8 +166,8 @@ func (ht *hedgedTransport) RoundTrip(req *http.Request) (*http.Response, error) } }) - for sent := 0; len(errOverall.Errors) < ht.upto; sent++ { - if sent < ht.upto { + for sent := 0; len(errOverall.Errors) < upto; sent++ { + if sent < upto { idx := sent subReq, cancel := reqWithCtx(req, mainCtx, idx != 0) cancels[idx] = cancel @@ -132,7 +185,7 @@ func (ht *hedgedTransport) RoundTrip(req *http.Request) (*http.Response, error) } // all request sent - effectively disabling timeout between requests - if sent == ht.upto { + if sent == upto { timeout = infiniteTimeout } resp, err := waitResult(mainCtx, resultCh, errorCh, timeout) diff --git a/hedged_test.go b/hedged_test.go index a931683..bbf0792 100644 --- a/hedged_test.go +++ b/hedged_test.go @@ -22,10 +22,10 @@ func TestValidateInput(t *testing.T) { _, _, err = hedgedhttp.NewClientAndStats(time.Second, -1, nil) mustFail(t, err) - _, _, err = hedgedhttp.NewClientAndStats(time.Second, 0, nil) + _, _, err = hedgedhttp.NewClientAndStats(time.Second, -1, nil) mustFail(t, err) - _, err = hedgedhttp.NewRoundTripper(time.Second, 0, nil) + _, err = hedgedhttp.NewRoundTripper(time.Second, -1, nil) mustFail(t, err) }