diff --git a/pkg/experiment/remote/client.go b/pkg/experiment/remote/client.go index 456a70b..83e8eb6 100644 --- a/pkg/experiment/remote/client.go +++ b/pkg/experiment/remote/client.go @@ -42,6 +42,7 @@ func Initialize(apiKey string, config *Config) *Client { client: &http.Client{}, } client.log.Debug("config: %v", *config) + clients[apiKey] = client } initMutex.Unlock() return client @@ -60,11 +61,17 @@ func (c *Client) Fetch(user *experiment.User) (map[string]experiment.Variant, er // FetchV2 fetches variants for a user from the remote evaluation service. // Unlike Fetch, this method returns all variants, including default variants. func (c *Client) FetchV2(user *experiment.User) (map[string]experiment.Variant, error) { - variants, err := c.doFetch(user, c.config.FetchTimeout) + ctx := context.Background() + return c.FetchV2WithContext(user, ctx) +} + +// FetchV2WithContext fetches variants for a user from the remote evaluation service with a context. +func (c *Client) FetchV2WithContext(user *experiment.User, ctx context.Context) (map[string]experiment.Variant, error) { + variants, err := c.doFetch(ctx, user, c.config.FetchTimeout) if err != nil { c.log.Error("fetch error: %v", err) if c.config.RetryBackoff.FetchRetries > 0 && shouldRetryFetch(err) { - return c.retryFetch(user) + return c.retryFetch(ctx, user) } else { return nil, err } @@ -72,7 +79,7 @@ func (c *Client) FetchV2(user *experiment.User) (map[string]experiment.Variant, return variants, err } -func (c *Client) doFetch(user *experiment.User, timeout time.Duration) (map[string]experiment.Variant, error) { +func (c *Client) doFetch(ctx context.Context, user *experiment.User, timeout time.Duration) (map[string]experiment.Variant, error) { addLibraryContext(user) endpoint, err := url.Parse(c.config.ServerUrl) if err != nil { @@ -87,7 +94,7 @@ func (c *Client) doFetch(user *experiment.User, timeout time.Duration) (map[stri return nil, err } c.log.Debug("fetch variants for user %s", string(jsonBytes)) - ctx, cancel := context.WithTimeout(context.Background(), timeout) + ctx, cancel := context.WithTimeout(ctx, timeout) defer cancel() req, err := http.NewRequestWithContext(ctx, "GET", endpoint.String(), nil) if err != nil { @@ -109,7 +116,7 @@ func (c *Client) doFetch(user *experiment.User, timeout time.Duration) (map[stri return c.parseResponse(resp) } -func (c *Client) retryFetch(user *experiment.User) (map[string]experiment.Variant, error) { +func (c *Client) retryFetch(ctx context.Context, user *experiment.User) (map[string]experiment.Variant, error) { var err error var variants map[string]experiment.Variant var timer *time.Timer @@ -118,7 +125,7 @@ func (c *Client) retryFetch(user *experiment.User) (map[string]experiment.Varian c.log.Debug("retry attempt %v", i) timer = time.NewTimer(delay) <-timer.C - variants, err = c.doFetch(user, c.config.RetryBackoff.FetchRetryTimeout) + variants, err = c.doFetch(ctx, user, c.config.RetryBackoff.FetchRetryTimeout) if err == nil && variants != nil { c.log.Debug("retry attempt %v success", i) return variants, nil diff --git a/pkg/experiment/remote/client_test.go b/pkg/experiment/remote/client_test.go index 59d9d45..0198988 100644 --- a/pkg/experiment/remote/client_test.go +++ b/pkg/experiment/remote/client_test.go @@ -22,7 +22,6 @@ func TestClient_Fetch_DoesNotReturnDefaultVariants(t *testing.T) { require.Empty(t, variant) } - func TestClient_FetchV2_ReturnsDefaultVariants(t *testing.T) { client := Initialize("server-qz35UwzJ5akieoAdIgzM4m9MIiOLXLoz", nil) user := &experiment.User{} diff --git a/pkg/experiment/remote/config.go b/pkg/experiment/remote/config.go index aedd236..4f9349c 100644 --- a/pkg/experiment/remote/config.go +++ b/pkg/experiment/remote/config.go @@ -1,6 +1,8 @@ package remote -import "time" +import ( + "time" +) type Config struct { Debug bool