From 2e930859c1461a5b5a0ac5a1c3b406045a85899e Mon Sep 17 00:00:00 2001 From: Peter Zhu Date: Sat, 5 Oct 2024 01:09:53 -0700 Subject: [PATCH] add stream client test, fix stream concurrency, prioritize ctx check --- pkg/experiment/local/client_stream_test.go | 157 ++++++++++++++++++ pkg/experiment/local/client_test.go | 49 ------ pkg/experiment/local/deployment_runner.go | 4 +- pkg/experiment/local/flag_config_updater.go | 14 +- .../local/flag_config_updater_test.go | 10 +- pkg/experiment/local/stream.go | 36 ++-- pkg/experiment/local/stream_test.go | 13 +- 7 files changed, 203 insertions(+), 80 deletions(-) create mode 100644 pkg/experiment/local/client_stream_test.go diff --git a/pkg/experiment/local/client_stream_test.go b/pkg/experiment/local/client_stream_test.go new file mode 100644 index 0000000..b3627b6 --- /dev/null +++ b/pkg/experiment/local/client_stream_test.go @@ -0,0 +1,157 @@ +package local + +import ( + "log" + "os" + "testing" + + "github.com/amplitude/experiment-go-server/pkg/experiment" + "github.com/joho/godotenv" + "github.com/stretchr/testify/assert" +) + +var streamClient *Client + +func init() { + err := godotenv.Load() + if err != nil { + log.Printf("Error loading .env file: %v", err) + } + projectApiKey := os.Getenv("API_KEY") + secretKey := os.Getenv("SECRET_KEY") + cohortSyncConfig := CohortSyncConfig{ + ApiKey: projectApiKey, + SecretKey: secretKey, + } + streamClient = Initialize("server-qz35UwzJ5akieoAdIgzM4m9MIiOLXLoz", + &Config{ + StreamUpdates: true, + StreamServerUrl: "https://stream.lab.amplitude.com", + CohortSyncConfig: &cohortSyncConfig, + }) + err = streamClient.Start() + if err != nil { + panic(err) + } +} + +func TestMakeSureStreamEnabled(t *testing.T) { + assert.True(t, streamClient.config.StreamUpdates) +} + +func TestStreamEvaluate(t *testing.T) { + user := &experiment.User{UserId: "test_user"} + result, err := streamClient.Evaluate(user, nil) + if err != nil { + t.Fatalf("Unexpected error %v", err) + } + variant := result["sdk-local-evaluation-ci-test"] + if variant.Key != "on" { + t.Fatalf("Unexpected variant %v", variant) + } + if variant.Value != "on" { + t.Fatalf("Unexpected variant %v", variant) + } + if variant.Payload != "payload" { + t.Fatalf("Unexpected variant %v", variant) + } + variant = result["sdk-ci-test"] + if variant.Key != "" { + t.Fatalf("Unexpected variant %v", variant) + } + if variant.Value != "" { + t.Fatalf("Unexpected variant %v", variant) + } +} + +func TestStreamEvaluateV2AllFlags(t *testing.T) { + user := &experiment.User{UserId: "test_user"} + result, err := streamClient.EvaluateV2(user, nil) + if err != nil { + t.Fatalf("Unexpected error %v", err) + } + variant := result["sdk-local-evaluation-ci-test"] + if variant.Key != "on" { + t.Fatalf("Unexpected variant %v", variant) + } + if variant.Value != "on" { + t.Fatalf("Unexpected variant %v", variant) + } + if variant.Payload != "payload" { + t.Fatalf("Unexpected variant %v", variant) + } + variant = result["sdk-ci-test"] + if variant.Key != "off" { + t.Fatalf("Unexpected variant %v", variant) + } + if variant.Value != "" { + t.Fatalf("Unexpected variant %v", variant) + } +} + +func TestStreamFlagMetadataLocalFlagKey(t *testing.T) { + md := streamClient.FlagMetadata("sdk-local-evaluation-ci-test") + if md["evaluationMode"] != "local" { + t.Fatalf("Unexpected metadata %v", md) + } +} + +func TestStreamEvaluateV2Cohort(t *testing.T) { + targetedUser := &experiment.User{UserId: "12345"} + nonTargetedUser := &experiment.User{UserId: "not_targeted"} + flagKeys := []string{"sdk-local-evaluation-user-cohort-ci-test"} + result, err := streamClient.EvaluateV2(targetedUser, flagKeys) + if err != nil { + t.Fatalf("Unexpected error %v", err) + } + variant := result["sdk-local-evaluation-user-cohort-ci-test"] + if variant.Key != "on" { + t.Fatalf("Unexpected variant %v", variant) + } + if variant.Value != "on" { + t.Fatalf("Unexpected variant %v", variant) + } + result, err = streamClient.EvaluateV2(nonTargetedUser, flagKeys) + if err != nil { + t.Fatalf("Unexpected error %v", err) + } + variant = result["sdk-local-evaluation-user-cohort-ci-test"] + if variant.Key != "off" { + t.Fatalf("Unexpected variant %v", variant) + } +} + +func TestStreamEvaluateV2GroupCohort(t *testing.T) { + targetedUser := &experiment.User{ + UserId: "12345", + DeviceId: "device_id", + Groups: map[string][]string{ + "org id": {"1"}, + }} + nonTargetedUser := &experiment.User{ + UserId: "12345", + DeviceId: "device_id", + Groups: map[string][]string{ + "org id": {"not_targeted"}, + }} + flagKeys := []string{"sdk-local-evaluation-group-cohort-ci-test"} + result, err := streamClient.EvaluateV2(targetedUser, flagKeys) + if err != nil { + t.Fatalf("Unexpected error %v", err) + } + variant := result["sdk-local-evaluation-group-cohort-ci-test"] + if variant.Key != "on" { + t.Fatalf("Unexpected variant %v", variant) + } + if variant.Value != "on" { + t.Fatalf("Unexpected variant %v", variant) + } + result, err = streamClient.EvaluateV2(nonTargetedUser, flagKeys) + if err != nil { + t.Fatalf("Unexpected error %v", err) + } + variant = result["sdk-local-evaluation-group-cohort-ci-test"] + if variant.Key != "off" { + t.Fatalf("Unexpected variant %v", variant) + } +} diff --git a/pkg/experiment/local/client_test.go b/pkg/experiment/local/client_test.go index bea47b8..8cb1b68 100644 --- a/pkg/experiment/local/client_test.go +++ b/pkg/experiment/local/client_test.go @@ -4,7 +4,6 @@ import ( "log" "os" "testing" - "time" "github.com/amplitude/experiment-go-server/pkg/experiment" "github.com/joho/godotenv" @@ -232,51 +231,3 @@ func TestEvaluateV2GroupCohort(t *testing.T) { t.Fatalf("Unexpected variant %v", variant) } } - - -func TestMain(t *testing.T) { - client := Initialize("server-tUTqR62DZefq7c73zMpbIr1M5VDtwY8T", &Config{ServerUrl: "noserver", StreamUpdates: true, StreamServerUrl: "https://skylab-stream.stag2.amplitude.com"}) - client.Start() - println(client.flagConfigStorage.getFlagConfigs(), len(client.flagConfigStorage.getFlagConfigs())) - time.Sleep(2000 * time.Millisecond) - println(client.flagConfigStorage.getFlagConfigs(), len(client.flagConfigStorage.getFlagConfigs())) - - // connTimeout := 1500 * time.Millisecond - // api := NewFlagConfigStreamApiV2("server-tUTqR62DZefq7c73zMpbIr1M5VDtwY8T", "https://skylab-stream.stag2.amplitude.com", connTimeout) - // cohortStorage := newInMemoryCohortStorage() - // flagConfigStorage := newInMemoryFlagConfigStorage() - // dr := newDeploymentRunner( - // DefaultConfig, - // NewFlagConfigApiV2("server-tUTqR62DZefq7c73zMpbIr1M5VDtwY8T", "https://skylab-api.staging.amplitude.com", connTimeout), - // api, - // flagConfigStorage, cohortStorage, nil) - // println("inited") - // // time.Sleep(5000 * time.Millisecond) - // dr.start() - - // for { - // fmt.Printf("%v+\n", time.Now()) - // fmt.Println(flagConfigStorage.GetFlagConfigs()) - // time.Sleep(5000 * time.Millisecond) - // fmt.Println(flagConfigStorage.GetFlagConfigs()) - // } - - // if len(os.Args) < 2 { - // fmt.Printf("error: command required\n") - // fmt.Printf("Available commands:\n" + - // " fetch\n" + - // " flags\n" + - // " evaluate\n") - // return - // } - // switch os.Args[1] { - // case "fetch": - // fetch() - // case "flags": - // flags() - // case "evaluate": - // evaluate() - // default: - // fmt.Printf("error: unknown sub-command '%v'", os.Args[1]) - // } -} \ No newline at end of file diff --git a/pkg/experiment/local/deployment_runner.go b/pkg/experiment/local/deployment_runner.go index 08b4ca6..2997ba5 100644 --- a/pkg/experiment/local/deployment_runner.go +++ b/pkg/experiment/local/deployment_runner.go @@ -28,9 +28,9 @@ func newDeploymentRunner( cohortStorage cohortStorage, cohortLoader *cohortLoader, ) *deploymentRunner { - flagConfigUpdater := NewFlagConfigFallbackRetryWrapper(NewFlagConfigPoller(flagConfigApi, config, flagConfigStorage, cohortStorage, cohortLoader), nil, config.FlagConfigPollerInterval, updaterRetryMaxJitter) + flagConfigUpdater := NewFlagConfigFallbackRetryWrapper(NewFlagConfigPoller(flagConfigApi, config, flagConfigStorage, cohortStorage, cohortLoader), nil, config.FlagConfigPollerInterval, updaterRetryMaxJitter, config.Debug) if (flagConfigStreamApi != nil) { - flagConfigUpdater = NewFlagConfigFallbackRetryWrapper(NewFlagConfigStreamer(flagConfigStreamApi, config, flagConfigStorage, cohortStorage, cohortLoader), flagConfigUpdater, streamUpdaterRetryDelay, updaterRetryMaxJitter) + flagConfigUpdater = NewFlagConfigFallbackRetryWrapper(NewFlagConfigStreamer(flagConfigStreamApi, config, flagConfigStorage, cohortStorage, cohortLoader), flagConfigUpdater, streamUpdaterRetryDelay, updaterRetryMaxJitter, config.Debug) } dr := &deploymentRunner{ config: config, diff --git a/pkg/experiment/local/flag_config_updater.go b/pkg/experiment/local/flag_config_updater.go index 1bd1d12..0217710 100644 --- a/pkg/experiment/local/flag_config_updater.go +++ b/pkg/experiment/local/flag_config_updater.go @@ -1,7 +1,6 @@ package local import ( - "fmt" "sync" "time" @@ -251,6 +250,7 @@ func (p *flagConfigPoller) Stop() { // A wrapper around flag config updaters to retry and fallback. // If the main updater fails, it will fallback to the fallback updater and main updater enters retry loop. type FlagConfigFallbackRetryWrapper struct { + log *logger.Log mainUpdater flagConfigUpdater fallbackUpdater flagConfigUpdater retryDelay time.Duration @@ -263,8 +263,10 @@ func NewFlagConfigFallbackRetryWrapper( fallbackUpdater flagConfigUpdater, retryDelay time.Duration, maxJitter time.Duration, + debug bool, ) flagConfigUpdater { return &FlagConfigFallbackRetryWrapper{ + log: logger.New(debug), mainUpdater: mainUpdater, fallbackUpdater: fallbackUpdater, retryDelay: retryDelay, @@ -291,7 +293,8 @@ func (w *FlagConfigFallbackRetryWrapper) Start(onError func (error)) error { w.retryTimer = nil } - err := w.mainUpdater.Start(func (error) { + err := w.mainUpdater.Start(func (err error) { + w.log.Error("main updater updating err, starting fallback if available. error: ", err) go func() {w.scheduleRetry()}() // Don't care if poller start error or not, always retry. if (w.fallbackUpdater != nil) { w.fallbackUpdater.Start(nil) @@ -304,14 +307,14 @@ func (w *FlagConfigFallbackRetryWrapper) Start(onError func (error)) error { } return nil } - fmt.Println("main start err", err) - // Logger.e("Primary flag configs start failed, start fallback. Error: ", t) + w.log.Debug("main updater start err, starting fallback. error: ", err) if (w.fallbackUpdater == nil) { // No fallback, main start failed is wrapper start fail return err } err = w.fallbackUpdater.Start(nil) if (err != nil) { + w.log.Debug("fallback updater start failed. error: ", err) return err } @@ -349,6 +352,7 @@ func (w *FlagConfigFallbackRetryWrapper) scheduleRetry() { w.retryTimer = nil } + w.log.Debug("main updater retry start") err := w.mainUpdater.Start(func (error) { go func() {w.scheduleRetry()}() // Don't care if poller start error or not, always retry. if (w.fallbackUpdater != nil) { @@ -357,12 +361,12 @@ func (w *FlagConfigFallbackRetryWrapper) scheduleRetry() { }) if (err == nil) { // Main start success, stop fallback. + w.log.Debug("main updater retry start success") if (w.fallbackUpdater != nil) { w.fallbackUpdater.Stop() } return } - fmt.Println("retrying failed", err) go func() {w.scheduleRetry()}() }) diff --git a/pkg/experiment/local/flag_config_updater_test.go b/pkg/experiment/local/flag_config_updater_test.go index 34dddfc..e710313 100644 --- a/pkg/experiment/local/flag_config_updater_test.go +++ b/pkg/experiment/local/flag_config_updater_test.go @@ -267,7 +267,7 @@ func TestFlagConfigFallbackRetryWrapper(t *testing.T) { } fallback.stopFunc = func () { } - w := NewFlagConfigFallbackRetryWrapper(&main, &fallback, 1 * time.Second, 0) + w := NewFlagConfigFallbackRetryWrapper(&main, &fallback, 1 * time.Second, 0, true) err := w.Start(nil) assert.Nil(t, err) assert.NotNil(t, mainOnError) @@ -292,7 +292,7 @@ func TestFlagConfigFallbackRetryWrapperBothStartFail(t *testing.T) { } fallback.stopFunc = func () { } - w := NewFlagConfigFallbackRetryWrapper(&main, &fallback, 1 * time.Second, 0) + w := NewFlagConfigFallbackRetryWrapper(&main, &fallback, 1 * time.Second, 0, true) err := w.Start(nil) assert.Equal(t, errors.New("fallback start error"), err) assert.NotNil(t, mainOnError) @@ -321,7 +321,7 @@ func TestFlagConfigFallbackRetryWrapperMainStartFailFallbackSuccess(t *testing.T fallback.stopFunc = func () { go func() {fallbackStopCh <- true} () } - w := NewFlagConfigFallbackRetryWrapper(&main, &fallback, 1 * time.Second, 0) + w := NewFlagConfigFallbackRetryWrapper(&main, &fallback, 1 * time.Second, 0, true) err := w.Start(nil) assert.Nil(t, err) assert.NotNil(t, mainOnError) @@ -366,7 +366,7 @@ func TestFlagConfigFallbackRetryWrapperMainUpdatingFail(t *testing.T) { return nil } fallback.stopFunc = func () {} - w := NewFlagConfigFallbackRetryWrapper(&main, &fallback, 1 * time.Second, 0) + w := NewFlagConfigFallbackRetryWrapper(&main, &fallback, 1 * time.Second, 0, true) // Start success err := w.Start(nil) assert.Nil(t, err) @@ -429,7 +429,7 @@ func TestFlagConfigFallbackRetryWrapperMainOnly(t *testing.T) { main.stopFunc = func () { mainOnError = nil } - w := NewFlagConfigFallbackRetryWrapper(&main, nil, 1 * time.Second, 0) + w := NewFlagConfigFallbackRetryWrapper(&main, nil, 1 * time.Second, 0, true) err := w.Start(nil) assert.Nil(t, err) assert.NotNil(t, mainOnError) diff --git a/pkg/experiment/local/stream.go b/pkg/experiment/local/stream.go index 252463b..b7c2f1a 100644 --- a/pkg/experiment/local/stream.go +++ b/pkg/experiment/local/stream.go @@ -63,7 +63,7 @@ type SseStream struct { reconnInterval time.Duration maxJitter time.Duration lock sync.Mutex - cancelClientContext context.CancelFunc + cancelClientContext *context.CancelFunc newESFactory func (httpClient *http.Client, url string, headers map[string]string) EventSource } @@ -100,7 +100,7 @@ func (s *SseStream) connectInternal( errorCh chan error, ) error { ctx, cancel := context.WithCancel(context.Background()) - s.cancelClientContext = cancel + s.cancelClientContext = &cancel transport := &http.Transport{ Dial: (&net.Dialer{ @@ -137,7 +137,7 @@ func (s *SseStream) connectInternal( case <-ctx.Done(): // Cancelled. return default: - connectCh <- true + go func() {connectCh <- true} () } }) go func() { @@ -149,31 +149,45 @@ func (s *SseStream) connectInternal( } }() + cancelWithLock := func() { + s.lock.Lock() + defer s.lock.Unlock() + cancel() + if (s.cancelClientContext == &cancel) { + s.cancelClientContext = nil + } + } go func() { // First wait for connect. select { case <-ctx.Done(): // Cancelled. return case err := <-esConnectErrCh: // Channel subscribe error. - cancel() + cancelWithLock() defer mutePanic(nil) errorCh <- err return case <-time.After(s.connectionTimeout): // Timeout. - cancel() + cancelWithLock() defer mutePanic(nil) errorCh <- errors.New("stream connection timeout") return case <-connectCh: // Connected callbacked. } for { + select { // Forced priority on context done. + case <-ctx.Done(): // Cancelled. + return + default: + } select { case <-ctx.Done(): // Cancelled. return case <- esDisconnectCh: // Disconnected. - cancel() + cancelWithLock() defer mutePanic(nil) errorCh <- errors.New("stream disconnected error") + return case event := <-esMsgCh: // Message received. if (len(event.Data) == 1 && event.Data[0] == STREAM_KEEP_ALIVE_BYTE) { // Keep alive. @@ -181,10 +195,10 @@ func (s *SseStream) connectInternal( } // Possible write to closed channel // If channel closed, cancel. - defer mutePanic(cancel) + defer mutePanic(cancelWithLock) messageCh <- StreamEvent{event.Data} case <-time.After(s.keepaliveTimeout): // Keep alive timeout. - cancel() + cancelWithLock() defer mutePanic(nil) errorCh <- errors.New("stream keepalive timed out") } @@ -193,13 +207,11 @@ func (s *SseStream) connectInternal( // Reconnect after interval. time.AfterFunc(randTimeDuration(s.reconnInterval, s.maxJitter), func () { - s.lock.Lock() - defer s.lock.Unlock() select { case <-ctx.Done(): // Cancelled. return default: // Reconnect. - cancel() + cancelWithLock() s.connectInternal(messageCh, errorCh) return } @@ -212,7 +224,7 @@ func (s *SseStream) Cancel() { s.lock.Lock() defer s.lock.Unlock() if (s.cancelClientContext != nil) { - s.cancelClientContext() + (*(s.cancelClientContext))() s.cancelClientContext = nil } } diff --git a/pkg/experiment/local/stream_test.go b/pkg/experiment/local/stream_test.go index 0e59646..f8e8276 100644 --- a/pkg/experiment/local/stream_test.go +++ b/pkg/experiment/local/stream_test.go @@ -49,8 +49,6 @@ func (s *mockEventSource) mockEventSourceFactory(httpClient *http.Client, url st func TestStream(t *testing.T) { var s = mockEventSource{chConnected: make(chan bool)} - // assert.Equal(t, 2 * time.Second, connTimeout) - // assert.Equal(t, 7 * time.Second, maxTime) client := NewSseStream("authToken", "url", 2 * time.Second, 4 * time.Second, 6 * time.Second, 1 * time.Second) client.setNewESFactory(s.mockEventSourceFactory) messageCh := make(chan StreamEvent) @@ -88,18 +86,18 @@ func TestStream(t *testing.T) { client.Cancel() assert.True(t, errors.Is(s.ctx.Err(), context.Canceled)) - // No message is passed through even it's received. + // No message is passed through after cancel even it's received. go func() {s.messageChan <- &sse.Event{Data: []byte("data4")}}() // Ensure no message after cancel. select { case msg, ok := <-messageCh: if ok { - assert.Fail(t, "Unexpected data message received", msg) + assert.Fail(t, "Unexpected data message received", string(msg.data)) } - case msg, ok := <-errorCh: + case err, ok := <-errorCh: if ok { - assert.Fail(t, "Unexpected error message received", msg) + assert.Fail(t, "Unexpected error message received", err) } case <-time.After(1 * time.Second): // No message received within the timeout, as expected @@ -188,7 +186,7 @@ func TestStreamReconnectsTimeout(t *testing.T) { if ok { assert.Fail(t, "Unexpected message received after disconnect", msg) } - case <-time.After(6 * time.Second): + case <-time.After(3 * time.Second): // No message received within the timeout, as expected } } @@ -230,6 +228,7 @@ func TestStreamChannelCloseOk(t *testing.T) { <-s.chConnected s.onConnCb(nil) + // Test no message received for closed channel. s.messageChan <- &sse.Event{Data: []byte("data1")} assert.True(t, errors.Is(s.ctx.Err(), context.Canceled))