diff --git a/go.mod b/go.mod index bad5092..bb7e6b0 100644 --- a/go.mod +++ b/go.mod @@ -8,5 +8,7 @@ require ( github.com/amplitude/analytics-go v1.0.1 github.com/jarcoal/httpmock v1.3.1 github.com/joho/godotenv v1.5.1 + github.com/r3labs/sse/v2 v2.10.0 github.com/stretchr/testify v1.9.0 + gopkg.in/cenkalti/backoff.v1 v1.1.0 ) diff --git a/go.sum b/go.sum index ec3c7e2..ba2228c 100644 --- a/go.sum +++ b/go.sum @@ -13,6 +13,8 @@ github.com/maxatome/go-testdeep v1.12.0 h1:Ql7Go8Tg0C1D/uMMX59LAoYK7LffeJQ6X2T04 github.com/maxatome/go-testdeep v1.12.0/go.mod h1:lPZc/HAcJMP92l7yI6TRz1aZN5URwUBUAfUNvrclaNM= github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= +github.com/r3labs/sse/v2 v2.10.0 h1:hFEkLLFY4LDifoHdiCN/LlGBAdVJYsANaLqNYa1l/v0= +github.com/r3labs/sse/v2 v2.10.0/go.mod h1:Igau6Whc+F17QUgML1fYe1VPZzTV6EMCnYktEmkNJ7I= github.com/spaolacci/murmur3 v1.1.0 h1:7c1g84S4BPRrfL5Xrdp6fOJ206sU9y293DDHaoy0bLI= github.com/spaolacci/murmur3 v1.1.0/go.mod h1:JwIasOWyU6f++ZhiEuf87xNszmSA2myDM2Kzu9HwQUA= github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= @@ -20,12 +22,20 @@ github.com/stretchr/objx v0.4.0/go.mod h1:YvHI0jy2hoMjB+UWwv71VJQ9isScKT/TqJzVSS github.com/stretchr/objx v0.5.0/go.mod h1:Yh+to48EsGEfYuaHDzXPcE3xhTkx73EhmCGUpEOglKo= github.com/stretchr/objx v0.5.2 h1:xuMeJ0Sdp5ZMRXx/aWO6RZxdr3beISkG5/G/aIRr3pY= github.com/stretchr/objx v0.5.2/go.mod h1:FRsXN1f5AsAjCGJKqEizvkpNtU+EGNCLh3NxZ/8L+MA= +github.com/stretchr/testify v1.7.0/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= github.com/stretchr/testify v1.7.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= github.com/stretchr/testify v1.8.0/go.mod h1:yNjHg4UonilssWZ8iaSj1OCr/vHnekPRkoO+kdMU+MU= github.com/stretchr/testify v1.8.1/go.mod h1:w2LPCIKwWwSfY2zedu0+kehJoqGctiVI29o6fzry7u4= github.com/stretchr/testify v1.8.4/go.mod h1:sz/lmYIOXD/1dqDmKjjqLyZ2RngseejIcXlSw2iwfAo= github.com/stretchr/testify v1.9.0 h1:HtqpIVDClZ4nwg75+f6Lvsy/wHu+3BoSGCbBAcpTsTg= github.com/stretchr/testify v1.9.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY= +golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w= +golang.org/x/net v0.0.0-20191116160921-f9c825593386 h1:ktbWvQrW08Txdxno1PiDpSxPXG6ndGsfnJjRRtkM0LQ= +golang.org/x/net v0.0.0-20191116160921-f9c825593386/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s= +golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= +golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= +gopkg.in/cenkalti/backoff.v1 v1.1.0 h1:Arh75ttbsvlpVA7WtVpH4u9h6Zl46xuptxqLxPiSo4Y= +gopkg.in/cenkalti/backoff.v1 v1.1.0/go.mod h1:J6Vskwqd+OMVJl8C33mmtxTBs2gyzfv7UDAkHu8BrjI= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405 h1:yhCVgyC4o1eVCa2tZl7eS0r+SDo693bJlVdllGtEeKM= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= diff --git a/pkg/experiment/local/assignment_service_test.go b/pkg/experiment/local/assignment_service_test.go index d1d5b77..4d364d6 100644 --- a/pkg/experiment/local/assignment_service_test.go +++ b/pkg/experiment/local/assignment_service_test.go @@ -22,9 +22,9 @@ func TestToEvent(t *testing.T) { }, }, "flag-key-2": { - Key: "control", + Key: "control", Metadata: map[string]interface{}{ - "default": true, + "default": true, "segmentName": "All Other Users", "flagVersion": float64(12), }, diff --git a/pkg/experiment/local/client.go b/pkg/experiment/local/client.go index 07286a7..7d3db54 100644 --- a/pkg/experiment/local/client.go +++ b/pkg/experiment/local/client.go @@ -4,13 +4,14 @@ import ( "context" "encoding/json" "fmt" - "github.com/amplitude/analytics-go/amplitude" "io/ioutil" "net/http" "net/url" "reflect" "sync" + "github.com/amplitude/analytics-go/amplitude" + "github.com/amplitude/experiment-go-server/internal/evaluation" "github.com/amplitude/experiment-go-server/pkg/experiment" @@ -59,9 +60,16 @@ func Initialize(apiKey string, config *Config) *Client { var deploymentRunner *deploymentRunner if config.CohortSyncConfig != nil { cohortDownloadApi := newDirectCohortDownloadApi(config.CohortSyncConfig.ApiKey, config.CohortSyncConfig.SecretKey, config.CohortSyncConfig.MaxCohortSize, config.CohortSyncConfig.CohortServerUrl, config.Debug) - cohortLoader = newCohortLoader(cohortDownloadApi, cohortStorage) + cohortLoader = newCohortLoader(cohortDownloadApi, cohortStorage, config.Debug) + } + var flagStreamApi *flagConfigStreamApiV2 + if config.StreamUpdates { + flagStreamApi = newFlagConfigStreamApiV2(apiKey, config.StreamServerUrl, config.StreamFlagConnTimeout) } - deploymentRunner = newDeploymentRunner(config, newFlagConfigApiV2(apiKey, config.ServerUrl, config.FlagConfigPollerRequestTimeout), flagConfigStorage, cohortStorage, cohortLoader) + deploymentRunner = newDeploymentRunner( + config, + newFlagConfigApiV2(apiKey, config.ServerUrl, config.FlagConfigPollerRequestTimeout), + flagStreamApi, flagConfigStorage, cohortStorage, cohortLoader) client = &Client{ log: log, apiKey: apiKey, diff --git a/pkg/experiment/local/client_stream_test.go b/pkg/experiment/local/client_stream_test.go new file mode 100644 index 0000000..543d7d4 --- /dev/null +++ b/pkg/experiment/local/client_stream_test.go @@ -0,0 +1,157 @@ +package local + +import ( + "log" + "os" + "testing" + + "github.com/amplitude/experiment-go-server/pkg/experiment" + "github.com/joho/godotenv" + "github.com/stretchr/testify/assert" +) + +var streamClient *Client + +func init() { + err := godotenv.Load() + if err != nil { + log.Printf("Error loading .env file: %v", err) + } + projectApiKey := os.Getenv("API_KEY") + secretKey := os.Getenv("SECRET_KEY") + cohortSyncConfig := CohortSyncConfig{ + ApiKey: projectApiKey, + SecretKey: secretKey, + } + streamClient = Initialize("server-qz35UwzJ5akieoAdIgzM4m9MIiOLXLoz", + &Config{ + StreamUpdates: true, + StreamServerUrl: "https://stream.lab.amplitude.com", + CohortSyncConfig: &cohortSyncConfig, + }) + err = streamClient.Start() + if err != nil { + panic(err) + } +} + +func TestMakeSureStreamEnabled(t *testing.T) { + assert.True(t, streamClient.config.StreamUpdates) +} + +func TestStreamEvaluate(t *testing.T) { + user := &experiment.User{UserId: "test_user"} + result, err := streamClient.Evaluate(user, nil) + if err != nil { + t.Fatalf("Unexpected error %v", err) + } + variant := result["sdk-local-evaluation-ci-test"] + if variant.Key != "on" { + t.Fatalf("Unexpected variant %v", variant) + } + if variant.Value != "on" { + t.Fatalf("Unexpected variant %v", variant) + } + if variant.Payload != "payload" { + t.Fatalf("Unexpected variant %v", variant) + } + variant = result["sdk-ci-test"] + if variant.Key != "" { + t.Fatalf("Unexpected variant %v", variant) + } + if variant.Value != "" { + t.Fatalf("Unexpected variant %v", variant) + } +} + +func TestStreamEvaluateV2AllFlags(t *testing.T) { + user := &experiment.User{UserId: "test_user"} + result, err := streamClient.EvaluateV2(user, nil) + if err != nil { + t.Fatalf("Unexpected error %v", err) + } + variant := result["sdk-local-evaluation-ci-test"] + if variant.Key != "on" { + t.Fatalf("Unexpected variant %v", variant) + } + if variant.Value != "on" { + t.Fatalf("Unexpected variant %v", variant) + } + if variant.Payload != "payload" { + t.Fatalf("Unexpected variant %v", variant) + } + variant = result["sdk-ci-test"] + if variant.Key != "off" { + t.Fatalf("Unexpected variant %v", variant) + } + if variant.Value != "" { + t.Fatalf("Unexpected variant %v", variant) + } +} + +func TestStreamFlagMetadataLocalFlagKey(t *testing.T) { + md := streamClient.FlagMetadata("sdk-local-evaluation-ci-test") + if md["evaluationMode"] != "local" { + t.Fatalf("Unexpected metadata %v", md) + } +} + +func TestStreamEvaluateV2Cohort(t *testing.T) { + targetedUser := &experiment.User{UserId: "12345"} + nonTargetedUser := &experiment.User{UserId: "not_targeted"} + flagKeys := []string{"sdk-local-evaluation-user-cohort-ci-test"} + result, err := streamClient.EvaluateV2(targetedUser, flagKeys) + if err != nil { + t.Fatalf("Unexpected error %v", err) + } + variant := result["sdk-local-evaluation-user-cohort-ci-test"] + if variant.Key != "on" { + t.Fatalf("Unexpected variant %v", variant) + } + if variant.Value != "on" { + t.Fatalf("Unexpected variant %v", variant) + } + result, err = streamClient.EvaluateV2(nonTargetedUser, flagKeys) + if err != nil { + t.Fatalf("Unexpected error %v", err) + } + variant = result["sdk-local-evaluation-user-cohort-ci-test"] + if variant.Key != "off" { + t.Fatalf("Unexpected variant %v", variant) + } +} + +func TestStreamEvaluateV2GroupCohort(t *testing.T) { + targetedUser := &experiment.User{ + UserId: "12345", + DeviceId: "device_id", + Groups: map[string][]string{ + "org id": {"1"}, + }} + nonTargetedUser := &experiment.User{ + UserId: "12345", + DeviceId: "device_id", + Groups: map[string][]string{ + "org id": {"not_targeted"}, + }} + flagKeys := []string{"sdk-local-evaluation-group-cohort-ci-test"} + result, err := streamClient.EvaluateV2(targetedUser, flagKeys) + if err != nil { + t.Fatalf("Unexpected error %v", err) + } + variant := result["sdk-local-evaluation-group-cohort-ci-test"] + if variant.Key != "on" { + t.Fatalf("Unexpected variant %v", variant) + } + if variant.Value != "on" { + t.Fatalf("Unexpected variant %v", variant) + } + result, err = streamClient.EvaluateV2(nonTargetedUser, flagKeys) + if err != nil { + t.Fatalf("Unexpected error %v", err) + } + variant = result["sdk-local-evaluation-group-cohort-ci-test"] + if variant.Key != "off" { + t.Fatalf("Unexpected variant %v", variant) + } +} diff --git a/pkg/experiment/local/client_test.go b/pkg/experiment/local/client_test.go index bada72a..8cb1b68 100644 --- a/pkg/experiment/local/client_test.go +++ b/pkg/experiment/local/client_test.go @@ -1,11 +1,12 @@ package local import ( - "github.com/amplitude/experiment-go-server/pkg/experiment" - "github.com/joho/godotenv" "log" "os" "testing" + + "github.com/amplitude/experiment-go-server/pkg/experiment" + "github.com/joho/godotenv" ) var client *Client diff --git a/pkg/experiment/local/cohort_loader.go b/pkg/experiment/local/cohort_loader.go index d325315..e3d234d 100644 --- a/pkg/experiment/local/cohort_loader.go +++ b/pkg/experiment/local/cohort_loader.go @@ -1,11 +1,16 @@ package local import ( + "fmt" + "strings" "sync" "sync/atomic" + + "github.com/amplitude/experiment-go-server/internal/logger" ) type cohortLoader struct { + log *logger.Log cohortDownloadApi cohortDownloadApi cohortStorage cohortStorage jobs sync.Map @@ -13,7 +18,7 @@ type cohortLoader struct { lockJobs sync.Mutex } -func newCohortLoader(cohortDownloadApi cohortDownloadApi, cohortStorage cohortStorage) *cohortLoader { +func newCohortLoader(cohortDownloadApi cohortDownloadApi, cohortStorage cohortStorage, debug bool) *cohortLoader { return &cohortLoader{ cohortDownloadApi: cohortDownloadApi, cohortStorage: cohortStorage, @@ -22,6 +27,7 @@ func newCohortLoader(cohortDownloadApi cohortDownloadApi, cohortStorage cohortSt return &CohortLoaderTask{} }, }, + log: logger.New(debug), } } @@ -86,3 +92,34 @@ func (cl *cohortLoader) downloadCohort(cohortID string) (*Cohort, error) { cohort := cl.cohortStorage.getCohort(cohortID) return cl.cohortDownloadApi.getCohort(cohortID, cohort) } + +func (cl *cohortLoader) downloadCohorts(cohortIDs map[string]struct{}) { + var wg sync.WaitGroup + errorChan := make(chan error, len(cohortIDs)) + + for cohortID := range cohortIDs { + wg.Add(1) + go func(id string) { + defer wg.Done() + task := cl.loadCohort(id) + if err := task.wait(); err != nil { + errorChan <- fmt.Errorf("cohort %s: %v", id, err) + } + }(cohortID) + } + + go func() { + wg.Wait() + close(errorChan) + }() + + var errorMessages []string + for err := range errorChan { + errorMessages = append(errorMessages, err.Error()) + cl.log.Error("Error downloading cohort: %v", err) + } + + if len(errorMessages) > 0 { + cl.log.Error("One or more cohorts failed to download:\n%s", strings.Join(errorMessages, "\n")) + } +} diff --git a/pkg/experiment/local/cohort_loader_test.go b/pkg/experiment/local/cohort_loader_test.go index 4f483e1..2dad534 100644 --- a/pkg/experiment/local/cohort_loader_test.go +++ b/pkg/experiment/local/cohort_loader_test.go @@ -10,7 +10,7 @@ import ( func TestLoadSuccess(t *testing.T) { api := &MockCohortDownloadApi{} storage := newInMemoryCohortStorage() - loader := newCohortLoader(api, storage) + loader := newCohortLoader(api, storage, true) // Define mock behavior api.On("getCohort", "a", mock.AnythingOfType("*local.Cohort")).Return(&Cohort{Id: "a", LastModified: 0, Size: 1, MemberIds: []string{"1"}, GroupType: userGroupType}, nil) @@ -48,7 +48,7 @@ func TestLoadSuccess(t *testing.T) { func TestFilterCohortsAlreadyComputed(t *testing.T) { api := &MockCohortDownloadApi{} storage := newInMemoryCohortStorage() - loader := newCohortLoader(api, storage) + loader := newCohortLoader(api, storage, true) storage.putCohort(&Cohort{Id: "a", LastModified: 0, Size: 0, MemberIds: []string{}}) storage.putCohort(&Cohort{Id: "b", LastModified: 0, Size: 0, MemberIds: []string{}}) @@ -89,7 +89,7 @@ func TestFilterCohortsAlreadyComputed(t *testing.T) { func TestLoadDownloadFailureThrows(t *testing.T) { api := &MockCohortDownloadApi{} storage := newInMemoryCohortStorage() - loader := newCohortLoader(api, storage) + loader := newCohortLoader(api, storage, true) // Define mock behavior api.On("getCohort", "a", mock.AnythingOfType("*local.Cohort")).Return(&Cohort{Id: "a", LastModified: 0, Size: 1, MemberIds: []string{"1"}, GroupType: userGroupType}, nil) diff --git a/pkg/experiment/local/config.go b/pkg/experiment/local/config.go index c896fd7..e84ba70 100644 --- a/pkg/experiment/local/config.go +++ b/pkg/experiment/local/config.go @@ -1,12 +1,14 @@ package local import ( - "github.com/amplitude/analytics-go/amplitude" "math" "time" + + "github.com/amplitude/analytics-go/amplitude" ) const EUFlagServerUrl = "https://flag.lab.eu.amplitude.com" +const EUFlagStreamServerUrl = "https://stream.lab.eu.amplitude.com" const EUCohortSyncUrl = "https://cohort-v2.lab.eu.amplitude.com" type ServerZone int @@ -22,6 +24,9 @@ type Config struct { ServerZone ServerZone FlagConfigPollerInterval time.Duration FlagConfigPollerRequestTimeout time.Duration + StreamUpdates bool + StreamServerUrl string + StreamFlagConnTimeout time.Duration AssignmentConfig *AssignmentConfig CohortSyncConfig *CohortSyncConfig } @@ -45,6 +50,9 @@ var DefaultConfig = &Config{ ServerZone: USServerZone, FlagConfigPollerInterval: 30 * time.Second, FlagConfigPollerRequestTimeout: 10 * time.Second, + StreamUpdates: false, + StreamServerUrl: "https://stream.lab.amplitude.com", + StreamFlagConnTimeout: 1500 * time.Millisecond, } var DefaultAssignmentConfig = &AssignmentConfig{ @@ -68,8 +76,10 @@ func fillConfigDefaults(c *Config) *Config { switch c.ServerZone { case USServerZone: c.ServerUrl = DefaultConfig.ServerUrl + c.StreamServerUrl = DefaultConfig.StreamServerUrl case EUServerZone: c.ServerUrl = EUFlagServerUrl + c.StreamServerUrl = EUFlagStreamServerUrl } } @@ -79,6 +89,9 @@ func fillConfigDefaults(c *Config) *Config { if c.FlagConfigPollerRequestTimeout == 0 { c.FlagConfigPollerRequestTimeout = DefaultConfig.FlagConfigPollerRequestTimeout } + if c.StreamFlagConnTimeout == 0 { + c.StreamFlagConnTimeout = DefaultConfig.StreamFlagConnTimeout + } if c.AssignmentConfig != nil && c.AssignmentConfig.CacheCapacity == 0 { c.AssignmentConfig.CacheCapacity = DefaultAssignmentConfig.CacheCapacity } diff --git a/pkg/experiment/local/config_test.go b/pkg/experiment/local/config_test.go index 6c790e7..b9ffd73 100644 --- a/pkg/experiment/local/config_test.go +++ b/pkg/experiment/local/config_test.go @@ -11,42 +11,49 @@ func TestFillConfigDefaults_ServerZoneAndServerUrl(t *testing.T) { input *Config expectedZone ServerZone expectedUrl string + expectedStreamUrl string }{ { name: "Nil config", input: nil, expectedZone: DefaultConfig.ServerZone, expectedUrl: DefaultConfig.ServerUrl, + expectedStreamUrl: DefaultConfig.StreamServerUrl, }, { name: "Empty ServerZone", input: &Config{}, expectedZone: DefaultConfig.ServerZone, expectedUrl: DefaultConfig.ServerUrl, + expectedStreamUrl: DefaultConfig.StreamServerUrl, }, { name: "ServerZone US", input: &Config{ServerZone: USServerZone}, expectedZone: USServerZone, expectedUrl: DefaultConfig.ServerUrl, + expectedStreamUrl: DefaultConfig.StreamServerUrl, }, { name: "ServerZone EU", input: &Config{ServerZone: EUServerZone}, expectedZone: EUServerZone, expectedUrl: EUFlagServerUrl, + expectedStreamUrl: EUFlagStreamServerUrl, }, { name: "Uppercase ServerZone EU", input: &Config{ServerZone: EUServerZone}, expectedZone: EUServerZone, expectedUrl: EUFlagServerUrl, + expectedStreamUrl: EUFlagStreamServerUrl, }, { name: "Custom ServerUrl", - input: &Config{ServerZone: USServerZone, ServerUrl: "https://custom.url/"}, + input: &Config{ServerZone: USServerZone, ServerUrl: "https://custom.url/", StreamServerUrl: "https://stream.custom.url"}, expectedZone: USServerZone, expectedUrl: "https://custom.url/", + expectedStreamUrl: "https://stream.custom.url", }, } @@ -59,6 +66,9 @@ func TestFillConfigDefaults_ServerZoneAndServerUrl(t *testing.T) { if result.ServerUrl != tt.expectedUrl { t.Errorf("expected ServerUrl %s, got %s", tt.expectedUrl, result.ServerUrl) } + if result.StreamServerUrl != tt.expectedStreamUrl { + t.Errorf("expected StreamServerUrl %s, got %s", tt.expectedStreamUrl, result.StreamServerUrl) + } }) } } @@ -133,6 +143,7 @@ func TestFillConfigDefaults_DefaultValues(t *testing.T) { expected: &Config{ ServerZone: DefaultConfig.ServerZone, ServerUrl: DefaultConfig.ServerUrl, + StreamServerUrl: DefaultConfig.StreamServerUrl, FlagConfigPollerInterval: DefaultConfig.FlagConfigPollerInterval, FlagConfigPollerRequestTimeout: DefaultConfig.FlagConfigPollerRequestTimeout, }, @@ -142,12 +153,14 @@ func TestFillConfigDefaults_DefaultValues(t *testing.T) { input: &Config{ ServerZone: EUServerZone, ServerUrl: "https://custom.url/", + StreamServerUrl: "https://stream.custom.url", FlagConfigPollerInterval: 60 * time.Second, FlagConfigPollerRequestTimeout: 20 * time.Second, }, expected: &Config{ ServerZone: EUServerZone, ServerUrl: "https://custom.url/", + StreamServerUrl: "https://stream.custom.url", FlagConfigPollerInterval: 60 * time.Second, FlagConfigPollerRequestTimeout: 20 * time.Second, }, @@ -163,6 +176,9 @@ func TestFillConfigDefaults_DefaultValues(t *testing.T) { if result.ServerUrl != tt.expected.ServerUrl { t.Errorf("expected ServerUrl %s, got %s", tt.expected.ServerUrl, result.ServerUrl) } + if result.StreamServerUrl != tt.expected.StreamServerUrl { + t.Errorf("expected StreamServerUrl %s, got %s", tt.expected.StreamServerUrl, result.StreamServerUrl) + } if result.FlagConfigPollerInterval != tt.expected.FlagConfigPollerInterval { t.Errorf("expected FlagConfigPollerInterval %v, got %v", tt.expected.FlagConfigPollerInterval, result.FlagConfigPollerInterval) } diff --git a/pkg/experiment/local/deployment_runner.go b/pkg/experiment/local/deployment_runner.go index 2b85437..bf22d82 100644 --- a/pkg/experiment/local/deployment_runner.go +++ b/pkg/experiment/local/deployment_runner.go @@ -1,196 +1,61 @@ package local import ( - "fmt" - "github.com/amplitude/experiment-go-server/internal/evaluation" - "github.com/amplitude/experiment-go-server/internal/logger" - "strings" "sync" + "time" + + "github.com/amplitude/experiment-go-server/internal/logger" ) type deploymentRunner struct { config *Config - flagConfigApi flagConfigApi flagConfigStorage flagConfigStorage - cohortStorage cohortStorage + flagConfigUpdater flagConfigUpdater cohortLoader *cohortLoader - lock sync.Mutex poller *poller + lock sync.Mutex log *logger.Log } +const streamUpdaterRetryDelay = 15 * time.Second +const updaterRetryMaxJitter = 1 * time.Second + func newDeploymentRunner( config *Config, flagConfigApi flagConfigApi, + flagConfigStreamApi *flagConfigStreamApiV2, flagConfigStorage flagConfigStorage, cohortStorage cohortStorage, cohortLoader *cohortLoader, ) *deploymentRunner { + flagConfigUpdater := newflagConfigFallbackRetryWrapper(newFlagConfigPoller(flagConfigApi, config, flagConfigStorage, cohortStorage, cohortLoader), nil, config.FlagConfigPollerInterval, updaterRetryMaxJitter, 0, 0, config.Debug) + if flagConfigStreamApi != nil { + flagConfigUpdater = newflagConfigFallbackRetryWrapper(newFlagConfigStreamer(flagConfigStreamApi, config, flagConfigStorage, cohortStorage, cohortLoader), flagConfigUpdater, streamUpdaterRetryDelay, updaterRetryMaxJitter, config.FlagConfigPollerInterval, 0, config.Debug) + } dr := &deploymentRunner{ config: config, - flagConfigApi: flagConfigApi, flagConfigStorage: flagConfigStorage, - cohortStorage: cohortStorage, cohortLoader: cohortLoader, + flagConfigUpdater: flagConfigUpdater, + poller: newPoller(), log: logger.New(config.Debug), } - dr.poller = newPoller() return dr } func (dr *deploymentRunner) start() error { dr.lock.Lock() defer dr.lock.Unlock() - - if err := dr.updateFlagConfigs(); err != nil { - dr.log.Error("Initial updateFlagConfigs failed: %v", err) + err := dr.flagConfigUpdater.Start(nil) + if err != nil { return err } - dr.poller.Poll(dr.config.FlagConfigPollerInterval, func() { - if err := dr.periodicRefresh(); err != nil { - dr.log.Error("Periodic updateFlagConfigs failed: %v", err) - } - }) - if dr.config.CohortSyncConfig != nil { dr.poller.Poll(dr.config.CohortSyncConfig.CohortPollingInterval, func() { - dr.updateStoredCohorts() + cohortIDs := getAllCohortIDsFromFlags(dr.flagConfigStorage.getFlagConfigsArray()) + dr.cohortLoader.downloadCohorts(cohortIDs) }) } return nil } - -func (dr *deploymentRunner) periodicRefresh() error { - defer func() { - if r := recover(); r != nil { - dr.log.Error("Recovered in periodicRefresh: %v", r) - } - }() - return dr.updateFlagConfigs() -} - -func (dr *deploymentRunner) updateFlagConfigs() error { - dr.log.Debug("Refreshing flag configs.") - flagConfigs, err := dr.flagConfigApi.getFlagConfigs() - if err != nil { - dr.log.Error("Failed to fetch flag configs: %v", err) - return err - } - - flagKeys := make(map[string]struct{}) - for _, flag := range flagConfigs { - flagKeys[flag.Key] = struct{}{} - } - - dr.flagConfigStorage.removeIf(func(f *evaluation.Flag) bool { - _, exists := flagKeys[f.Key] - return !exists - }) - - if dr.cohortLoader == nil { - for _, flagConfig := range flagConfigs { - dr.log.Debug("Putting non-cohort flag %s", flagConfig.Key) - dr.flagConfigStorage.putFlagConfig(flagConfig) - } - return nil - } - - newCohortIDs := make(map[string]struct{}) - for _, flagConfig := range flagConfigs { - for cohortID := range getAllCohortIDsFromFlag(flagConfig) { - newCohortIDs[cohortID] = struct{}{} - } - } - - existingCohortIDs := dr.cohortStorage.getCohortIds() - cohortIDsToDownload := difference(newCohortIDs, existingCohortIDs) - - // Download all new cohorts - dr.downloadCohorts(cohortIDsToDownload) - - // Get updated set of cohort ids - updatedCohortIDs := dr.cohortStorage.getCohortIds() - // Iterate through new flag configs and check if their required cohorts exist - for _, flagConfig := range flagConfigs { - cohortIDs := getAllCohortIDsFromFlag(flagConfig) - missingCohorts := difference(cohortIDs, updatedCohortIDs) - - dr.flagConfigStorage.putFlagConfig(flagConfig) - dr.log.Debug("Putting flag %s", flagConfig.Key) - if len(missingCohorts) != 0 { - dr.log.Error("Flag %s - failed to load cohorts: %v", flagConfig.Key, missingCohorts) - } - } - - // Delete unused cohorts - dr.deleteUnusedCohorts() - dr.log.Debug("Refreshed %d flag configs.", len(flagConfigs)) - - return nil -} - -func (dr *deploymentRunner) updateStoredCohorts() { - cohortIDs := getAllCohortIDsFromFlags(dr.flagConfigStorage.getFlagConfigsArray()) - dr.downloadCohorts(cohortIDs) -} - -func (dr *deploymentRunner) deleteUnusedCohorts() { - flagCohortIDs := make(map[string]struct{}) - for _, flag := range dr.flagConfigStorage.getFlagConfigs() { - for cohortID := range getAllCohortIDsFromFlag(flag) { - flagCohortIDs[cohortID] = struct{}{} - } - } - - storageCohorts := dr.cohortStorage.getCohorts() - for cohortID := range storageCohorts { - if _, exists := flagCohortIDs[cohortID]; !exists { - cohort := storageCohorts[cohortID] - if cohort != nil { - dr.cohortStorage.deleteCohort(cohort.GroupType, cohortID) - } - } - } -} - -func difference(set1, set2 map[string]struct{}) map[string]struct{} { - diff := make(map[string]struct{}) - for k := range set1 { - if _, exists := set2[k]; !exists { - diff[k] = struct{}{} - } - } - return diff -} - -func (dr *deploymentRunner) downloadCohorts(cohortIDs map[string]struct{}) { - var wg sync.WaitGroup - errorChan := make(chan error, len(cohortIDs)) - - for cohortID := range cohortIDs { - wg.Add(1) - go func(id string) { - defer wg.Done() - task := dr.cohortLoader.loadCohort(id) - if err := task.wait(); err != nil { - errorChan <- fmt.Errorf("cohort %s: %v", id, err) - } - }(cohortID) - } - - go func() { - wg.Wait() - close(errorChan) - }() - - var errorMessages []string - for err := range errorChan { - errorMessages = append(errorMessages, err.Error()) - dr.log.Error("Error downloading cohort: %v", err) - } - - if len(errorMessages) > 0 { - dr.log.Error("One or more cohorts failed to download:\n%s", strings.Join(errorMessages, "\n")) - } -} diff --git a/pkg/experiment/local/deployment_runner_test.go b/pkg/experiment/local/deployment_runner_test.go index 691dae7..a5f1e57 100644 --- a/pkg/experiment/local/deployment_runner_test.go +++ b/pkg/experiment/local/deployment_runner_test.go @@ -19,11 +19,12 @@ func TestStartThrowsIfFirstFlagConfigLoadFails(t *testing.T) { cohortDownloadAPI := &mockCohortDownloadApi{} flagConfigStorage := newInMemoryFlagConfigStorage() cohortStorage := newInMemoryCohortStorage() - cohortLoader := newCohortLoader(cohortDownloadAPI, cohortStorage) + cohortLoader := newCohortLoader(cohortDownloadAPI, cohortStorage, true) runner := newDeploymentRunner( &Config{}, flagAPI, + nil, flagConfigStorage, cohortStorage, cohortLoader, @@ -45,11 +46,12 @@ func TestStartSucceedsEvenIfFirstCohortLoadFails(t *testing.T) { }} flagConfigStorage := newInMemoryFlagConfigStorage() cohortStorage := newInMemoryCohortStorage() - cohortLoader := newCohortLoader(cohortDownloadAPI, cohortStorage) + cohortLoader := newCohortLoader(cohortDownloadAPI, cohortStorage, true) runner := newDeploymentRunner( DefaultConfig, flagAPI, + nil, flagConfigStorage, cohortStorage, cohortLoader, diff --git a/pkg/experiment/local/flag_config_api.go b/pkg/experiment/local/flag_config_api.go index b8dc324..0076bfc 100644 --- a/pkg/experiment/local/flag_config_api.go +++ b/pkg/experiment/local/flag_config_api.go @@ -4,12 +4,13 @@ import ( "context" "encoding/json" "fmt" - "github.com/amplitude/experiment-go-server/internal/evaluation" - "github.com/amplitude/experiment-go-server/pkg/experiment" "io/ioutil" "net/http" "net/url" "time" + + "github.com/amplitude/experiment-go-server/internal/evaluation" + "github.com/amplitude/experiment-go-server/pkg/experiment" ) type flagConfigApi interface { diff --git a/pkg/experiment/local/flag_config_storage.go b/pkg/experiment/local/flag_config_storage.go index 02daea6..e53e2f1 100644 --- a/pkg/experiment/local/flag_config_storage.go +++ b/pkg/experiment/local/flag_config_storage.go @@ -1,8 +1,9 @@ package local import ( - "github.com/amplitude/experiment-go-server/internal/evaluation" "sync" + + "github.com/amplitude/experiment-go-server/internal/evaluation" ) type flagConfigStorage interface { @@ -24,6 +25,10 @@ func newInMemoryFlagConfigStorage() *inMemoryFlagConfigStorage { } } +func (storage *inMemoryFlagConfigStorage) GetFlagConfigs() map[string]*evaluation.Flag { + return storage.getFlagConfigs() +} + func (storage *inMemoryFlagConfigStorage) getFlagConfig(key string) *evaluation.Flag { storage.flagConfigsLock.Lock() defer storage.flagConfigsLock.Unlock() diff --git a/pkg/experiment/local/flag_config_stream_api.go b/pkg/experiment/local/flag_config_stream_api.go new file mode 100644 index 0000000..181cd74 --- /dev/null +++ b/pkg/experiment/local/flag_config_stream_api.go @@ -0,0 +1,195 @@ +package local + +import ( + "encoding/json" + "errors" + "net/url" + "sync" + "time" + + "github.com/amplitude/experiment-go-server/internal/evaluation" +) + +const streamApiMaxJitter = 5 * time.Second +const streamApiKeepaliveTimeout = 17 * time.Second +const streamApiReconnInterval = 15 * time.Minute + +type flagConfigStreamApi interface { + Connect( + onInitUpdate func(map[string]*evaluation.Flag) error, + onUpdate func(map[string]*evaluation.Flag) error, + onError func(error), + ) error + Close() +} + +type flagConfigStreamApiV2 struct { + DeploymentKey string + ServerURL string + connectionTimeout time.Duration + stopCh chan bool + lock sync.Mutex + newSseStreamFactory func( + authToken, + url string, + connectionTimeout time.Duration, + keepaliveTimeout time.Duration, + reconnInterval time.Duration, + maxJitter time.Duration, + ) stream +} + +func newFlagConfigStreamApiV2( + deploymentKey string, + serverURL string, + connectionTimeout time.Duration, +) *flagConfigStreamApiV2 { + return &flagConfigStreamApiV2{ + DeploymentKey: deploymentKey, + ServerURL: serverURL, + connectionTimeout: connectionTimeout, + stopCh: nil, + lock: sync.Mutex{}, + newSseStreamFactory: newSseStream, + } +} + +func (api *flagConfigStreamApiV2) Connect( + onInitUpdate func(map[string]*evaluation.Flag) error, + onUpdate func(map[string]*evaluation.Flag) error, + onError func(error), +) error { + api.lock.Lock() + defer api.lock.Unlock() + + api.closeInternal() + + // Create URL. + endpoint, err := url.Parse(api.ServerURL) + if err != nil { + return err + } + endpoint.Path = "sdk/stream/v1/flags" + + // Create Stream. + stream := api.newSseStreamFactory("Api-Key "+api.DeploymentKey, endpoint.String(), api.connectionTimeout, streamApiKeepaliveTimeout, streamApiReconnInterval, streamApiMaxJitter) + + streamMsgCh := make(chan streamEvent) + streamErrCh := make(chan error) + + closeStream := func() { + stream.Cancel() + close(streamMsgCh) + close(streamErrCh) + } + + // Connect. + stream.Connect(streamMsgCh, streamErrCh) + + // Retrieve first flag configs and parse it. + // If any error here means init error. + select { + case msg := <-streamMsgCh: + // Parse message and verify data correct. + flags, err := parseData(msg.data) + if err != nil { + closeStream() + return errors.New("flag config stream api corrupt data, cause: " + err.Error()) + } + if onInitUpdate != nil { + err = onInitUpdate(flags) + } else if onUpdate != nil { + err = onUpdate(flags) + } + if err != nil { + closeStream() + return err + } + case err := <-streamErrCh: + // Error when creating the stream. + closeStream() + return err + case <-time.After(api.connectionTimeout): + // Timed out. + closeStream() + return errors.New("flag config stream api connect timeout") + } + + // Prep procedures for stopping. + stopCh := make(chan bool) + api.stopCh = stopCh + + closeAll := func() { + api.lock.Lock() + defer api.lock.Unlock() + closeStream() + if api.stopCh == stopCh { + api.stopCh = nil + } + close(stopCh) + } + + // Retrieve and pass on message forever until stopCh closes. + go func() { + for { + select { + case <-stopCh: // Channel returns immediately when closed. Note the local channel is referred here, so it's guaranteed to not be nil. + closeStream() + return + case msg := <-streamMsgCh: + // Parse message and verify data correct. + flags, err := parseData(msg.data) + if err != nil { + // Error, close everything. + closeAll() + if onError != nil { + onError(errors.New("stream corrupt data, cause: " + err.Error())) + } + return + } + if onUpdate != nil { + // Deliver async. Don't care about any errors. + //nolint:errcheck + go func() { onUpdate(flags) }() + } + case err := <-streamErrCh: + // Error, close everything. + closeAll() + if onError != nil { + onError(err) + } + return + } + } + }() + + return nil +} + +func parseData(data []byte) (map[string]*evaluation.Flag, error) { + + var flagsArray []*evaluation.Flag + err := json.Unmarshal(data, &flagsArray) + if err != nil { + return nil, err + } + flags := make(map[string]*evaluation.Flag) + for _, flag := range flagsArray { + flags[flag.Key] = flag + } + + return flags, nil +} + +func (api *flagConfigStreamApiV2) closeInternal() { + if api.stopCh != nil { + close(api.stopCh) + api.stopCh = nil + } +} +func (api *flagConfigStreamApiV2) Close() { + api.lock.Lock() + defer api.lock.Unlock() + + api.closeInternal() +} diff --git a/pkg/experiment/local/flag_config_stream_api_test.go b/pkg/experiment/local/flag_config_stream_api_test.go new file mode 100644 index 0000000..8c5e79d --- /dev/null +++ b/pkg/experiment/local/flag_config_stream_api_test.go @@ -0,0 +1,207 @@ +package local + +import ( + "errors" + "net/http" + "strings" + "testing" + "time" + + "github.com/amplitude/experiment-go-server/internal/evaluation" + "github.com/stretchr/testify/assert" +) + +type mockSseStream struct { + // Params + authToken string + url string + connectionTimeout time.Duration + keepaliveTimeout time.Duration + reconnInterval time.Duration + maxJitter time.Duration + + // Channels to emit messages to simulate new events received through stream. + messageCh chan (streamEvent) + errorCh chan (error) + + // Channel to tell there's a connection call. + chConnected chan bool +} + +func (s *mockSseStream) Connect(messageCh chan (streamEvent), errorCh chan (error)) { + s.messageCh = messageCh + s.errorCh = errorCh + + s.chConnected <- true +} + +func (s *mockSseStream) Cancel() { +} + +func (s *mockSseStream) setNewESFactory(f func(httpClient *http.Client, url string, headers map[string]string) eventSource) { +} + +func (s *mockSseStream) newSseStreamFactory( + authToken, + url string, + connectionTimeout time.Duration, + keepaliveTimeout time.Duration, + reconnInterval time.Duration, + maxJitter time.Duration, +) stream { + s.authToken = authToken + s.url = url + s.connectionTimeout = connectionTimeout + s.keepaliveTimeout = keepaliveTimeout + s.reconnInterval = reconnInterval + s.maxJitter = maxJitter + return s +} + +var FLAG_1_STR = []byte("[{\"key\":\"flagkey\",\"variants\":{},\"segments\":[]}]") +var FLAG_1, _ = parseData(FLAG_1_STR) + +func TestFlagConfigStreamApi(t *testing.T) { + sse := mockSseStream{chConnected: make(chan bool)} + api := newFlagConfigStreamApiV2("deploymentkey", "serverurl", 1*time.Second) + api.newSseStreamFactory = sse.newSseStreamFactory + receivedMsgCh := make(chan map[string]*evaluation.Flag) + receivedErrCh := make(chan error) + + go func() { + // On connect. + <-sse.chConnected + sse.messageCh <- streamEvent{data: FLAG_1_STR} + assert.Equal(t, FLAG_1, <-receivedMsgCh) + }() + err := api.Connect( + func(m map[string]*evaluation.Flag) error { + receivedMsgCh <- m + return nil + }, + func(m map[string]*evaluation.Flag) error { + receivedMsgCh <- m + return nil + }, + func(err error) { receivedErrCh <- err }, + ) + assert.Nil(t, err) + + go func() { sse.messageCh <- streamEvent{data: FLAG_1_STR} }() + assert.Equal(t, FLAG_1, <-receivedMsgCh) + go func() { sse.messageCh <- streamEvent{data: FLAG_1_STR} }() + assert.Equal(t, FLAG_1, <-receivedMsgCh) + + api.Close() +} + +func TestFlagConfigStreamApiErrorNoInitialFlags(t *testing.T) { + sse := mockSseStream{chConnected: make(chan bool)} + api := newFlagConfigStreamApiV2("deploymentkey", "serverurl", 1*time.Second) + api.newSseStreamFactory = sse.newSseStreamFactory + + go func() { + // On connect. + <-sse.chConnected + }() + err := api.Connect(nil, nil, nil) + assert.Equal(t, errors.New("flag config stream api connect timeout"), err) +} + +func TestFlagConfigStreamApiErrorCorruptInitialFlags(t *testing.T) { + sse := mockSseStream{chConnected: make(chan bool)} + api := newFlagConfigStreamApiV2("deploymentkey", "serverurl", 1*time.Second) + api.newSseStreamFactory = sse.newSseStreamFactory + receivedMsgCh := make(chan map[string]*evaluation.Flag) + receivedErrCh := make(chan error) + + go func() { + // On connect. + <-sse.chConnected + sse.messageCh <- streamEvent{data: []byte("bad data")} + <-receivedMsgCh // Should hang as no good data was received. + assert.Fail(t, "Bad message went through") + }() + err := api.Connect( + func(m map[string]*evaluation.Flag) error { receivedMsgCh <- m; return nil }, + func(m map[string]*evaluation.Flag) error { receivedMsgCh <- m; return nil }, + func(err error) { receivedErrCh <- err }, + ) + assert.Equal(t, "flag config stream api corrupt data", strings.Split(err.Error(), ", cause: ")[0]) +} + +func TestFlagConfigStreamApiErrorInitialFlagsUpdateFailStopsApi(t *testing.T) { + sse := mockSseStream{chConnected: make(chan bool)} + api := newFlagConfigStreamApiV2("deploymentkey", "serverurl", 1*time.Second) + api.newSseStreamFactory = sse.newSseStreamFactory + receivedMsgCh := make(chan map[string]*evaluation.Flag) + receivedErrCh := make(chan error) + + go func() { + // On connect. + <-sse.chConnected + sse.messageCh <- streamEvent{data: FLAG_1_STR} + <-receivedMsgCh // Should hang as no updates was received. + assert.Fail(t, "Bad message went through") + }() + err := api.Connect( + func(m map[string]*evaluation.Flag) error { return errors.New("bad update") }, + func(m map[string]*evaluation.Flag) error { receivedMsgCh <- m; return nil }, + func(err error) { receivedErrCh <- err }, + ) + assert.Equal(t, errors.New("bad update"), err) +} + +func TestFlagConfigStreamApiErrorInitialFlagsFutureUpdateFailDoesntStopApi(t *testing.T) { + sse := mockSseStream{chConnected: make(chan bool)} + api := newFlagConfigStreamApiV2("deploymentkey", "serverurl", 1*time.Second) + api.newSseStreamFactory = sse.newSseStreamFactory + receivedMsgCh := make(chan map[string]*evaluation.Flag) + receivedErrCh := make(chan error) + + go func() { + // On connect. + <-sse.chConnected + sse.messageCh <- streamEvent{data: FLAG_1_STR} + assert.Equal(t, FLAG_1, <-receivedMsgCh) // Should hang as no updates was received. + }() + err := api.Connect( + func(m map[string]*evaluation.Flag) error { receivedMsgCh <- m; return nil }, + func(m map[string]*evaluation.Flag) error { return errors.New("bad update") }, + func(err error) { receivedErrCh <- err }, + ) + assert.Nil(t, err) + // Send an update, this should call onUpdate cb which fails. + sse.messageCh <- streamEvent{data: FLAG_1_STR} + // Make sure channel is not closed. + sse.messageCh <- streamEvent{data: FLAG_1_STR} +} + +func TestFlagConfigStreamApiErrorDuringStreaming(t *testing.T) { + sse := mockSseStream{chConnected: make(chan bool)} + api := newFlagConfigStreamApiV2("deploymentkey", "serverurl", 1*time.Second) + api.newSseStreamFactory = sse.newSseStreamFactory + receivedMsgCh := make(chan map[string]*evaluation.Flag) + receivedErrCh := make(chan error) + + go func() { + // On connect. + <-sse.chConnected + sse.messageCh <- streamEvent{data: FLAG_1_STR} + assert.Equal(t, FLAG_1, <-receivedMsgCh) + }() + err := api.Connect( + func(m map[string]*evaluation.Flag) error { receivedMsgCh <- m; return nil }, + func(m map[string]*evaluation.Flag) error { receivedMsgCh <- m; return nil }, + func(err error) { receivedErrCh <- err }, + ) + assert.Nil(t, err) + + go func() { sse.errorCh <- errors.New("error1") }() + assert.Equal(t, errors.New("error1"), <-receivedErrCh) + + // The message channel should be closed. + defer mutePanic(nil) + sse.messageCh <- streamEvent{data: FLAG_1_STR} + assert.Fail(t, "Unexpected message after error") +} diff --git a/pkg/experiment/local/flag_config_updater.go b/pkg/experiment/local/flag_config_updater.go new file mode 100644 index 0000000..384df30 --- /dev/null +++ b/pkg/experiment/local/flag_config_updater.go @@ -0,0 +1,425 @@ +package local + +import ( + "sync" + "time" + + "github.com/amplitude/experiment-go-server/internal/evaluation" + "github.com/amplitude/experiment-go-server/internal/logger" +) + +type flagConfigUpdater interface { + // Start the updater. There can be multiple calls. + // If start fails, it should return err. The caller should handle error. + // If some other async error happened while updating (after already started successfully), + // it should call the `func (error)` callback function. + Start(func(error)) error + Stop() +} + +// The base for all flag config updaters. +// Contains a method to properly update the flag configs into storage and download cohorts. +type flagConfigUpdaterBase struct { + flagConfigStorage flagConfigStorage + cohortStorage cohortStorage + cohortLoader *cohortLoader + log *logger.Log +} + +func newFlagConfigUpdaterBase( + flagConfigStorage flagConfigStorage, + cohortStorage cohortStorage, + cohortLoader *cohortLoader, + config *Config, +) flagConfigUpdaterBase { + return flagConfigUpdaterBase{ + flagConfigStorage: flagConfigStorage, + cohortStorage: cohortStorage, + cohortLoader: cohortLoader, + log: logger.New(config.Debug), + } +} + +// Updates the received flag configs into storage and download cohorts. +func (u *flagConfigUpdaterBase) update(flagConfigs map[string]*evaluation.Flag) error { + + flagKeys := make(map[string]struct{}) + for _, flag := range flagConfigs { + flagKeys[flag.Key] = struct{}{} + } + + u.flagConfigStorage.removeIf(func(f *evaluation.Flag) bool { + _, exists := flagKeys[f.Key] + return !exists + }) + + if u.cohortLoader == nil { + for _, flagConfig := range flagConfigs { + u.log.Debug("Putting non-cohort flag %s", flagConfig.Key) + u.flagConfigStorage.putFlagConfig(flagConfig) + } + return nil + } + + newCohortIDs := make(map[string]struct{}) + for _, flagConfig := range flagConfigs { + for cohortID := range getAllCohortIDsFromFlag(flagConfig) { + newCohortIDs[cohortID] = struct{}{} + } + } + + existingCohortIDs := u.cohortStorage.getCohortIds() + cohortIDsToDownload := difference(newCohortIDs, existingCohortIDs) + + // Download all new cohorts + u.cohortLoader.downloadCohorts(cohortIDsToDownload) + + // Get updated set of cohort ids + updatedCohortIDs := u.cohortStorage.getCohortIds() + // Iterate through new flag configs and check if their required cohorts exist + for _, flagConfig := range flagConfigs { + cohortIDs := getAllCohortIDsFromFlag(flagConfig) + missingCohorts := difference(cohortIDs, updatedCohortIDs) + + u.flagConfigStorage.putFlagConfig(flagConfig) + u.log.Debug("Putting flag %s", flagConfig.Key) + if len(missingCohorts) != 0 { + u.log.Error("Flag %s - failed to load cohorts: %v", flagConfig.Key, missingCohorts) + } + } + + // Delete unused cohorts + u.deleteUnusedCohorts() + u.log.Debug("Refreshed %d flag configs.", len(flagConfigs)) + + return nil +} + +func (u *flagConfigUpdaterBase) deleteUnusedCohorts() { + flagCohortIDs := make(map[string]struct{}) + for _, flag := range u.flagConfigStorage.getFlagConfigs() { + for cohortID := range getAllCohortIDsFromFlag(flag) { + flagCohortIDs[cohortID] = struct{}{} + } + } + + storageCohorts := u.cohortStorage.getCohorts() + for cohortID := range storageCohorts { + if _, exists := flagCohortIDs[cohortID]; !exists { + cohort := storageCohorts[cohortID] + if cohort != nil { + u.cohortStorage.deleteCohort(cohort.GroupType, cohortID) + } + } + } +} + +// The streamer for flag configs. It receives flag configs through server side events. +type flagConfigStreamer struct { + flagConfigUpdaterBase + flagConfigStreamApi flagConfigStreamApi + lock sync.Mutex +} + +func newFlagConfigStreamer( + flagConfigStreamApi flagConfigStreamApi, + config *Config, + flagConfigStorage flagConfigStorage, + cohortStorage cohortStorage, + cohortLoader *cohortLoader, +) flagConfigUpdater { + return &flagConfigStreamer{ + flagConfigStreamApi: flagConfigStreamApi, + flagConfigUpdaterBase: newFlagConfigUpdaterBase(flagConfigStorage, cohortStorage, cohortLoader, config), + } +} + +func (s *flagConfigStreamer) Start(onError func(error)) error { + s.lock.Lock() + defer s.lock.Unlock() + + s.stopInternal() + return s.flagConfigStreamApi.Connect( + func(flags map[string]*evaluation.Flag) error { + return s.update(flags) + }, + func(flags map[string]*evaluation.Flag) error { + return s.update(flags) + }, + func(err error) { + s.Stop() + if onError != nil { + go func() {onError(err)}() + } + }, + ) +} + +func (s *flagConfigStreamer) stopInternal() { + s.flagConfigStreamApi.Close() +} + +func (s *flagConfigStreamer) Stop() { + s.lock.Lock() + defer s.lock.Unlock() + s.stopInternal() +} + +// The poller for flag configs. It polls every configured interval. +// On start, it polls a set of flag configs. If failed, error is returned. If success, poller starts. +type flagConfigPoller struct { + flagConfigUpdaterBase + flagConfigApi flagConfigApi + config *Config + poller *poller + lock sync.Mutex +} + +func newFlagConfigPoller( + flagConfigApi flagConfigApi, + config *Config, + flagConfigStorage flagConfigStorage, + cohortStorage cohortStorage, + cohortLoader *cohortLoader, +) flagConfigUpdater { + return &flagConfigPoller{ + flagConfigApi: flagConfigApi, + config: config, + flagConfigUpdaterBase: newFlagConfigUpdaterBase(flagConfigStorage, cohortStorage, cohortLoader, config), + } +} + +func (p *flagConfigPoller) Start(onError func(error)) error { + p.lock.Lock() + defer p.lock.Unlock() + + p.stopInternal() + + if err := p.updateFlagConfigs(); err != nil { + p.log.Error("Initial updateFlagConfigs failed: %v", err) + return err + } + + p.poller = newPoller() + p.poller.Poll(p.config.FlagConfigPollerInterval, func() { + if err := p.periodicRefresh(); err != nil { + p.log.Error("Periodic updateFlagConfigs failed: %v", err) + p.Stop() + if (onError != nil) { + go func() {onError(err)}() + } + } + }) + return nil +} + +func (p *flagConfigPoller) periodicRefresh() error { + defer func() { + if r := recover(); r != nil { + p.log.Error("Recovered in periodicRefresh: %v", r) + } + }() + return p.updateFlagConfigs() +} + +func (p *flagConfigPoller) updateFlagConfigs() error { + p.log.Debug("Refreshing flag configs.") + flagConfigs, err := p.flagConfigApi.getFlagConfigs() + if err != nil { + p.log.Error("Failed to fetch flag configs: %v", err) + return err + } + + return p.update(flagConfigs) +} + +func (p *flagConfigPoller) stopInternal() { + if p.poller != nil { + close(p.poller.shutdown) + p.poller = nil + } +} + +func (p *flagConfigPoller) Stop() { + p.lock.Lock() + defer p.lock.Unlock() + p.stopInternal() +} + +// A wrapper around flag config updaters to retry and fallback. +// If the main updater fails, it will fallback to the fallback updater and main updater enters retry loop. +type flagConfigFallbackRetryWrapper struct { + log *logger.Log + mainUpdater flagConfigUpdater + fallbackUpdater flagConfigUpdater + retryDelay time.Duration + maxJitter time.Duration + retryTimer *time.Timer + fallbackStartRetryDelay time.Duration + fallbackStartRetryMaxJitter time.Duration + fallbackStartRetryTimer *time.Timer + lock sync.Mutex + isRunning bool +} + +func newflagConfigFallbackRetryWrapper( + mainUpdater flagConfigUpdater, + fallbackUpdater flagConfigUpdater, + retryDelay time.Duration, + maxJitter time.Duration, + fallbackStartRetryDelay time.Duration, + fallbackStartRetryMaxJitter time.Duration, + debug bool, +) flagConfigUpdater { + return &flagConfigFallbackRetryWrapper{ + log: logger.New(debug), + mainUpdater: mainUpdater, + fallbackUpdater: fallbackUpdater, + retryDelay: retryDelay, + maxJitter: maxJitter, + fallbackStartRetryDelay: fallbackStartRetryDelay, + fallbackStartRetryMaxJitter: fallbackStartRetryMaxJitter, + isRunning: false, + } +} + +/** + * Start tries to start main updater first. + * If it failed, start the fallback updater. + * If fallback updater failed as well, return error. + * If fallback updater succeed, main updater enters retry, return ok. + * After started, if main failed, main enters retry loop and fallback will start. + * If fallback start failed, fallback will enter start retry loop until it's successfully started. + * If fallback start success, but failed later, it's not monitored. It's recommended to wrap fallback with flagConfigFallbackRetryWrapper. + * Since the wrapper retries, so there will never be error case. + * Thus, onError will never be called. + */ +func (w *flagConfigFallbackRetryWrapper) Start(onError func(error)) error { + // if (mainUpdater is flagConfigFallbackRetryWrapper) { + // return errors.New("Do not use flagConfigFallbackRetryWrapper as main updater. Fallback updater will never be used. Rewrite retry and fallback logic.") + // } + + w.lock.Lock() + defer w.lock.Unlock() + + if w.retryTimer != nil { + w.retryTimer.Stop() + w.retryTimer = nil + } + + err := w.mainUpdater.Start(func(err error) { + w.log.Debug("main updater updating err, starting fallback if available. error: ", err) + go func() { w.scheduleRetry() }() // Don't care if poller start error or not, always retry. + go func() { w.fallbackStart() }() + }) + if err == nil { + // Main start success, stop fallback. + if w.fallbackStartRetryTimer != nil { + w.fallbackStartRetryTimer.Stop() + } + if w.fallbackUpdater != nil { + w.fallbackUpdater.Stop() + } + w.isRunning = true + return nil + } + if w.fallbackUpdater == nil { + // No fallback, main start failed is wrapper start fail + w.log.Error("main updater start err, no fallback. error: ", err) + return err + } + w.log.Debug("main updater start err, starting fallback. error: ", err) + err = w.fallbackUpdater.Start(nil) + if err != nil { + w.log.Debug("fallback updater start failed. error: ", err) + return err + } + + w.isRunning = true + go func() { w.scheduleRetry() }() + return nil +} + +func (w *flagConfigFallbackRetryWrapper) Stop() { + w.lock.Lock() + defer w.lock.Unlock() + w.isRunning = false + + if w.retryTimer != nil { + w.retryTimer.Stop() + w.retryTimer = nil + } + w.mainUpdater.Stop() + if w.fallbackStartRetryTimer != nil { + w.fallbackStartRetryTimer.Stop() + } + if w.fallbackUpdater != nil { + w.fallbackUpdater.Stop() + } +} + +func (w *flagConfigFallbackRetryWrapper) scheduleRetry() { + w.lock.Lock() + defer w.lock.Unlock() + + if (!w.isRunning) { + return + } + + if w.retryTimer != nil { + w.retryTimer.Stop() + w.retryTimer = nil + } + w.retryTimer = time.AfterFunc(randTimeDuration(w.retryDelay, w.maxJitter), func() { + w.lock.Lock() + defer w.lock.Unlock() + + if (!w.isRunning) { + return + } + + if w.retryTimer != nil { + w.retryTimer = nil + } + + w.log.Debug("main updater retry start") + err := w.mainUpdater.Start(func(err error) { + w.log.Debug("main updater updating err, starting fallback if available. error: ", err) + go func() { w.scheduleRetry() }() // Don't care if poller start error or not, always retry. + go func() { w.fallbackStart() }() + }) + if err == nil { + // Main start success, stop fallback. + w.log.Debug("main updater retry start success") + if w.fallbackStartRetryTimer != nil { + w.fallbackStartRetryTimer.Stop() + } + if w.fallbackUpdater != nil { + w.fallbackUpdater.Stop() + } + return + } + + go func() { w.scheduleRetry() }() + }) +} + +func (w *flagConfigFallbackRetryWrapper) fallbackStart() { + w.lock.Lock() + defer w.lock.Unlock() + + if (!w.isRunning) { + return + } + if (w.fallbackUpdater == nil) { + return + } + + err := w.fallbackUpdater.Start(nil) + if (err != nil) { + w.log.Debug("fallback updater start failed and scheduling retry") + w.fallbackStartRetryTimer = time.AfterFunc(randTimeDuration(w.fallbackStartRetryDelay, w.fallbackStartRetryMaxJitter), func() { + w.fallbackStart() + }) + } +} diff --git a/pkg/experiment/local/flag_config_updater_test.go b/pkg/experiment/local/flag_config_updater_test.go new file mode 100644 index 0000000..769f88a --- /dev/null +++ b/pkg/experiment/local/flag_config_updater_test.go @@ -0,0 +1,567 @@ +package local + +import ( + "errors" + "testing" + "time" + + "github.com/amplitude/experiment-go-server/internal/evaluation" + "github.com/stretchr/testify/assert" +) + +func createTestPollerObjs() (mockFlagConfigApi, flagConfigStorage, cohortStorage, *cohortLoader) { + api := mockFlagConfigApi{} + cohortDownloadAPI := &mockCohortDownloadApi{} + flagConfigStorage := newInMemoryFlagConfigStorage() + cohortStorage := newInMemoryCohortStorage() + cohortLoader := newCohortLoader(cohortDownloadAPI, cohortStorage, true) + return api, flagConfigStorage, cohortStorage, cohortLoader +} + +func TestFlagConfigPoller(t *testing.T) { + api, flagConfigStorage, cohortStorage, cohortLoader := createTestPollerObjs() + + poller := newFlagConfigPoller(&api, &Config{FlagConfigPollerInterval: 1 * time.Second}, flagConfigStorage, cohortStorage, cohortLoader) + errorCh := make(chan error) + + // Poller start normal. + api.getFlagConfigsFunc = func() (map[string]*evaluation.Flag, error) { + return FLAG_1, nil + } + err := poller.Start(func(e error) { + errorCh <- e + }) // Start should block for first poll. + assert.Nil(t, err) + assert.Equal(t, FLAG_1, flagConfigStorage.getFlagConfigs()) // Test flags in storage. + + // Change up flags to empty. + api.getFlagConfigsFunc = func() (map[string]*evaluation.Flag, error) { + return map[string]*evaluation.Flag{}, nil + } + time.Sleep(1100 * time.Millisecond) // Sleep for poller to poll. + assert.Equal(t, map[string]*evaluation.Flag{}, flagConfigStorage.getFlagConfigs()) // Test flags empty in storage. + + // Stop poller, make sure there's no more poll. + poller.Stop() + api.getFlagConfigsFunc = func() (map[string]*evaluation.Flag, error) { + assert.Fail(t, "Unexpected poll") + return nil, nil + } + time.Sleep(1100 * time.Millisecond) // Sleep for poller to poll. +} + +func TestFlagConfigPollerStartFail(t *testing.T) { + api, flagConfigStorage, cohortStorage, cohortLoader := createTestPollerObjs() + + poller := newFlagConfigPoller(&api, &Config{FlagConfigPollerInterval: 1 * time.Second}, flagConfigStorage, cohortStorage, cohortLoader) + errorCh := make(chan error) + + // Poller start normal. + api.getFlagConfigsFunc = func() (map[string]*evaluation.Flag, error) { + return nil, errors.New("start error") + } + err := poller.Start(func(e error) { + errorCh <- e + }) // Start should block for first poll. + assert.Equal(t, errors.New("start error"), err) // Test flags in storage. +} + +func TestFlagConfigPollerPollingFail(t *testing.T) { + api, flagConfigStorage, cohortStorage, cohortLoader := createTestPollerObjs() + + poller := newFlagConfigPoller(&api, &Config{FlagConfigPollerInterval: 1 * time.Second}, flagConfigStorage, cohortStorage, cohortLoader) + errorCh := make(chan error) + + // Poller start normal. + api.getFlagConfigsFunc = func() (map[string]*evaluation.Flag, error) { + return FLAG_1, nil + } + err := poller.Start(func(e error) { + errorCh <- e + }) // Start should block for first poll. + assert.Nil(t, err) + assert.Equal(t, FLAG_1, flagConfigStorage.getFlagConfigs()) // Test flags in storage. + + // Return error on poll. + api.getFlagConfigsFunc = func() (map[string]*evaluation.Flag, error) { + return nil, errors.New("flag error") + } + time.Sleep(1100 * time.Millisecond) // Sleep for poller to poll. + assert.Equal(t, errors.New("flag error"), <-errorCh) // Error callback called. + + // Make sure there's no more poll. + api.getFlagConfigsFunc = func() (map[string]*evaluation.Flag, error) { + assert.Fail(t, "Unexpected poll") + return nil, nil + } + time.Sleep(1100 * time.Millisecond) // Wait for a poll which should never happen. + + // Can start again. + api.getFlagConfigsFunc = func() (map[string]*evaluation.Flag, error) { + return map[string]*evaluation.Flag{}, nil + } + err = poller.Start(func(e error) { + errorCh <- e + }) + assert.Nil(t, err) + assert.Equal(t, map[string]*evaluation.Flag{}, flagConfigStorage.getFlagConfigs()) // Test flags in storage. +} + +type mockFlagConfigStreamApi struct { + connectFunc func( + func(map[string]*evaluation.Flag) error, + func(map[string]*evaluation.Flag) error, + func(error), + ) error + closeFunc func() +} + +func (api *mockFlagConfigStreamApi) Connect( + onInitUpdate func(map[string]*evaluation.Flag) error, + onUpdate func(map[string]*evaluation.Flag) error, + onError func(error), +) error { + return api.connectFunc(onInitUpdate, onUpdate, onError) +} +func (api *mockFlagConfigStreamApi) Close() { api.closeFunc() } + +func createTestStreamerObjs() (mockFlagConfigStreamApi, flagConfigStorage, cohortStorage, *cohortLoader) { + api := mockFlagConfigStreamApi{} + cohortDownloadAPI := &mockCohortDownloadApi{} + flagConfigStorage := newInMemoryFlagConfigStorage() + cohortStorage := newInMemoryCohortStorage() + cohortLoader := newCohortLoader(cohortDownloadAPI, cohortStorage, true) + return api, flagConfigStorage, cohortStorage, cohortLoader +} + +func TestFlagConfigStreamer(t *testing.T) { + api, flagConfigStorage, cohortStorage, cohortLoader := createTestStreamerObjs() + + streamer := newFlagConfigStreamer(&api, &Config{FlagConfigPollerInterval: 1 * time.Second}, flagConfigStorage, cohortStorage, cohortLoader) + errorCh := make(chan error) + + var updateCb func(map[string]*evaluation.Flag) error + api.connectFunc = func( + onInitUpdate func(map[string]*evaluation.Flag) error, + onUpdate func(map[string]*evaluation.Flag) error, + onError func(error), + ) error { + err := onInitUpdate(FLAG_1) + updateCb = onUpdate + return err + } + api.closeFunc = func() { + updateCb = nil + } + + // Streamer start normal. + err := streamer.Start(func(e error) { + errorCh <- e + }) // Start should block for first set of flags. + assert.Nil(t, err) + assert.Equal(t, FLAG_1, flagConfigStorage.getFlagConfigs()) // Test flags in storage. + + // Update flags with empty set. + err = updateCb(map[string]*evaluation.Flag{}) + assert.Nil(t, err) + assert.Equal(t, map[string]*evaluation.Flag{}, flagConfigStorage.getFlagConfigs()) // Empty flags are updated. + + // Stop streamer. + streamer.Stop() + assert.Nil(t, updateCb) // Make sure stream Close is called. + + // Streamer start again. + err = streamer.Start(func(e error) { + errorCh <- e + }) // Start should block for first set of flags. + assert.Nil(t, err) + assert.Equal(t, FLAG_1, flagConfigStorage.getFlagConfigs()) // Test flags in storage. + + streamer.Stop() +} + +func TestFlagConfigStreamerStartFail(t *testing.T) { + api, flagConfigStorage, cohortStorage, cohortLoader := createTestStreamerObjs() + + streamer := newFlagConfigStreamer(&api, &Config{FlagConfigPollerInterval: 1 * time.Second}, flagConfigStorage, cohortStorage, cohortLoader) + errorCh := make(chan error) + + api.connectFunc = func( + onInitUpdate func(map[string]*evaluation.Flag) error, + onUpdate func(map[string]*evaluation.Flag) error, + onError func(error), + ) error { + return errors.New("api connect error") + } + api.closeFunc = func() { + } + + // Streamer start. + err := streamer.Start(func(e error) { + errorCh <- e + }) // Start should block for first set of flags, which is error. + assert.Equal(t, errors.New("api connect error"), err) +} + +func TestFlagConfigStreamerStreamingFail(t *testing.T) { + api, flagConfigStorage, cohortStorage, cohortLoader := createTestStreamerObjs() + + streamer := newFlagConfigStreamer(&api, &Config{FlagConfigPollerInterval: 1 * time.Second}, flagConfigStorage, cohortStorage, cohortLoader) + errorCh := make(chan error) + + var updateCb func(map[string]*evaluation.Flag) error + var errorCb func(error) + api.connectFunc = func( + onInitUpdate func(map[string]*evaluation.Flag) error, + onUpdate func(map[string]*evaluation.Flag) error, + onError func(error), + ) error { + err := onInitUpdate(FLAG_1) + updateCb = onUpdate + errorCb = onError + return err + } + api.closeFunc = func() { + updateCb = nil + errorCb = nil + } + + // Streamer start normal. + err := streamer.Start(func(e error) { + errorCh <- e + }) // Start should block for first set of flags. + assert.Nil(t, err) + assert.Equal(t, FLAG_1, flagConfigStorage.getFlagConfigs()) // Test flags in storage. + + // Stream error. + go func() { errorCb(errors.New("stream error")) }() + assert.Equal(t, errors.New("stream error"), <-errorCh) // Error callback is called. + assert.Nil(t, updateCb) // Make sure stream Close is called. + assert.Nil(t, errorCb) + + // Streamer start again. + flagConfigStorage.removeIf(func(f *evaluation.Flag) bool { return true }) + err = streamer.Start(func(e error) { + errorCh <- e + }) // Start should block for first set of flags. + assert.Nil(t, err) + assert.Equal(t, FLAG_1, flagConfigStorage.getFlagConfigs()) // Test flags in storage. + + streamer.Stop() +} + +type mockFlagConfigUpdater struct { + startFunc func(func(error)) error + stopFunc func() +} + +func (u *mockFlagConfigUpdater) Start(f func(error)) error { return u.startFunc(f) } +func (u *mockFlagConfigUpdater) Stop() { u.stopFunc() } + +func TestFlagConfigFallbackRetryWrapper(t *testing.T) { + main := mockFlagConfigUpdater{} + var mainOnError func(error) + main.startFunc = func(onError func(error)) error { + mainOnError = onError + return nil + } + main.stopFunc = func() { + mainOnError = nil + } + fallback := mockFlagConfigUpdater{} + fallback.startFunc = func(onError func(error)) error { + return nil + } + fallback.stopFunc = func() { + } + w := newflagConfigFallbackRetryWrapper(&main, &fallback, 1*time.Second, 0, 1*time.Second, 0, true) + err := w.Start(nil) + assert.Nil(t, err) + assert.NotNil(t, mainOnError) + + w.Stop() + assert.Nil(t, mainOnError) +} + +func TestFlagConfigFallbackRetryWrapperBothStartFail(t *testing.T) { + main := mockFlagConfigUpdater{} + var mainOnError func(error) + main.startFunc = func(onError func(error)) error { + mainOnError = onError + return errors.New("main start error") + } + main.stopFunc = func() { + mainOnError = nil + } + fallback := mockFlagConfigUpdater{} + fallback.startFunc = func(onError func(error)) error { + return errors.New("fallback start error") + } + fallback.stopFunc = func() { + } + w := newflagConfigFallbackRetryWrapper(&main, &fallback, 1*time.Second, 0, 1*time.Second, 0, true) + err := w.Start(nil) + assert.Equal(t, errors.New("fallback start error"), err) + assert.NotNil(t, mainOnError) + mainOnError = nil + + // Test no retry if start fail. + time.Sleep(2000 * time.Millisecond) + assert.Nil(t, mainOnError) +} + +func TestFlagConfigFallbackRetryWrapperMainStartFailFallbackSuccess(t *testing.T) { + main := mockFlagConfigUpdater{} + var mainOnError func(error) + main.startFunc = func(onError func(error)) error { + mainOnError = onError + return errors.New("main start error") + } + main.stopFunc = func() { + mainOnError = nil + } + fallback := mockFlagConfigUpdater{} + fallbackStopCh := make(chan bool) + fallback.startFunc = func(onError func(error)) error { + return nil + } + fallback.stopFunc = func() { + go func() { fallbackStopCh <- true }() + } + w := newflagConfigFallbackRetryWrapper(&main, &fallback, 1*time.Second, 0, 1*time.Second, 0, true) + err := w.Start(nil) + assert.Nil(t, err) + assert.NotNil(t, mainOnError) + mainOnError = nil + + // Test retry if main start fail and fallback start success. + time.Sleep(1100 * time.Millisecond) + assert.NotNil(t, mainOnError) // Main started called. + mainOnError = nil + select { + case <-fallbackStopCh: + assert.Fail(t, "Unexpected fallback stopped") + default: + } + + // Test next retry success. + main.startFunc = func(onError func(error)) error { + mainOnError = onError + return nil + } + time.Sleep(1100 * time.Millisecond) + assert.NotNil(t, mainOnError) // Main errored. + <-fallbackStopCh // Fallback stopped. + + w.Stop() +} + +func TestFlagConfigFallbackRetryWrapperMainUpdatingFail(t *testing.T) { + main := mockFlagConfigUpdater{} + var mainOnError func(error) + main.startFunc = func(onError func(error)) error { + mainOnError = onError + return nil + } + main.stopFunc = func() { + mainOnError = nil + } + fallback := mockFlagConfigUpdater{} + fallbackStartCh := make(chan bool) + fallbackStopCh := make(chan bool) + fallback.startFunc = func(onError func(error)) error { + go func() { fallbackStartCh <- true }() + return nil + } + fallback.stopFunc = func() {} + w := newflagConfigFallbackRetryWrapper(&main, &fallback, 1*time.Second, 0, 1*time.Second, 0, true) + // Start success + err := w.Start(nil) + assert.Nil(t, err) + assert.NotNil(t, mainOnError) + select { + case <-fallbackStartCh: + assert.Fail(t, "Unexpected fallback started") + default: + } + + // Test main updating failed, fallback. + fallback.stopFunc = func() { // Start tracking fallback stops (Start() may call stops). + go func() { fallbackStopCh <- true }() + } + mainOnError(errors.New("main updating error")) + mainOnError = nil + <-fallbackStartCh // Fallbacks started. + select { + case <-fallbackStopCh: + assert.Fail(t, "Unexpected fallback stopped") + default: + } + + // Test retry start fail as main updating fail. + main.startFunc = func(onError func(error)) error { + mainOnError = onError + return errors.New("main start error") + } + time.Sleep(1100 * time.Millisecond) + assert.NotNil(t, mainOnError) // Main started called. + mainOnError = nil + select { // Test no changes made to fallback updater. + case <-fallbackStartCh: + assert.Fail(t, "Unexpected fallback started") + case <-fallbackStopCh: + assert.Fail(t, "Unexpected fallback stopped") + default: + } + + // Test next retry success. + main.startFunc = func(onError func(error)) error { + mainOnError = onError + return nil + } + time.Sleep(1100 * time.Millisecond) + assert.NotNil(t, mainOnError) // Main errored. + select { + case <-fallbackStartCh: + assert.Fail(t, "Unexpected fallback stopped") + default: + } + <-fallbackStopCh // Fallback stopped. + + w.Stop() + +} + +func TestFlagConfigFallbackRetryWrapperMainUpdatingFailFallbackStartFail(t *testing.T) { + main := mockFlagConfigUpdater{} + var mainOnError func(error) + main.startFunc = func(onError func(error)) error { + mainOnError = onError + return nil + } + main.stopFunc = func() { + mainOnError = nil + } + fallback := mockFlagConfigUpdater{} + fallbackStartCh := make(chan bool) + fallbackStopCh := make(chan bool) + fallback.startFunc = func(onError func(error)) error { + println(1) + go func() { fallbackStartCh <- true }() + return errors.New("fallback start fail") + } + fallback.stopFunc = func() {} + w := newflagConfigFallbackRetryWrapper(&main, &fallback, 1100 * time.Millisecond, 0, 500 * time.Millisecond, 0, true) + // Start success + err := w.Start(nil) + assert.Nil(t, err) + assert.NotNil(t, mainOnError) + select { + case <-fallbackStartCh: + assert.Fail(t, "Unexpected fallback started") + default: + } + + // Test main updating failed, fallback. + mainStartCh := make(chan bool) + main.startFunc = func(onError func(error)) error { + go func() { mainStartCh <- true }() + return errors.New("main start fail") + } + fallback.stopFunc = func() { // Start tracking fallback stops (Start() may call stops). + go func() { fallbackStopCh <- true }() + } + mainOnError(errors.New("main updating error")) + mainOnError = nil + <-fallbackStartCh // Fallbacks start tried. + <-fallbackStartCh // Fallbacks start retry once. + select { + case <-mainStartCh: + assert.Fail(t, "Unexpected fallback stopped") + default: + } + <-fallbackStartCh // Fallbacks start retry second. + select { + case <-mainStartCh: + assert.Fail(t, "Unexpected fallback stopped") + default: + } + // Main start failed again on retry. + <-mainStartCh + // Make next start success. + main.startFunc = func(onError func(error)) error { + go func() { mainStartCh <- true }() + return nil + } + // Fallback start continue to retry. + <-fallbackStartCh // Fallbacks start retry third. + select { + case <-mainStartCh: + assert.Fail(t, "Unexpected fallback stopped") + default: + } + <-fallbackStartCh // Fallbacks start retry fourth. + select { + case <-mainStartCh: + assert.Fail(t, "Unexpected fallback stopped") + default: + } + // Main start success. + <-mainStartCh + + // No more fallback start. + time.Sleep(4100 * time.Millisecond) + select { + case <-fallbackStartCh: + assert.Fail(t, "Unexpected fallback start") + default: + } + + w.Stop() +} + +func TestFlagConfigFallbackRetryWrapperMainOnly(t *testing.T) { + main := mockFlagConfigUpdater{} + var mainOnError func(error) + main.startFunc = func(onError func(error)) error { + mainOnError = onError + return nil + } + main.stopFunc = func() { + mainOnError = nil + } + w := newflagConfigFallbackRetryWrapper(&main, nil, 1*time.Second, 0, 1*time.Second, 0, true) + err := w.Start(nil) + assert.Nil(t, err) + assert.NotNil(t, mainOnError) + + // Signal updating error. + mainOnError(errors.New("main error")) + mainOnError = nil + + // Wait for retry and check. + time.Sleep(1100 * time.Millisecond) + assert.NotNil(t, mainOnError) + mainOnError_2 := mainOnError + mainOnError = nil + + // Check no more retrys after start success. + time.Sleep(1100 * time.Millisecond) + assert.Nil(t, mainOnError) + + // Again. + // Signal updating error. + mainOnError_2(errors.New("main error")) + mainOnError = nil + + // Wait for retry and check. + time.Sleep(1100 * time.Millisecond) + assert.NotNil(t, mainOnError) + mainOnError = nil + + // Check no more retrys after start success. + time.Sleep(1100 * time.Millisecond) + assert.Nil(t, mainOnError) + + w.Stop() +} diff --git a/pkg/experiment/local/stream.go b/pkg/experiment/local/stream.go new file mode 100644 index 0000000..52df5e6 --- /dev/null +++ b/pkg/experiment/local/stream.go @@ -0,0 +1,228 @@ +package local + +import ( + "context" + "errors" + "fmt" + "net" + "net/http" + "sync" + "time" + + "github.com/amplitude/experiment-go-server/pkg/experiment" + "github.com/r3labs/sse/v2" + "gopkg.in/cenkalti/backoff.v1" +) + +// Keep alive data. +const STREAM_KEEP_ALIVE_BYTE = byte(' ') + +// Mute panics caused by writing to a closed channel. +func mutePanic(f func()) { + if err := recover(); err != nil && f != nil { + f() + } +} + +// This is a boiled down version of sse.Client. +type eventSource interface { + OnDisconnect(fn sse.ConnCallback) + OnConnect(fn sse.ConnCallback) + SubscribeChanRawWithContext(ctx context.Context, ch chan *sse.Event) error +} + +func newEventSource(httpClient *http.Client, url string, headers map[string]string) eventSource { + client := sse.NewClient(url) + client.Connection = httpClient + client.Headers = headers + sse.ClientMaxBufferSize(1 << 32)(client) + client.ReconnectStrategy = &backoff.StopBackOff{} + return client +} + +type streamEvent struct { + data []byte +} + +type stream interface { + Connect(messageCh chan streamEvent, errorCh chan error) + Cancel() + // For testing. + setNewESFactory(f func(httpClient *http.Client, url string, headers map[string]string) eventSource) +} + +type sseStream struct { + AuthToken string + url string + connectionTimeout time.Duration + keepaliveTimeout time.Duration + reconnInterval time.Duration + maxJitter time.Duration + lock sync.Mutex + cancelClientContext *context.CancelFunc + newESFactory func(httpClient *http.Client, url string, headers map[string]string) eventSource +} + +func newSseStream( + authToken, + url string, + connectionTimeout time.Duration, + keepaliveTimeout time.Duration, + reconnInterval time.Duration, + maxJitter time.Duration, +) stream { + return &sseStream{ + AuthToken: authToken, + url: url, + connectionTimeout: connectionTimeout, + keepaliveTimeout: keepaliveTimeout, + reconnInterval: reconnInterval, + maxJitter: maxJitter, + newESFactory: newEventSource, + } +} + +func (s *sseStream) setNewESFactory(f func(httpClient *http.Client, url string, headers map[string]string) eventSource) { + s.newESFactory = f +} + +func (s *sseStream) Connect( + messageCh chan streamEvent, + errorCh chan error, +) { + s.lock.Lock() + defer s.lock.Unlock() + s.connectInternal(messageCh, errorCh) +} + +func (s *sseStream) connectInternal( + messageCh chan streamEvent, + errorCh chan error, +) { + ctx, cancel := context.WithCancel(context.Background()) + s.cancelClientContext = &cancel + + transport := &http.Transport{ + Dial: (&net.Dialer{ + Timeout: s.connectionTimeout, + }).Dial, + TLSHandshakeTimeout: s.connectionTimeout, + ResponseHeaderTimeout: s.connectionTimeout, + } + + // The http client timeout includes reading body, which is the entire SSE lifecycle until SSE is closed. + httpClient := &http.Client{Transport: transport, Timeout: s.reconnInterval + s.maxJitter} // Max time for this connection. + + client := s.newESFactory(httpClient, s.url, map[string]string{ + "Authorization": s.AuthToken, + "X-Amp-Exp-Library": fmt.Sprintf("experiment-go-server/%v", experiment.VERSION), + }) + + connectCh := make(chan bool) + esMsgCh := make(chan *sse.Event) + esConnectErrCh := make(chan error) + esDisconnectCh := make(chan bool) + // Redirect on disconnect to a channel. + client.OnDisconnect(func(s *sse.Client) { + select { + case <-ctx.Done(): // Cancelled. + return + default: + esDisconnectCh <- true + } + }) + // Redirect on connect to a channel. + client.OnConnect(func(s *sse.Client) { + select { + case <-ctx.Done(): // Cancelled. + return + default: + go func() { connectCh <- true }() + } + }) + go func() { + // Subscribe to messages using channel. + // This should be a non blocking call, but unsure how long it takes. + err := client.SubscribeChanRawWithContext(ctx, esMsgCh) + if err != nil { + esConnectErrCh <- err + } + }() + + cancelWithLock := func() { + s.lock.Lock() + defer s.lock.Unlock() + cancel() + if s.cancelClientContext == &cancel { + s.cancelClientContext = nil + } + } + go func() { + // First wait for connect. + select { + case <-ctx.Done(): // Cancelled. + return + case err := <-esConnectErrCh: // Channel subscribe error. + cancelWithLock() + defer mutePanic(nil) + errorCh <- err + return + case <-time.After(s.connectionTimeout): // Timeout. + cancelWithLock() + defer mutePanic(nil) + errorCh <- errors.New("stream connection timeout") + return + case <-connectCh: // Connected callbacked. + } + for { + select { // Forced priority on context done. + case <-ctx.Done(): // Cancelled. + return + default: + } + select { + case <-ctx.Done(): // Cancelled. + return + case <-esDisconnectCh: // Disconnected. + cancelWithLock() + defer mutePanic(nil) + errorCh <- errors.New("stream disconnected error") + return + case event := <-esMsgCh: // Message received. + if len(event.Data) == 1 && event.Data[0] == STREAM_KEEP_ALIVE_BYTE { + // Keep alive. + continue + } + // Possible write to closed channel + // If channel closed, cancel. + defer mutePanic(cancelWithLock) + messageCh <- streamEvent{event.Data} + case <-time.After(s.keepaliveTimeout): // Keep alive timeout. + cancelWithLock() + defer mutePanic(nil) + errorCh <- errors.New("stream keepalive timed out") + } + } + }() + + // Reconnect after interval. + time.AfterFunc(randTimeDuration(s.reconnInterval, s.maxJitter), func() { + select { + case <-ctx.Done(): // Cancelled. + return + default: // Reconnect. + cancelWithLock() + s.connectInternal(messageCh, errorCh) + return + } + }) +} + +func (s *sseStream) Cancel() { + s.lock.Lock() + defer s.lock.Unlock() + if s.cancelClientContext != nil { + (*(s.cancelClientContext))() + s.cancelClientContext = nil + } +} diff --git a/pkg/experiment/local/stream_test.go b/pkg/experiment/local/stream_test.go new file mode 100644 index 0000000..49bbe78 --- /dev/null +++ b/pkg/experiment/local/stream_test.go @@ -0,0 +1,299 @@ +package local + +import ( + "context" + "errors" + "net/http" + "testing" + "time" + + "github.com/r3labs/sse/v2" + "github.com/stretchr/testify/assert" +) + +type mockEventSource struct { + httpClient *http.Client + url string + headers map[string]string + + subscribeChanError error + chConnected chan bool + + ctx context.Context + messageChan chan *sse.Event + onDisCb sse.ConnCallback + onConnCb sse.ConnCallback +} + +func (s *mockEventSource) OnDisconnect(fn sse.ConnCallback) { + s.onDisCb = fn +} + +func (s *mockEventSource) OnConnect(fn sse.ConnCallback) { + s.onConnCb = fn +} + +func (s *mockEventSource) SubscribeChanRawWithContext(ctx context.Context, ch chan *sse.Event) error { + s.ctx = ctx + s.messageChan = ch + s.chConnected <- true + return s.subscribeChanError +} + +func (s *mockEventSource) mockEventSourceFactory(httpClient *http.Client, url string, headers map[string]string) eventSource { + s.httpClient = httpClient + s.url = url + s.headers = headers + return s +} + +func TestStream(t *testing.T) { + var s = mockEventSource{chConnected: make(chan bool)} + client := newSseStream("authToken", "url", 2*time.Second, 4*time.Second, 6*time.Second, 1*time.Second) + client.setNewESFactory(s.mockEventSourceFactory) + messageCh := make(chan streamEvent) + errorCh := make(chan error) + + // Make connection. + client.Connect(messageCh, errorCh) + // Wait for connection "establish". + <-s.chConnected + + // Check for all variables. + assert.Equal(t, "url", s.url) + assert.Equal(t, "authToken", s.headers["Authorization"]) + assert.NotNil(t, s.headers["X-Amp-Exp-Library"]) + + // Signal connected. + s.onConnCb(nil) + + // Send update 1, ensure received. + go func() { s.messageChan <- &sse.Event{Data: []byte("data1")} }() + assert.Equal(t, []byte("data1"), (<-messageCh).data) + + // Send keep alive, not passed down, checked later along with updates 2 and 3. + go func() { s.messageChan <- &sse.Event{Data: []byte(" ")} }() + + // Send update 2 and 3, ensure received in order. + go func() { + s.messageChan <- &sse.Event{Data: []byte("data2")} + s.messageChan <- &sse.Event{Data: []byte("data3")} + }() + assert.Equal(t, []byte("data2"), (<-messageCh).data) + assert.Equal(t, []byte("data3"), (<-messageCh).data) + + // Stop client, ensure context cancelled. + client.Cancel() + assert.True(t, errors.Is(s.ctx.Err(), context.Canceled)) + + // No message is passed through after cancel even it's received. + go func() { s.messageChan <- &sse.Event{Data: []byte("data4")} }() + + // Ensure no message after cancel. + select { + case msg, ok := <-messageCh: + if ok { + assert.Fail(t, "Unexpected data message received", string(msg.data)) + } + case err, ok := <-errorCh: + if ok { + assert.Fail(t, "Unexpected error message received", err) + } + case <-time.After(1 * time.Second): + // No message received within the timeout, as expected + } +} + +func TestStreamConnTimeout(t *testing.T) { + var s = mockEventSource{chConnected: make(chan bool)} + client := newSseStream("", "", 2*time.Second, 4*time.Second, 6*time.Second, 1*time.Second) + client.setNewESFactory(s.mockEventSourceFactory) + messageCh := make(chan streamEvent) + errorCh := make(chan error) + + // Make connection. + client.Connect(messageCh, errorCh) + <-s.chConnected + // Wait for timeout to reach. + time.Sleep(2*time.Second + 10*time.Millisecond) + // Check that context cancelled and error received. + assert.True(t, errors.Is(s.ctx.Err(), context.Canceled)) + assert.Equal(t, errors.New("stream connection timeout"), <-errorCh) +} + +func TestStreamKeepAliveTimeout(t *testing.T) { + var s = mockEventSource{chConnected: make(chan bool)} + client := newSseStream("", "", 2*time.Second, 1*time.Second, 6*time.Second, 1*time.Second) + client.setNewESFactory(s.mockEventSourceFactory) + messageCh := make(chan streamEvent) + errorCh := make(chan error) + + // Make connection. + client.Connect(messageCh, errorCh) + <-s.chConnected + s.onConnCb(nil) + + // Send keepalive 1 and wait. + go func() { s.messageChan <- &sse.Event{Data: []byte(" ")} }() + time.Sleep(1*time.Second - 10*time.Millisecond) + assert.False(t, errors.Is(s.ctx.Err(), context.Canceled)) + // Send keepalive 2 and wait. + go func() { s.messageChan <- &sse.Event{Data: []byte(" ")} }() + time.Sleep(1*time.Second - 10*time.Millisecond) + assert.False(t, errors.Is(s.ctx.Err(), context.Canceled)) + // Send data and wait, data should reset keepalive. + go func() { s.messageChan <- &sse.Event{Data: []byte("data1")} }() + assert.Equal(t, []byte("data1"), (<-messageCh).data) + time.Sleep(1*time.Second - 10*time.Millisecond) + assert.False(t, errors.Is(s.ctx.Err(), context.Canceled)) + // Send data ensure stream is open. + go func() { s.messageChan <- &sse.Event{Data: []byte("data1")} }() + assert.Equal(t, []byte("data1"), (<-messageCh).data) + assert.False(t, errors.Is(s.ctx.Err(), context.Canceled)) + // Wait for keepalive to timeout, stream should close. + time.Sleep(1*time.Second + 10*time.Millisecond) + assert.Equal(t, errors.New("stream keepalive timed out"), <-errorCh) + assert.True(t, errors.Is(s.ctx.Err(), context.Canceled)) +} + +func TestStreamReconnectsTimeout(t *testing.T) { + var s = mockEventSource{chConnected: make(chan bool)} + client := newSseStream("", "", 2*time.Second, 3*time.Second, 2*time.Second, 0*time.Second) + client.setNewESFactory(s.mockEventSourceFactory) + messageCh := make(chan streamEvent) + errorCh := make(chan error) + + // Make connection. + client.Connect(messageCh, errorCh) + <-s.chConnected + s.onConnCb(nil) + + go func() { s.messageChan <- &sse.Event{Data: []byte("data1")} }() + assert.Equal(t, []byte("data1"), (<-messageCh).data) + // Sleep for reconnect to timeout, data should pass through. + time.Sleep(2*time.Second + 100*time.Millisecond) + <-s.chConnected + s.onConnCb(nil) + go func() { s.messageChan <- &sse.Event{Data: []byte(" ")} }() + go func() { s.messageChan <- &sse.Event{Data: []byte("data2")} }() + assert.Equal(t, []byte("data2"), (<-messageCh).data) + assert.False(t, errors.Is(s.ctx.Err(), context.Canceled)) + // Cancel stream, should cancel context. + client.Cancel() + assert.True(t, errors.Is(s.ctx.Err(), context.Canceled)) + select { + case msg, ok := <-errorCh: + if ok { + assert.Fail(t, "Unexpected message received after disconnect", msg) + } + case <-time.After(3 * time.Second): + // No message received within the timeout, as expected + } +} + +func TestStreamConnectAndCancelImmediately(t *testing.T) { + var s = mockEventSource{chConnected: make(chan bool)} + client := newSseStream("", "", 2*time.Second, 3*time.Second, 2*time.Second, 0*time.Second) + client.setNewESFactory(s.mockEventSourceFactory) + messageCh := make(chan streamEvent) + errorCh := make(chan error) + + // Make connection and cancel immediately. + client.Connect(messageCh, errorCh) + client.Cancel() + // Make sure no error for all timeouts. + select { + case msg, ok := <-errorCh: + if ok { + assert.Fail(t, "Unexpected message received after disconnect", msg) + } + case <-time.After(4 * time.Second): + // No message received within the timeout, as expected + } +} + +func TestStreamChannelCloseOk(t *testing.T) { + var s = mockEventSource{chConnected: make(chan bool)} + client := newSseStream("", "", 1*time.Second, 1*time.Second, 1*time.Second, 0*time.Second) + client.setNewESFactory(s.mockEventSourceFactory) + messageCh := make(chan streamEvent) + errorCh := make(chan error) + + // Close channels. + close(messageCh) + close(errorCh) + + // Connect and send message, the client should cancel right away. + client.Connect(messageCh, errorCh) + <-s.chConnected + s.onConnCb(nil) + + // Test no message received for closed channel. + s.messageChan <- &sse.Event{Data: []byte("data1")} + assert.True(t, errors.Is(s.ctx.Err(), context.Canceled)) + + select { + case msg, ok := <-messageCh: + if ok { + assert.Fail(t, "Unexpected message received after close", msg) + } + case msg, ok := <-errorCh: + if ok { + assert.Fail(t, "Unexpected message received after close", msg) + } + case <-time.After(2 * time.Second): + // No message received within the timeout, as expected + } +} + +func TestStreamDisconnectErrorPasses(t *testing.T) { + var s = mockEventSource{chConnected: make(chan bool)} + client := newSseStream("", "", 1*time.Second, 1*time.Second, 1*time.Second, 0*time.Second) + client.setNewESFactory(s.mockEventSourceFactory) + messageCh := make(chan streamEvent) + errorCh := make(chan error) + + // Make connection. + client.Connect(messageCh, errorCh) + <-s.chConnected + s.onConnCb(nil) + + // Disconnect error goes through. + s.onDisCb(nil) + assert.Equal(t, errors.New("stream disconnected error"), <-errorCh) + + select { + case msg, ok := <-errorCh: + if ok { + assert.Fail(t, "Unexpected message received after disconnect", msg) + } + case <-time.After(2 * time.Second): + // No message received within the timeout, as expected + } +} + +func TestStreamConnectErrorPasses(t *testing.T) { + var s = mockEventSource{chConnected: make(chan bool)} + client := newSseStream("", "", 1*time.Second, 1*time.Second, 1*time.Second, 0*time.Second) + client.setNewESFactory(s.mockEventSourceFactory) + messageCh := make(chan streamEvent) + errorCh := make(chan error) + + // Make connection. + s.subscribeChanError = errors.New("some error occurred") + client.Connect(messageCh, errorCh) + <-s.chConnected + s.onConnCb(nil) + + // Connect error goes through. + assert.Equal(t, errors.New("some error occurred"), <-errorCh) + + select { + case msg, ok := <-errorCh: + if ok { + assert.Fail(t, "Unexpected message received after disconnect", msg) + } + case <-time.After(2 * time.Second): + // No message received within the timeout, as expected + } +} diff --git a/pkg/experiment/local/util.go b/pkg/experiment/local/util.go index 3e8a2d3..872dd3e 100644 --- a/pkg/experiment/local/util.go +++ b/pkg/experiment/local/util.go @@ -1,5 +1,11 @@ package local +import ( + "math" + "math/rand" + "time" +) + func hashCode(s string) int { hash := 0 if len(s) == 0 { @@ -12,3 +18,29 @@ func hashCode(s string) int { } return hash } + +func difference(set1, set2 map[string]struct{}) map[string]struct{} { + diff := make(map[string]struct{}) + for k := range set1 { + if _, exists := set2[k]; !exists { + diff[k] = struct{}{} + } + } + return diff +} + +func randTimeDuration(base time.Duration, jitter time.Duration) time.Duration { + if jitter == 0 { + return base + } + dmin := base.Nanoseconds() - jitter.Nanoseconds() + if dmin < 0 { + dmin = 0 + } + dmiddle := base.Nanoseconds() + if dmiddle > math.MaxInt64-jitter.Nanoseconds() { + dmiddle = math.MaxInt64 - jitter.Nanoseconds() + } + dmax := dmiddle + jitter.Nanoseconds() + return time.Duration(dmin + rand.Int63n(dmax-dmin)) +}